Browse Source

修改追踪

leon 2 months ago
parent
commit
a36297b005
4 changed files with 23 additions and 21 deletions
  1. 6 5
      src/main.cpp
  2. 1 1
      src/nodes/record/recordNode.hpp
  3. 12 14
      src/nodes/track/trackNode.cpp
  4. 4 1
      src/nodes/track/trackNode.hpp

+ 6 - 5
src/main.cpp

@@ -30,21 +30,22 @@ void test_depth()
 
 void test_yolo()
 {
-    std::vector<std::string> names = { "person", "clothes", "vest" };
+    // std::vector<std::string> names = { "person", "clothes", "vest" };
+    std::vector<std::string> names = { "person", "car", "close", "open" };
     std::shared_ptr<GNode::StreamNode> src_node0   = std::make_shared<GNode::StreamNode>("src0", "rtsp://admin:lww123456@172.16.22.16:554/Streaming/Channels/201", 0, GNode::DecodeType::GPU);
     src_node0->set_skip_frame(1);
 
-    std::shared_ptr<Infer> yolo_model = load("model/yolo11s.engine", ModelType::YOLO11, names, 0, 0.25, 0.45);
-    std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("yolo11");
+    std::shared_ptr<Infer> yolo_model = load("model/carperson.engine", ModelType::YOLOV5, names, 0, 0.25, 0.45);
+    std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("yolov5");
     infer_node->set_model_instance(yolo_model, ModelType::YOLO11);
 
-    std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", 30, 30);
+    std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", {"person"}, 30, 30);
 
     std::shared_ptr<GNode::DrawNode> draw_node     = std::make_shared<GNode::DrawNode>("draw");
     std::shared_ptr<GNode::RecordNode> record_node = std::make_shared<GNode::RecordNode>("record");
     record_node->set_record_path("result/result.mp4");
     record_node->set_fps(25);
-    record_node->set_fourcc(cv::VideoWriter::fourcc('M', 'J', 'P', 'G'));
+    record_node->set_fourcc(cv::VideoWriter::fourcc('X', '2', '6', '4'));
     
     GNode::LinkNode(src_node0, infer_node);
     GNode::LinkNode(infer_node, track_node);

+ 1 - 1
src/nodes/record/recordNode.hpp

@@ -27,7 +27,7 @@ public:
 private:
     std::string record_path_;
     int fps_ = 25;
-    int fourcc_ = cv::VideoWriter::fourcc('M', 'J', 'P', 'G');
+    int fourcc_ = cv::VideoWriter::fourcc('X', '2', '6', '4');
     cv::VideoWriter writer_;
 };
     

+ 12 - 14
src/nodes/track/trackNode.cpp

@@ -3,7 +3,7 @@
 namespace GNode
 {
 
-const float IOU_THRESHOLD = 0.9; // 可调阈值
+const float IOU_THRESHOLD = 0.7; // 可调阈值
 
 float calculate_iou(const std::vector<float>& tlwh, const data::Box& box) {
     float track_x1 = tlwh[0];
@@ -59,14 +59,16 @@ void TrackNode::work()
             }
             std::vector<Object> objects;
             std::transform(metaData->boxes.begin(), metaData->boxes.end(), std::back_inserter(objects), [](data::Box& box) {
-                Object obj;
-                obj.rect.x = box.left;
-                obj.rect.y = box.top;
-                obj.rect.width = box.right - box.left;
-                obj.rect.height = box.bottom - box.top;
-                obj.label = box.class_id;
-                obj.prob = box.score;
-                return obj;
+                if (std::find(track_labels_.begin(), track_labels_.end(), box.label) != track_labels_.end()) {
+                    Object obj;
+                    obj.rect.x = box.left;
+                    obj.rect.y = box.top;
+                    obj.rect.width = box.right - box.left;
+                    obj.rect.height = box.bottom - box.top;
+                    obj.label = box.class_id;
+                    obj.prob = box.score;
+                    return obj;
+                }
             });
             std::vector<STrack> output_stracks = tracker_->update(objects);
             std::vector<bool> box_matched(metaData->boxes.size(), false); // 标记 box 是否已被匹配
@@ -93,13 +95,9 @@ void TrackNode::work()
                 if (best_match_idx != -1 && max_iou >= IOU_THRESHOLD) {
                     metaData->boxes[best_match_idx].track_id = track_id;
                     box_matched[best_match_idx] = true; // 标记此 box 已匹配
-                    printf("Track ID %d matched to box %d with IoU %.2f\n", track_id, best_match_idx, max_iou);
-                } else {
-                    printf("Track ID %d (TLWH: %.1f,%.1f,%.1f,%.1f) could not find a suitable match (max IoU: %.2f)\n",
-                            track_id, tlwh[0], tlwh[1], tlwh[2], tlwh[3], max_iou);
                 }
             }
-            
+
             for (auto& output_buffer : output_buffers_)
             {
                 // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());

+ 4 - 1
src/nodes/track/trackNode.hpp

@@ -15,8 +15,9 @@ class TrackNode : public BaseNode
 {
 public:
     TrackNode() = delete;
-    TrackNode(const std::string& name, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
+    TrackNode(const std::string& name, std::vector<std::string> track_labels, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
     {
+        track_labels_ = track_labels;
         tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
     }
     virtual ~TrackNode()  { stop(); };
@@ -24,6 +25,8 @@ public:
 
 private:
     std::shared_ptr<BYTETracker> tracker_ = nullptr;
+    std::vector<std::string> track_labels_;
+
 };
 
 }