leon пре 5 месеци
родитељ
комит
2e5212fc2d
1 измењених фајлова са 82 додато и 35 уклоњено
  1. 82 35
      src/interface.cpp

+ 82 - 35
src/interface.cpp

@@ -12,44 +12,91 @@ using namespace std;
 
 namespace py=pybind11;
 
-namespace pybind11::detail {
-    template <> struct type_caster<cv::Mat> {
-    public:
-        PYBIND11_TYPE_CASTER(cv::Mat, _("cv::Mat"));
-
-        // Python -> C++
-        bool load(handle src, bool) {
-            py::array array = py::array::ensure(src);
-            if (!array) return false;
-
-            py::buffer_info buf_info = array.request();
-
-            int dtype;
-            if (buf_info.format == py::format_descriptor<uint8_t>::format()) {
-                dtype = CV_8UC(buf_info.ndim == 3 ? buf_info.shape[2] : 1);
-            } else if (buf_info.format == py::format_descriptor<float>::format()) {
-                dtype = CV_32FC(buf_info.ndim == 3 ? buf_info.shape[2] : 1);
-            } else {
-                throw std::runtime_error("Unsupported data type in input array");
-            }
-
-            mat = cv::Mat(buf_info.shape[0], buf_info.shape[1], dtype, buf_info.ptr).clone();
-            return true;
+namespace pybind11 { namespace detail{
+template<>
+struct type_caster<cv::Mat>{
+public:
+    PYBIND11_TYPE_CASTER(cv::Mat, _("numpy.ndarray"));
+
+    //! 1. cast numpy.ndarray to cv::Mat
+    bool load(handle obj, bool){
+        array b = reinterpret_borrow<array>(obj);
+        buffer_info info = b.request();
+
+        //const int ndims = (int)info.ndim;
+        int nh = 1;
+        int nw = 1;
+        int nc = 1;
+        int ndims = info.ndim;
+        if(ndims == 2){
+            nh = info.shape[0];
+            nw = info.shape[1];
+        } else if(ndims == 3){
+            nh = info.shape[0];
+            nw = info.shape[1];
+            nc = info.shape[2];
+        }else{
+            char msg[64];
+            std::sprintf(msg, "Unsupported dim %d, only support 2d, or 3-d", ndims);
+            throw std::logic_error(msg);
+            return false;
+        }
+
+        int dtype;
+        if(info.format == format_descriptor<unsigned char>::format()){
+            dtype = CV_8UC(nc);
+        }else if (info.format == format_descriptor<int>::format()){
+            dtype = CV_32SC(nc);
+        }else if (info.format == format_descriptor<float>::format()){
+            dtype = CV_32FC(nc);
+        }else{
+            throw std::logic_error("Unsupported type, only support uchar, int32, float");
+            return false;
         }
 
-        // C++ -> Python
-        static handle cast(const cv::Mat &mat, return_value_policy, handle) {
-            py::array_t<unsigned char> array(
-                {mat.rows, mat.cols, mat.channels()},
-                {mat.step[0], mat.step[1], sizeof(unsigned char)}
-            );
-            std::memcpy(array.mutable_data(), mat.data, mat.total() * mat.elemSize());
-            return array.release();
+        value = cv::Mat(nh, nw, dtype, info.ptr);
+        return true;
+    }
+
+    //! 2. cast cv::Mat to numpy.ndarray
+    static handle cast(const cv::Mat& mat, return_value_policy, handle defval){
+        UNUSED(defval);
+
+        std::string format = format_descriptor<unsigned char>::format();
+        size_t elemsize = sizeof(unsigned char);
+        int nw = mat.cols;
+        int nh = mat.rows;
+        int nc = mat.channels();
+        int depth = mat.depth();
+        int type = mat.type();
+        int dim = (depth == type)? 2 : 3;
+
+        if(depth == CV_8U){
+            format = format_descriptor<unsigned char>::format();
+            elemsize = sizeof(unsigned char);
+        }else if(depth == CV_32S){
+            format = format_descriptor<int>::format();
+            elemsize = sizeof(int);
+        }else if(depth == CV_32F){
+            format = format_descriptor<float>::format();
+            elemsize = sizeof(float);
+        }else{
+            throw std::logic_error("Unsupport type, only support uchar, int32, float");
         }
-    private:
-        cv::Mat mat;
-    };
-}
+
+        std::vector<size_t> bufferdim;
+        std::vector<size_t> strides;
+        if (dim == 2) {
+            bufferdim = {(size_t) nh, (size_t) nw};
+            strides = {elemsize * (size_t) nw, elemsize};
+        } else if (dim == 3) {
+            bufferdim = {(size_t) nh, (size_t) nw, (size_t) nc};
+            strides = {(size_t) elemsize * nw * nc, (size_t) elemsize * nc, (size_t) elemsize};
+        }
+        return array(buffer_info( mat.data,  elemsize,  format, dim, bufferdim, strides )).release();
+    }
+};
+}}//! end namespace pybind11::detail
 
 class TrtResnetInfer{
 public: