leon 1 ay önce
ebeveyn
işleme
da7c8d91d5

+ 3 - 2
src/infer/infer.hpp

@@ -13,11 +13,12 @@ enum class ModelType : int{
     DEPTH_ANYTHING = 4
 };
 
+using Result = std::variant<data::BoxArray, cv::Mat>;
 
 class Infer{
 public:
-    virtual data::BoxArray forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream = nullptr) = 0;
-    virtual data::BoxArray forward(const tensor::Image &image, void *stream = nullptr) = 0;
+    virtual Result forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream = nullptr) = 0;
+    virtual Result forward(const tensor::Image &image, void *stream = nullptr) = 0;
 
 protected:
     std::mutex mutex_;

+ 1 - 1
src/infer/trt/depth/depth.cu

@@ -79,7 +79,7 @@ void DepthModelImpl::postprocess(int width, int height, void *stream)
 }
 
 
-cv::Mat DepthModelImpl::forward(const tensor::Image &image, void *stream)
+Result DepthModelImpl::forward(const tensor::Image &image, void *stream)
 {
     int num_image = 1;
     if (num_image == 0) return {};

+ 1 - 1
src/infer/trt/depth/depth.hpp

@@ -66,7 +66,7 @@ namespace depth
     
         bool load(const std::string &engine_file);
     
-        virtual cv::Mat forward(const tensor::Image &image, void *stream = nullptr);
+        virtual Result forward(const tensor::Image &image, void *stream = nullptr);
     
 };
 

+ 3 - 3
src/infer/trt/yolo/yolo.cu

@@ -398,21 +398,21 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, c
     return true;
 }
 
-data::BoxArray YoloModelImpl::forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream)
+Result YoloModelImpl::forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream)
 {
     std::lock_guard<std::mutex> lock(mutex_); // 自动加锁/解锁
     slice_->slice(image, slice_width, slice_height, overlap_width_ratio, overlap_height_ratio, stream);
     return forwards(stream);
 }
 
-data::BoxArray YoloModelImpl::forward(const tensor::Image &image, void *stream)
+Result YoloModelImpl::forward(const tensor::Image &image, void *stream)
 {
     std::lock_guard<std::mutex> lock(mutex_); // 自动加锁/解锁
     slice_->autoSlice(image, stream);
     return forwards(stream);
 }
 
-data::BoxArray YoloModelImpl::forwards(void *stream)
+Result YoloModelImpl::forwards(void *stream)
 {
     int num_image = slice_->slice_num_h_ * slice_->slice_num_v_;
     if (num_image == 0) return {};

+ 3 - 3
src/infer/trt/yolo/yolo.hpp

@@ -58,11 +58,11 @@ namespace yolo
     
         bool load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold);
     
-        virtual data::BoxArray forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream = nullptr);
+        virtual Result forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream = nullptr);
     
-        virtual data::BoxArray forward(const tensor::Image &image, void *stream = nullptr);
+        virtual Result forward(const tensor::Image &image, void *stream = nullptr);
     
-        virtual data::BoxArray forwards(void *stream = nullptr);
+        virtual Result forwards(void *stream = nullptr);
 };
 
 Infer *loadraw(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold,