leon 1 mēnesi atpakaļ
vecāks
revīzija
c3d7ba1d42
1 mainītis faili ar 3 papildinājumiem un 3 dzēšanām
  1. 3 3
      src/infer/opencv/yolov5.cpp

+ 3 - 3
src/infer/opencv/yolov5.cpp

@@ -29,7 +29,7 @@ public:
 
     std::vector<std::string> names_;
 
-    bool load(const std::string& model_path, const std::vector<std::string>& names, int gpu_id=0, float confidence_threshold=0.5f, float nms_threshold=0.45f)
+    bool load(const std::string& model_path, const std::vector<std::string>& names, float confidence_threshold=0.5f, float nms_threshold=0.45f)
     {
         net_ = std::make_shared<cv::dnn::Net>(cv::dnn::readNet(model_path));
         // 获取模型输入层名称
@@ -172,10 +172,10 @@ public:
     }
 };
 
-std::shared_ptr<Infer> load(const std::string &engine_file, YoloType yolo_type, int gpu_id, float confidence_threshold, float nms_threshold)
+std::shared_ptr<Infer> load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, int gpu_id, float confidence_threshold, float nms_threshold)
 {
     // checkRuntime(cudaSetDevice(gpu_id));
-    return std::shared_ptr<Yolov5InferImpl>((Yolov5InferImpl *)(new Yolov5InferImpl(engine_file, yolo_type, confidence_threshold, nms_threshold)));
+    return std::shared_ptr<Yolov5InferImpl>((Yolov5InferImpl *)(new Yolov5InferImpl(engine_file, names, confidence_threshold, nms_threshold)));
 }