leon 5 miesięcy temu
rodzic
commit
4747f53575
1 zmienionych plików z 15 dodań i 9 usunięć
  1. 15 9
      src/interface.cpp

+ 15 - 9
src/interface.cpp

@@ -23,22 +23,28 @@ namespace pybind11::detail {
             if (!array) return false;
 
             py::buffer_info buf_info = array.request();
-            int rows = buf_info.shape[0];
-            int cols = buf_info.shape[1];
-            int channels = buf_info.ndim == 3 ? buf_info.shape[2] : 1;
 
-            // 假设数据类型为uint8,可以根据需求调整
-            mat = cv::Mat(rows, cols, CV_8UC(channels), buf_info.ptr).clone();
+            int dtype;
+            if (buf_info.format == py::format_descriptor<uint8_t>::format()) {
+                dtype = CV_8UC(buf_info.ndim == 3 ? buf_info.shape[2] : 1);
+            } else if (buf_info.format == py::format_descriptor<float>::format()) {
+                dtype = CV_32FC(buf_info.ndim == 3 ? buf_info.shape[2] : 1);
+            } else {
+                throw std::runtime_error("Unsupported data type in input array");
+            }
+
+            mat = cv::Mat(buf_info.shape[0], buf_info.shape[1], dtype, buf_info.ptr).clone();
             return true;
         }
 
         // C++ -> Python
         static handle cast(const cv::Mat &mat, return_value_policy, handle) {
-            return py::array_t<unsigned char>(
+            py::array_t<unsigned char> array(
                 {mat.rows, mat.cols, mat.channels()},
-                {mat.step[0], mat.step[1], sizeof(unsigned char)},
-                mat.data
-            ).release();
+                {mat.step[0], mat.step[1], sizeof(unsigned char)}
+            );
+            std::memcpy(array.mutable_data(), mat.data, mat.total() * mat.elemSize());
+            return array.release();
         }
     private:
         cv::Mat mat;