trackNode.cpp 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #include "nodes/track/trackNode.hpp"
  2. namespace GNode
  3. {
  4. void TrackNode::work()
  5. {
  6. PLOGI.printf("TrackNode : [%s] start", name_.c_str());
  7. for (const auto& input_buffer : input_buffers_)
  8. {
  9. track_map_[input_buffer.first] = std::make_shared<BYTETracker>(frame_rate_, track_buffer_);
  10. }
  11. while (running_)
  12. {
  13. bool has_data = false;
  14. for (auto& input_buffer : input_buffers_)
  15. {
  16. std::shared_ptr<meta::MetaData> metaData;
  17. if (!input_buffer.second->try_pop(metaData))
  18. {
  19. continue;
  20. }
  21. has_data = true;
  22. // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
  23. // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  24. if (!track_map_[input_buffer.first] )
  25. {
  26. PLOGE.printf("TrackNode : [%s] track is nullptr", name_.c_str());
  27. continue;
  28. }
  29. std::vector<Object> objects;
  30. for (const auto& box : metaData->boxes)
  31. {
  32. // 只处理需要的 label
  33. if (box.label == track_label_) {
  34. Object obj;
  35. obj.rect.x = box.left;
  36. obj.rect.y = box.top;
  37. obj.rect.width = box.right - box.left;
  38. obj.rect.height = box.bottom - box.top;
  39. obj.label = box.class_id; // 假设 Object::label 存的是 int 类型的 class_id
  40. obj.prob = box.score;
  41. if (obj.rect.width > 0 && obj.rect.height > 0 && obj.prob > 0) { // 至少prob > 0
  42. objects.push_back(obj); // 只添加有效的对象
  43. }
  44. }
  45. }
  46. std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
  47. for (const auto& track : output_stracks) {
  48. const std::vector<float>& tlwh = track.tlwh;
  49. metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
  50. }
  51. send_output_data(metaData);
  52. }
  53. if (!has_data)
  54. {
  55. std::unique_lock<std::mutex> lock(mutex_);
  56. cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
  57. return !running_; // 等待时检查退出条件
  58. });
  59. }
  60. }
  61. };
  62. } // namespace Node