inferNode.cpp 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #include "nodes/base/base.hpp"
  2. #include "nodes/infer/inferNode.hpp"
  3. #include "common/image.hpp"
  4. #include <unordered_map>
  5. #include <random>
  6. #include <algorithm>
  7. namespace Node
  8. {
  9. void print_mat(const cv::Mat& mat, int max_rows = 10, int max_cols = 10)
  10. {
  11. for (int i = 0; i < std::min(max_rows, mat.rows); i++) {
  12. for (int j = 0; j < std::min(max_cols, mat.cols); j++) {
  13. std::cout << mat.at<float>(i,j) << " ";
  14. }
  15. std::cout << (mat.cols > max_cols ? "..." : "") << std::endl;
  16. }
  17. if (mat.rows > max_rows) std::cout << "[...]" << std::endl;
  18. }
  19. void InferNode::work()
  20. {
  21. printf("InferNode %s\n", name_.c_str());
  22. while (running_)
  23. {
  24. bool has_data = false;
  25. for (auto& input_buffer : input_buffers_)
  26. {
  27. std::shared_ptr<meta::MetaData> metaData;
  28. if (!input_buffer.second->try_pop(metaData))
  29. {
  30. continue;
  31. }
  32. has_data = true;
  33. // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
  34. cv::Mat image = metaData->image;
  35. int width = image.cols;
  36. int height = image.rows;
  37. // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
  38. if (!model_)
  39. {
  40. printf("model is nullptr\n");
  41. continue;
  42. }
  43. auto det_result = model_->forward(tensor::cvimg(image));
  44. if (std::holds_alternative<data::BoxArray>(det_result))
  45. {
  46. auto result = std::get<data::BoxArray>(det_result);
  47. for (auto& r : result)
  48. {
  49. metaData->boxes.push_back(r);
  50. }
  51. metaData->boxes = result;
  52. // 处理检测框...
  53. }
  54. else if(std::holds_alternative<cv::Mat>(det_result))
  55. {
  56. auto depth_mat = std::get<cv::Mat>(det_result);
  57. print_mat(depth_mat);
  58. metaData->depth = depth_mat;
  59. }
  60. else
  61. {
  62. printf("Unexpected result type from model");
  63. throw std::runtime_error("Unexpected result type from model");
  64. }
  65. for (auto& output_buffer : output_buffers_)
  66. {
  67. // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
  68. output_buffer.second->push(metaData);
  69. }
  70. }
  71. if (!has_data)
  72. {
  73. std::unique_lock<std::mutex> lock(mutex_);
  74. cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
  75. return !running_; // 等待时检查退出条件
  76. });
  77. }
  78. }
  79. };
  80. } // namespace Node