leon 1 月之前
父节点
当前提交
7a76a1ca8d
共有 2 个文件被更改,包括 12 次插入4 次删除
  1. 10 2
      src/nodes/track/trackNode.cpp
  2. 2 2
      src/nodes/track/trackNode.hpp

+ 10 - 2
src/nodes/track/trackNode.cpp

@@ -6,6 +6,10 @@ namespace GNode
 void TrackNode::work()
 {
     printf("TrackNode %s\n", name_.c_str());
+    for (const auto& input_buffers : input_buffers_)
+    {
+        tracker_[input_buffers.first] = std::make_shared<BYTETracker>(30, 30);
+    }
     while (running_)
     {
         bool has_data = false;
@@ -20,7 +24,7 @@ void TrackNode::work()
             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
 
             // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
-            if (!tracker_)
+            if (!tracker_[input_buffers.first])
             {
                 printf("track is nullptr\n");
                 continue;
@@ -38,7 +42,7 @@ void TrackNode::work()
                     return obj;
                 }
             });
-            std::vector<STrack> output_stracks = tracker_->update(objects);
+            std::vector<STrack> output_stracks = tracker_[input_buffers.first]->update(objects);
             if (output_stracks.size() > objects.size())
             {
                 printf("output_stracks size: %d, objects size: %d\n", output_stracks.size(), objects.size());
@@ -48,6 +52,10 @@ void TrackNode::work()
                 const std::vector<float>& tlwh = track.tlwh;
                 metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
             }
+            if (output_stracks.size() != metaData->track_boxes.size())
+            {
+                printf("output_stracks size: %d, metaData->track_boxes size: %d\n", output_stracks.size(), metaData->track_boxes.size());
+            }
 
             for (auto& output_buffer : output_buffers_)
             {

+ 2 - 2
src/nodes/track/trackNode.hpp

@@ -18,13 +18,13 @@ public:
     TrackNode(const std::string& name, const std::string& track_label, int frame_rate, int track_buffer) : BaseNode(name, NODE_TYPE::MID_NODE) 
     {
         track_label_ = track_label;
-        tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
+        // tracker_ = std::make_shared<BYTETracker>(frame_rate, track_buffer);
     }
     virtual ~TrackNode()  { stop(); };
     void work() override;
 
 private:
-    std::shared_ptr<BYTETracker> tracker_ = nullptr;
+    std::unordered_map<std::string, std::shared_ptr<BYTETracker>> tracker_;
     std::string track_label_;
 
 };