yolo.cu 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. #include "infer/trt/yolo.hpp"
  2. #include <vector>
  3. #include <memory>
  4. #include "infer/slice/slice.hpp"
  5. #include "infer/trt/affine.hpp"
  6. #include "common/check.hpp"
  7. #define GPU_BLOCK_THREADS 512
  8. namespace yolo
  9. {
  10. static const int NUM_BOX_ELEMENT = 8; // left, top, right, bottom, confidence, class, keepflag, row_index(output)
  11. static const int MAX_IMAGE_BOXES = 1024 * 4;
  12. static const int KEY_POINT_NUM = 17; // 关键点数量
  13. static dim3 grid_dims(int numJobs){
  14. int numBlockThreads = numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  15. return dim3(((numJobs + numBlockThreads - 1) / (float)numBlockThreads));
  16. }
  17. static dim3 block_dims(int numJobs){
  18. return numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  19. }
  20. static __host__ __device__ void affine_project(float *matrix, float x, float y, float *ox, float *oy)
  21. {
  22. *ox = matrix[0] * x + matrix[1] * y + matrix[2];
  23. *oy = matrix[3] * x + matrix[4] * y + matrix[5];
  24. }
  25. static __global__ void decode_kernel_v5(float *predict, int num_bboxes, int num_classes,
  26. int output_cdim, float confidence_threshold,
  27. float *invert_affine_matrix, float *parray, int *box_count,
  28. int max_image_boxes, int start_x, int start_y)
  29. {
  30. int position = blockDim.x * blockIdx.x + threadIdx.x;
  31. if (position >= num_bboxes) return;
  32. float *pitem = predict + output_cdim * position;
  33. float objectness = pitem[4];
  34. if (objectness < confidence_threshold) return;
  35. float *class_confidence = pitem + 5;
  36. float confidence = *class_confidence++;
  37. int label = 0;
  38. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  39. {
  40. if (*class_confidence > confidence)
  41. {
  42. confidence = *class_confidence;
  43. label = i;
  44. }
  45. }
  46. confidence *= objectness;
  47. if (confidence < confidence_threshold) return;
  48. int index = atomicAdd(box_count, 1);
  49. if (index >= max_image_boxes) return;
  50. float cx = *pitem++;
  51. float cy = *pitem++;
  52. float width = *pitem++;
  53. float height = *pitem++;
  54. float left = cx - width * 0.5f;
  55. float top = cy - height * 0.5f;
  56. float right = cx + width * 0.5f;
  57. float bottom = cy + height * 0.5f;
  58. affine_project(invert_affine_matrix, left, top, &left, &top);
  59. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  60. float *pout_item = parray + index * NUM_BOX_ELEMENT;
  61. *pout_item++ = left + start_x;
  62. *pout_item++ = top + start_y;
  63. *pout_item++ = right + start_x;
  64. *pout_item++ = bottom + start_y;
  65. *pout_item++ = confidence;
  66. *pout_item++ = label;
  67. *pout_item++ = 1; // 1 = keep, 0 = ignore
  68. *pout_item++ = position;
  69. }
  70. static __global__ void decode_kernel_v8(float *predict, int num_bboxes, int num_classes,
  71. int output_cdim, float confidence_threshold,
  72. float *invert_affine_matrix, float *parray, int *box_count,
  73. int max_image_boxes, int start_x, int start_y)
  74. {
  75. int position = blockDim.x * blockIdx.x + threadIdx.x;
  76. if (position >= num_bboxes) return;
  77. float *pitem = predict + output_cdim * position;
  78. float *class_confidence = pitem + 4;
  79. float confidence = *class_confidence++;
  80. int label = 0;
  81. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  82. {
  83. if (*class_confidence > confidence)
  84. {
  85. confidence = *class_confidence;
  86. label = i;
  87. }
  88. }
  89. if (confidence < confidence_threshold) return;
  90. int index = atomicAdd(box_count, 1);
  91. if (index >= max_image_boxes) return;
  92. float cx = *pitem++;
  93. float cy = *pitem++;
  94. float width = *pitem++;
  95. float height = *pitem++;
  96. float left = cx - width * 0.5f;
  97. float top = cy - height * 0.5f;
  98. float right = cx + width * 0.5f;
  99. float bottom = cy + height * 0.5f;
  100. affine_project(invert_affine_matrix, left, top, &left, &top);
  101. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  102. float *pout_item = parray + index * NUM_BOX_ELEMENT;
  103. *pout_item++ = left + start_x;
  104. *pout_item++ = top + start_y;
  105. *pout_item++ = right + start_x;
  106. *pout_item++ = bottom + start_y;
  107. *pout_item++ = confidence;
  108. *pout_item++ = label;
  109. *pout_item++ = 1; // 1 = keep, 0 = ignore
  110. *pout_item++ = position;
  111. }
  112. static __global__ void decode_kernel_11pose(float *predict, int num_bboxes, int num_classes,
  113. int output_cdim, float confidence_threshold,
  114. float *invert_affine_matrix, float *parray,
  115. int *box_count, int max_image_boxes, int start_x, int start_y)
  116. {
  117. int position = blockDim.x * blockIdx.x + threadIdx.x;
  118. if (position >= num_bboxes) return;
  119. float *pitem = predict + output_cdim * position;
  120. float *class_confidence = pitem + 4;
  121. float *key_points = pitem + 4 + num_classes;
  122. float confidence = *class_confidence++;
  123. int label = 0;
  124. for (int i = 1; i < num_classes; ++i, ++class_confidence)
  125. {
  126. if (*class_confidence > confidence)
  127. {
  128. confidence = *class_confidence;
  129. label = i;
  130. }
  131. }
  132. if (confidence < confidence_threshold) return;
  133. int index = atomicAdd(box_count, 1);
  134. if (index >= max_image_boxes) return;
  135. float cx = *pitem++;
  136. float cy = *pitem++;
  137. float width = *pitem++;
  138. float height = *pitem++;
  139. float left = cx - width * 0.5f;
  140. float top = cy - height * 0.5f;
  141. float right = cx + width * 0.5f;
  142. float bottom = cy + height * 0.5f;
  143. affine_project(invert_affine_matrix, left, top, &left, &top);
  144. affine_project(invert_affine_matrix, right, bottom, &right, &bottom);
  145. float *pout_item = parray + index * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  146. *pout_item++ = left + start_x;
  147. *pout_item++ = top + start_y;
  148. *pout_item++ = right + start_x;
  149. *pout_item++ = bottom + start_y;
  150. *pout_item++ = confidence;
  151. *pout_item++ = label;
  152. *pout_item++ = 1; // 1 = keep, 0 = ignore
  153. *pout_item++ = position;
  154. for (int i = 0; i < KEY_POINT_NUM; i++)
  155. {
  156. float x = *key_points++;
  157. float y = *key_points++;
  158. affine_project(invert_affine_matrix, x, y, &x, &y);
  159. float score = *key_points++;
  160. *pout_item++ = x + start_x;
  161. *pout_item++ = y + start_y;
  162. *pout_item++ = score;
  163. }
  164. }
  165. static __device__ float box_iou(float aleft, float atop, float aright, float abottom, float bleft,
  166. float btop, float bright, float bbottom)
  167. {
  168. float cleft = max(aleft, bleft);
  169. float ctop = max(atop, btop);
  170. float cright = min(aright, bright);
  171. float cbottom = min(abottom, bbottom);
  172. float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f);
  173. if (c_area == 0.0f) return 0.0f;
  174. float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop);
  175. float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop);
  176. return c_area / (a_area + b_area - c_area);
  177. }
  178. static __global__ void fast_nms_kernel(float *bboxes, int* box_count, int max_image_boxes, float threshold)
  179. {
  180. int position = (blockDim.x * blockIdx.x + threadIdx.x);
  181. int count = min((int)*box_count, MAX_IMAGE_BOXES);
  182. if (position >= count) return;
  183. // left, top, right, bottom, confidence, class, keepflag
  184. float *pcurrent = bboxes + position * NUM_BOX_ELEMENT;
  185. for (int i = 0; i < count; ++i)
  186. {
  187. float *pitem = bboxes + i * NUM_BOX_ELEMENT;
  188. if (i == position || pcurrent[5] != pitem[5]) continue;
  189. if (pitem[4] >= pcurrent[4])
  190. {
  191. if (pitem[4] == pcurrent[4] && i < position) continue;
  192. float iou = box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pitem[0], pitem[1],
  193. pitem[2], pitem[3]);
  194. if (iou > threshold)
  195. {
  196. pcurrent[6] = 0; // 1=keep, 0=ignore
  197. return;
  198. }
  199. }
  200. }
  201. }
  202. static __global__ void fast_nms_pose_kernel(float *bboxes, int* box_count, int max_image_boxes, float threshold)
  203. {
  204. int position = (blockDim.x * blockIdx.x + threadIdx.x);
  205. int count = min((int)*box_count, MAX_IMAGE_BOXES);
  206. if (position >= count) return;
  207. // left, top, right, bottom, confidence, class, keepflag
  208. float *pcurrent = bboxes + position * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  209. for (int i = 0; i < count; ++i)
  210. {
  211. float *pitem = bboxes + i * (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3);
  212. if (i == position || pcurrent[5] != pitem[5]) continue;
  213. if (pitem[4] >= pcurrent[4])
  214. {
  215. if (pitem[4] == pcurrent[4] && i < position) continue;
  216. float iou = box_iou(pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], pitem[0], pitem[1],
  217. pitem[2], pitem[3]);
  218. if (iou > threshold)
  219. {
  220. pcurrent[6] = 0; // 1=keep, 0=ignore
  221. return;
  222. }
  223. }
  224. }
  225. }
  226. static void decode_kernel_invoker_v8(float *predict, int num_bboxes, int num_classes, int output_cdim,
  227. float confidence_threshold, float nms_threshold,
  228. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  229. int start_x, int start_y, cudaStream_t stream)
  230. {
  231. auto grid = grid_dims(num_bboxes);
  232. auto block = block_dims(num_bboxes);
  233. checkKernel(decode_kernel_v8<<<grid, block, 0, stream>>>(
  234. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  235. parray, box_count, max_image_boxes, start_x, start_y));
  236. }
  237. static void decode_kernel_invoker_v5(float *predict, int num_bboxes, int num_classes, int output_cdim,
  238. float confidence_threshold, float nms_threshold,
  239. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  240. int start_x, int start_y, cudaStream_t stream)
  241. {
  242. auto grid = grid_dims(num_bboxes);
  243. auto block = block_dims(num_bboxes);
  244. checkKernel(decode_kernel_v5<<<grid, block, 0, stream>>>(
  245. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  246. parray, box_count, max_image_boxes, start_x, start_y));
  247. }
  248. static void decode_kernel_invoker_v11pose(float *predict, int num_bboxes, int num_classes, int output_cdim,
  249. float confidence_threshold, float nms_threshold,
  250. float *invert_affine_matrix, float *parray, int* box_count, int max_image_boxes,
  251. int start_x, int start_y, cudaStream_t stream)
  252. {
  253. auto grid = grid_dims(num_bboxes);
  254. auto block = block_dims(num_bboxes);
  255. checkKernel(decode_kernel_11pose<<<grid, block, 0, stream>>>(
  256. predict, num_bboxes, num_classes, output_cdim, confidence_threshold, invert_affine_matrix,
  257. parray, box_count, max_image_boxes, start_x, start_y));
  258. }
  259. static void fast_nms_kernel_invoker(float *parray, int* box_count, int max_image_boxes, float nms_threshold, cudaStream_t stream)
  260. {
  261. auto grid = grid_dims(max_image_boxes);
  262. auto block = block_dims(max_image_boxes);
  263. checkKernel(fast_nms_kernel<<<grid, block, 0, stream>>>(parray, box_count, max_image_boxes, nms_threshold));
  264. }
  265. static void fast_nms_pose_kernel_invoker(float *parray, int* box_count, int max_image_boxes, float nms_threshold, cudaStream_t stream)
  266. {
  267. auto grid = grid_dims(max_image_boxes);
  268. auto block = block_dims(max_image_boxes);
  269. checkKernel(fast_nms_pose_kernel<<<grid, block, 0, stream>>>(parray, box_count, max_image_boxes, nms_threshold));
  270. }
  271. void YoloModelImpl::adjust_memory(int batch_size)
  272. {
  273. // the inference batch_size
  274. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  275. input_buffer_.gpu(batch_size * input_numel);
  276. bbox_predict_.gpu(batch_size * bbox_head_dims_[1] * bbox_head_dims_[2]);
  277. output_boxarray_.gpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  278. output_boxarray_.cpu(MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  279. affine_matrix_.gpu(6);
  280. affine_matrix_.cpu(6);
  281. box_count_.gpu(1);
  282. box_count_.cpu(1);
  283. }
  284. void YoloModelImpl::preprocess(int ibatch, affine::LetterBoxMatrix &affine, void *stream)
  285. {
  286. affine.compute(std::make_tuple(slice_->slice_width_, slice_->slice_height_),
  287. std::make_tuple(network_input_width_, network_input_height_));
  288. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  289. float *input_device = input_buffer_.gpu() + ibatch * input_numel;
  290. size_t size_image = slice_->slice_width_ * slice_->slice_height_ * 3;
  291. float *affine_matrix_device = affine_matrix_.gpu();
  292. uint8_t *image_device = slice_->output_images_.gpu() + ibatch * size_image;
  293. float *affine_matrix_host = affine_matrix_.cpu();
  294. // speed up
  295. cudaStream_t stream_ = (cudaStream_t)stream;
  296. memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
  297. checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
  298. cudaMemcpyHostToDevice, stream_));
  299. affine::warp_affine_bilinear_and_normalize_plane(image_device, slice_->slice_width_ * 3, slice_->slice_width_,
  300. slice_->slice_height_, input_device, network_input_width_,
  301. network_input_height_, affine_matrix_device, 114,
  302. normalize_, stream_);
  303. }
  304. bool YoloModelImpl::load(const std::string &engine_file, ModelType model_type, float confidence_threshold, float nms_threshold)
  305. {
  306. trt_ = TensorRT::load(engine_file);
  307. if (trt_ == nullptr) return false;
  308. trt_->print();
  309. this->confidence_threshold_ = confidence_threshold;
  310. this->nms_threshold_ = nms_threshold;
  311. this->model_type_ = model_type;
  312. auto input_dim = trt_->static_dims(0);
  313. bbox_head_dims_ = trt_->static_dims(1);
  314. network_input_width_ = input_dim[3];
  315. network_input_height_ = input_dim[2];
  316. isdynamic_model_ = trt_->has_dynamic_dim();
  317. normalize_ = affine::Norm::alpha_beta(1 / 255.0f, 0.0f, affine::ChannelType::SwapRB);
  318. if (this->model_type_ == ModelType::YOLOV8 || this->model_type_ == ModelType::YOLO11)
  319. {
  320. num_classes_ = bbox_head_dims_[2] - 4;
  321. }
  322. else if (this->model_type_ == ModelType::YOLOV5)
  323. {
  324. num_classes_ = bbox_head_dims_[2] - 5;
  325. }
  326. else if (this->model_type_ == ModelType::YOLO11POSE)
  327. {
  328. num_classes_ = bbox_head_dims_[2] - 4 - KEY_POINT_NUM * 3;
  329. // NUM_BOX_ELEMENT = 8 + KEY_POINT_NUM * 3;
  330. }
  331. return true;
  332. }
  333. data::BoxArray YoloModelImpl::forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream)
  334. {
  335. slice_->slice(image, slice_width, slice_height, overlap_width_ratio, overlap_height_ratio, stream);
  336. return forwards(stream);
  337. }
  338. data::BoxArray YoloModelImpl::forward(const tensor::Image &image, void *stream)
  339. {
  340. slice_->autoSlice(image, stream);
  341. return forwards(stream);
  342. }
  343. data::BoxArray YoloModelImpl::forwards(void *stream)
  344. {
  345. int num_image = slice_->slice_num_h_ * slice_->slice_num_v_;
  346. if (num_image == 0) return {};
  347. auto input_dims = trt_->static_dims(0);
  348. int infer_batch_size = input_dims[0];
  349. if (infer_batch_size != num_image)
  350. {
  351. if (isdynamic_model_)
  352. {
  353. infer_batch_size = num_image;
  354. input_dims[0] = num_image;
  355. if (!trt_->set_run_dims(0, input_dims))
  356. {
  357. printf("Fail to set run dims\n");
  358. return {};
  359. }
  360. }
  361. else
  362. {
  363. if (infer_batch_size < num_image)
  364. {
  365. printf(
  366. "When using static shape model, number of images[%d] must be "
  367. "less than or equal to the maximum batch[%d].",
  368. num_image, infer_batch_size);
  369. return {};
  370. }
  371. }
  372. }
  373. adjust_memory(infer_batch_size);
  374. affine::LetterBoxMatrix affine_matrix;
  375. cudaStream_t stream_ = (cudaStream_t)stream;
  376. for (int i = 0; i < num_image; ++i)
  377. preprocess(i, affine_matrix, stream);
  378. float *bbox_output_device = bbox_predict_.gpu();
  379. #ifdef TRT10
  380. if (!trt_->forward(std::unordered_map<std::string, const void *>{
  381. { "images", input_buffer_.gpu() },
  382. { "output0", bbox_predict_.gpu() }
  383. }, stream_))
  384. {
  385. printf("Failed to tensorRT forward.");
  386. return {};
  387. }
  388. #else
  389. std::vector<void *> bindings{input_buffer_.gpu(), bbox_output_device};
  390. if (!trt_->forward(bindings, stream))
  391. {
  392. printf("Failed to tensorRT forward.");
  393. return {};
  394. }
  395. #endif
  396. int* box_count = box_count_.gpu();
  397. checkRuntime(cudaMemsetAsync(box_count, 0, sizeof(int), stream_));
  398. for (int ib = 0; ib < num_image; ++ib)
  399. {
  400. int start_x = slice_->slice_start_point_.cpu()[ib*2];
  401. int start_y = slice_->slice_start_point_.cpu()[ib*2+1];
  402. // float *boxarray_device =
  403. // output_boxarray_.gpu() + ib * (MAX_IMAGE_BOXES * NUM_BOX_ELEMENT);
  404. float *boxarray_device = output_boxarray_.gpu();
  405. float *affine_matrix_device = affine_matrix_.gpu();
  406. float *image_based_bbox_output =
  407. bbox_output_device + ib * (bbox_head_dims_[1] * bbox_head_dims_[2]);
  408. if (model_type_ == ModelType::YOLOV5)
  409. {
  410. decode_kernel_invoker_v5(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  411. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  412. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
  413. }
  414. else if (model_type_ == ModelType::YOLOV8 || model_type_ == ModelType::YOLO11)
  415. {
  416. decode_kernel_invoker_v8(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  417. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  418. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
  419. }
  420. else if (model_type_ == ModelType::YOLO11POSE)
  421. {
  422. decode_kernel_invoker_v11pose(image_based_bbox_output, bbox_head_dims_[1], num_classes_,
  423. bbox_head_dims_[2], confidence_threshold_, nms_threshold_,
  424. affine_matrix_device, boxarray_device, box_count, MAX_IMAGE_BOXES, start_x, start_y, stream_);
  425. }
  426. }
  427. float *boxarray_device = output_boxarray_.gpu();
  428. if (model_type_ == ModelType::YOLO11POSE)
  429. {
  430. fast_nms_pose_kernel_invoker(boxarray_device, box_count, MAX_IMAGE_BOXES, nms_threshold_, stream_);
  431. }
  432. else
  433. {
  434. fast_nms_kernel_invoker(boxarray_device, box_count, MAX_IMAGE_BOXES, nms_threshold_, stream_);
  435. }
  436. checkRuntime(cudaMemcpyAsync(output_boxarray_.cpu(), output_boxarray_.gpu(),
  437. output_boxarray_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  438. checkRuntime(cudaMemcpyAsync(box_count_.cpu(), box_count_.gpu(),
  439. box_count_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  440. checkRuntime(cudaStreamSynchronize(stream_));
  441. data::BoxArray result;
  442. // int imemory = 0;
  443. float *parray = output_boxarray_.cpu();
  444. int count = min(MAX_IMAGE_BOXES, *(box_count_.cpu()));
  445. for (int i = 0; i < count; ++i)
  446. {
  447. int box_element = (model_type_ == ModelType::YOLO11POSE) ? (NUM_BOX_ELEMENT + KEY_POINT_NUM * 3) : NUM_BOX_ELEMENT;
  448. float *pbox = parray + i * box_element;
  449. int label = pbox[5];
  450. int keepflag = pbox[6];
  451. if (keepflag == 1)
  452. {
  453. data::Box result_object_box(pbox[0], pbox[1], pbox[2], pbox[3], pbox[4], label);
  454. if (model_type_ == ModelType::YOLO11POSE)
  455. {
  456. result_object_box.keypoints.reserve(KEY_POINT_NUM);
  457. for (int i = 0; i < KEY_POINT_NUM; i++)
  458. {
  459. result_object_box.keypoints.emplace_back(pbox[8+i*3], pbox[8+i*3+1], pbox[8+i*3+2]);
  460. }
  461. }
  462. result.emplace_back(result_object_box);
  463. }
  464. }
  465. return result;
  466. }
  467. Infer *loadraw(const std::string &engine_file, ModelType model_type, float confidence_threshold,
  468. float nms_threshold)
  469. {
  470. YoloModelImpl *impl = new YoloModelImpl();
  471. if (!impl->load(engine_file, model_type, confidence_threshold, nms_threshold))
  472. {
  473. delete impl;
  474. impl = nullptr;
  475. }
  476. impl->slice_ = std::make_shared<slice::SliceImage>();
  477. return impl;
  478. }
  479. std::shared_ptr<Infer> load_yolo(const std::string &engine_file, ModelType model_type, int gpu_id, float confidence_threshold, float nms_threshold)
  480. {
  481. checkRuntime(cudaSetDevice(gpu_id));
  482. return std::shared_ptr<YoloModelImpl>((YoloModelImpl *)loadraw(engine_file, model_type, confidence_threshold, nms_threshold));
  483. }
  484. }