slice.cu 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. #include "infer/slice/slice.hpp"
  2. #include "common/check.hpp"
  3. #include <cmath>
  4. static __global__ void slice_kernel(
  5. const uchar3* __restrict__ image,
  6. uchar3* __restrict__ outs,
  7. const int width,
  8. const int height,
  9. const int slice_width,
  10. const int slice_height,
  11. const int slice_num_h,
  12. const int slice_num_v,
  13. const int* __restrict__ slice_start_point)
  14. {
  15. const int slice_idx = blockIdx.z;
  16. const int start_x = slice_start_point[slice_idx * 2];
  17. const int start_y = slice_start_point[slice_idx * 2 + 1];
  18. // 当前像素在切片内的相对位置
  19. const int x = blockIdx.x * blockDim.x + threadIdx.x;
  20. const int y = blockIdx.y * blockDim.y + threadIdx.y;
  21. if(x >= slice_width || y >= slice_height)
  22. return;
  23. const int dx = start_x + x;
  24. const int dy = start_y + y;
  25. if(dx >= width || dy >= height)
  26. return;
  27. // 读取像素
  28. const int src_index = dy * width + dx;
  29. const uchar3 pixel = image[src_index];
  30. // 写入切片
  31. const int dst_index = slice_idx * slice_width * slice_height + y * slice_width + x;
  32. outs[dst_index] = pixel;
  33. }
  34. static void slice_plane(const uint8_t* image,
  35. uint8_t* outs,
  36. int* slice_start_point,
  37. const int width,
  38. const int height,
  39. const int slice_width,
  40. const int slice_height,
  41. const int slice_num_h,
  42. const int slice_num_v,
  43. void* stream=nullptr)
  44. {
  45. int slice_total = slice_num_h * slice_num_v;
  46. cudaStream_t stream_ = (cudaStream_t)stream;
  47. dim3 block(32, 32);
  48. dim3 grid(
  49. (slice_width + block.x - 1) / block.x,
  50. (slice_height + block.y - 1) / block.y,
  51. slice_total
  52. );
  53. slice_kernel<<<grid, block, 0, stream_>>>(
  54. reinterpret_cast<const uchar3*>(image),
  55. reinterpret_cast<uchar3*>(outs),
  56. width, height,
  57. slice_width, slice_height,
  58. slice_num_h, slice_num_v,
  59. slice_start_point
  60. );
  61. }
  62. namespace slice
  63. {
  64. int calculateNumCuts(int dimension, int subDimension, float overlapRatio) {
  65. float step = subDimension * (1 - overlapRatio);
  66. if(step == 0)
  67. {
  68. return 1;
  69. }
  70. float cuts = static_cast<float>(dimension - subDimension) / step;
  71. // 浮点数会有很小的误差,直接向上取整会出现多裁剪了一张图的情况
  72. if (fabs(cuts - round(cuts)) < 0.0001) {
  73. cuts = round(cuts);
  74. }
  75. int numCuts = static_cast<int>(std::ceil(cuts));
  76. return numCuts + 1;
  77. }
  78. static int calc_resolution_factor(int resolution)
  79. {
  80. int expo = 0;
  81. while(pow(2, expo) < resolution) expo++;
  82. return expo - 1;
  83. }
  84. static std::string calc_aspect_ratio_orientation(int width, int height)
  85. {
  86. if (width < height)
  87. return "vertical";
  88. else if(width > height)
  89. return "horizontal";
  90. else
  91. return "square";
  92. }
  93. static std::tuple<int, int, float, float> calc_ratio_and_slice(const std::string& orientation, int slide=1, float ratio=0.1)
  94. {
  95. int slice_row, slice_col;
  96. float overlap_height_ratio, overlap_width_ratio;
  97. if (orientation == "vertical")
  98. {
  99. slice_row = slide;
  100. slice_col = slide * 2;
  101. overlap_height_ratio = ratio;
  102. overlap_width_ratio = ratio;
  103. }
  104. else if (orientation == "horizontal")
  105. {
  106. slice_row = slide * 2;
  107. slice_col = slide;
  108. overlap_height_ratio = ratio;
  109. overlap_width_ratio = ratio;
  110. }
  111. else if (orientation == "square")
  112. {
  113. slice_row = slide;
  114. slice_col = slide;
  115. overlap_height_ratio = ratio;
  116. overlap_width_ratio = ratio;
  117. }
  118. return std::make_tuple(slice_row, slice_col, overlap_height_ratio, overlap_width_ratio);
  119. }
  120. static std::tuple<int, int, float, float> calc_slice_and_overlap_params(
  121. const std::string& resolution, int width, int height, std::string orientation)
  122. {
  123. int split_row, split_col;
  124. float overlap_height_ratio, overlap_width_ratio;
  125. if (resolution == "medium")
  126. std::tie(split_row, split_col, overlap_height_ratio, overlap_width_ratio) = calc_ratio_and_slice(
  127. orientation, 1, 0.8
  128. );
  129. else if (resolution == "high")
  130. std::tie(split_row, split_col, overlap_height_ratio, overlap_width_ratio) = calc_ratio_and_slice(
  131. orientation, 2, 0.4
  132. );
  133. else if (resolution == "ultra-high")
  134. std::tie(split_row, split_col, overlap_height_ratio, overlap_width_ratio) = calc_ratio_and_slice(
  135. orientation, 4, 0.4
  136. );
  137. else
  138. {
  139. split_col = 1;
  140. split_row = 1;
  141. overlap_width_ratio = 1;
  142. overlap_height_ratio = 1;
  143. }
  144. int slice_height = height / split_col;
  145. int slice_width = width / split_row;
  146. return std::make_tuple(slice_width, slice_height, overlap_height_ratio, overlap_width_ratio);
  147. }
  148. static std::tuple<int, int, float, float> get_resolution_selector(const std::string& resolution, int width, int height)
  149. {
  150. std::string orientation = calc_aspect_ratio_orientation(width, height);
  151. return calc_slice_and_overlap_params(resolution, width, height, orientation);
  152. }
  153. static std::tuple<int, int, float, float> get_auto_slice_params(int width, int height)
  154. {
  155. int resolution = height * width;
  156. int factor = calc_resolution_factor(resolution);
  157. if (factor <= 18)
  158. return get_resolution_selector("low", width, height);
  159. else if (18 <= factor && factor < 21)
  160. return get_resolution_selector("medium", width, height);
  161. else if (21 <= factor && factor < 24)
  162. return get_resolution_selector("high", width, height);
  163. else
  164. return get_resolution_selector("ultra-high", width, height);
  165. }
  166. void SliceImage::autoSlice(
  167. const tensor::Image& image,
  168. void* stream)
  169. {
  170. int slice_width;
  171. int slice_height;
  172. float overlap_width_ratio;
  173. float overlap_height_ratio;
  174. std::tie(slice_width, slice_height, overlap_width_ratio, overlap_height_ratio) = get_auto_slice_params(image.width, image.height);
  175. slice(image, slice_width, slice_height, overlap_width_ratio, overlap_height_ratio, stream);
  176. }
  177. void SliceImage::slice(
  178. const tensor::Image& image,
  179. const int slice_width,
  180. const int slice_height,
  181. const float overlap_width_ratio,
  182. const float overlap_height_ratio,
  183. void* stream)
  184. {
  185. slice_width_ = slice_width;
  186. slice_height_ = slice_height;
  187. cudaStream_t stream_ = (cudaStream_t)stream;
  188. int width = image.width;
  189. int height = image.height;
  190. slice_num_h_ = calculateNumCuts(width, slice_width, overlap_width_ratio);
  191. slice_num_v_ = calculateNumCuts(height, slice_height, overlap_height_ratio);
  192. /*
  193. printf("------------------------------------------------------\n"
  194. "CUDA SAHI CROP IMAGE ✂️\n"
  195. "------------------------------------------------------\n"
  196. "%-30s: %-10d\n"
  197. "%-30s: %-10d\n"
  198. "%-30s: %-10.2f\n"
  199. "%-30s: %-10.2f\n"
  200. "%-30s: %-10d\n"
  201. "%-30s: %-10d\n"
  202. "------------------------------------------------------\n",
  203. "Slice width", slice_width_,
  204. "Slice height", slice_height_,
  205. "Overlap width ratio", overlap_width_ratio,
  206. "Overlap height ratio", overlap_height_ratio,
  207. "Number of horizontal cuts", slice_num_h_,
  208. "Number of vertical cuts", slice_num_v_);
  209. */
  210. int slice_num = slice_num_h_ * slice_num_v_;
  211. int overlap_width_pixel = slice_width * overlap_width_ratio;
  212. int overlap_height_pixel = slice_height * overlap_height_ratio;
  213. size_t size_image = 3 * width * height;
  214. size_t output_img_size = 3 * slice_width * slice_height;
  215. input_image_.gpu(size_image);
  216. output_images_.gpu(slice_num * output_img_size);
  217. checkRuntime(cudaMemsetAsync(output_images_.gpu(), 114, output_images_.gpu_bytes(), stream_));
  218. checkRuntime(cudaMemcpyAsync(input_image_.gpu(), image.bgrptr, size_image, cudaMemcpyHostToDevice, stream_));
  219. // checkRuntime(cudaStreamSynchronize(stream_));
  220. uint8_t* input_device = input_image_.gpu();
  221. uint8_t* output_device = output_images_.gpu();
  222. slice_start_point_.cpu(slice_num * 2);
  223. slice_start_point_.gpu(slice_num * 2);
  224. int* slice_start_point_ptr = slice_start_point_.cpu();
  225. for (int i = 0; i < slice_num_h_; i++)
  226. {
  227. int x = std::min(width - slice_width, std::max(0, i * (slice_width - overlap_width_pixel)));
  228. for (int j = 0; j < slice_num_v_; j++)
  229. {
  230. int y = std::min(height - slice_height, std::max(0, j * (slice_height - overlap_height_pixel)));
  231. int index = (i * slice_num_v_ + j) * 2;
  232. slice_start_point_ptr[index] = x;
  233. slice_start_point_ptr[index + 1] = y;
  234. }
  235. }
  236. checkRuntime(cudaMemcpyAsync(slice_start_point_.gpu(), slice_start_point_.cpu(), slice_num*2*sizeof(int), cudaMemcpyHostToDevice, stream_));
  237. checkRuntime(cudaStreamSynchronize(stream_));
  238. slice_plane(
  239. input_device, output_device, slice_start_point_.gpu(),
  240. width, height,
  241. slice_width, slice_height,
  242. slice_num_h_, slice_num_v_,
  243. stream);
  244. // checkRuntime(cudaStreamSynchronize(stream_));
  245. // for (int i = 0; i < slice_num_h_; i++)
  246. // {
  247. // for (int j = 0; j < slice_num_v_; j++)
  248. // {
  249. // int index = i * slice_num_v_ + j;
  250. // slice_position_[index*2] = slice_start_point_ptr[index*2];
  251. // slice_position_[index*2+1] = slice_start_point_ptr[index*2+1];
  252. // // cv::Mat image = cv::Mat::zeros(slice_height, slice_width, CV_8UC3);
  253. // // uint8_t* output_img_data = image.ptr<uint8_t>();
  254. // // cudaMemcpyAsync(output_img_data, output_device+index*output_img_size, output_img_size*sizeof(uint8_t), cudaMemcpyDeviceToHost, stream_);
  255. // // checkRuntime(cudaStreamSynchronize(stream_));
  256. // // cv::imwrite(std::to_string(index) + ".png", image);
  257. // }
  258. // }
  259. }
  260. }