inferNode.cpp 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #include "nodes/base/base.hpp"
  2. #include "nodes/infer/inferNode.hpp"
  3. #include "common/image.hpp"
  4. #include <unordered_map>
  5. #include <random>
  6. #include <algorithm>
  7. namespace GNode
  8. {
  9. void print_mat(const cv::Mat& mat, int max_rows = 10, int max_cols = 10)
  10. {
  11. for (int i = 0; i < std::min(max_rows, mat.rows); i++) {
  12. for (int j = 0; j < std::min(max_cols, mat.cols); j++) {
  13. std::cout << mat.at<float>(i,j) << " ";
  14. }
  15. std::cout << (mat.cols > max_cols ? "..." : "") << std::endl;
  16. }
  17. if (mat.rows > max_rows) std::cout << "[...]" << std::endl;
  18. }
  19. void InferNode::work()
  20. {
  21. printf("InferNode %s\n", name_.c_str());
  22. while (running_)
  23. {
  24. bool has_data = false;
  25. for (auto& input_buffer : input_buffers_)
  26. {
  27. std::shared_ptr<meta::MetaData> metaData;
  28. if (!input_buffer.second->try_pop(metaData))
  29. {
  30. continue;
  31. }
  32. has_data = true;
  33. // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
  34. cv::Mat image = metaData->image;
  35. int width = image.cols;
  36. int height = image.rows;
  37. // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  38. if (!model_)
  39. {
  40. printf("model is nullptr\n");
  41. continue;
  42. }
  43. auto det_result = model_->forward(tensor::cvimg(image));
  44. if (std::holds_alternative<data::BoxArray>(det_result))
  45. {
  46. auto result = std::get<data::BoxArray>(det_result);
  47. for (auto& r : result)
  48. {
  49. metaData->boxes.push_back(r);
  50. }
  51. }
  52. else if(std::holds_alternative<cv::Mat>(det_result))
  53. {
  54. auto depth_mat = std::get<cv::Mat>(det_result);
  55. // print_mat(depth_mat);
  56. metaData->depth = depth_mat;
  57. }
  58. else
  59. {
  60. printf("Unexpected result type from model");
  61. throw std::runtime_error("Unexpected result type from model");
  62. }
  63. for (auto& output_buffer : output_buffers_)
  64. {
  65. // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
  66. output_buffer.second->push(metaData);
  67. }
  68. }
  69. if (!has_data)
  70. {
  71. std::unique_lock<std::mutex> lock(mutex_);
  72. cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
  73. return !running_; // 等待时检查退出条件
  74. });
  75. }
  76. }
  77. };
  78. } // namespace StreamNode