yolov5.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #ifndef YOLO_HPP__
  2. #define YOLO_HPP__
  3. #include "src/common/data.hpp"
  4. #include "infer/infer.hpp"
  5. #include "opencv2/opencv.hpp"
  6. #include "opencv2/dnn.hpp"
  7. #include <memory>
  8. #include <string>
  9. #include <vector>
  10. #include <iostream>
  11. namespace yolo
  12. {
  13. class Yolov5InferImpl : public Infer
  14. {
  15. public:
  16. ModelType model_type;
  17. std::shared_ptr<cv::dnn::Net> net_;
  18. float confidence_threshold_;
  19. float nms_threshold_;
  20. int network_input_width_, network_input_height_;
  21. std::vector<std::string> names_;
  22. bool load(const std::string& model_path, const std::vector<std::string>& names, int gpu_id=0, float confidence_threshold=0.5f, float nms_threshold=0.45f)
  23. {
  24. net_ = std::make_shared<cv::dnn::Net>(cv::dnn::readNet(model_path));
  25. // 获取模型输入层名称
  26. std::vector<std::string> inputNames = net->getLayerNames();
  27. // 获取输入层的形状信息
  28. std::vector<std::vector<int>> inShapes, outShapes;
  29. net.getLayerShapes(cv::dnn::Dict(), 0, inShapes, outShapes);
  30. if (!inShapes.empty()) {
  31. int batchSize = inShapes[0][0]; // 批次大小(通常为1)
  32. int channels = inShapes[0][1]; // 通道数
  33. network_input_height_ = inShapes[0][2]; // 高度
  34. network_input_width_ = inShapes[0][3]; // 宽度
  35. std::cout << "Model Input Shape: " << batchSize << "x"
  36. << channels << "x" << network_input_height_ << "x" << network_input_width_ << std::endl;
  37. } else {
  38. std::cout << "Failed to get input shape!" << std::endl;
  39. }
  40. return true;
  41. }
  42. void warpAffine(cv::Mat& src_image, cv::Mat& dst_image, float *d2i)
  43. {
  44. int src_image_width = src_image.cols;
  45. int src_image_height = src_image.rows;
  46. float scale_x = network_input_width_ / (float)src_image_width;
  47. float scale_y = network_input_height_ / (float)src_image_height;
  48. float scale = std::min(scale_x, scale_y);
  49. float i2d[6];
  50. i2d[0] = scale;
  51. i2d[1] = 0;
  52. i2d[2] = (-scale * src_image_width + network_input_width_ + scale - 1) * 0.5;
  53. i2d[3] = 0;
  54. i2d[4] = scale;
  55. i2d[5] = (-scale * src_image_height + network_input_height_ + scale - 1) * 0.5;
  56. cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);
  57. cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);
  58. cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);
  59. dst_image.create(network_input_height_, network_input_width_, CV_8UC3);
  60. cv::warpAffine(src_image, dst_image, m2x3_i2d, dst_image.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));
  61. }
  62. void decode(std::vector<cv::Mat>& outs,
  63. data::BoxArray& result_boxes,
  64. float *d2i,
  65. int src_image_width,
  66. int src_image_height)
  67. {
  68. data::BoxArray boxes;
  69. int cols = outs[0].size[2];
  70. int rows = outs[0].size[1];
  71. float* predict = (float*)outs[0].data;
  72. int num_classes = cols - 5;
  73. for(int i = 0; i < rows; ++i)
  74. {
  75. float* pitem = predict + i * cols;
  76. float objness = pitem[4];
  77. if (objness < confidence_threshold_)
  78. {
  79. continue;
  80. }
  81. float* pclass = pitem + 5;
  82. int label = std::max_element(pclass, pclass + num_classes) - pclass;
  83. float prob = pclass[label];
  84. float confidence = prob * objness;
  85. if(confidence < confidence_threshold_)
  86. {
  87. continue;
  88. }
  89. float cx = pitem[0];
  90. float cy = pitem[1];
  91. float width = pitem[2];
  92. float height = pitem[3];
  93. // 通过反变换恢复到图像尺度
  94. float left = (cx - width * 0.5) * d2i[0] + d2i[2];
  95. float top = (cy - height * 0.5) * d2i[0] + d2i[5];
  96. float right = (cx + width * 0.5) * d2i[0] + d2i[2];
  97. float bottom = (cy + height * 0.5) * d2i[0] + d2i[5];
  98. boxes.emplace_back(left, top, right, bottom, confidence, names_[label]);
  99. }
  100. std::sort(boxes.begin(), boxes.end(), [](Box& a, Box& b){return a.confidence > b.confidence;});
  101. std::vector<bool> remove_flags(boxes.size());
  102. result_boxes.reserve(boxes.size());
  103. auto iou = [](const Box& a, const Box& b){
  104. int cross_left = std::max(a.left, b.left);
  105. int cross_top = std::max(a.top, b.top);
  106. int cross_right = std::min(a.right, b.right);
  107. int cross_bottom = std::min(a.bottom, b.bottom);
  108. int cross_area = std::max(0, cross_right - cross_left) * std::max(0, cross_bottom - cross_top);
  109. int union_area = std::max(0.f, a.right - a.left) * std::max(0.f, a.bottom - a.top)
  110. + std::max(0.f, b.right - b.left) * std::max(0.f, b.bottom - b.top) - cross_area;
  111. if(cross_area == 0 || union_area == 0) return 0.0f;
  112. return 1.0f * cross_area / union_area;
  113. };
  114. for(int i = 0; i < boxes.size(); ++i)
  115. {
  116. if(remove_flags[i]) continue;
  117. auto& ibox = boxes[i];
  118. result_boxes.emplace_back(ibox);
  119. for (int j = i + 1; j < boxes.size(); ++j)
  120. {
  121. if (remove_flags[j]) continue;
  122. auto& jbox = boxes[j];
  123. if (ibox.class_id == jbox.class_id)
  124. {
  125. // class matched
  126. if (iou(ibox, jbox) >= nms_threshold_)
  127. remove_flags[j] = true;
  128. }
  129. }
  130. }
  131. }
  132. virtual data::BoxArray forward(cv::Mat& image) override
  133. {
  134. float d2i[6];
  135. cv::Mat affine_image;
  136. warpAffine(image, affine_image, d2i);
  137. std::vector<cv::Mat> outs;
  138. auto blob = cv::dnn::blobFromImage(affine_image, 1 / 255.0, cv::Size(network_input_height_, network_input_width_), cv::Scalar(0, 0, 0), true, false);
  139. net_->setInput(blob);
  140. net_->forward(outs, net_.getUnconnectedOutLayersNames());
  141. data::BoxArray result;
  142. decode(outs, result, d2i, image.cols, image.rows);
  143. return result;
  144. }
  145. };
  146. }
  147. #endif // YOLO_HPP__