leon 4 tygodni temu
rodzic
commit
06d559dff2

+ 25 - 1
src/common/data.hpp

@@ -6,7 +6,7 @@
 #include <map>
 #include <iostream>
 
-
+#include "common/check.hpp"
 
 namespace data
 {
@@ -29,6 +29,29 @@ struct Point
     }
 };
 
+struct InstanceSegmentMap 
+{
+    int width = 0, height = 0;      // width % 8 == 0
+    unsigned char *data = nullptr;  // is width * height memory
+  
+    InstanceSegmentMap(int width, int height)
+    {
+        this->width = width;
+        this->height = height;
+        checkRuntime(cudaMallocHost(&this->data, width * height));
+    }
+    virtual ~InstanceSegmentMap()
+    {
+        if (this->data) 
+        {
+            checkRuntime(cudaFreeHost(this->data));
+            this->data = nullptr;
+        }
+        this->width = 0;
+        this->height = 0; 
+    }
+};
+
 
 struct Box
 {
@@ -37,6 +60,7 @@ struct Box
     int class_id;
     std::string label;
     std::vector<Point> keypoints;
+    std::shared_ptr<InstanceSegmentMap> seg;
     Box() : left(0), top(0), right(0), bottom(0), score(0), class_id(0), label("") {}
     Box(float left, float top, float right, float bottom, float score, int class_id) 
         : left(left), top(top), right(right), bottom(bottom), score(score), class_id(class_id), label("") {}

+ 6 - 5
src/infer/infer.hpp

@@ -8,11 +8,12 @@
 #include <variant>
 
 enum class ModelType : int{
-    YOLOV5  = 0,
-    YOLOV8  = 1,
-    YOLO11 = 2,
-    YOLO11POSE = 3,
-    DEPTH_ANYTHING = 4
+    YOLOV5         = 0,
+    YOLOV5SEG      = 1,
+    YOLO11         = 2,
+    YOLO11POSE     = 3,
+    YOLO11SEG      = 4,
+    DEPTH_ANYTHING = 5
 };
 
 using Result = std::variant<data::BoxArray, cv::Mat>;

+ 63 - 2
src/infer/trt/affine.cu

@@ -168,6 +168,54 @@ static __global__ void warp_affine_bilinear_single_channel_kernel(
 }
 
 
+static __global__ void warp_affine_bilinear_single_channel_mask_kernel(
+    float *src, int src_line_size, int src_width, int src_height, uint8_t *dst, int dst_width,
+    int dst_height, float const_value_st, float *warp_affine_matrix_2_3) 
+{
+    int dx = blockDim.x * blockIdx.x + threadIdx.x;
+    int dy = blockDim.y * blockIdx.y + threadIdx.y;
+    if (dx >= dst_width || dy >= dst_height) return;
+
+    float m_x1 = warp_affine_matrix_2_3[0];
+    float m_y1 = warp_affine_matrix_2_3[1];
+    float m_z1 = warp_affine_matrix_2_3[2];
+    float m_x2 = warp_affine_matrix_2_3[3];
+    float m_y2 = warp_affine_matrix_2_3[4];
+    float m_z2 = warp_affine_matrix_2_3[5];
+
+    float src_x = m_x1 * dx + m_y1 * dy + m_z1;
+    float src_y = m_x2 * dx + m_y2 * dy + m_z2;
+    float c0;
+
+    if (src_x < 0 || src_x >= src_width || src_y < 0 || src_y >= src_height) 
+    {
+        c0 = const_value_st;
+    } 
+    else 
+    {
+        int y_low = __float2int_rz(src_y);
+        int x_low = __float2int_rz(src_x);
+        int y_high = y_low + 1;
+        int x_high = x_low + 1;
+
+        float w1 = (1 - (src_y - y_low)) * (1 - (src_x - x_low));
+        float w2 = (1 - (src_y - y_low)) * (src_x - x_low);
+        float w3 = (src_y - y_low) * (1 - (src_x - x_low));
+        float w4 = (src_y - y_low) * (src_x - x_low);
+
+        float *v1 = (y_low  >= 0 && y_low  < src_height && x_low  >= 0 && x_low  < src_width) ? src + y_low  * src_line_size + x_low  : &const_value_st;
+        float *v2 = (y_low  >= 0 && y_low  < src_height && x_high >= 0 && x_high < src_width) ? src + y_low  * src_line_size + x_high : &const_value_st;
+        float *v3 = (y_high >= 0 && y_high < src_height && x_low  >= 0 && x_low  < src_width) ? src + y_high * src_line_size + x_low  : &const_value_st;
+        float *v4 = (y_high >= 0 && y_high < src_height && x_high >= 0 && x_high < src_width) ? src + y_high * src_line_size + x_high : &const_value_st;
+
+        c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0];
+        c0 = c0 > 0.5 ? c0 : 0;
+    }
+
+    dst[dy * dst_width + dx] = c0 * 255;
+}
+
+
 void warp_affine_bilinear_and_normalize_plane(uint8_t *src, int src_line_size, int src_width,
                                                     int src_height, float *dst, int dst_width,
                                                     int dst_height, float *matrix_2_3,
@@ -191,8 +239,21 @@ void warp_affine_bilinear_single_channel_plane(float *src, int src_line_size, in
     dim3 block(32, 32);
 
     checkKernel(warp_affine_bilinear_single_channel_kernel<<<grid, block, 0, stream>>>(
-    src, src_line_size, src_width, src_height, dst, dst_width, dst_height, const_value,
-    matrix_2_3));
+        src, src_line_size, src_width, src_height, dst, dst_width, dst_height, const_value,
+        matrix_2_3));
+}
+
+void warp_affine_bilinear_single_channel_mask_plane(float *src, int src_line_size, int src_width,
+    int src_height, uint8_t *dst, int dst_width,
+    int dst_height, float *matrix_2_3,
+    uint8_t const_value, cudaStream_t stream)
+{
+    dim3 grid((dst_width + 31) / 32, (dst_height + 31) / 32);
+    dim3 block(32, 32);
+
+    checkKernel(warp_affine_bilinear_single_channel_mask_kernel<<<grid, block, 0, stream>>>(
+        src, src_line_size, src_width, src_height, dst, dst_width, dst_height, const_value,
+        matrix_2_3));
 }
 
 } // namespace affine

