Przeglądaj źródła

update track node

leon 4 tygodni temu
rodzic
commit
02e0fe3c65
2 zmienionych plików z 24 dodań i 53 usunięć
  1. 23 52
      src/nodes/track/trackNode.cpp
  2. 1 1
      src/nodes/track/trackNode.hpp

+ 23 - 52
src/nodes/track/trackNode.cpp

@@ -3,66 +3,37 @@
 namespace GNode
 {
 
-void TrackNode::work()
+void TrackNode::handle_data(std::shared_ptr<meta::MetaData>& meta_data)
 {
-    PLOGI.printf("TrackNode : [%s] start", name_.c_str());
-    for (const auto& input_buffer : input_buffers_)
+    if (!track_map_[input_buffer.first] )
     {
-        track_map_[input_buffer.first] = std::make_shared<BYTETracker>(frame_rate_, track_buffer_);
+        return;
     }
-    while (running_)
+    std::vector<Object> objects;
+    for (const auto& box : metaData->boxes) 
     {
-        bool has_data = false;
-        for (auto& input_buffer : input_buffers_)
-        {
-            std::shared_ptr<meta::MetaData> metaData;
-            if (!input_buffer.second->try_pop(metaData))
-            {
-                continue;
-            }
-            has_data = true;
-            // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
-
-            // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
-            if (!track_map_[input_buffer.first] )
-            {
-                PLOGE.printf("TrackNode : [%s] track is nullptr", name_.c_str());
-                continue;
-            }
-            std::vector<Object> objects;
-            for (const auto& box : metaData->boxes) 
-            {
-                // 只处理需要的 label
-                if (box.label == track_label_) { 
-                    Object obj;
-                    obj.rect.x = box.left;
-                    obj.rect.y = box.top;
-                    obj.rect.width = box.right - box.left;
-                    obj.rect.height = box.bottom - box.top;
-                    obj.label = box.class_id; // 假设 Object::label 存的是 int 类型的 class_id
-                    obj.prob = box.score;
-
-                    if (obj.rect.width > 0 && obj.rect.height > 0 && obj.prob > 0) { // 至少prob > 0
-                         objects.push_back(obj); // 只添加有效的对象
-                    }
-                }
-            }
-            std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
+        // 只处理需要的 label
+        if (box.label == track_label_) { 
+            Object obj;
+            obj.rect.x = box.left;
+            obj.rect.y = box.top;
+            obj.rect.width = box.right - box.left;
+            obj.rect.height = box.bottom - box.top;
+            obj.label = box.class_id; // 假设 Object::label 存的是 int 类型的 class_id
+            obj.prob = box.score;
 
-            for (const auto& track : output_stracks) {
-                const std::vector<float>& tlwh = track.tlwh;
-                metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
+            if (obj.rect.width > 0 && obj.rect.height > 0 && obj.prob > 0) { // 至少prob > 0
+                    objects.push_back(obj); // 只添加有效的对象
             }
-            send_output_data(metaData);    
-        }
-        if (!has_data)
-        {
-            std::unique_lock<std::mutex> lock(mutex_);
-            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
-                return !running_;  // 等待时检查退出条件
-            });
         }
     }
+    std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
+
+    for (const auto& track : output_stracks) {
+        const std::vector<float>& tlwh = track.tlwh;
+        metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
+    }
+
 };
 
 }   // namespace Node

+ 1 - 1
src/nodes/track/trackNode.hpp

@@ -26,7 +26,7 @@ public:
         PLOGI.printf("TrackNode : [%s] Init. track label is %s, rate is %d, buffer is %d", name_.c_str(), track_label_.c_str(), frame_rate_, track_buffer_);
     }
     virtual ~TrackNode()  { stop(); };
-    void work() override;
+    void handle_data(std::shared_ptr<meta::MetaData>& meta_data) override;
 
 private:
     std::unordered_map<std::string, std::shared_ptr<BYTETracker>> track_map_;