leon il y a 1 mois
Parent
commit
586ef8b7ed
2 fichiers modifiés avec 13 ajouts et 13 suppressions
  1. 4 4
      src/common/data.hpp
  2. 9 9
      src/infer/trt/yolo.cu

+ 4 - 4
src/common/data.hpp

@@ -35,14 +35,14 @@ struct Box
     float left, top, right, bottom, score;
     int class_id;
     std::string label;
-    std::vector<Point> points;
+    std::vector<Point> keypoints;
     Box() : left(0), top(0), right(0), bottom(0), score(0), class_id(0), label("") {}
     Box(float left, float top, float right, float bottom, float score, int class_id) 
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label("") {}
     Box(float left, float top, float right, float bottom, float score, int class_id,  const std::string& label) 
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label(label) {}
-    Box(const Box& b) : left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), class_id(b.class_id), label(b.label), points(b.points) {}
-    Box(const Box&& b) : left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), class_id(b.class_id), label(b.label), points(b.points) {}
+    Box(const Box& b) : left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), class_id(b.class_id), label(b.label), keypoints(b.keypoints) {}
+    Box(const Box&& b) : left(b.left), top(b.top), right(b.right), bottom(b.bottom), score(b.score), class_id(b.class_id), label(b.label), keypoints(b.keypoints) {}
     Box& operator=(const Box& b)
     {
         left = b.left;
@@ -52,7 +52,7 @@ struct Box
         score = b.score;
         class_id = b.class_id;
         label = b.label;
-        points = b.points;
+        keypoints = b.keypoints;
         return *this;
     }
 };

+ 9 - 9
src/infer/trt/yolo.cu

@@ -380,7 +380,7 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, f
     isdynamic_model_ = trt_->has_dynamic_dim();
 
     normalize_ = affine::Norm::alpha_beta(1 / 255.0f, 0.0f, affine::ChannelType::SwapRB);
-    if (this->model_type_ == ModelType::YOLOV8 || this->model_type_ == ModelType::YOLOV11)
+    if (this->model_type_ == ModelType::YOLOV8 || this->model_type_ == ModelType::YOLO11)
     {
         num_classes_ = bbox_head_dims_[2] - 4;
     }
@@ -388,7 +388,7 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, f
     {
         num_classes_ = bbox_head_dims_[2] - 5;
     }
-    else if (this->model_type_ == ModelType::YOLOV11POSE)
+    else if (this->model_type_ == ModelType::YOLO11POSE)
     {
         num_classes_ = bbox_head_dims_[2] - 4 - KEY_POINT_NUM * 3;
         // NUM_BOX_ELEMENT = 8 + KEY_POINT_NUM * 3;
@@ -483,13 +483,13 @@ data::BoxArray YoloModelImpl::forwards(void *stream)
                                 bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
                                 affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
         }
-        else if (model_type_ == ModelType::YOLOV8 || model_type_ == ModelType::YOLOV11)
+        else if (model_type_ == ModelType::YOLOV8 || model_type_ == ModelType::YOLO11)
         {
             decode_kernel_invoker_v8(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
                                 bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
                                 affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
         }
-        else if (model_type_ == ModelType::YOLOV11POSE)
+        else if (model_type_ == ModelType::YOLO11POSE)
         {
             decode_kernel_invoker_v11pose(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
                 bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
@@ -498,7 +498,7 @@ data::BoxArray YoloModelImpl::forwards(void *stream)
         
     }
     float *boxarray_device =  output_boxarray_.gpu();
-    if (model_type_ == ModelType::YOLOV11POSE)
+    if (model_type_ == ModelType::YOLO11POSE)
     {
         fast_nms_pose_kernel_invoker(boxarray_device, box_count, MAX_IMAGE_BOXES, nms_threshold_, stream_);
     }
@@ -520,19 +520,19 @@ data::BoxArray YoloModelImpl::forwards(void *stream)
 
     for (int i = 0; i < count; ++i) 
     {
-        int box_element = (model_type_ == ModelType::YOLOV11POSE) ? (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3) : NUM_BOX_ELEMENT;
+        int box_element = (model_type_ == ModelType::YOLO11POSE) ? (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3) : NUM_BOX_ELEMENT;
         float *pbox = parray + i * box_element;
         int label = pbox[5];
         int keepflag = pbox[6];
         if (keepflag == 1) 
         {
             data::Box result_object_box(pbox[0], pbox[1], pbox[2], pbox[3], pbox[4], label);
-            if (model_type_ == ModelType::YOLOV11POSE)
+            if (model_type_ == ModelType::YOLO11POSE)
             {
-                result_object_box.pose.reserve(KEY_POINT_NUM);
+                result_object_box.keypoints.reserve(KEY_POINT_NUM);
                 for (int i = 0; i < KEY_POINT_NUM; i++)
                 {
-                    result_object_box.pose.emplace_back(pbox[8+i*3], pbox[8+i*3+1], pbox[8+i*3+2]);
+                    result_object_box.keypoints.emplace_back(pbox[8+i*3], pbox[8+i*3+1], pbox[8+i*3+2]);
                 }
             }