+ 6 - 0
src/infer/trt/affine.hpp

@@ -113,6 +113,12 @@ void warp_affine_bilinear_single_channel_plane(float *src, int src_line_size, in
     int dst_height, float *matrix_2_3,
     float const_value, cudaStream_t stream);
 
+
+void warp_affine_bilinear_single_channel_mask_plane(float *src, int src_line_size, int src_width,
+    int src_height, uint8_t *dst, int dst_width,
+    int dst_height, float *matrix_2_3,
+    uint8_t const_value, cudaStream_t stream);
+
 }
 
 #endif // AFFINE_HPP__

+ 264 - 42
src/infer/trt/yolo/yolo.cu

@@ -11,7 +11,12 @@
 namespace yolo
 {
 
-static const int NUM_BOX_ELEMENT = 8;  // left, top, right, bottom, confidence, class, keepflag, row_index(output)
+static const int NUM_BOX_ELEMENT = 9;  
+// left, top, right, bottom, confidence, class, keepflag, row_index(output), batch_index
+// 9个元素,分别是:左上角坐标,右下角坐标,置信度,类别,是否保留,行索引(mask weights),batch_index
+// 其中行索引用于找到mask weights,batch_index用于找到当前batch的图片位置
+
+
 static const int MAX_IMAGE_BOXES = 1024 * 4;
 
 static const int KEY_POINT_NUM   = 17; // 关键点数量
@@ -31,10 +36,11 @@ static __host__ __device__ void affine_project(float *matrix, float x, float y,
     *oy = matrix[3] * x + matrix[4] * y + matrix[5];
 }
 
-static __global__ void decode_kernel_v5(float *predict, int num_bboxes, int num_classes,
-                                              int output_cdim, float confidence_threshold,
-                                              float *invert_affine_matrix, float *parray, int *box_count,
-                                              int max_image_boxes, int start_x, int start_y) 
+static __global__ void decode_kernel_v5(
+    float *predict, int num_bboxes, int num_classes,
+    int output_cdim, float confidence_threshold,
+    float *invert_affine_matrix, float *parray, int *box_count,
+    int max_image_boxes, int start_x, int start_y, int batch_index) 
 {
     int position = blockDim.x * blockIdx.x + threadIdx.x;
     if (position >= num_bboxes) return;
@@ -81,12 +87,15 @@ static __global__ void decode_kernel_v5(float *predict, int num_bboxes, int num_
     *pout_item++ = label;
     *pout_item++ = 1;  // 1 = keep, 0 = ignore
     *pout_item++ = position;
+    *pout_item++ = batch_index; // batch_index
+    // 这里的batch_index是为了在后续的mask weights中使用,方便找到当前batch的图片位置
 }
 
-static __global__ void decode_kernel_v8(float *predict, int num_bboxes, int num_classes,
-                                              int output_cdim, float confidence_threshold,
-                                              float *invert_affine_matrix, float *parray, int *box_count,
-                                              int max_image_boxes, int start_x, int start_y) 
+static __global__ void decode_kernel_v8(
+    float *predict, int num_bboxes, int num_classes,
+    int output_cdim, float confidence_threshold,
+    float *invert_affine_matrix, float *parray, int *box_count,
+    int max_image_boxes, int start_x, int start_y, int batch_index) 
 {
     int position = blockDim.x * blockIdx.x + threadIdx.x;
     if (position >= num_bboxes) return;
@@ -128,12 +137,15 @@ static __global__ void decode_kernel_v8(float *predict, int num_bboxes, int num_
     *pout_item++ = label;
     *pout_item++ = 1;  // 1 = keep, 0 = ignore
     *pout_item++ = position;
+    *pout_item++ = batch_index; // batch_index
+    // 这里的batch_index是为了在后续的mask weights中使用,方便找到当前batch的图片位置
 }
 
-static __global__ void decode_kernel_11pose(float *predict, int num_bboxes, int num_classes,
+static __global__ void decode_kernel_11pose(
+    float *predict, int num_bboxes, int num_classes,
     int output_cdim, float confidence_threshold,
     float *invert_affine_matrix, float *parray,
-    int *box_count, int max_image_boxes, int start_x, int start_y) 
+    int *box_count, int max_image_boxes, int start_x, int start_y, int batch_index) 
 {
     int position = blockDim.x * blockIdx.x + threadIdx.x;
     if (position >= num_bboxes) return;
@@ -177,6 +189,7 @@ static __global__ void decode_kernel_11pose(float *predict, int num_bboxes, int
     *pout_item++ = label;
     *pout_item++ = 1;  // 1 = keep, 0 = ignore
     *pout_item++ = position;
+    *pout_item++ = batch_index; // batch_index
     for (int i = 0; i < KEY_POINT_NUM; i++)
     {
         float x = *key_points++;
@@ -266,44 +279,91 @@ static __global__ void fast_nms_pose_kernel(float *bboxes, int* box_count, int m
     }
 }
 
+static __global__ void decode_single_mask_kernel(int left, int top, float *mask_weights,
+    float *mask_predict, int mask_width,
+    int mask_height, float *mask_out,
+    int mask_dim, int out_width, int out_height) 
+{
+    // mask_predict to mask_out
+    // mask_weights @ mask_predict
+    int dx = blockDim.x * blockIdx.x + threadIdx.x;
+    int dy = blockDim.y * blockIdx.y + threadIdx.y;
+    if (dx >= out_width || dy >= out_height) return;
+
+    int sx = left + dx;
+    int sy = top + dy;
+    if (sx < 0 || sx >= mask_width || sy < 0 || sy >= mask_height) 
+    {
+        mask_out[dy * out_width + dx] = 0;
+        return;
+    }
+
+    float cumprod = 0;
+    for (int ic = 0; ic < mask_dim; ++ic) 
+    {
+        float cval = mask_predict[(ic * mask_height + sy) * mask_width + sx];
+        float wval = mask_weights[ic];
+        cumprod += cval * wval;
+    }
+
+    float alpha = 1.0f / (1.0f + exp(-cumprod));
+    // 在这里先返回float值,再将mask采样回原图后才x255
+    mask_out[dy * out_width + dx] = alpha;
+}
+
+static void decode_single_mask(float left, float top, float *mask_weights, float *mask_predict,
+                            int mask_width, int mask_height, float *mask_out,
+                            int mask_dim, int out_width, int out_height, cudaStream_t stream) 
+{
+    // mask_weights is mask_dim(32 element) gpu pointer
+    dim3 grid((out_width + 31) / 32, (out_height + 31) / 32);
+    dim3 block(32, 32);
+
+    checkKernel(decode_single_mask_kernel<<<grid, block, 0, stream>>>(
+    left, top, mask_weights, mask_predict, mask_width, mask_height, mask_out, mask_dim, out_width,
+    out_height));
+}
+
 static void decode_kernel_invoker_v8(float *predict, int num_bboxes, int num_classes, int output_cdim,
                                   float confidence_threshold, float nms_threshold,
                                   float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
-                                  int start_x, int start_y, cudaStream_t stream) 
+                                  int start_x, int start_y, int batch_index, cudaStream_t stream) 
 {
     auto grid = grid_dims(num_bboxes);
     auto block = block_dims(num_bboxes);
 
     checkKernel(decode_kernel_v8<<<grid, block, 0, stream>>>(
             predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
-            parray, box_count, max_image_boxes, start_x, start_y));
+            parray, box_count, max_image_boxes, start_x, start_y, batch_index));
 }
 
 
-static void decode_kernel_invoker_v5(float *predict, int num_bboxes, int num_classes, int output_cdim,
-                                  float confidence_threshold, float nms_threshold,
-                                  float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
-                                  int start_x, int start_y, cudaStream_t stream) 
+static void decode_kernel_invoker_v5(
+    float *predict, int num_bboxes, int num_classes, int output_cdim,
+    float confidence_threshold, float nms_threshold,
+    float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
+    int start_x, int start_y, int batch_index, cudaStream_t stream) 
 {
     auto grid = grid_dims(num_bboxes);
     auto block = block_dims(num_bboxes);
 
     checkKernel(decode_kernel_v5<<<grid, block, 0, stream>>>(
             predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
-            parray, box_count, max_image_boxes, start_x, start_y));
+            parray, box_count, max_image_boxes, start_x, start_y, batch_index));
 }
 
