leon 1 月之前
父节点
当前提交
89212c77f9
共有 1 个文件被更改,包括 47 次插入0 次删除
  1. 47 0
      src/infer/trt/affine.cu

+ 47 - 0
src/infer/trt/affine.cu

@@ -121,6 +121,53 @@ static __global__ void warp_affine_bilinear_and_normalize_plane_kernel(
     *pdst_c2 = c2;
 }
 
+static __global__ void warp_affine_bilinear_single_channel_kernel(
+    float *src, int src_line_size, int src_width, int src_height, float *dst, int dst_width,
+    int dst_height, float const_value_st, float *warp_affine_matrix_2_3) 
+{
+    int dx = blockDim.x * blockIdx.x + threadIdx.x;
+    int dy = blockDim.y * blockIdx.y + threadIdx.y;
+    if (dx >= dst_width || dy >= dst_height) return;
+
+    float m_x1 = warp_affine_matrix_2_3[0];
+    float m_y1 = warp_affine_matrix_2_3[1];
+    float m_z1 = warp_affine_matrix_2_3[2];
+    float m_x2 = warp_affine_matrix_2_3[3];
+    float m_y2 = warp_affine_matrix_2_3[4];
+    float m_z2 = warp_affine_matrix_2_3[5];
+
+    float src_x = m_x1 * dx + m_y1 * dy + m_z1;
+    float src_y = m_x2 * dx + m_y2 * dy + m_z2;
+    float c0;
+
+    if (src_x < 0 || src_x >= src_width || src_y < 0 || src_y >= src_height) 
+    {
+        c0 = const_value_st;
+    } 
+    else 
+    {
+        int y_low = __float2int_rz(src_y);
+        int x_low = __float2int_rz(src_x);
+        int y_high = y_low + 1;
+        int x_high = x_low + 1;
+
+        float w1 = (1 - (src_y - y_low)) * (1 - (src_x - x_low));
+        float w2 = (1 - (src_y - y_low)) * (src_x - x_low);
+        float w3 = (src_y - y_low) * (1 - (src_x - x_low));
+        float w4 = (src_y - y_low) * (src_x - x_low);
+
+        float *v1 = (y_low  >= 0 && y_low  < src_height && x_low  >= 0 && x_low  < src_width) ? src + y_low  * src_line_size + x_low  : &const_value_st;
+        float *v2 = (y_low  >= 0 && y_low  < src_height && x_high >= 0 && x_high < src_width) ? src + y_low  * src_line_size + x_high : &const_value_st;
+        float *v3 = (y_high >= 0 && y_high < src_height && x_low  >= 0 && x_low  < src_width) ? src + y_high * src_line_size + x_low  : &const_value_st;
+        float *v4 = (y_high >= 0 && y_high < src_height && x_high >= 0 && x_high < src_width) ? src + y_high * src_line_size + x_high : &const_value_st;
+
+        c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0];
+    }
+
+    dst[dy * dst_width + dx] = c0;
+}
+
+
 void warp_affine_bilinear_and_normalize_plane(uint8_t *src, int src_line_size, int src_width,
                                                     int src_height, float *dst, int dst_width,
                                                     int dst_height, float *matrix_2_3,