leon 4 nedēļas atpakaļ
vecāks
revīzija
c1f0f04862
2 mainītis faili ar 76 papildinājumiem un 5 dzēšanām
  1. 34 5
      src/main.cpp
  2. 42 0
      src/nodes/draw/drawNode.cpp

+ 34 - 5
src/main.cpp

@@ -42,16 +42,44 @@ void test_yolo()
     OverflowStrategy stage = OverflowStrategy::BlockTimeout;
     int max_size = 100;
 
+    std::vector<std::string> names = {
+        "person",        "bicycle",      "car",
+        "motorcycle",    "airplane",     "bus",
+        "train",         "truck",        "boat",
+        "traffic light", "fire hydrant", "stop sign",
+        "parking meter", "bench",        "bird",
+        "cat",           "dog",          "horse",
+        "sheep",         "cow",          "elephant",
+        "bear",          "zebra",        "giraffe",
+        "backpack",      "umbrella",     "handbag",
+        "tie",           "suitcase",     "frisbee",
+        "skis",          "snowboard",    "sports ball",
+        "kite",          "baseball bat", "baseball glove",
+        "skateboard",    "surfboard",    "tennis racket",
+        "bottle",        "wine glass",   "cup",
+        "fork",          "knife",        "spoon",
+        "bowl",          "banana",       "apple",
+        "sandwich",      "orange",       "broccoli",
+        "carrot",        "hot dog",      "pizza",
+        "donut",         "cake",         "chair",
+        "couch",         "potted plant", "bed",
+        "dining table",  "toilet",       "tv",
+        "laptop",        "mouse",        "remote",
+        "keyboard",      "cell phone",   "microwave",
+        "oven",          "toaster",      "sink",
+        "refrigerator",  "book",         "clock",
+        "vase",          "scissors",     "teddy bear",
+        "hair drier",    "toothbrush"};
 
     // std::vector<std::string> names = { "person", "clothes", "vest" };
-    std::vector<std::string> names = { "person", "car", "close", "open" };
+    // std::vector<std::string> names = { "person", "car", "close", "open" };
     // std::shared_ptr<GNode::StreamNode> src_node0   = std::make_shared<GNode::StreamNode>("src0", "rtsp://admin:lww123456@172.16.22.16:554/Streaming/Channels/201", 0, GNode::DecodeType::GPU);
     std::shared_ptr<GNode::StreamNode> src_node0   = std::make_shared<GNode::StreamNode>("src0", "carperson.mp4", 1, GNode::DecodeType::GPU);
     src_node0->set_skip_frame(1);
 
-    std::shared_ptr<Infer> yolo_model = load("model/carperson.engine", ModelType::YOLO11, names, 1, 0.25, 0.45);
-    std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("carperson_model");
-    infer_node->set_model_instance(yolo_model, ModelType::YOLO11, 1);
+    std::shared_ptr<Infer> yolo_model = load("model/model1.engine", ModelType::YOLO11SEG, names, 1, 0.25, 0.45);
+    std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("seg");
+    infer_node->set_model_instance(yolo_model, ModelType::YOLO11SEG, 1);
 
     std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", "person", 30, 30);
 
@@ -133,10 +161,11 @@ int main()
 // 画图节点         完成
 // 推送节点         基本完成
 // 日志             完成
+// YOLO11 seg       完成
 
 
 // 分析节点         
-// YOLO11 seg 
+
 // 通过配置文件创建 pipeline
 // 置信度阈值修改
 // 设置电子围栏

+ 42 - 0
src/nodes/draw/drawNode.cpp

@@ -17,6 +17,42 @@ static std::tuple<int, int, int> getFontSize(const std::string& text)
     return std::make_tuple(textSize.width, textSize.height, baseline);
 }
 
+cv::Mat overlay_mask(
+    cv::Mat& image, const cv::Mat& smallMask,
+    int roiX, int roiY,
+    const cv::Scalar& color, double alpha) 
+{
+    if (image.empty() || smallMask.empty() ||
+        image.type() != CV_8UC3 || smallMask.type() != CV_8UC1) {
+        return;
+    }
+    alpha = std::max(0.0, std::min(1.0, alpha));
+    
+    cv::Rect roiRect(roiX, roiY, smallMask.cols, smallMask.rows);
+    
+    cv::Rect imageRect(0, 0, image.cols, image.rows);
+    cv::Rect intersectionRect = roiRect & imageRect; // 使用 & 操作符计算交集
+    
+    if (intersectionRect.width <= 0 || intersectionRect.height <= 0) {
+        return;
+    }
+    
+    cv::Mat originalROI = image(intersectionRect); // ROI 指向 image 的数据
+    
+    int maskStartX = intersectionRect.x - roiX;
+    int maskStartY = intersectionRect.y - roiY;
+    cv::Rect maskIntersectionRect(maskStartX, maskStartY, intersectionRect.width, intersectionRect.height);
+    cv::Mat smallMaskROI = smallMask(maskIntersectionRect);
+    
+    cv::Mat colorPatchROI(intersectionRect.size(), image.type(), color);
+    
+    cv::Mat tempColoredROI = originalROI.clone(); // 需要一个临时区域进行覆盖
+    colorPatchROI.copyTo(tempColoredROI, smallMaskROI);
+    
+    cv::addWeighted(originalROI, 1.0 - alpha, tempColoredROI, alpha, 0.0, originalROI);
+}
+
+
 static std::tuple<uint8_t, uint8_t, uint8_t> hsv2bgr(float h, float s, float v) 
 {
     const int h_i = static_cast<int>(h * 6);
@@ -85,6 +121,12 @@ void DrawNode::handle_data(std::shared_ptr<meta::MetaData>& meta_data)
         std::tie(x, y) = pm.selectOptimalPosition(pbox, image_width, image_height, text);
         
         cv::putText(image, text, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
+
+        if (box.seg)
+        {
+            cv::Mat mask(box.seg->height, box.seg->width, CV_8UC1, box.seg->data);
+            overlay_mask(image, mask, obj.left, obj.top, cv::Scalar(b, g, r), 0.6);
+        }
     }
     meta_data->draw_image = image;
 }