-static void decode_kernel_invoker_v11pose(float *predict, int num_bboxes, int num_classes, int output_cdim,
+static void decode_kernel_invoker_v11pose(
+    float *predict, int num_bboxes, int num_classes, int output_cdim,
     float confidence_threshold, float nms_threshold,
     float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
-    int start_x, int start_y, cudaStream_t stream) 
+    int start_x, int start_y, int batch_index, cudaStream_t stream) 
 {
     auto grid = grid_dims(num_bboxes);
     auto block = block_dims(num_bboxes);
 
     checkKernel(decode_kernel_11pose<<<grid, block, 0, stream>>>(
             predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
-            parray, box_count, max_image_boxes, start_x, start_y));
+            parray, box_count, max_image_boxes, start_x, start_y, batch_index));
 }
 
 static void fast_nms_kernel_invoker(float *parray, int* box_count, int max_image_boxes, float nms_threshold, cudaStream_t stream)
@@ -329,18 +389,47 @@ void YoloModelImpl::adjust_memory(int batch_size)
     output_boxarray_.gpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
     output_boxarray_.cpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
 
+    if (has_segment_)
+    {
+        segment_predict_.gpu(batch_size * segment_head_dims_[1] * segment_head_dims_[2] *
+                                segment_head_dims_[3]);
+    }
+
     affine_matrix_.gpu(6);
     affine_matrix_.cpu(6);
 
