123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- #include "infer/trt/classfier/classifier.hpp"
- #include <mutex>
- #include <vector>
- #include <algorithm>
- namespace cls
- {
- static __global__ void softmax(float *predict, int length, int *max_index) {
- extern __shared__ float shared_data[];
- float *shared_max_vals = shared_data;
- int *shared_max_indices = (int*)&shared_max_vals[blockDim.x];
-
- int tid = threadIdx.x;
- // 1. 找到最大值和最大值的下标,存储在共享内存中
- float max_val = -FLT_MAX;
- int max_idx = -1;
- for (int i = tid; i < length; i += blockDim.x) {
- if (predict[i] > max_val) {
- max_val = predict[i];
- max_idx = i;
- }
- }
- shared_max_vals[tid] = max_val;
- shared_max_indices[tid] = max_idx;
- __syncthreads();
- // 在所有线程间找到全局最大值和对应的下标
- if (tid == 0) {
- for (int i = 1; i < blockDim.x; i++) {
- if (shared_max_vals[i] > shared_max_vals[0]) {
- shared_max_vals[0] = shared_max_vals[i];
- shared_max_indices[0] = shared_max_indices[i];
- }
- }
- *max_index = shared_max_indices[0];
- }
- __syncthreads();
- max_val = shared_max_vals[0];
- // 2. 计算指数并求和
- float sum_exp = 0.0f;
- for (int i = tid; i < length; i += blockDim.x) {
- predict[i] = expf(predict[i] - max_val);
- sum_exp += predict[i];
- }
- shared_max_vals[tid] = sum_exp;
- __syncthreads();
- // 汇总所有线程的指数和
- if (tid == 0) {
- for (int i = 1; i < blockDim.x; i++) {
- shared_max_vals[0] += shared_max_vals[i];
- }
- }
- __syncthreads();
- float total_sum = shared_max_vals[0];
- // 3. 每个元素除以总和,得到 softmax 值
- for (int i = tid; i < length; i += blockDim.x) {
- predict[i] /= total_sum;
- }
- }
- static void classfier_softmax(float *predict, int length, int *max_index, cudaStream_t stream) {
- int block_size = 256;
- checkKernel(softmax<<<1, block_size, block_size * sizeof(float), stream>>>(predict, length, max_index));
- }
-
- bool ClassifierModelImpl::load(const std::string &engine_file, int gpu_id)
- {
- trt_ = TensorRT::load(engine_file);
- device_id_ = gpu_id;
- if (trt_ == nullptr) return false;
- trt_->print();
- auto input_dim = trt_->static_dims(0);
- network_input_width_ = input_dim[3];
- network_input_height_ = input_dim[2];
- isdynamic_model_ = trt_->has_dynamic_dim();
- num_classes_ = trt_->static_dims(1)[1];
- float mean[3] = {0.485f, 0.456f, 0.406f};
- float std[3] = {0.229f, 0.224f, 0.225f};
- normalize_ = affine::Norm::mean_std(mean, std, 1 / 255.0f, affine::ChannelType::SwapRB);
- // normalize_ = affine::Norm::alpha_beta(1 / 255.0f, 0.0f, affine::ChannelType::SwapRB);
- return true;
- }
- void ClassifierModelImpl::preprocess(int ibatch, affine::CropResizeMatrix& matrix, int x, int y, int w, int h, void *stream)
- {
- matrix.compute(
- std::make_tuple(w, h),
- std::make_tuple(network_input_width_, network_input_height_));
- size_t input_numel = network_input_width_ * network_input_height_ * 3;
- float *input_device = input_buffer_.gpu() + ibatch * input_numel;
- uint8_t *image_device = preprocess_buffer_.gpu();
- uint8_t *image_host = preprocess_buffer_.cpu();
- float *affine_matrix_device = affine_matrix_.gpu();
- float *affine_matrix_host = affine_matrix_.cpu();
-
- cudaStream_t stream_ = (cudaStream_t)stream;
- memcpy(affine_matrix_host, matrix.d2i, sizeof(matrix.d2i));
- checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(matrix.d2i),
- cudaMemcpyHostToDevice, stream_));
- affine::warp_affine_bilinear_and_normalize_plane(image_device, image.width * 3, image.width,
- image.height, input_device, network_input_width_,
- network_input_height_, affine_matrix_device, 114,
- normalize_, stream_);
- checkRuntime(cudaStreamSynchronize(stream_));
- }
- virtual Result ClassifierModelImpl::forward(const tensor::Image &image, void *stream)
- {
- return;
- }
- virtual Result ClassifierModelImpl::forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream)
- {
- return;
- }
- virtual Result ClassifierModelImpl::forward(const tensor::Image &image, data::BoxArray& boxes, void *stream)
- {
- std::lock_guard<std::mutex> lock(mutex_);
- std::vector<data::Box*> classfier_boxes_ptr;
- for (auto& box : boxes)
- {
- if (std::find(box.label, class_names_.begin(), class_names_.end()) != class_names_.end())
- {
- classfier_boxes_ptr.push_back(&box);
- }
- }
- int num_image = classfier_boxes_ptr.size();
- if (num_image == 0){ return; }
- auto input_dims = trt_->static_dims(0);
- int infer_batch_size = input_dims[0];
- if (infer_batch_size != num_image)
- {
- if (isdynamic_model_)
- {
- infer_batch_size = num_image;
- input_dims[0] = num_image;
- if (!trt_->set_run_dims(0, input_dims))
- {
- printf("Fail to set run dims\n");
- return;
- }
- }
- else
- {
- if (infer_batch_size < num_image)
- {
- printf(
- "When using static shape model, number of images[%d] must be "
- "less than or equal to the maximum batch[%d].",
- num_image, infer_batch_size);
- return;
- }
- }
- }
- adjust_memory(num_image, image.width, image.height);
-
- uint8_t *image_device = preprocess_buffer_.gpu();
- uint8_t *image_host = preprocess_buffer_.cpu();
- size_t size_image = image.width * image.height * 3;
- cudaStream_t stream_ = (cudaStream_t)stream;
- memcpy(image_host, image.bgrptr, size_image);
- checkRuntime(
- cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_));
- affine::CropResizeMatrix crmatrix;
- for(int i = 0; i < num_image; i++)
- {
- data::Box* box_ptr = classfier_boxes_ptr[i];
- int x = (int)box_ptr->left;
- int y = (int)box_ptr->top;
- int w = (int)box_ptr->right - x;
- int h = (int)box_ptr->bottom - y;
- preprocess(i, crmatrix, x, y, w, h, stream);
- }
- #ifdef TRT10
- if (!trt_->forward(std::unordered_map<std::string, const void *>{
- { "input", input_buffer_.gpu() },
- { "output", output_buffer_.gpu() }
- }, stream_))
- {
- printf("Failed to tensorRT forward.\n");
- return cv::Mat();
- }
- #else
- std::vector<void *> bindings{input_buffer_.gpu(), output_buffer_.gpu()};
- if (!trt_->forward(bindings, stream))
- {
- printf("Failed to tensorRT forward.");
- return cv::Mat();
- }
- #endif
- for (int ib = 0; ib < num_image; ++ib)
- {
- float *output_buffer_device = output_buffer_.gpu() + ib * num_classes_;
- int *classes_indices_device = classes_indices_.gpu() + ib;
- classfier_softmax(output_buffer_device, num_classes_, classes_indices_device, stream_);
- }
- checkRuntime(cudaMemcpyAsync(output_buffer_.cpu(), output_buffer_.gpu(),
- output_buffer_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
- checkRuntime(cudaMemcpyAsync(classes_indices_.cpu(), classes_indices_.gpu(),
- classes_indices_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
- checkRuntime(cudaStreamSynchronize(stream_));
- for (int ib = 0; ib < num_image; ++ib)
- {
- int *max_index = classes_indices_.cpu() + ib;
- int index = *max_index;
- classfier_boxes_ptr[ib]->label = class_names_[index];
- }
- }
- }
|