leon преди 1 месец
родител
ревизия
a5f7ccfa4b
променени са 6 файла, в които са добавени 12 реда и са изтрити 31 реда
  1. 3 3
      src/common/data.hpp
  2. 1 0
      src/common/meta.hpp
  3. 2 3
      src/main.cpp
  4. 1 1
      src/nodes/draw/drawNode.cpp
  5. 2 21
      src/nodes/track/trackNode.cpp
  6. 3 3
      src/nodes/track/trackNode.hpp

+ 3 - 3
src/common/data.hpp

@@ -33,8 +33,8 @@ struct Point
 struct Box
 {
     float left, top, right, bottom, score;
+    // 在目标追踪中,将class_id赋值表示追踪的目标id
     int class_id;
-    int track_id = -1;
     std::string label;
     std::vector<Point> keypoints;
     Box() : left(0), top(0), right(0), bottom(0), score(0), class_id(0), label("") {}
@@ -44,10 +44,10 @@ struct Box
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label(label) {}
     Box(const Box& b) : 
         left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), 
-        class_id(b.class_id), label(b.label), track_id(b.track_id), keypoints(b.keypoints) {}
+        class_id(b.class_id), label(b.label), keypoints(b.keypoints) {}
     Box(const Box&& b) : 
         left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), 
-        class_id(b.class_id), label(b.label), track_id(b.track_id), keypoints(b.keypoints) {}
+        class_id(b.class_id), label(b.label), keypoints(b.keypoints) {}
     Box& operator=(const Box& b)
     {
         left = b.left;

+ 1 - 0
src/common/meta.hpp

@@ -14,6 +14,7 @@ struct MetaData{
     cv::Mat draw_image; // 画框图
     cv::Mat depth; // 深度图
     data::BoxArray boxes; // 目标检测识别结果
+    data::BoxArray track_boxes; // 跟踪结果
     data::BoxArray result; // 分析结果
 };
 

+ 2 - 3
src/main.cpp

@@ -40,10 +40,9 @@ void test_yolo()
     std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("yolov5");
     infer_node->set_model_instance(yolo_model, ModelType::YOLO11);
 
-    std::vector<std::string> tracker_label = { "person"};
-    std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", tracker_label, 30, 30);
+    std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", "person", 30, 30);
 
-    std::shared_ptr<GNode::DrawNode> draw_node     = std::make_shared<GNode::DrawNode>("draw");
+    std::shared_ptr<GNode::DrawNode> draw_node     = std::make_shared<GNode::DrawNode>("draw_track");
     std::shared_ptr<GNode::RecordNode> record_node = std::make_shared<GNode::RecordNode>("record");
     record_node->set_record_path("result/result.mp4");
     record_node->set_fps(25);

+ 1 - 1
src/nodes/draw/drawNode.cpp

@@ -78,7 +78,7 @@ void DrawNode::work()
             int image_width = image.cols;
             int image_height = image.rows;
             PositionManager<float> pm(getFontSize);
-            for (auto& box : metaData->boxes)
+            for (auto& box : metaData->track_boxes)
             {
                 uint8_t b, g, r;
                 std::tie(b, g, r) = random_color(box.class_id);

+ 2 - 21
src/nodes/track/trackNode.cpp

@@ -59,7 +59,7 @@ void TrackNode::work()
             }
             std::vector<Object> objects;
             std::transform(metaData->boxes.begin(), metaData->boxes.end(), std::back_inserter(objects), [this](data::Box& box) {
-                if (std::find(track_labels_.begin(), track_labels_.end(), box.label) != track_labels_.end()) {
+                if (box.label == track_label_) {
                     Object obj;
                     obj.rect.x = box.left;
                     obj.rect.y = box.top;
@@ -74,28 +74,9 @@ void TrackNode::work()
             std::vector<bool> box_matched(metaData->boxes.size(), false); // 标记 box 是否已被匹配
 
             for (const auto& track : output_stracks) {
-                int best_match_idx = -1;
-                float max_iou = -1.0f;
-
                 const std::vector<float>& tlwh = track.tlwh;
                 int track_id = track.track_id;
-
-                for (size_t i = 0; i < metaData->boxes.size(); ++i) {
-                    // 只考虑尚未匹配的 box
-                    if (!box_matched[i]) {
-                        float iou = calculate_iou(tlwh, metaData->boxes[i]);
-                        if (iou > max_iou) {
-                            max_iou = iou;
-                            best_match_idx = i;
-                        }
-                    }
-                }
-
-                // 如果找到足够好的匹配
-                if (best_match_idx != -1 && max_iou >= IOU_THRESHOLD) {
-                    metaData->boxes[best_match_idx].track_id = track_id;
-                    box_matched[best_match_idx] = true; // 标记此 box 已匹配
-                }
+                metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], 1.0f, track_id, track_label_);
             }
 
             for (auto& output_buffer : output_buffers_)

+ 3 - 3
src/nodes/track/trackNode.hpp

@@ -15,9 +15,9 @@ class TrackNode : public BaseNode
 {
 public:
     TrackNode() = delete;
-    TrackNode(const std::string& name, const std::vector<std::string>& track_labels, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
+    TrackNode(const std::string& name, const std::string& track_label, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
     {
-        track_labels_ = track_labels;
+        track_label_ = track_label;
         tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
     }
     virtual ~TrackNode()  { stop(); };
@@ -25,7 +25,7 @@ public:
 
 private:
     std::shared_ptr<BYTETracker> tracker_ = nullptr;
-    std::vector<std::string> track_labels_;
+    std::string track_label_;
 
 };