#include #include #include "pybind11/pybind11.h" #include "pybind11/numpy.h" #include "opencv2/opencv.hpp" #include "infer.hpp" #include "resnet.hpp" #include "cpm.hpp" using namespace std; namespace py=pybind11; namespace pybind11::detail { template <> struct type_caster { public: PYBIND11_TYPE_CASTER(cv::Mat, _("cv::Mat")); // Python -> C++ bool load(handle src, bool) { py::array array = py::array::ensure(src); if (!array) return false; py::buffer_info buf_info = array.request(); int rows = buf_info.shape[0]; int cols = buf_info.shape[1]; int channels = buf_info.ndim == 3 ? buf_info.shape[2] : 1; // 假设数据类型为uint8,可以根据需求调整 mat = cv::Mat(rows, cols, CV_8UC(channels), buf_info.ptr).clone(); return true; } // C++ -> Python static handle cast(const cv::Mat &mat, return_value_policy, handle) { return py::array_t( {mat.rows, mat.cols, mat.channels()}, {mat.step[0], mat.step[1], sizeof(unsigned char)}, mat.data ).release(); } private: cv::Mat mat; }; } class TrtResnetInfer{ public: TrtResnetInfer(std::string model_path) { instance_ = resnet::load(model_path); } resnet::Image cvimg(const cv::Mat &image) { return resnet::Image(image.data, image.cols, image.rows); } resnet::Attribute forward(const cv::Mat& image) { cout << image.size << std::endl; return instance_->forward(cvimg(image)); } resnet::Attribute forward_path(const std::string& image_path) { cv::Mat image = cv::imread(image_path); return instance_->forward(cvimg(image)); } bool valid(){ return instance_ != nullptr; } private: std::shared_ptr instance_; }; PYBIND11_MODULE(trtresnet, m){ py::class_(m, "Attribute") .def_readwrite("confidence", &resnet::Attribute::confidence) .def_readwrite("class_label", &resnet::Attribute::class_label) .def("__repr__", [](const resnet::Attribute &attr) { std::ostringstream oss; oss << "Attribute(class_label: " << attr.class_label << ", confidence: " << attr.confidence << ")"; return oss.str(); });; py::class_(m, "TrtResnetInfer") .def(py::init(), py::arg("model_path")) .def_property_readonly("valid", &TrtResnetInfer::valid) .def("forward_path", &TrtResnetInfer::forward_path, py::arg("image_path")) .def("forward", &TrtResnetInfer::forward, py::arg("image")); };