+    invert_affine_matrix_.gpu(6);
+    invert_affine_matrix_.cpu(6);
+
+    mask_affine_matrix_.gpu(6);
+    mask_affine_matrix_.cpu(6);
+
     box_count_.gpu(1);
     box_count_.cpu(1);
 }
 
-void YoloModelImpl::preprocess(int ibatch, affine::LetterBoxMatrix &affine, void *stream)
+void YoloModelImpl::cal_affine_matrix(affine::LetterBoxMatrix &affine, void *stream)
 {
     affine.compute(std::make_tuple(slice_->slice_width_, slice_->slice_height_),
-                std::make_tuple(network_input_width_, network_input_height_));
+        std::make_tuple(network_input_width_, network_input_height_));
+    
+    float *affine_matrix_device = affine_matrix_.gpu();
+    float *affine_matrix_host   = affine_matrix_.cpu();
 
+    float *invert_affine_matrix_device = invert_affine_matrix_.gpu();
+    float *invert_affine_matrix_host = invert_affine_matrix_.cpu();
+    
+    cudaStream_t stream_ = (cudaStream_t)stream;
+    memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
+    checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
+                                cudaMemcpyHostToDevice, stream_));
+
+    memcpy(invert_affine_matrix_host, affine.i2d, sizeof(affine.i2d));
+    checkRuntime(cudaMemcpyAsync(invert_affine_matrix_device, invert_affine_matrix_host, sizeof(affine.i2d), cudaMemcpyHostToDevice, stream_));
+}
+
+void YoloModelImpl::preprocess(int ibatch, void *stream)
+{
     size_t input_numel = network_input_width_ * network_input_height_ * 3;
     float *input_device = input_buffer_.gpu() + ibatch * input_numel;
     size_t size_image = slice_->slice_width_ * slice_->slice_height_ * 3;
@@ -348,13 +437,8 @@ void YoloModelImpl::preprocess(int ibatch, affine::LetterBoxMatrix &affine, void
     float *affine_matrix_device = affine_matrix_.gpu();
     uint8_t *image_device = slice_->output_images_.gpu() + ibatch * size_image;
 
-    float *affine_matrix_host = affine_matrix_.cpu();
-
     // speed up
     cudaStream_t stream_ = (cudaStream_t)stream;
-    memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
-    checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
-                                cudaMemcpyHostToDevice, stream_));
 
     affine::warp_affine_bilinear_and_normalize_plane(image_device, slice_->slice_width_ * 3, slice_->slice_width_,
                                             slice_->slice_height_, input_device, network_input_width_,
@@ -363,6 +447,97 @@ void YoloModelImpl::preprocess(int ibatch, affine::LetterBoxMatrix &affine, void
 }
 
 
+std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imemory, float* pbox)
+{
+    int row_index = pbox[7];
+    int batch_index = pbox[8];
+    std::shared_ptr<data::InstanceSegmentMap> seg = nullptr;
+    float *bbox_output_device = bbox_predict_.gpu();
+
+    int start_x = slice_->slice_start_point_.cpu()[batch_index*2];
+    int start_y = slice_->slice_start_point_.cpu()[batch_index*2+1];
+
+    float *mask_weights = bbox_output_device +
+        (batch_index * bbox_head_dims_[1] + row_index) * bbox_head_dims_[2] +
+        num_classes_ + 4;
+    
+    float *mask_head_predict = segment_predict_.gpu();
+
+    // 变回640 x 640下的坐标
+    float left, top, right, bottom;
+    float *i2d = invert_affine_matrix_.cpu();
+    affine_project(i2d, pbox[0] - start_x, pbox[1] - start_y, &left, &top);
+    affine_project(i2d, pbox[2] - start_x, pbox[3] - start_y, &right, &bottom);
+
+    // 原始框大小
+    int oirginal_box_width  = pbox[2] - pbox[0];
+    int oirginal_box_height = pbox[3] - pbox[1];
+
+    float box_width = right - left;
+    float box_height = bottom - top;
+
+    // 变成160 x 160下的坐标
+    float scale_to_predict_x = segment_head_dims_[3] / (float)network_input_width_;
+    float scale_to_predict_y = segment_head_dims_[2] / (float)network_input_height_;
+
+    left = left * scale_to_predict_x + 0.5f;
+    top = top * scale_to_predict_y + 0.5f;
+    int mask_out_width = box_width * scale_to_predict_x + 0.5f;
+    int mask_out_height = box_height * scale_to_predict_y + 0.5f;
+
+    if (mask_out_width > 0 && mask_out_height > 0)
+    {
+        if (imemory >= (int)box_segment_cache_.size()) 
+        {
+            box_segment_cache_.push_back(std::make_shared<tensor::Memory<float>>());
+        }
+        int bytes_of_mask_out = mask_out_width * mask_out_height;
+        auto box_segment_output_memory = box_segment_cache_[imemory];
+
+        seg = std::make_shared<data::InstanceSegmentMap>(oirginal_box_width, oirginal_box_height);
+
+        float *mask_out_device = box_segment_output_memory->gpu(bytes_of_mask_out);
+        unsigned char *original_mask_out_host = seg->data;
+
+        decode_single_mask(left, top, mask_weights,
+            mask_head_predict + batch_index * segment_head_dims_[1] *
+                                    segment_head_dims_[2] *
+                                    segment_head_dims_[3],
+            segment_head_dims_[3], segment_head_dims_[2], mask_out_device,
+            mask_dim, mask_out_width, mask_out_height, stream_);
+        
+        tensor::Memory<unsigned char> original_mask_out;
+        original_mask_out.gpu(oirginal_box_width * oirginal_box_height);
+        unsigned char *original_mask_out_device = original_mask_out.gpu();
+
+        // 将160 x 160下的mask变换回原图下的mask 的变换矩阵
+        affine::LetterBoxMatrix mask_affine_matrix;
+        mask_affine_matrix.compute(std::make_tuple(mask_out_width, mask_out_height),
+                                std::make_tuple(oirginal_box_width, oirginal_box_height));
+        
+        float *mask_affine_matrix_device = mask_affine_matrix_.gpu();
+        float *mask_affine_matrix_host = mask_affine_matrix_.cpu();
+
+        memcpy(mask_affine_matrix_host, mask_affine_matrix.d2i, sizeof(mask_affine_matrix.d2i));
+        checkRuntime(cudaMemcpyAsync(mask_affine_matrix_device, mask_affine_matrix_host,
+                                    sizeof(mask_affine_matrix.d2i), cudaMemcpyHostToDevice,
+                                    stream_));
+        
+        // 单通道的变换矩阵
+        // 在这里做过插值后将mask的值由0-1 变为 0-255,并且将 < 0.5的丢弃,不然范围会很大。
+        // 先变为0-255再做插值会有锯齿
+        affine::warp_affine_bilinear_single_channel_mask_plane(
+            mask_out_device, mask_out_width, mask_out_width, mask_out_height,
+            original_mask_out_device, oirginal_box_width, oirginal_box_height,
+            mask_affine_matrix_device, 0, stream_);
+        checkRuntime(cudaMemcpyAsync(original_mask_out_host, original_mask_out_device,
+                                    original_mask_out.gpu_bytes(),
+                                    cudaMemcpyDeviceToHost, stream_));
+        return seg;
+    }
+    
+}
+
 bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold)
 {
     trt_ = TensorRT::load(engine_file);
@@ -377,12 +552,20 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, c
 
     auto input_dim = trt_->static_dims(0);
     bbox_head_dims_ = trt_->static_dims(1);
+
+    has_segment_ = this->model_type_ == ModelType::YOLO11SEG;
+
+    if (has_segment_) 
+    {
+        segment_head_dims_ = trt_->static_dims(2);
+    }
+
     network_input_width_ = input_dim[3];
     network_input_height_ = input_dim[2];
     isdynamic_model_ = trt_->has_dynamic_dim();
 
     normalize_ = affine::Norm::alpha_beta(1 / 255.0f, 0.0f, affine::ChannelType::SwapRB);
-    if (this->model_type_ == ModelType::YOLOV8 || this->model_type_ == ModelType::YOLO11)
+    if (this->model_type_ == ModelType::YOLO11)
     {
         num_classes_ = bbox_head_dims_[2] - 4;
     }
@@ -395,6 +578,10 @@ bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, c
         num_classes_ = bbox_head_dims_[2] - 4 - KEY_POINT_NUM * 3;
         // NUM_BOX_ELEMENT = 8 + KEY_POINT_NUM * 3;
     }
