|
@@ -0,0 +1,177 @@
|
|
|
+#ifndef YOLO_HPP__
|
|
|
+#define YOLO_HPP__
|
|
|
+
|
|
|
+#include "src/common/data.hpp"
|
|
|
+#include "infer/infer.hpp"
|
|
|
+
|
|
|
+#include "opencv2/opencv.hpp"
|
|
|
+#include "opencv2/dnn.hpp"
|
|
|
+#include <memory>
|
|
|
+#include <string>
|
|
|
+#include <vector>
|
|
|
+#include <iostream>
|
|
|
+
|
|
|
+
|
|
|
+namespace yolo
|
|
|
+{
|
|
|
+
|
|
|
+class Yolov5InferImpl : public Infer
|
|
|
+{
|
|
|
+public:
|
|
|
+ ModelType model_type;
|
|
|
+ std::shared_ptr<cv::dnn::Net> net_;
|
|
|
+
|
|
|
+ float confidence_threshold_;
|
|
|
+ float nms_threshold_;
|
|
|
+ int network_input_width_, network_input_height_;
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ std::vector<std::string> names_;
|
|
|
+
|
|
|
+ bool load(const std::string& model_path, const std::vector<std::string>& names, int gpu_id=0, float confidence_threshold=0.5f, float nms_threshold=0.45f)
|
|
|
+ {
|
|
|
+ net_ = std::make_shared<cv::dnn::Net>(cv::dnn::readNet(model_path));
|
|
|
+ // 获取模型输入层名称
|
|
|
+ std::vector<std::string> inputNames = net->getLayerNames();
|
|
|
+
|
|
|
+ // 获取输入层的形状信息
|
|
|
+ std::vector<std::vector<int>> inShapes, outShapes;
|
|
|
+ net.getLayerShapes(cv::dnn::Dict(), 0, inShapes, outShapes);
|
|
|
+
|
|
|
+ if (!inShapes.empty()) {
|
|
|
+ int batchSize = inShapes[0][0]; // 批次大小(通常为1)
|
|
|
+ int channels = inShapes[0][1]; // 通道数
|
|
|
+ network_input_height_ = inShapes[0][2]; // 高度
|
|
|
+ network_input_width_ = inShapes[0][3]; // 宽度
|
|
|
+
|
|
|
+ std::cout << "Model Input Shape: " << batchSize << "x"
|
|
|
+ << channels << "x" << network_input_height_ << "x" << network_input_width_ << std::endl;
|
|
|
+ } else {
|
|
|
+ std::cout << "Failed to get input shape!" << std::endl;
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ void warpAffine(cv::Mat& src_image, cv::Mat& dst_image, float *d2i)
|
|
|
+ {
|
|
|
+ int src_image_width = src_image.cols;
|
|
|
+ int src_image_height = src_image.rows;
|
|
|
+
|
|
|
+ float scale_x = network_input_width_ / (float)src_image_width;
|
|
|
+ float scale_y = network_input_height_ / (float)src_image_height;
|
|
|
+ float scale = std::min(scale_x, scale_y);
|
|
|
+ float i2d[6];
|
|
|
+ i2d[0] = scale;
|
|
|
+ i2d[1] = 0;
|
|
|
+ i2d[2] = (-scale * src_image_width + network_input_width_ + scale - 1) * 0.5;
|
|
|
+ i2d[3] = 0;
|
|
|
+ i2d[4] = scale;
|
|
|
+ i2d[5] = (-scale * src_image_height + network_input_height_ + scale - 1) * 0.5;
|
|
|
+
|
|
|
+ cv::Mat m2x3_i2d(2, 3, CV_32F, i2d);
|
|
|
+ cv::Mat m2x3_d2i(2, 3, CV_32F, d2i);
|
|
|
+ cv::invertAffineTransform(m2x3_i2d, m2x3_d2i);
|
|
|
+
|
|
|
+ dst_image.create(network_input_height_, network_input_width_, CV_8UC3);
|
|
|
+ cv::warpAffine(src_image, dst_image, m2x3_i2d, dst_image.size(), cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar::all(114));
|
|
|
+ }
|
|
|
+
|
|
|
+ void decode(std::vector<cv::Mat>& outs,
|
|
|
+ data::BoxArray& result_boxes,
|
|
|
+ float *d2i,
|
|
|
+ int src_image_width,
|
|
|
+ int src_image_height)
|
|
|
+ {
|
|
|
+ data::BoxArray boxes;
|
|
|
+ int cols = outs[0].size[2];
|
|
|
+ int rows = outs[0].size[1];
|
|
|
+ float* predict = (float*)outs[0].data;
|
|
|
+ int num_classes = cols - 5;
|
|
|
+
|
|
|
+ for(int i = 0; i < rows; ++i)
|
|
|
+ {
|
|
|
+ float* pitem = predict + i * cols;
|
|
|
+ float objness = pitem[4];
|
|
|
+ if (objness < confidence_threshold_)
|
|
|
+ {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ float* pclass = pitem + 5;
|
|
|
+ int label = std::max_element(pclass, pclass + num_classes) - pclass;
|
|
|
+ float prob = pclass[label];
|
|
|
+ float confidence = prob * objness;
|
|
|
+ if(confidence < confidence_threshold_)
|
|
|
+ {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ float cx = pitem[0];
|
|
|
+ float cy = pitem[1];
|
|
|
+ float width = pitem[2];
|
|
|
+ float height = pitem[3];
|
|
|
+
|
|
|
+ // 通过反变换恢复到图像尺度
|
|
|
+ float left = (cx - width * 0.5) * d2i[0] + d2i[2];
|
|
|
+ float top = (cy - height * 0.5) * d2i[0] + d2i[5];
|
|
|
+ float right = (cx + width * 0.5) * d2i[0] + d2i[2];
|
|
|
+ float bottom = (cy + height * 0.5) * d2i[0] + d2i[5];
|
|
|
+ boxes.emplace_back(left, top, right, bottom, confidence, names_[label]);
|
|
|
+ }
|
|
|
+ std::sort(boxes.begin(), boxes.end(), [](Box& a, Box& b){return a.confidence > b.confidence;});
|
|
|
+ std::vector<bool> remove_flags(boxes.size());
|
|
|
+ result_boxes.reserve(boxes.size());
|
|
|
+
|
|
|
+ auto iou = [](const Box& a, const Box& b){
|
|
|
+ int cross_left = std::max(a.left, b.left);
|
|
|
+ int cross_top = std::max(a.top, b.top);
|
|
|
+ int cross_right = std::min(a.right, b.right);
|
|
|
+ int cross_bottom = std::min(a.bottom, b.bottom);
|
|
|
+
|
|
|
+ int cross_area = std::max(0, cross_right - cross_left) * std::max(0, cross_bottom - cross_top);
|
|
|
+ int union_area = std::max(0.f, a.right - a.left) * std::max(0.f, a.bottom - a.top)
|
|
|
+ + std::max(0.f, b.right - b.left) * std::max(0.f, b.bottom - b.top) - cross_area;
|
|
|
+ if(cross_area == 0 || union_area == 0) return 0.0f;
|
|
|
+ return 1.0f * cross_area / union_area;
|
|
|
+ };
|
|
|
+
|
|
|
+ for(int i = 0; i < boxes.size(); ++i)
|
|
|
+ {
|
|
|
+ if(remove_flags[i]) continue;
|
|
|
+
|
|
|
+ auto& ibox = boxes[i];
|
|
|
+ result_boxes.emplace_back(ibox);
|
|
|
+ for (int j = i + 1; j < boxes.size(); ++j)
|
|
|
+ {
|
|
|
+ if (remove_flags[j]) continue;
|
|
|
+
|
|
|
+ auto& jbox = boxes[j];
|
|
|
+ if (ibox.class_id == jbox.class_id)
|
|
|
+ {
|
|
|
+ // class matched
|
|
|
+ if (iou(ibox, jbox) >= nms_threshold_)
|
|
|
+ remove_flags[j] = true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ virtual data::BoxArray forward(cv::Mat& image) override
|
|
|
+ {
|
|
|
+ float d2i[6];
|
|
|
+ cv::Mat affine_image;
|
|
|
+ warpAffine(image, affine_image, d2i);
|
|
|
+ std::vector<cv::Mat> outs;
|
|
|
+ auto blob = cv::dnn::blobFromImage(affine_image, 1 / 255.0, cv::Size(network_input_height_, network_input_width_), cv::Scalar(0, 0, 0), true, false);
|
|
|
+ net_->setInput(blob);
|
|
|
+ net_->forward(outs, net_.getUnconnectedOutLayersNames());
|
|
|
+ data::BoxArray result;
|
|
|
+ decode(outs, result, d2i, image.cols, image.rows);
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+}
|
|
|
+
|
|
|
+#endif // YOLO_HPP__
|