interface.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #include <sstream>
  2. #include <iostream>
  3. #include "pybind11/pybind11.h"
  4. #include "pybind11/numpy.h"
  5. #include "opencv2/opencv.hpp"
  6. #include "infer.hpp"
  7. #include "resnet.hpp"
  8. #include "cpm.hpp"
  9. using namespace std;
  10. namespace py=pybind11;
  11. namespace pybind11::detail {
  12. template <> struct type_caster<cv::Mat> {
  13. public:
  14. PYBIND11_TYPE_CASTER(cv::Mat, _("cv::Mat"));
  15. // Python -> C++
  16. bool load(handle src, bool) {
  17. py::array array = py::array::ensure(src);
  18. if (!array) return false;
  19. py::buffer_info buf_info = array.request();
  20. int rows = buf_info.shape[0];
  21. int cols = buf_info.shape[1];
  22. int channels = buf_info.ndim == 3 ? buf_info.shape[2] : 1;
  23. // 假设数据类型为uint8,可以根据需求调整
  24. mat = cv::Mat(rows, cols, CV_8UC(channels), buf_info.ptr).clone();
  25. return true;
  26. }
  27. // C++ -> Python
  28. static handle cast(const cv::Mat &mat, return_value_policy, handle) {
  29. return py::array_t<unsigned char>(
  30. {mat.rows, mat.cols, mat.channels()},
  31. {mat.step[0], mat.step[1], sizeof(unsigned char)},
  32. mat.data
  33. ).release();
  34. }
  35. private:
  36. cv::Mat mat;
  37. };
  38. }
  39. class TrtResnetInfer{
  40. public:
  41. TrtResnetInfer(std::string model_path)
  42. {
  43. instance_ = resnet::load(model_path);
  44. }
  45. resnet::Image cvimg(const cv::Mat &image)
  46. {
  47. return resnet::Image(image.data, image.cols, image.rows);
  48. }
  49. resnet::Attribute forward(const cv::Mat& image)
  50. {
  51. cout << image.size << std::endl;
  52. return instance_->forward(cvimg(image));
  53. }
  54. resnet::Attribute forward_path(const std::string& image_path)
  55. {
  56. cv::Mat image = cv::imread(image_path);
  57. return instance_->forward(cvimg(image));
  58. }
  59. bool valid(){
  60. return instance_ != nullptr;
  61. }
  62. private:
  63. std::shared_ptr<resnet::Infer> instance_;
  64. };
  65. PYBIND11_MODULE(trtresnet, m){
  66. py::class_<resnet::Attribute>(m, "Attribute")
  67. .def_readwrite("confidence", &resnet::Attribute::confidence)
  68. .def_readwrite("class_label", &resnet::Attribute::class_label)
  69. .def("__repr__", [](const resnet::Attribute &attr) {
  70. std::ostringstream oss;
  71. oss << "Attribute(class_label: " << attr.class_label << ", confidence: " << attr.confidence << ")";
  72. return oss.str();
  73. });;
  74. py::class_<TrtResnetInfer>(m, "TrtResnetInfer")
  75. .def(py::init<string>(), py::arg("model_path"))
  76. .def_property_readonly("valid", &TrtResnetInfer::valid)
  77. .def("forward_path", &TrtResnetInfer::forward_path, py::arg("image_path"))
  78. .def("forward", &TrtResnetInfer::forward, py::arg("image"));
  79. };