+    else if (this->yolo_type_ == ModelType::YOLO11SEG)
+	{
+	    num_classes_ = bbox_head_dims_[2] - 4 - segment_head_dims_[1];
+	}
     return true;
 }
 
@@ -447,22 +634,48 @@ Result YoloModelImpl::forwards(void *stream)
 
     affine::LetterBoxMatrix affine_matrix;
     cudaStream_t stream_ = (cudaStream_t)stream;
+    // 切割后的小图每张都一样,所以只用计算一次
+    cal_affine_matrix(affine_matrix);
     for (int i = 0; i < num_image; ++i)
-        preprocess(i, affine_matrix, stream);
+        preprocess(i, stream);
 
     float *bbox_output_device = bbox_predict_.gpu();
     #ifdef TRT10
-    if (!trt_->forward(std::unordered_map<std::string, const void *>{
+    std::unordered_map<std::string, const void *> bindings;
+    if (has_segment_)
+    {
+        float *segment_output_device = segment_predict_.gpu();
+        bindings = {
             { "images", input_buffer_.gpu() }, 
-            { "output0", bbox_predict_.gpu() }
-        }, stream_))
+            { "output0", bbox_output_device },
+            { "output1", segment_output_device }
+        };
+    }
+    else
+    {
+        bindings = {
+            { "images", input_buffer_.gpu() }, 
+            { "output0", bbox_output_device }
+        };
+        
+    } 
+    if (!trt_->forward(bindings, stream_))
     {
         printf("Failed to tensorRT forward.");
         return {};
     }
     #else
     std::vector<void *> bindings{input_buffer_.gpu(), bbox_output_device};
