Bladeren bron

update track

leon 1 maand geleden
bovenliggende
commit
efa23819be
3 gewijzigde bestanden met toevoegingen van 17 en 40 verwijderingen
  1. 0 1
      src/3rd/ByteTrack/src/BYTETracker.cpp
  2. 8 36
      src/nodes/track/trackNode.cpp
  3. 9 3
      src/nodes/track/trackNode.hpp

+ 0 - 1
src/3rd/ByteTrack/src/BYTETracker.cpp

@@ -18,7 +18,6 @@ BYTETracker::~BYTETracker()
 
 vector<STrack> BYTETracker::update(const vector<Object>& objects)
 {
-
 	////////////////// Step 1: Get detections //////////////////
 	this->frame_id++;
 	vector<STrack> activated_stracks;

+ 8 - 36
src/nodes/track/trackNode.cpp

@@ -6,6 +6,10 @@ namespace GNode
 void TrackNode::work()
 {
     printf("TrackNode %s\n", name_.c_str());
+    for (const auto& input_buffer : input_buffers_)
+    {
+        track_map_[input_buffer.first] = std::make_shared<BYTETracker>(frame_rate_, track_buffer_);
+    }
     while (running_)
     {
         bool has_data = false;
@@ -20,7 +24,7 @@ void TrackNode::work()
             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
 
             // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
-            if (!tracker_)
+            if (!track_map_[input_buffer.first] )
             {
                 printf("track is nullptr\n");
                 continue;
@@ -28,7 +32,8 @@ void TrackNode::work()
             std::vector<Object> objects;
             for (const auto& box : metaData->boxes) 
             {
-                if (box.label == track_label_) { // 只处理需要的 label
+                // 只处理需要的 label
+                if (box.label == track_label_) { 
                     Object obj;
                     obj.rect.x = box.left;
                     obj.rect.y = box.top;
@@ -42,45 +47,12 @@ void TrackNode::work()
                     }
                 }
             }
-            // ***** 关键调试打印:检查输入给 tracker 的 objects *****
-            // printf("节点 %s: tracker_->update() 的输入 objects (共 %zu 个):\n", name_.c_str(), objects.size());
-            // for (size_t i = 0; i < objects.size(); ++i) {
-            //     const auto& obj = objects[i];
-            //     printf("  Object %zu: Prob=%.2f, Rect=[%.1f, %.1f, %.1f, %.1f]\n",
-            //         i,
-            //         obj.prob, // 确保 Object 结构里有 score/prob
-            //         obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height); // 确保 Object 结构里有 rect
-            // }
-            // ***********************************************
-            std::vector<STrack> output_stracks = tracker_->update(objects);
-            // ***** 详细打印 tracker_->update() 的输出 *****
-            // printf("节点 %s: tracker_->update() 返回了 %zu 个 STrack:\n", name_.c_str(), output_stracks.size());
-            // for (size_t i = 0; i < output_stracks.size(); ++i) {
-            //     const auto& track = output_stracks[i];
-            //     const std::vector<float>& tlwh = track.tlwh;
-            //     printf("  Track %zu: ID=%d, Score=%.2f, TLWH=[%.1f, %.1f, %.1f, %.1f]\n",
-            //         i,
-            //         track.track_id,
-            //         track.score,
-            //         tlwh[0], tlwh[1], tlwh[2], tlwh[3]);
-            // }
-            // ***********************************************
-
-            // 然后才开始添加到 metaData->track_boxes
-            metaData->track_boxes.clear(); // 确保清空 (或者你已确认不需要)
-            if (output_stracks.size() > objects.size())
-            {
-                printf("output_stracks size: %d, objects size: %d\n", output_stracks.size(), objects.size());
-            }
+            std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
 
             for (const auto& track : output_stracks) {
                 const std::vector<float>& tlwh = track.tlwh;
                 metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
             }
-            if (output_stracks.size() != metaData->track_boxes.size())
-            {
-                printf("output_stracks size: %d, metaData->track_boxes size: %d\n", output_stracks.size(), metaData->track_boxes.size());
-            }
 
             for (auto& output_buffer : output_buffers_)
             {

+ 9 - 3
src/nodes/track/trackNode.hpp

@@ -15,18 +15,24 @@ class TrackNode : public BaseNode
 {
 public:
     TrackNode() = delete;
-    TrackNode(const std::string& name, const std::string& track_label, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
+    TrackNode(const std::string& name, const std::string& track_label, int frame_rate=30, int track_buffer=30) : BaseNode(name, NODE_TYPE::MID_NODE) 
     {
         track_label_ = track_label;
-        tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
+        frame_rate_ = frame_rate;
+        track_buffer_ = track_buffer;
+        // tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
     }
     virtual ~TrackNode()  { stop(); };
     void work() override;
 
 private:
-    std::shared_ptr<BYTETracker> tracker_;
+    // std::shared_ptr<BYTETracker> tracker_;
+    std::unordered_map<std::string, std::shared_ptr<BYTETracker>> track_map_;
     std::string track_label_;
 
+    int frame_rate_;
+    int track_buffer_;
+
 };
 
 }