resnet.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. #include "infer.hpp"
  2. #include "resnet.hpp"
  3. #include <cfloat>
  4. namespace resnet
  5. {
  6. using namespace std;
  7. #define GPU_BLOCK_THREADS 512
  8. #define checkRuntime(call) \
  9. do { \
  10. auto ___call__ret_code__ = (call); \
  11. if (___call__ret_code__ != cudaSuccess) { \
  12. INFO("CUDA Runtime error💥 %s # %s, code = %s [ %d ]", #call, \
  13. cudaGetErrorString(___call__ret_code__), cudaGetErrorName(___call__ret_code__), \
  14. ___call__ret_code__); \
  15. abort(); \
  16. } \
  17. } while (0)
  18. #define checkKernel(...) \
  19. do { \
  20. { (__VA_ARGS__); } \
  21. checkRuntime(cudaPeekAtLastError()); \
  22. } while (0)
  23. enum class NormType : int { None = 0, MeanStd = 1, AlphaBeta = 2 };
  24. enum class ChannelType : int { None = 0, SwapRB = 1 };
  25. /* 归一化操作,可以支持均值标准差,alpha beta,和swap RB */
  26. struct Norm {
  27. float mean[3];
  28. float std[3];
  29. float alpha, beta;
  30. NormType type = NormType::None;
  31. ChannelType channel_type = ChannelType::None;
  32. // out = (x * alpha - mean) / std
  33. static Norm mean_std(const float mean[3], const float std[3], float alpha = 1 / 255.0f,
  34. ChannelType channel_type = ChannelType::None);
  35. // out = x * alpha + beta
  36. static Norm alpha_beta(float alpha, float beta = 0, ChannelType channel_type = ChannelType::None);
  37. // None
  38. static Norm None();
  39. };
  40. Norm Norm::mean_std(const float mean[3], const float std[3], float alpha,
  41. ChannelType channel_type) {
  42. Norm out;
  43. out.type = NormType::MeanStd;
  44. out.alpha = alpha;
  45. out.channel_type = channel_type;
  46. memcpy(out.mean, mean, sizeof(out.mean));
  47. memcpy(out.std, std, sizeof(out.std));
  48. return out;
  49. }
  50. Norm Norm::alpha_beta(float alpha, float beta, ChannelType channel_type) {
  51. Norm out;
  52. out.type = NormType::AlphaBeta;
  53. out.alpha = alpha;
  54. out.beta = beta;
  55. out.channel_type = channel_type;
  56. return out;
  57. }
  58. Norm Norm::None() { return Norm(); }
  59. static dim3 grid_dims(int numJobs) {
  60. int numBlockThreads = numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  61. return dim3(((numJobs + numBlockThreads - 1) / (float)numBlockThreads));
  62. }
  63. static dim3 block_dims(int numJobs) {
  64. return numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
  65. }
  66. inline int upbound(int n, int align = 32) { return (n + align - 1) / align * align; }
  67. static __global__ void warp_affine_bilinear_and_normalize_plane_kernel(
  68. uint8_t *src, int src_line_size, int src_width, int src_height, float *dst, int dst_width,
  69. int dst_height, uint8_t const_value_st, float *warp_affine_matrix_2_3, Norm norm) {
  70. int dx = blockDim.x * blockIdx.x + threadIdx.x;
  71. int dy = blockDim.y * blockIdx.y + threadIdx.y;
  72. if (dx >= dst_width || dy >= dst_height) return;
  73. float m_x1 = warp_affine_matrix_2_3[0];
  74. float m_y1 = warp_affine_matrix_2_3[1];
  75. float m_z1 = warp_affine_matrix_2_3[2];
  76. float m_x2 = warp_affine_matrix_2_3[3];
  77. float m_y2 = warp_affine_matrix_2_3[4];
  78. float m_z2 = warp_affine_matrix_2_3[5];
  79. float src_x = m_x1 * dx + m_y1 * dy + m_z1;
  80. float src_y = m_x2 * dx + m_y2 * dy + m_z2;
  81. float c0, c1, c2;
  82. if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) {
  83. // out of range
  84. c0 = const_value_st;
  85. c1 = const_value_st;
  86. c2 = const_value_st;
  87. } else {
  88. int y_low = floorf(src_y);
  89. int x_low = floorf(src_x);
  90. int y_high = y_low + 1;
  91. int x_high = x_low + 1;
  92. uint8_t const_value[] = {const_value_st, const_value_st, const_value_st};
  93. float ly = src_y - y_low;
  94. float lx = src_x - x_low;
  95. float hy = 1 - ly;
  96. float hx = 1 - lx;
  97. float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
  98. uint8_t *v1 = const_value;
  99. uint8_t *v2 = const_value;
  100. uint8_t *v3 = const_value;
  101. uint8_t *v4 = const_value;
  102. if (y_low >= 0) {
  103. if (x_low >= 0) v1 = src + y_low * src_line_size + x_low * 3;
  104. if (x_high < src_width) v2 = src + y_low * src_line_size + x_high * 3;
  105. }
  106. if (y_high < src_height) {
  107. if (x_low >= 0) v3 = src + y_high * src_line_size + x_low * 3;
  108. if (x_high < src_width) v4 = src + y_high * src_line_size + x_high * 3;
  109. }
  110. // same to opencv
  111. c0 = floorf(w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0] + 0.5f);
  112. c1 = floorf(w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1] + 0.5f);
  113. c2 = floorf(w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2] + 0.5f);
  114. }
  115. if (norm.channel_type == ChannelType::SwapRB) {
  116. float t = c2;
  117. c2 = c0;
  118. c0 = t;
  119. }
  120. if (norm.type == NormType::MeanStd) {
  121. c0 = (c0 * norm.alpha - norm.mean[0]) / norm.std[0];
  122. c1 = (c1 * norm.alpha - norm.mean[1]) / norm.std[1];
  123. c2 = (c2 * norm.alpha - norm.mean[2]) / norm.std[2];
  124. } else if (norm.type == NormType::AlphaBeta) {
  125. c0 = c0 * norm.alpha + norm.beta;
  126. c1 = c1 * norm.alpha + norm.beta;
  127. c2 = c2 * norm.alpha + norm.beta;
  128. }
  129. int area = dst_width * dst_height;
  130. float *pdst_c0 = dst + dy * dst_width + dx;
  131. float *pdst_c1 = pdst_c0 + area;
  132. float *pdst_c2 = pdst_c1 + area;
  133. *pdst_c0 = c0;
  134. *pdst_c1 = c1;
  135. *pdst_c2 = c2;
  136. }
  137. static void warp_affine_bilinear_and_normalize_plane(uint8_t *src, int src_line_size, int src_width,
  138. int src_height, float *dst, int dst_width,
  139. int dst_height, float *matrix_2_3,
  140. uint8_t const_value, const Norm &norm,
  141. cudaStream_t stream) {
  142. dim3 grid((dst_width + 31) / 32, (dst_height + 31) / 32);
  143. dim3 block(32, 32);
  144. checkKernel(warp_affine_bilinear_and_normalize_plane_kernel<<<grid, block, 0, stream>>>(
  145. src, src_line_size, src_width, src_height, dst, dst_width, dst_height, const_value,
  146. matrix_2_3, norm));
  147. }
  148. struct AffineMatrix {
  149. float i2d[6]; // image to dst(network), 2x3 matrix
  150. float d2i[6]; // dst to image, 2x3 matrix
  151. void compute(const std::tuple<int, int> &from, const std::tuple<int, int> &to) {
  152. float scale_x = get<0>(to) / (float)get<0>(from);
  153. float scale_y = get<1>(to) / (float)get<1>(from);
  154. float scale = std::min(scale_x, scale_y);
  155. // letter box
  156. // i2d[0] = scale;
  157. // i2d[1] = 0;
  158. // i2d[2] = -scale * get<0>(from) * 0.5 + get<0>(to) * 0.5 + scale * 0.5 - 0.5;
  159. // i2d[3] = 0;
  160. // i2d[4] = scale;
  161. // i2d[5] = -scale * get<1>(from) * 0.5 + get<1>(to) * 0.5 + scale * 0.5 - 0.5;
  162. // resize
  163. i2d[0] = scale;
  164. i2d[1] = 0;
  165. i2d[2] = 0;
  166. i2d[3] = 0;
  167. i2d[4] = scale;
  168. i2d[5] = 0;
  169. double D = i2d[0] * i2d[4] - i2d[1] * i2d[3];
  170. D = D != 0. ? double(1.) / D : double(0.);
  171. double A11 = i2d[4] * D, A22 = i2d[0] * D, A12 = -i2d[1] * D, A21 = -i2d[3] * D;
  172. double b1 = -A11 * i2d[2] - A12 * i2d[5];
  173. double b2 = -A21 * i2d[2] - A22 * i2d[5];
  174. d2i[0] = A11;
  175. d2i[1] = A12;
  176. d2i[2] = b1;
  177. d2i[3] = A21;
  178. d2i[4] = A22;
  179. d2i[5] = b2;
  180. }
  181. };
  182. static __global__ void softmax(float *predict, int length, int *max_index) {
  183. extern __shared__ float shared_data[];
  184. float *shared_max_vals = shared_data;
  185. int *shared_max_indices = (int*)&shared_max_vals[blockDim.x];
  186. int tid = threadIdx.x;
  187. // 1. 找到最大值和最大值的下标,存储在共享内存中
  188. float max_val = -FLT_MAX;
  189. int max_idx = -1;
  190. for (int i = tid; i < length; i += blockDim.x) {
  191. if (predict[i] > max_val) {
  192. max_val = predict[i];
  193. max_idx = i;
  194. }
  195. }
  196. shared_max_vals[tid] = max_val;
  197. shared_max_indices[tid] = max_idx;
  198. __syncthreads();
  199. // 在所有线程间找到全局最大值和对应的下标
  200. if (tid == 0) {
  201. for (int i = 1; i < blockDim.x; i++) {
  202. if (shared_max_vals[i] > shared_max_vals[0]) {
  203. shared_max_vals[0] = shared_max_vals[i];
  204. shared_max_indices[0] = shared_max_indices[i];
  205. }
  206. }
  207. *max_index = shared_max_indices[0];
  208. }
  209. __syncthreads();
  210. max_val = shared_max_vals[0];
  211. // 2. 计算指数并求和
  212. float sum_exp = 0.0f;
  213. for (int i = tid; i < length; i += blockDim.x) {
  214. predict[i] = expf(predict[i] - max_val);
  215. sum_exp += predict[i];
  216. }
  217. shared_max_vals[tid] = sum_exp;
  218. __syncthreads();
  219. // 汇总所有线程的指数和
  220. if (tid == 0) {
  221. for (int i = 1; i < blockDim.x; i++) {
  222. shared_max_vals[0] += shared_max_vals[i];
  223. }
  224. }
  225. __syncthreads();
  226. float total_sum = shared_max_vals[0];
  227. // 3. 每个元素除以总和,得到 softmax 值
  228. for (int i = tid; i < length; i += blockDim.x) {
  229. predict[i] /= total_sum;
  230. }
  231. }
  232. static void classfier_softmax(float *predict, int length, int *max_index, cudaStream_t stream) {
  233. int block_size = 256;
  234. checkKernel(softmax<<<1, block_size, block_size * sizeof(float), stream>>>(predict, length, max_index));
  235. }
  236. class InferImpl : public Infer {
  237. public:
  238. shared_ptr<trt::Infer> trt_;
  239. string engine_file_;
  240. vector<shared_ptr<trt::Memory<unsigned char>>> preprocess_buffers_;
  241. trt::Memory<float> input_buffer_, output_array_;
  242. trt::Memory<int> classes_indices_;
  243. int network_input_width_, network_input_height_;
  244. Norm normalize_;
  245. int num_classes_ = 0;
  246. bool isdynamic_model_ = false;
  247. virtual ~InferImpl() = default;
  248. void adjust_memory(int batch_size) {
  249. // the inference batch_size
  250. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  251. input_buffer_.gpu(batch_size * input_numel);
  252. output_array_.gpu(batch_size * num_classes_);
  253. output_array_.cpu(batch_size * num_classes_);
  254. classes_indices_.gpu(batch_size);
  255. classes_indices_.cpu(batch_size);
  256. if ((int)preprocess_buffers_.size() < batch_size) {
  257. for (int i = preprocess_buffers_.size(); i < batch_size; ++i)
  258. preprocess_buffers_.push_back(make_shared<trt::Memory<unsigned char>>());
  259. }
  260. }
  261. void preprocess(int ibatch, const Image &image,
  262. shared_ptr<trt::Memory<unsigned char>> preprocess_buffer,
  263. void *stream = nullptr) {
  264. AffineMatrix affine;
  265. affine.compute(make_tuple(image.width, image.height),
  266. make_tuple(network_input_width_, network_input_height_));
  267. size_t input_numel = network_input_width_ * network_input_height_ * 3;
  268. float *input_device = input_buffer_.gpu() + ibatch * input_numel;
  269. size_t size_image = image.width * image.height * 3;
  270. size_t size_matrix = upbound(sizeof(affine.d2i), 32);
  271. uint8_t *gpu_workspace = preprocess_buffer->gpu(size_matrix + size_image);
  272. float *affine_matrix_device = (float *)gpu_workspace;
  273. uint8_t *image_device = gpu_workspace + size_matrix;
  274. uint8_t *cpu_workspace = preprocess_buffer->cpu(size_matrix + size_image);
  275. float *affine_matrix_host = (float *)cpu_workspace;
  276. uint8_t *image_host = cpu_workspace + size_matrix;
  277. // speed up
  278. cudaStream_t stream_ = (cudaStream_t)stream;
  279. memcpy(image_host, image.bgrptr, size_image);
  280. memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
  281. checkRuntime(
  282. cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_));
  283. checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
  284. cudaMemcpyHostToDevice, stream_));
  285. warp_affine_bilinear_and_normalize_plane(image_device, image.width * 3, image.width,
  286. image.height, input_device, network_input_width_,
  287. network_input_height_, affine_matrix_device, 114,
  288. normalize_, stream_);
  289. }
  290. bool load(const string &engine_file) {
  291. trt_ = trt::load(engine_file);
  292. if (trt_ == nullptr) return false;
  293. trt_->print();
  294. auto input_dim = trt_->static_dims(0);
  295. network_input_width_ = input_dim[3];
  296. network_input_height_ = input_dim[2];
  297. isdynamic_model_ = trt_->has_dynamic_dim();
  298. // normalize_ = Norm::alpha_beta(1 / 255.0f, 0.0f, ChannelType::SwapRB);
  299. // [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
  300. float mean[3] = {0.485, 0.456, 0.406};
  301. float std[3] = {0.229, 0.224, 0.225};
  302. normalize_ = Norm::mean_std(mean, std, 1/255.0, ChannelType::SwapRB);
  303. num_classes_ = trt_->static_dims(1)[1];
  304. return true;
  305. }
  306. virtual Attribute forward(const Image &image, void *stream = nullptr) override {
  307. auto output = forwards({image}, stream);
  308. if (output.empty()) return {};
  309. return output[0];
  310. }
  311. virtual vector<Attribute> forwards(const vector<Image> &images, void *stream = nullptr) override {
  312. int num_image = images.size();
  313. if (num_image == 0) return {};
  314. auto input_dims = trt_->static_dims(0);
  315. int infer_batch_size = input_dims[0];
  316. if (infer_batch_size != num_image) {
  317. if (isdynamic_model_) {
  318. infer_batch_size = num_image;
  319. input_dims[0] = num_image;
  320. if (!trt_->set_run_dims(0, input_dims)) return {};
  321. } else {
  322. if (infer_batch_size < num_image) {
  323. INFO(
  324. "When using static shape model, number of images[%d] must be "
  325. "less than or equal to the maximum batch[%d].",
  326. num_image, infer_batch_size);
  327. return {};
  328. }
  329. }
  330. }
  331. adjust_memory(infer_batch_size);
  332. cudaStream_t stream_ = (cudaStream_t)stream;
  333. for (int i = 0; i < num_image; ++i)
  334. preprocess(i, images[i], preprocess_buffers_[i], stream);
  335. float *output_array_device = output_array_.gpu();
  336. vector<void *> bindings{input_buffer_.gpu(), output_array_device};
  337. if (!trt_->forward(bindings, stream)) {
  338. INFO("Failed to tensorRT forward.");
  339. return {};
  340. }
  341. for (int ib = 0; ib < num_image; ++ib) {
  342. float *output_array_device = output_array_.gpu() + ib * num_classes_;
  343. int *classes_indices_device = classes_indices_.gpu() + ib;
  344. classfier_softmax(output_array_device, num_classes_, classes_indices_device, stream_);
  345. }
  346. checkRuntime(cudaMemcpyAsync(output_array_.cpu(), output_array_.gpu(),
  347. output_array_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  348. checkRuntime(cudaMemcpyAsync(classes_indices_.cpu(), classes_indices_.gpu(),
  349. classes_indices_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
  350. checkRuntime(cudaStreamSynchronize(stream_));
  351. vector<Attribute> arrout;
  352. arrout.reserve(num_image);
  353. for (int ib = 0; ib < num_image; ++ib) {
  354. float *output_array_cpu = output_array_.cpu() + ib * num_classes_;
  355. int *max_index = classes_indices_.cpu() + ib;
  356. int index = *max_index;
  357. float max_score = output_array_cpu[index];
  358. arrout.emplace_back(max_score, index);
  359. }
  360. return arrout;
  361. }
  362. };
  363. Infer *loadraw(const std::string &engine_file) {
  364. InferImpl *impl = new InferImpl();
  365. if (!impl->load(engine_file)) {
  366. delete impl;
  367. impl = nullptr;
  368. }
  369. return impl;
  370. }
  371. shared_ptr<Infer> load(const string &engine_file) {
  372. return std::shared_ptr<InferImpl>(
  373. (InferImpl *)loadraw(engine_file));
  374. }
  375. }