leon 1 miesiąc temu
rodzic
commit
879908866f

+ 1 - 1
src/infer/infer.cpp

@@ -14,7 +14,7 @@ std::shared_ptr<Infer> load(const std::string& model_path, ModelType model_type,
     case ModelType::YOLOV8:
     case ModelType::YOLO11:
     case ModelType::YOLO11POSE:
-        infer = yolo::load_yolo(model_path, model_type, gpu_id, confidence_threshold, nms_threshold);
+        infer = yolo::load_yolo(model_path, model_type, names, gpu_id, confidence_threshold, nms_threshold);
         break;
     default:
         break;

+ 7 - 5
src/infer/trt/yolo.cu

@@ -362,7 +362,7 @@ void YoloModelImpl::preprocess(int ibatch, affine::LetterBoxMatrix &affine, void
 }
 
 
-bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, float confidence_threshold, float nms_threshold)
+bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold)
 {
     trt_ = TensorRT::load(engine_file);
     if (trt_ == nullptr) return false;
@@ -372,6 +372,7 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, f
     this->confidence_threshold_ = confidence_threshold;
     this->nms_threshold_ = nms_threshold;
     this->model_type_ = model_type;
+    this->class_names_ = names;
 
     auto input_dim = trt_->static_dims(0);
     bbox_head_dims_ = trt_->static_dims(1);
@@ -527,6 +528,7 @@ data::BoxArray YoloModelImpl::forwards(void *stream)
         if (keepflag == 1) 
         {
             data::Box result_object_box(pbox[0], pbox[1], pbox[2], pbox[3], pbox[4], label);
+            result_object_box.label = class_names_[label];
             if (model_type_ == ModelType::YOLO11POSE)
             {
                 result_object_box.keypoints.reserve(KEY_POINT_NUM);
@@ -542,11 +544,11 @@ data::BoxArray YoloModelImpl::forwards(void *stream)
     return result;
 }
 
-Infer *loadraw(const std::string &engine_file, ModelType model_type, float confidence_threshold,
+Infer *loadraw(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold,
                float nms_threshold) 
 {
     YoloModelImpl *impl = new YoloModelImpl();
-    if (!impl->load(engine_file, model_type, confidence_threshold, nms_threshold)) 
+    if (!impl->load(engine_file, model_type, names, confidence_threshold, nms_threshold)) 
     {
         delete impl;
         impl = nullptr;
@@ -555,10 +557,10 @@ Infer *loadraw(const std::string &engine_file, ModelType model_type, float confi
     return impl;
 }
 
-std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, int gpu_id, float confidence_threshold, float nms_threshold) 
+std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, int gpu_id, float confidence_threshold, float nms_threshold) 
 {
     checkRuntime(cudaSetDevice(gpu_id));
-    return std::shared_ptr<YoloModelImpl>((YoloModelImpl *)loadraw(engine_file, model_type, confidence_threshold, nms_threshold));
+    return std::shared_ptr<YoloModelImpl>((YoloModelImpl *)loadraw(engine_file, model_type, names, confidence_threshold, nms_threshold));
 }
 
 }

+ 4 - 2
src/infer/trt/yolo.hpp

@@ -24,6 +24,8 @@ namespace yolo
     {
     public:
         ModelType model_type_;
+
+        std::vector<std::string> class_names_;
     
         // for sahi crop image
         std::shared_ptr<slice::SliceImage> slice_;
@@ -63,10 +65,10 @@ namespace yolo
         virtual data::BoxArray forwards(void *stream = nullptr);
 };
 
-Infer *loadraw(const std::string &engine_file, ModelType model_type, float confidence_threshold,
+Infer *loadraw(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold,
     float nms_threshold);
 
-std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, int gpu_id, float confidence_threshold, float nms_threshold);
+std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, int gpu_id, float confidence_threshold, float nms_threshold);
 
 } // namespace yolo
 

+ 14 - 6
src/nodes/infer/inferNode.cpp

@@ -24,12 +24,20 @@ void InferNode::work()
             int width = image.cols;
             int height = image.rows;
 
-            // cv::imwrite("test.jpg", image);
-            int x = rand() % width;
-            int y = rand() % height;
-            int w = rand() % (width - x);
-            int h = rand() % (height - y);
-            metaData->boxes.push_back(data::Box(x, y, x + w, y + h, 0.9, 0));
+            // // cv::imwrite("test.jpg", image);
+            // int x = rand() % width;
+            // int y = rand() % height;
+            // int w = rand() % (width - x);
+            // int h = rand() % (height - y);
+            // metaData->boxes.push_back(data::Box(x, y, x + w, y + h, 0.9, 0));
+
+            auto res = model_->forward(image);
+            for (auto& r : res)
+            {
+                r.label = model_->get_class_name(r.class_id);
+                metaData->boxes.push_back(r);
+            }
+            metaData->boxes = res;
             for (auto& output_buffer : output_buffers_)
             {
                 // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());