yolo.cu 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. #include <vector>
  2. #include <mutex>
  3. #include <memory>
  4. #include "infer/trt/yolo/yolo.hpp"
  5. #include "infer/slice/slice.hpp"
  6. #include "infer/trt/affine.hpp"
  7. #include "common/check.hpp"
  8. #define GPU_BLOCK_THREADS 512
  9. namespace yolo
  10. {
  11. static const int NUM_BOX_ELEMENT = 9;
  12. // left, top, right, bottom, confidence, class, keepflag, row_index(output), batch_index
  13. // 9个元素,分别是:左上角坐标,右下角坐标,置信度,类别,是否保留,行索引(mask weights),batch_index
  14. // 其中行索引用于找到mask weights,batch_index用于找到当前batch的图片位置
  15. static const int MAX_IMAGE_BOXES = 1024 * 4;
  16. static const int KEY_POINT_NUM = 17; // 关键点数量
  17. static dim3 grid_dims(int numJobs){
  18. int numBlockThreads = numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  19. return dim3(((numJobs + numBlockThreads - 1) / (float)numBlockThreads));
  20. }
  21. static dim3 block_dims(int numJobs){
  22. return numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  23. }
  24. static __host__ __device__ void affine_project(float *matrix, float x, float y, float *ox, float *oy)
  25. {
  26. *ox = matrix[0] * x + matrix[1] * y + matrix[2];
  27. *oy = matrix[3] * x + matrix[4] * y + matrix[5];
  28. }
  29. static __global__ void decode_kernel_v5(
  30. float *predict, int num_bboxes, int num_classes,
  31. int output_cdim, float confidence_threshold,
  32. float *invert_affine_matrix, float *parray, int *box_count,
  33. int max_image_boxes, int start_x, int start_y, int batch_index)
  34. {
  35. int position = blockDim.x * blockIdx.x + threadIdx.x;
  36. if (position >= num_bboxes) return;
  37. float *pitem = predict + output_cdim * position;
  38. float objectness = pitem[4];
  39. if (objectness < confidence_threshold) return;
  40. float *class_confidence = pitem + 5;
  41. float confidence = *class_confidence++;
  42. int label = 0;
  43. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  44. {
  45. if (*class_confidence > confidence)
  46. {
  47. confidence = *class_confidence;
  48. label = i;
  49. }
  50. }
  51. confidence *= objectness;
  52. if (confidence < confidence_threshold) return;
  53. int index = atomicAdd(box_count, 1);
  54. if (index >= max_image_boxes) return;
  55. float cx = *pitem++;
  56. float cy = *pitem++;
  57. float width = *pitem++;
  58. float height = *pitem++;
  59. float left = cx - width * 0.5f;
  60. float top = cy - height * 0.5f;
  61. float right = cx + width * 0.5f;
  62. float bottom = cy + height * 0.5f;
  63. affine_project(invert_affine_matrix, left, top, &left, &top);
  64. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  65. float *pout_item = parray + index * NUM_BOX_ELEMENT;
  66. *pout_item++ = left + start_x;
  67. *pout_item++ = top + start_y;
  68. *pout_item++ = right + start_x;
  69. *pout_item++ = bottom + start_y;
  70. *pout_item++ = confidence;
  71. *pout_item++ = label;
  72. *pout_item++ = 1; // 1 = keep, 0 = ignore
  73. *pout_item++ = position;
  74. *pout_item++ = batch_index; // batch_index
  75. // 这里的batch_index是为了在后续的mask weights中使用,方便找到当前batch的图片位置
  76. }
  77. static __global__ void decode_kernel_v8(
  78. float *predict, int num_bboxes, int num_classes,
  79. int output_cdim, float confidence_threshold,
  80. float *invert_affine_matrix, float *parray, int *box_count,
  81. int max_image_boxes, int start_x, int start_y, int batch_index)
  82. {
  83. int position = blockDim.x * blockIdx.x + threadIdx.x;
  84. if (position >= num_bboxes) return;
  85. float *pitem = predict + output_cdim * position;
  86. float *class_confidence = pitem + 4;
  87. float confidence = *class_confidence++;
  88. int label = 0;
  89. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  90. {
  91. if (*class_confidence > confidence)
  92. {
  93. confidence = *class_confidence;
  94. label = i;
  95. }
  96. }
  97. if (confidence < confidence_threshold) return;
  98. int index = atomicAdd(box_count, 1);
  99. if (index >= max_image_boxes) return;
  100. float cx = *pitem++;
  101. float cy = *pitem++;
  102. float width = *pitem++;
  103. float height = *pitem++;
  104. float left = cx - width * 0.5f;
  105. float top = cy - height * 0.5f;
  106. float right = cx + width * 0.5f;
  107. float bottom = cy + height * 0.5f;
  108. affine_project(invert_affine_matrix, left, top, &left, &top);
  109. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  110. float *pout_item = parray + index * NUM_BOX_ELEMENT;
  111. *pout_item++ = left + start_x;
  112. *pout_item++ = top + start_y;
  113. *pout_item++ = right + start_x;
  114. *pout_item++ = bottom + start_y;
  115. *pout_item++ = confidence;
  116. *pout_item++ = label;
  117. *pout_item++ = 1; // 1 = keep, 0 = ignore
  118. *pout_item++ = position;
  119. *pout_item++ = batch_index; // batch_index
  120. // 这里的batch_index是为了在后续的mask weights中使用,方便找到当前batch的图片位置
  121. }
  122. static __global__ void decode_kernel_11pose(
  123. float *predict, int num_bboxes, int num_classes,
  124. int output_cdim, float confidence_threshold,
  125. float *invert_affine_matrix, float *parray,
  126. int *box_count, int max_image_boxes, int start_x, int start_y, int batch_index)
  127. {
  128. int position = blockDim.x * blockIdx.x + threadIdx.x;
  129. if (position >= num_bboxes) return;
  130. float *pitem = predict + output_cdim * position;
  131. float *class_confidence = pitem + 4;
  132. float *key_points = pitem + 4 + num_classes;
  133. float confidence = *class_confidence++;
  134. int label = 0;
  135. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  136. {
  137. if (*class_confidence > confidence)
  138. {
  139. confidence = *class_confidence;
  140. label = i;
  141. }
  142. }
  143. if (confidence < confidence_threshold) return;
  144. int index = atomicAdd(box_count, 1);
  145. if (index >= max_image_boxes) return;
  146. float cx = *pitem++;
  147. float cy = *pitem++;
  148. float width = *pitem++;
  149. float height = *pitem++;
  150. float left = cx - width * 0.5f;
  151. float top = cy - height * 0.5f;
  152. float right = cx + width * 0.5f;
  153. float bottom = cy + height * 0.5f;
  154. affine_project(invert_affine_matrix, left, top, &left, &top);
  155. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  156. float *pout_item = parray + index * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  157. *pout_item++ = left + start_x;
  158. *pout_item++ = top + start_y;
  159. *pout_item++ = right + start_x;
  160. *pout_item++ = bottom + start_y;
  161. *pout_item++ = confidence;
  162. *pout_item++ = label;
  163. *pout_item++ = 1; // 1 = keep, 0 = ignore
  164. *pout_item++ = position;
  165. *pout_item++ = batch_index; // batch_index
  166. for (int i = 0; i < KEY_POINT_NUM; i++)
  167. {
  168. float x = *key_points++;
  169. float y = *key_points++;
  170. affine_project(invert_affine_matrix, x, y, &x, &y);
  171. float score = *key_points++;
  172. *pout_item++ = x + start_x;
  173. *pout_item++ = y + start_y;
  174. *pout_item++ = score;
  175. }
  176. }
  177. static __device__ float box_iou(float aleft, float atop, float aright, float abottom, float bleft,
  178. float btop, float bright, float bbottom)
  179. {
  180. float cleft = max(aleft, bleft);
  181. float ctop = max(atop, btop);
  182. float cright = min(aright, bright);
  183. float cbottom = min(abottom, bbottom);
  184. float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f);
  185. if (c_area == 0.0f) return 0.0f;
  186. float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop);
  187. float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop);
  188. return c_area / (a_area + b_area - c_area);
  189. }
  190. static __global__ void fast_nms_kernel(float *bboxes, int* box_count, int max_image_boxes, float threshold)
  191. {
  192. int position = (blockDim.x * blockIdx.x + threadIdx.x);
  193. int count = min((int)*box_count, MAX_IMAGE_BOXES);
  194. if (position >= count) return;
  195. // left, top, right, bottom, confidence, class, keepflag
  196. float *pcurrent = bboxes + position * NUM_BOX_ELEMENT;
  197. for (int i = 0; i < count; ++i)
  198. {
  199. float *pitem = bboxes + i * NUM_BOX_ELEMENT;
  200. if (i == position || pcurrent[5] != pitem[5]) continue;
  201. if (pitem[4] >= pcurrent[4])
  202. {
  203. if (pitem[4] == pcurrent[4] && i < position) continue;
  204. float iou = box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pitem[0], pitem[1],
  205. pitem[2], pitem[3]);
  206. if (iou > threshold)
  207. {
  208. pcurrent[6] = 0; // 1=keep, 0=ignore
  209. return;
  210. }
  211. }
  212. }
  213. }
  214. static __global__ void fast_nms_pose_kernel(float *bboxes, int* box_count, int max_image_boxes, float threshold)
  215. {
  216. int position = (blockDim.x * blockIdx.x + threadIdx.x);
  217. int count = min((int)*box_count, MAX_IMAGE_BOXES);
  218. if (position >= count) return;
  219. // left, top, right, bottom, confidence, class, keepflag
  220. float *pcurrent = bboxes + position * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  221. for (int i = 0; i < count; ++i)
  222. {
  223. float *pitem = bboxes + i * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  224. if (i == position || pcurrent[5] != pitem[5]) continue;
  225. if (pitem[4] >= pcurrent[4])
  226. {
  227. if (pitem[4] == pcurrent[4] && i < position) continue;
  228. float iou = box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pitem[0], pitem[1],
  229. pitem[2], pitem[3]);
  230. if (iou > threshold)
  231. {
  232. pcurrent[6] = 0; // 1=keep, 0=ignore
  233. return;
  234. }
  235. }
  236. }
  237. }
  238. static __global__ void decode_single_mask_kernel(int left, int top, float *mask_weights,
  239. float *mask_predict, int mask_width,
  240. int mask_height, float *mask_out,
  241. int mask_dim, int out_width, int out_height)
  242. {
  243. // mask_predict to mask_out
  244. // mask_weights @ mask_predict
  245. int dx = blockDim.x * blockIdx.x + threadIdx.x;
  246. int dy = blockDim.y * blockIdx.y + threadIdx.y;
  247. if (dx >= out_width || dy >= out_height) return;
  248. int sx = left + dx;
  249. int sy = top + dy;
  250. if (sx < 0 || sx >= mask_width || sy < 0 || sy >= mask_height)
  251. {
  252. mask_out[dy * out_width + dx] = 0;
  253. return;
  254. }
  255. float cumprod = 0;
  256. for (int ic = 0; ic < mask_dim; ++ic)
  257. {
  258. float cval = mask_predict[(ic * mask_height + sy) * mask_width + sx];
  259. float wval = mask_weights[ic];
  260. cumprod += cval * wval;
  261. }
  262. float alpha = 1.0f / (1.0f + exp(-cumprod));
  263. // 在这里先返回float值,再将mask采样回原图后才x255
  264. mask_out[dy * out_width + dx] = alpha;
  265. }
  266. static void decode_single_mask(float left, float top, float *mask_weights, float *mask_predict,
  267. int mask_width, int mask_height, float *mask_out,
  268. int mask_dim, int out_width, int out_height, cudaStream_t stream)
  269. {
  270. // mask_weights is mask_dim(32 element) gpu pointer
  271. dim3 grid((out_width + 31) / 32, (out_height + 31) / 32);
  272. dim3 block(32, 32);
  273. checkKernel(decode_single_mask_kernel<<<grid, block, 0, stream>>>(
  274. left, top, mask_weights, mask_predict, mask_width, mask_height, mask_out, mask_dim, out_width,
  275. out_height));
  276. }
  277. static void decode_kernel_invoker_v8(float *predict, int num_bboxes, int num_classes, int output_cdim,
  278. float confidence_threshold, float nms_threshold,
  279. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  280. int start_x, int start_y, int batch_index, cudaStream_t stream)
  281. {
  282. auto grid = grid_dims(num_bboxes);
  283. auto block = block_dims(num_bboxes);
  284. checkKernel(decode_kernel_v8<<<grid, block, 0, stream>>>(
  285. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  286. parray, box_count, max_image_boxes, start_x, start_y, batch_index));
  287. }
  288. static void decode_kernel_invoker_v5(
  289. float *predict, int num_bboxes, int num_classes, int output_cdim,
  290. float confidence_threshold, float nms_threshold,
  291. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  292. int start_x, int start_y, int batch_index, cudaStream_t stream)
  293. {
  294. auto grid = grid_dims(num_bboxes);
  295. auto block = block_dims(num_bboxes);
  296. checkKernel(decode_kernel_v5<<<grid, block, 0, stream>>>(
  297. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  298. parray, box_count, max_image_boxes, start_x, start_y, batch_index));
  299. }
  300. static void decode_kernel_invoker_v11pose(
  301. float *predict, int num_bboxes, int num_classes, int output_cdim,
  302. float confidence_threshold, float nms_threshold,
  303. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  304. int start_x, int start_y, int batch_index, cudaStream_t stream)
  305. {
  306. auto grid = grid_dims(num_bboxes);
  307. auto block = block_dims(num_bboxes);
  308. checkKernel(decode_kernel_11pose<<<grid, block, 0, stream>>>(
  309. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  310. parray, box_count, max_image_boxes, start_x, start_y, batch_index));
  311. }
  312. static void fast_nms_kernel_invoker(float *parray, int* box_count, int max_image_boxes, float nms_threshold, cudaStream_t stream)
  313. {
  314. auto grid = grid_dims(max_image_boxes);
  315. auto block = block_dims(max_image_boxes);
  316. checkKernel(fast_nms_kernel<<<grid, block, 0, stream>>>(parray, box_count, max_image_boxes, nms_threshold));
  317. }
  318. static void fast_nms_pose_kernel_invoker(float *parray, int* box_count, int max_image_boxes, float nms_threshold, cudaStream_t stream)
  319. {
  320. auto grid = grid_dims(max_image_boxes);
  321. auto block = block_dims(max_image_boxes);
  322. checkKernel(fast_nms_pose_kernel<<<grid, block, 0, stream>>>(parray, box_count, max_image_boxes, nms_threshold));
  323. }
  324. void YoloModelImpl::adjust_memory(int batch_size)
  325. {
  326. // the inference batch_size
  327. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  328. input_buffer_.gpu(batch_size * input_numel);
  329. bbox_predict_.gpu(batch_size * bbox_head_dims_[1] * bbox_head_dims_[2]);
  330. output_boxarray_.gpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  331. output_boxarray_.cpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  332. if (has_segment_)
  333. {
  334. segment_predict_.gpu(batch_size * segment_head_dims_[1] * segment_head_dims_[2] *
  335. segment_head_dims_[3]);
  336. }
  337. affine_matrix_.gpu(6);
  338. affine_matrix_.cpu(6);
  339. invert_affine_matrix_.gpu(6);
  340. invert_affine_matrix_.cpu(6);
  341. mask_affine_matrix_.gpu(6);
  342. mask_affine_matrix_.cpu(6);
  343. box_count_.gpu(1);
  344. box_count_.cpu(1);
  345. }
  346. void YoloModelImpl::cal_affine_matrix(affine::LetterBoxMatrix &affine, void *stream)
  347. {
  348. affine.compute(std::make_tuple(slice_->slice_width_, slice_->slice_height_),
  349. std::make_tuple(network_input_width_, network_input_height_));
  350. float *affine_matrix_device = affine_matrix_.gpu();
  351. float *affine_matrix_host = affine_matrix_.cpu();
  352. float *invert_affine_matrix_device = invert_affine_matrix_.gpu();
  353. float *invert_affine_matrix_host = invert_affine_matrix_.cpu();
  354. cudaStream_t stream_ = (cudaStream_t)stream;
  355. memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
  356. checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
  357. cudaMemcpyHostToDevice, stream_));
  358. memcpy(invert_affine_matrix_host, affine.i2d, sizeof(affine.i2d));
  359. checkRuntime(cudaMemcpyAsync(invert_affine_matrix_device, invert_affine_matrix_host, sizeof(affine.i2d), cudaMemcpyHostToDevice, stream_));
  360. }
  361. void YoloModelImpl::preprocess(int ibatch, void *stream)
  362. {
  363. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  364. float *input_device = input_buffer_.gpu() + ibatch * input_numel;
  365. size_t size_image = slice_->slice_width_ * slice_->slice_height_ * 3;
  366. float *affine_matrix_device = affine_matrix_.gpu();
  367. uint8_t *image_device = slice_->output_images_.gpu() + ibatch * size_image;
  368. // speed up
  369. cudaStream_t stream_ = (cudaStream_t)stream;
  370. affine::warp_affine_bilinear_and_normalize_plane(image_device, slice_->slice_width_ * 3, slice_->slice_width_,
  371. slice_->slice_height_, input_device, network_input_width_,
  372. network_input_height_, affine_matrix_device, 114,
  373. normalize_, stream_);
  374. }
  375. std::shared_ptr<data::InstanceSegmentMap> YoloModelImpl::decode_segment(int imemory, float* pbox, void *stream)
  376. {
  377. int row_index = pbox[7];
  378. int batch_index = pbox[8];
  379. std::shared_ptr<data::InstanceSegmentMap> seg = nullptr;
  380. float *bbox_output_device = bbox_predict_.gpu();
  381. int start_x = slice_->slice_start_point_.cpu()[batch_index*2];
  382. int start_y = slice_->slice_start_point_.cpu()[batch_index*2+1];
  383. int mask_dim = segment_head_dims_[1];
  384. float *mask_weights = bbox_output_device +
  385. (batch_index * bbox_head_dims_[1] + row_index) * bbox_head_dims_[2] +
  386. num_classes_ + 4;
  387. float *mask_head_predict = segment_predict_.gpu();
  388. // 变回640 x 640下的坐标
  389. float left, top, right, bottom;
  390. float *i2d = invert_affine_matrix_.cpu();
  391. affine_project(i2d, pbox[0] - start_x, pbox[1] - start_y, &left, &top);
  392. affine_project(i2d, pbox[2] - start_x, pbox[3] - start_y, &right, &bottom);
  393. // 原始框大小
  394. int oirginal_box_width = pbox[2] - pbox[0];
  395. int oirginal_box_height = pbox[3] - pbox[1];
  396. float box_width = right - left;
  397. float box_height = bottom - top;
  398. // 变成160 x 160下的坐标
  399. float scale_to_predict_x = segment_head_dims_[3] / (float)network_input_width_;
  400. float scale_to_predict_y = segment_head_dims_[2] / (float)network_input_height_;
  401. left = left * scale_to_predict_x + 0.5f;
  402. top = top * scale_to_predict_y + 0.5f;
  403. int mask_out_width = box_width * scale_to_predict_x + 0.5f;
  404. int mask_out_height = box_height * scale_to_predict_y + 0.5f;
  405. cudaStream_t stream_ = (cudaStream_t)stream;
  406. if (mask_out_width > 0 && mask_out_height > 0)
  407. {
  408. if (imemory >= (int)box_segment_cache_.size())
  409. {
  410. box_segment_cache_.push_back(std::make_shared<tensor::Memory<float>>());
  411. }
  412. int bytes_of_mask_out = mask_out_width * mask_out_height;
  413. auto box_segment_output_memory = box_segment_cache_[imemory];
  414. seg = std::make_shared<data::InstanceSegmentMap>(oirginal_box_width, oirginal_box_height);
  415. float *mask_out_device = box_segment_output_memory->gpu(bytes_of_mask_out);
  416. unsigned char *original_mask_out_host = seg->data;
  417. decode_single_mask(left, top, mask_weights,
  418. mask_head_predict + batch_index * segment_head_dims_[1] *
  419. segment_head_dims_[2] *
  420. segment_head_dims_[3],
  421. segment_head_dims_[3], segment_head_dims_[2], mask_out_device,
  422. mask_dim, mask_out_width, mask_out_height, stream_);
  423. tensor::Memory<unsigned char> original_mask_out;
  424. original_mask_out.gpu(oirginal_box_width * oirginal_box_height);
  425. unsigned char *original_mask_out_device = original_mask_out.gpu();
  426. // 将160 x 160下的mask变换回原图下的mask 的变换矩阵
  427. affine::LetterBoxMatrix mask_affine_matrix;
  428. mask_affine_matrix.compute(std::make_tuple(mask_out_width, mask_out_height),
  429. std::make_tuple(oirginal_box_width, oirginal_box_height));
  430. float *mask_affine_matrix_device = mask_affine_matrix_.gpu();
  431. float *mask_affine_matrix_host = mask_affine_matrix_.cpu();
  432. memcpy(mask_affine_matrix_host, mask_affine_matrix.d2i, sizeof(mask_affine_matrix.d2i));
  433. checkRuntime(cudaMemcpyAsync(mask_affine_matrix_device, mask_affine_matrix_host,
  434. sizeof(mask_affine_matrix.d2i), cudaMemcpyHostToDevice,
  435. stream_));
  436. // 单通道的变换矩阵
  437. // 在这里做过插值后将mask的值由0-1 变为 0-255,并且将 < 0.5的丢弃,不然范围会很大。
  438. // 先变为0-255再做插值会有锯齿
  439. affine::warp_affine_bilinear_single_channel_mask_plane(
  440. mask_out_device, mask_out_width, mask_out_width, mask_out_height,
  441. original_mask_out_device, oirginal_box_width, oirginal_box_height,
  442. mask_affine_matrix_device, 0, stream_);
  443. checkRuntime(cudaMemcpyAsync(original_mask_out_host, original_mask_out_device,
  444. original_mask_out.gpu_bytes(),
  445. cudaMemcpyDeviceToHost, stream_));
  446. }
  447. return seg;
  448. }
  449. bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold, int gpu_id)
  450. {
  451. trt_ = TensorRT::load(engine_file);
  452. device_id_ = gpu_id;
  453. if (trt_ == nullptr) return false;
  454. trt_->print();
  455. this->confidence_threshold_ = confidence_threshold;
  456. this->nms_threshold_ = nms_threshold;
  457. this->model_type_ = model_type;
  458. this->class_names_ = names;
  459. auto input_dim = trt_->static_dims(0);
  460. bbox_head_dims_ = trt_->static_dims(1);
  461. has_segment_ = this->model_type_ == ModelType::YOLO11SEG;
  462. if (has_segment_)
  463. {
  464. segment_head_dims_ = trt_->static_dims(2);
  465. }
  466. network_input_width_ = input_dim[3];
  467. network_input_height_ = input_dim[2];
  468. isdynamic_model_ = trt_->has_dynamic_dim();
  469. normalize_ = affine::Norm::alpha_beta(1 / 255.0f, 0.0f, affine::ChannelType::SwapRB);
  470. if (this->model_type_ == ModelType::YOLO11)
  471. {
  472. num_classes_ = bbox_head_dims_[2] - 4;
  473. }
  474. else if (this->model_type_ == ModelType::YOLOV5)
  475. {
  476. num_classes_ = bbox_head_dims_[2] - 5;
  477. }
  478. else if (this->model_type_ == ModelType::YOLO11POSE)
  479. {
  480. num_classes_ = bbox_head_dims_[2] - 4 - KEY_POINT_NUM * 3;
  481. // NUM_BOX_ELEMENT = 8 + KEY_POINT_NUM * 3;
  482. }
  483. else if (this->model_type_ == ModelType::YOLO11SEG)
  484. {
  485. num_classes_ = bbox_head_dims_[2] - 4 - segment_head_dims_[1];
  486. }
  487. return true;
  488. }
  489. Result YoloModelImpl::forward(const tensor::Image &image, data::BoxArray& boxes, void *stream)
  490. {
  491. return {};
  492. }
  493. Result YoloModelImpl::forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream)
  494. {
  495. std::lock_guard<std::mutex> lock(mutex_); // 自动加锁/解锁
  496. slice_->slice(image, slice_width, slice_height, overlap_width_ratio, overlap_height_ratio, stream);
  497. return forwards(stream);
  498. }
  499. Result YoloModelImpl::forward(const tensor::Image &image, void *stream)
  500. {
  501. std::lock_guard<std::mutex> lock(mutex_); // 自动加锁/解锁
  502. slice_->autoSlice(image, stream);
  503. return forwards(stream);
  504. }
  505. Result YoloModelImpl::forwards(void *stream)
  506. {
  507. int num_image = slice_->slice_num_h_ * slice_->slice_num_v_;
  508. if (num_image == 0) return {};
  509. auto input_dims = trt_->static_dims(0);
  510. int infer_batch_size = input_dims[0];
  511. if (infer_batch_size != num_image)
  512. {
  513. if (isdynamic_model_)
  514. {
  515. infer_batch_size = num_image;
  516. input_dims[0] = num_image;
  517. if (!trt_->set_run_dims(0, input_dims))
  518. {
  519. printf("Fail to set run dims\n");
  520. return {};
  521. }
  522. }
  523. else
  524. {
  525. if (infer_batch_size < num_image)
  526. {
  527. printf(
  528. "When using static shape model, number of images[%d] must be "
  529. "less than or equal to the maximum batch[%d].",
  530. num_image, infer_batch_size);
  531. return {};
  532. }
  533. }
  534. }
  535. adjust_memory(infer_batch_size);
  536. affine::LetterBoxMatrix affine_matrix;
  537. cudaStream_t stream_ = (cudaStream_t)stream;
  538. // 切割后的小图每张都一样,所以只用计算一次
  539. cal_affine_matrix(affine_matrix);
  540. for (int i = 0; i < num_image; ++i)
  541. preprocess(i, stream);
  542. float *bbox_output_device = bbox_predict_.gpu();
  543. #ifdef TRT10
  544. std::unordered_map<std::string, const void *> bindings;
  545. if (has_segment_)
  546. {
  547. float *segment_output_device = segment_predict_.gpu();
  548. bindings = {
  549. { "images", input_buffer_.gpu() },
  550. { "output0", bbox_output_device },
  551. { "output1", segment_output_device }
  552. };
  553. }
  554. else
  555. {
  556. bindings = {
  557. { "images", input_buffer_.gpu() },
  558. { "output0", bbox_output_device }
  559. };
  560. }
  561. if (!trt_->forward(bindings, stream_))
  562. {
  563. printf("Failed to tensorRT forward.");
  564. return {};
  565. }
  566. #else
  567. std::vector<void *> bindings{input_buffer_.gpu(), bbox_output_device};
  568. if (has_segment_)
  569. {
  570. float *segment_output_device = segment_predicr_.gpu();
  571. bindings = { input_buffer_.gpu(), bbox_output_device, segment_predicr_.gpu()};
  572. }
  573. else
  574. {
  575. bindings = { input_buffer_.gpu(), bbox_output_device };
  576. }
  577. if (!trt_->forward(bindings, stream_))
  578. {
  579. printf("Failed to tensorRT forward.");
  580. return {};
  581. }
  582. #endif
  583. int* box_count = box_count_.gpu();
  584. checkRuntime(cudaMemsetAsync(box_count, 0, sizeof(int), stream_));
  585. for (int ib = 0; ib < num_image; ++ib)
  586. {
  587. int start_x = slice_->slice_start_point_.cpu()[ib*2];
  588. int start_y = slice_->slice_start_point_.cpu()[ib*2+1];
  589. // float *boxarray_device =
  590. // output_boxarray_.gpu() + ib * (MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  591. float *boxarray_device = output_boxarray_.gpu();
  592. float *affine_matrix_device = affine_matrix_.gpu();
  593. float *image_based_bbox_output =
  594. bbox_output_device + ib * (bbox_head_dims_[1] * bbox_head_dims_[2]);
  595. if (model_type_ == ModelType::YOLOV5|| model_type_ == ModelType::YOLOV5SEG)
  596. {
  597. decode_kernel_invoker_v5(
  598. image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  599. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  600. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES,
  601. start_x, start_y, ib, stream_);
  602. }
  603. else if (model_type_ == ModelType::YOLO11 || model_type_ == ModelType::YOLO11SEG)
  604. {
  605. decode_kernel_invoker_v8(
  606. image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  607. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  608. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES,
  609. start_x, start_y, ib, stream_);
  610. }
  611. else if (model_type_ == ModelType::YOLO11POSE)
  612. {
  613. decode_kernel_invoker_v11pose(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  614. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  615. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, ib, stream_);
  616. }
  617. }
  618. float *boxarray_device = output_boxarray_.gpu();
  619. if (model_type_ == ModelType::YOLO11POSE)
  620. {
  621. fast_nms_pose_kernel_invoker(boxarray_device, box_count, MAX_IMAGE_BOXES, nms_threshold_, stream_);
  622. }
  623. else
  624. {
  625. fast_nms_kernel_invoker(boxarray_device, box_count, MAX_IMAGE_BOXES, nms_threshold_, stream_);
  626. }
  627. checkRuntime(cudaMemcpyAsync(output_boxarray_.cpu(), output_boxarray_.gpu(),
  628. output_boxarray_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  629. checkRuntime(cudaMemcpyAsync(box_count_.cpu(), box_count_.gpu(),
  630. box_count_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  631. checkRuntime(cudaStreamSynchronize(stream_));
  632. data::BoxArray result;
  633. // int imemory = 0;
  634. float *parray = output_boxarray_.cpu();
  635. int count = min(MAX_IMAGE_BOXES, *(box_count_.cpu()));
  636. int imemory = 0;
  637. for (int i = 0; i < count; ++i)
  638. {
  639. int box_element = (model_type_ == ModelType::YOLO11POSE) ? (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3) : NUM_BOX_ELEMENT;
  640. float *pbox = parray + i * box_element;
  641. int label = pbox[5];
  642. int keepflag = pbox[6];
  643. if (keepflag == 1)
  644. {
  645. data::Box result_object_box(pbox[0], pbox[1], pbox[2], pbox[3], pbox[4], label);
  646. result_object_box.label = class_names_[label];
  647. result_object_box.class_id = label - class_names_.size();
  648. if (model_type_ == ModelType::YOLO11POSE)
  649. {
  650. result_object_box.keypoints.reserve(KEY_POINT_NUM);
  651. for (int i = 0; i < KEY_POINT_NUM; i++)
  652. {
  653. result_object_box.keypoints.emplace_back(pbox[9+i*3], pbox[9+i*3+1], pbox[9+i*3+2]);
  654. }
  655. }
  656. else if (model_type_ == ModelType::YOLO11SEG)
  657. {
  658. auto seg = decode_segment(imemory, pbox, stream);
  659. result_object_box.seg_mask = cv::Mat(seg->height, seg->width, CV_8UC1, seg->data);
  660. }
  661. result.emplace_back(result_object_box);
  662. }
  663. }
  664. return result;
  665. }
  666. Infer *loadraw(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold,
  667. float nms_threshold, int gpu_id)
  668. {
  669. YoloModelImpl *impl = new YoloModelImpl();
  670. if (!impl->load(engine_file, model_type, names, confidence_threshold, nms_threshold, gpu_id))
  671. {
  672. delete impl;
  673. impl = nullptr;
  674. }
  675. impl->slice_ = std::make_shared<slice::SliceImage>();
  676. return impl;
  677. }
  678. std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, int gpu_id, float confidence_threshold, float nms_threshold)
  679. {
  680. try
  681. {
  682. checkRuntime(cudaSetDevice(gpu_id));
  683. return std::shared_ptr<YoloModelImpl>((YoloModelImpl *)loadraw(engine_file, model_type, names, confidence_threshold, nms_threshold, gpu_id));
  684. }
  685. catch (const std::exception& ex)
  686. {
  687. return nullptr;
  688. }
  689. }
  690. }