-    if (!trt_->forward(bindings, stream)) 
+    if (has_segment_)
+    {
+        float *segment_output_device = segment_predicr_.gpu();
+        bindings = { input_buffer_.gpu(), bbox_output_device, segment_predicr_.gpu()};
+    }
+    else
+    {
+        bindings = { input_buffer_.gpu(), bbox_output_device };
+    } 
+    if (!trt_->forward(bindings, stream_))
     {
         printf("Failed to tensorRT forward.");
         return {};
@@ -481,17 +694,21 @@ Result YoloModelImpl::forwards(void *stream)
         float *affine_matrix_device = affine_matrix_.gpu();
         float *image_based_bbox_output =
             bbox_output_device + ib * (bbox_head_dims_[1] * bbox_head_dims_[2]);
-        if (model_type_ == ModelType::YOLOV5)
+        if (model_type_ == ModelType::YOLOV5|| model_type_ == ModelType::YOLOV5SEG)
         {
-            decode_kernel_invoker_v5(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
-                                bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
-                                affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
+            decode_kernel_invoker_v5(
+                image_based_bbox_output, bbox_head_dims_[1], num_classes_,
+                bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
+                affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, 
+                start_x, start_y, ib, stream_);
         }
-        else if (model_type_ == ModelType::YOLOV8 || model_type_ == ModelType::YOLO11)
+        else if (model_type_ == ModelType::YOLO11 || model_type_ == ModelType::YOLO11SEG)
         {
-            decode_kernel_invoker_v8(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
-                                bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
-                                affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
+            decode_kernel_invoker_v8(
+                image_based_bbox_output, bbox_head_dims_[1], num_classes_,
+                bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
+                affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, 
+                start_x, start_y, ib, stream_);
         }
         else if (model_type_ == ModelType::YOLO11POSE)
         {
@@ -522,6 +739,7 @@ Result YoloModelImpl::forwards(void *stream)
     float *parray = output_boxarray_.cpu();
     int count = min(MAX_IMAGE_BOXES, *(box_count_.cpu()));
 
+    int imemory = 0;
     for (int i = 0; i < count; ++i) 
     {
         int box_element = (model_type_ == ModelType::YOLO11POSE) ? (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3) : NUM_BOX_ELEMENT;
@@ -540,6 +758,10 @@ Result YoloModelImpl::forwards(void *stream)
                     result_object_box.keypoints.emplace_back(pbox[8+i*3], pbox[8+i*3+1], pbox[8+i*3+2]);
                 }
             }
+            else if (model_type_ == ModelType::YOLO11SEG)
+            {
+                result_object_box.seg = decode_segment(imemory, pbox);
+            }
             
             result.emplace_back(result_object_box);
         }

+ 14 - 2
src/infer/trt/yolo/yolo.hpp

@@ -9,6 +9,8 @@
 #include "infer/slice/slice.hpp"
 #include "infer/trt/affine.hpp"
 
+#include "common/data.hpp"
+
 #ifdef TRT10
 #include "common/tensorrt.hpp"
 namespace TensorRT = TensorRT10;
@@ -17,6 +19,7 @@ namespace TensorRT = TensorRT10;
 namespace TensorRT = TensorRT8;
 #endif
 
+
 namespace yolo
 {
 
@@ -35,12 +38,18 @@ namespace yolo
         tensor::Memory<int> box_count_;
     
         tensor::Memory<float> affine_matrix_;
-        tensor::Memory<float>  input_buffer_, bbox_predict_, output_boxarray_;
+        tensor::Memory<float> invert_affine_matrix_;
+        tensor::Memory<float> mask_affine_matrix_;
+        tensor::Memory<float> input_buffer_, bbox_predict_, segment_predict_, output_boxarray_;
+
+        std::vector<std::shared_ptr<tensor::Memory<float>>> box_segment_cache_;
     
         int network_input_width_, network_input_height_;
         affine::Norm normalize_;
         std::vector<int> bbox_head_dims_;
+        std::vector<int> segment_head_dims_;
         bool isdynamic_model_ = false;
+        bool has_segment_ = false;
     
         float confidence_threshold_;
         float nms_threshold_;
@@ -55,8 +64,11 @@ namespace yolo
     
         void adjust_memory(int batch_size);
     
-        void preprocess(int ibatch, affine::LetterBoxMatrix &affine, void *stream = nullptr);
+        void preprocess(int ibatch, void *stream = nullptr);
+
+        void cal_affine_matrix(affine::LetterBoxMatrix &affine, void *stream = nullptr);
         
+        std::shared_ptr<data::InstanceSegmentMap> decode_segment(int imemory, float* pbox);
     
         bool load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold);