trackNode.cpp 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. #include "nodes/track/trackNode.hpp"
  2. namespace GNode
  3. {
  4. void TrackNode::handle_data(std::shared_ptr<meta::MetaData>& meta_data)
  5. {
  6. if (!track_map_[meta_data->from] )
  7. {
  8. track_map_[meta_data->from] = std::make_shared<BYTETracker>(frame_rate_, track_buffer_);
  9. }
  10. std::vector<Object> objects;
  11. for (const auto& box : meta_data->boxes)
  12. {
  13. // 只处理需要的 label
  14. if (box.label == track_label_) {
  15. Object obj;
  16. obj.rect.x = box.left;
  17. obj.rect.y = box.top;
  18. obj.rect.width = box.right - box.left;
  19. obj.rect.height = box.bottom - box.top;
  20. obj.label = box.class_id; // 假设 Object::label 存的是 int 类型的 class_id
  21. obj.prob = box.score;
  22. if (obj.rect.width > 0 && obj.rect.height > 0 && obj.prob > 0) { // 至少prob > 0
  23. objects.push_back(obj); // 只添加有效的对象
  24. }
  25. }
  26. }
  27. std::vector<STrack> output_stracks = track_map_[input_buffer.first] ->update(objects);
  28. for (const auto& track : output_stracks) {
  29. const std::vector<float>& tlwh = track.tlwh;
  30. meta_data->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
  31. }
  32. };
  33. } // namespace Node