inferNode.cpp 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include "nodes/base/base.hpp"
  2. #include "nodes/infer/inferNode.hpp"
  3. #include "common/image.hpp"
  4. #include "common/utils.hpp"
  5. #include <unordered_map>
  6. #include <random>
  7. #include <algorithm>
  8. namespace GNode
  9. {
  10. void print_mat(const cv::Mat& mat, int max_rows = 10, int max_cols = 10)
  11. {
  12. for (int i = 0; i < std::min(max_rows, mat.rows); i++) {
  13. for (int j = 0; j < std::min(max_cols, mat.cols); j++) {
  14. std::cout << mat.at<float>(i,j) << " ";
  15. }
  16. std::cout << (mat.cols > max_cols ? "..." : "") << std::endl;
  17. }
  18. if (mat.rows > max_rows) std::cout << "[...]" << std::endl;
  19. }
  20. void InferNode::work()
  21. {
  22. printf("InferNode %s\n", name_.c_str());
  23. while (running_)
  24. {
  25. Timer timer("InferNode");
  26. bool has_data = false;
  27. for (auto& input_buffer : input_buffers_)
  28. {
  29. std::shared_ptr<meta::MetaData> metaData;
  30. if (!input_buffer.second->try_pop(metaData))
  31. {
  32. continue;
  33. }
  34. has_data = true;
  35. // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
  36. cv::Mat image = metaData->image;
  37. int width = image.cols;
  38. int height = image.rows;
  39. // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  40. if (!model_)
  41. {
  42. printf("model is nullptr\n");
  43. continue;
  44. }
  45. auto det_result = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  46. if (std::holds_alternative<data::BoxArray>(det_result))
  47. {
  48. auto result = std::get<data::BoxArray>(det_result);
  49. for (auto& r : result)
  50. {
  51. metaData->boxes.push_back(r);
  52. }
  53. }
  54. else if(std::holds_alternative<cv::Mat>(det_result))
  55. {
  56. auto depth_mat = std::get<cv::Mat>(det_result);
  57. // print_mat(depth_mat);
  58. metaData->depth = depth_mat;
  59. }
  60. else
  61. {
  62. printf("Unexpected result type from model");
  63. throw std::runtime_error("Unexpected result type from model");
  64. }
  65. for (auto& output_buffer : output_buffers_)
  66. {
  67. // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
  68. output_buffer.second->push(metaData);
  69. }
  70. }
  71. if (!has_data)
  72. {
  73. std::unique_lock<std::mutex> lock(mutex_);
  74. cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
  75. return !running_; // 等待时检查退出条件
  76. });
  77. }
  78. }
  79. };
  80. } // namespace StreamNode