浏览代码

update tracknode

leon 2 月之前
父节点
当前提交
125c0c5d29
共有 3 个文件被更改,包括 60 次插入26 次删除
  1. 1 5
      src/3rd/ByteTrack/src/bytetrack.cpp0
  2. 1 1
      src/main.cpp
  3. 58 20
      src/nodes/track/trackNode.cpp

+ 1 - 5
src/3rd/ByteTrack/src/bytetrack.cpp0

@@ -388,7 +388,6 @@ void doInference(IExecutionContext& context, float* input, float* output, const
     CHECK(cudaFree(buffers[outputIndex]));
 }
 
-/*
 
 int main(int argc, char** argv) {
     cudaSetDevice(DEVICE);
@@ -505,7 +504,4 @@ int main(int argc, char** argv) {
     engine->destroy();
     runtime->destroy();
     return 0;
-}
-
-
-*/
+}

+ 1 - 1
src/main.cpp

@@ -9,7 +9,7 @@
 void test_depth()
 {
     std::shared_ptr<GNode::StreamNode> src_node0   = std::make_shared<GNode::StreamNode>("src0", "rtsp://admin:lww123456@172.16.22.16:554/Streaming/Channels/201", 0, GNode::DecodeType::GPU);
-    src_node0->set_skip_frame(10);
+    src_node0->set_skip_frame(1);
 
     std::shared_ptr<Infer> depth_model = load("model/depth.engine", ModelType::DEPTH_ANYTHING, {}, 0, 0.25, 0.45);
     std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("depth");

+ 58 - 20
src/nodes/track/trackNode.cpp

@@ -3,8 +3,36 @@
 namespace GNode
 {
 
+const float IOU_THRESHOLD = 0.9; // 可调阈值
 
-constexpr float EPSILON = 1e-3; // 根据实际情况调整
+float calculate_iou(const std::vector<float>& tlwh, const data::Box& box) {
+    float track_x1 = tlwh[0];
+    float track_y1 = tlwh[1];
+    float track_x2 = tlwh[0] + tlwh[2];
+    float track_y2 = tlwh[1] + tlwh[3];
+
+    float box_x1 = box.left;
+    float box_y1 = box.top;
+    float box_x2 = box.right;
+    float box_y2 = box.bottom;
+
+    float inter_x1 = std::max(track_x1, box_x1);
+    float inter_y1 = std::max(track_y1, box_y1);
+    float inter_x2 = std::min(track_x2, box_x2);
+    float inter_y2 = std::min(track_y2, box_y2);
+
+    float inter_area = std::max(0.0f, inter_x2 - inter_x1) * std::max(0.0f, inter_y2 - inter_y1);
+
+    float track_area = tlwh[2] * tlwh[3];
+    float box_area = (box.right - box.left) * (box.bottom - box.top);
+
+    float union_area = track_area + box_area - inter_area;
+
+    if (union_area <= 0.0f) {
+        return 0.0f;
+    }
+    return inter_area / union_area;
+}
 
 
 void TrackNode::work()
@@ -41,27 +69,37 @@ void TrackNode::work()
                 return obj;
             });
             std::vector<STrack> output_stracks = tracker_->update(objects);
-            for (int i = 0; i < output_stracks.size(); i++)
-            {
-                vector<float> tlwh = output_stracks[i].tlwh;
-                int track_id = output_stracks[i].track_id;
-                printf("track id : %d\n", track_id);
-
-                std::for_each(metaData->boxes.begin(), metaData->boxes.end(),
-                    [track_id, &tlwh](data::Box& box) { // 注意这里改为引用捕获 tlwh
-                        bool width_match  = std::abs((box.right - box.left) - tlwh[2]) < EPSILON;
-                        bool height_match = std::abs((box.bottom - box.top) - tlwh[3]) < EPSILON;
-                        printf("width_match: %d, height_match: %d\n", width_match, height_match);
-                        if (std::abs(box.left - tlwh[0]) < EPSILON &&
-                            std::abs(box.top - tlwh[1]) < EPSILON &&
-                            width_match &&
-                            height_match) 
-                        {
-                            box.track_id = track_id;
-                            // 如果只需修改第一个匹配项,可在此抛出异常或记录状态
+            std::vector<bool> box_matched(metaData->boxes.size(), false); // 标记 box 是否已被匹配
+
+            for (const auto& track : output_stracks) {
+                int best_match_idx = -1;
+                float max_iou = -1.0f;
+
+                const std::vector<float>& tlwh = track.tlwh;
+                int track_id = track.track_id;
+
+                for (size_t i = 0; i < metaData->boxes.size(); ++i) {
+                    // 只考虑尚未匹配的 box
+                    if (!box_matched[i]) {
+                        float iou = calculate_iou(tlwh, metaData->boxes[i]);
+                        if (iou > max_iou) {
+                            max_iou = iou;
+                            best_match_idx = i;
                         }
-                    });
+                    }
+                }
+
+                // 如果找到足够好的匹配
+                if (best_match_idx != -1 && max_iou >= IOU_THRESHOLD) {
+                    metaData->boxes[best_match_idx].track_id = track_id;
+                    box_matched[best_match_idx] = true; // 标记此 box 已匹配
+                    printf("Track ID %d matched to box %d with IoU %.2f\n", track_id, best_match_idx, max_iou);
+                } else {
+                    printf("Track ID %d (TLWH: %.1f,%.1f,%.1f,%.1f) could not find a suitable match (max IoU: %.2f)\n",
+                            track_id, tlwh[0], tlwh[1], tlwh[2], tlwh[3], max_iou);
+                }
             }
+            
             for (auto& output_buffer : output_buffers_)
             {
                 // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());