trackNode.cpp 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #include "nodes/track/trackNode.hpp"
  2. namespace GNode
  3. {
  4. void TrackNode::work()
  5. {
  6. PLOGI.printf("TrackNode : [%s] start", name_.c_str());
  7. for (const auto& input_buffer : input_buffers_)
  8. {
  9. track_map_[input_buffer.first] = std::make_shared<BYTETracker>(frame_rate_, track_buffer_);
  10. }
  11. while (running_)
  12. {
  13. bool has_data = false;
  14. for (auto& input_buffer : input_buffers_)
  15. {
  16. std::shared_ptr<meta::MetaData> metaData;
  17. if (!input_buffer.second->try_pop(metaData))
  18. {
  19. continue;
  20. }
  21. has_data = true;
  22. // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
  23. // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  24. if (!track_map_[input_buffer.first] )
  25. {
  26. PLOGE.printf("TrackNode : [%s] track is nullptr", name_.c_str());
  27. continue;
  28. }
  29. std::vector<Object> objects;
  30. for (const auto& box : metaData->boxes)
  31. {
  32. // 只处理需要的 label
  33. if (box.label == track_label_) {
  34. Object obj;
  35. obj.rect.x = box.left;
  36. obj.rect.y = box.top;
  37. obj.rect.width = box.right - box.left;
  38. obj.rect.height = box.bottom - box.top;
  39. obj.label = box.class_id; // 假设 Object::label 存的是 int 类型的 class_id
  40. obj.prob = box.score;
  41. if (obj.rect.width > 0 && obj.rect.height > 0 && obj.prob > 0) { // 至少prob > 0
  42. objects.push_back(obj); // 只添加有效的对象
  43. }
  44. }
  45. }
  46. std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
  47. for (const auto& track : output_stracks) {
  48. const std::vector<float>& tlwh = track.tlwh;
  49. metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
  50. }
  51. for (auto& output_buffer : output_buffers_)
  52. {
  53. // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
  54. output_buffer.second->push(metaData);
  55. }
  56. }
  57. if (!has_data)
  58. {
  59. std::unique_lock<std::mutex> lock(mutex_);
  60. cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
  61. return !running_; // 等待时检查退出条件
  62. });
  63. }
  64. }
  65. };
  66. } // namespace Node