leon 4 settimane fa
parent
commit
caadb484f2
2 ha cambiato i file con 5 aggiunte e 4 eliminazioni
  1. 4 3
      src/infer/trt/yolo/yolo.cu
  2. 1 1
      src/infer/trt/yolo/yolo.hpp

+ 4 - 3
src/infer/trt/yolo/yolo.cu

@@ -447,7 +447,7 @@ void YoloModelImpl::preprocess(int ibatch, void *stream)
 }
 
 
-std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imemory, float* pbox)
+std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imemory, float* pbox, void *stream)
 {
     int row_index = pbox[7];
     int batch_index = pbox[8];
@@ -456,7 +456,7 @@ std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imem
 
     int start_x = slice_->slice_start_point_.cpu()[batch_index*2];
     int start_y = slice_->slice_start_point_.cpu()[batch_index*2+1];
-    
+
     int mask_dim = segment_head_dims_[1];
 
     float *mask_weights = bbox_output_device +
@@ -487,6 +487,7 @@ std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imem
     int mask_out_width = box_width * scale_to_predict_x + 0.5f;
     int mask_out_height = box_height * scale_to_predict_y + 0.5f;
 
+    cudaStream_t stream_ = (cudaStream_t)stream;
     if (mask_out_width > 0 && mask_out_height > 0)
     {
         if (imemory >= (int)box_segment_cache_.size()) 
@@ -762,7 +763,7 @@ Result YoloModelImpl::forwards(void *stream)
             }
             else if (model_type_ == ModelType::YOLO11SEG)
             {
-                result_object_box.seg = decode_segment(imemory, pbox);
+                result_object_box.seg = decode_segment(imemory, pbox, stream);
             }
             
             result.emplace_back(result_object_box);

+ 1 - 1
src/infer/trt/yolo/yolo.hpp

@@ -68,7 +68,7 @@ namespace yolo
 
         void cal_affine_matrix(affine::LetterBoxMatrix &affine, void *stream = nullptr);
         
-        std::shared_ptr<data::InstanceSegmentMap> decode_segment(int imemory, float* pbox);
+        std::shared_ptr<data::InstanceSegmentMap> decode_segment(int imemory, float* pbox, void *stream);
     
         bool load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold);