leon 4 周之前
父节点
当前提交
93b6fd5027
共有 3 个文件被更改,包括 7 次插入6 次删除
  1. 2 1
      src/common/data.hpp
  2. 2 1
      src/infer/trt/yolo/yolo.cu
  3. 3 4
      src/nodes/draw/drawNode.cpp

+ 2 - 1
src/common/data.hpp

@@ -60,7 +60,7 @@ struct Box
     int class_id;
     int class_id;
     std::string label;
     std::string label;
     std::vector<Point> keypoints;
     std::vector<Point> keypoints;
-    std::shared_ptr<InstanceSegmentMap> seg;
+    cv::Mat seg_mask;
     Box() : left(0), top(0), right(0), bottom(0), score(0), class_id(0), label("") {}
     Box() : left(0), top(0), right(0), bottom(0), score(0), class_id(0), label("") {}
     Box(float left, float top, float right, float bottom, float score, int class_id) 
     Box(float left, float top, float right, float bottom, float score, int class_id) 
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label("") {}
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label("") {}
@@ -82,6 +82,7 @@ struct Box
         class_id = b.class_id;
         class_id = b.class_id;
         label = b.label;
         label = b.label;
         keypoints = b.keypoints;
         keypoints = b.keypoints;
+        seg_mask  = b.seg_mask.clone();
         return *this;
         return *this;
     }
     }
 };
 };

+ 2 - 1
src/infer/trt/yolo/yolo.cu

@@ -763,7 +763,8 @@ Result YoloModelImpl::forwards(void *stream)
             }
             }
             else if (model_type_ == ModelType::YOLO11SEG)
             else if (model_type_ == ModelType::YOLO11SEG)
             {
             {
-                result_object_box.seg = decode_segment(imemory, pbox, stream);
+                auto seg = decode_segment(imemory, pbox, stream);
+                result_object_box.seg_mask = cv::Mat mask(seg->height, seg->width, CV_8UC1, seg->data);
             }
             }
             
             
             result.emplace_back(result_object_box);
             result.emplace_back(result_object_box);

+ 3 - 4
src/nodes/draw/drawNode.cpp

@@ -17,7 +17,7 @@ static std::tuple<int, int, int> getFontSize(const std::string& text)
     return std::make_tuple(textSize.width, textSize.height, baseline);
     return std::make_tuple(textSize.width, textSize.height, baseline);
 }
 }
 
 
-cv::Mat overlay_mask(
+static void overlay_mask(
     cv::Mat& image, const cv::Mat& smallMask,
     cv::Mat& image, const cv::Mat& smallMask,
     int roiX, int roiY,
     int roiX, int roiY,
     const cv::Scalar& color, double alpha) 
     const cv::Scalar& color, double alpha) 
@@ -122,10 +122,9 @@ void DrawNode::handle_data(std::shared_ptr<meta::MetaData>& meta_data)
         
         
         cv::putText(image, text, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
         cv::putText(image, text, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
 
 
-        if (box.seg)
+        if (box.seg_mask)
         {
         {
-            cv::Mat mask(box.seg->height, box.seg->width, CV_8UC1, box.seg->data);
-            overlay_mask(image, mask, obj.left, obj.top, cv::Scalar(b, g, r), 0.6);
+            overlay_mask(image, box.seg_mask, box.left, box.top, cv::Scalar(b, g, r), 0.6);
         }
         }
     }
     }
     meta_data->draw_image = image;
     meta_data->draw_image = image;