123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- #include <sstream>
- #include <iostream>
- #include "pybind11/pybind11.h"
- #include "pybind11/numpy.h"
- #include "opencv2/opencv.hpp"
- #include "infer.hpp"
- #include "resnet.hpp"
- #include "cpm.hpp"
- using namespace std;
- namespace py=pybind11;
- namespace pybind11::detail {
- template <> struct type_caster<cv::Mat> {
- public:
- PYBIND11_TYPE_CASTER(cv::Mat, _("cv::Mat"));
- // Python -> C++
- bool load(handle src, bool) {
- py::array array = py::array::ensure(src);
- if (!array) return false;
- py::buffer_info buf_info = array.request();
- int rows = buf_info.shape[0];
- int cols = buf_info.shape[1];
- int channels = buf_info.ndim == 3 ? buf_info.shape[2] : 1;
- // 假设数据类型为uint8,可以根据需求调整
- mat = cv::Mat(rows, cols, CV_8UC(channels), buf_info.ptr).clone();
- return true;
- }
- // C++ -> Python
- static handle cast(const cv::Mat &mat, return_value_policy, handle) {
- return py::array_t<unsigned char>(
- {mat.rows, mat.cols, mat.channels()},
- {mat.step[0], mat.step[1], sizeof(unsigned char)},
- mat.data
- ).release();
- }
- private:
- cv::Mat mat;
- };
- }
- class TrtResnetInfer{
- public:
- TrtResnetInfer(std::string model_path)
- {
- instance_ = resnet::load(model_path);
- }
- resnet::Image cvimg(const cv::Mat &image)
- {
- return resnet::Image(image.data, image.cols, image.rows);
- }
- resnet::Attribute forward(const cv::Mat& image)
- {
- cout << image.size << std::endl;
- return instance_->forward(cvimg(image));
- }
- resnet::Attribute forward_path(const std::string& image_path)
- {
- cv::Mat image = cv::imread(image_path);
- return instance_->forward(cvimg(image));
- }
- bool valid(){
- return instance_ != nullptr;
- }
- private:
- std::shared_ptr<resnet::Infer> instance_;
- };
- PYBIND11_MODULE(trtresnet, m){
- py::class_<resnet::Attribute>(m, "Attribute")
- .def_readwrite("confidence", &resnet::Attribute::confidence)
- .def_readwrite("class_label", &resnet::Attribute::class_label)
- .def("__repr__", [](const resnet::Attribute &attr) {
- std::ostringstream oss;
- oss << "Attribute(class_label: " << attr.class_label << ", confidence: " << attr.confidence << ")";
- return oss.str();
- });;
- py::class_<TrtResnetInfer>(m, "TrtResnetInfer")
- .def(py::init<string>(), py::arg("model_path"))
- .def_property_readonly("valid", &TrtResnetInfer::valid)
- .def("forward_path", &TrtResnetInfer::forward_path, py::arg("image_path"))
- .def("forward", &TrtResnetInfer::forward, py::arg("image"));
- };
|