leon 5 kuukautta sitten
commit
7462d842e1
6 muutettua tiedostoa jossa 1033 lisäystä ja 0 poistoa
  1. 59 0
      .vscode/settings.json
  2. 444 0
      src/infer.cu
  3. 98 0
      src/infer.hpp
  4. 27 0
      src/main.cpp
  5. 364 0
      src/resnet.cu
  6. 41 0
      src/resnet.hpp

+ 59 - 0
.vscode/settings.json

@@ -0,0 +1,59 @@
+{
+    "files.associations": {
+        "vector": "cpp",
+        "__bit_reference": "cpp",
+        "__hash_table": "cpp",
+        "__locale": "cpp",
+        "__node_handle": "cpp",
+        "__split_buffer": "cpp",
+        "__verbose_abort": "cpp",
+        "array": "cpp",
+        "bitset": "cpp",
+        "cctype": "cpp",
+        "charconv": "cpp",
+        "clocale": "cpp",
+        "cmath": "cpp",
+        "cstdarg": "cpp",
+        "cstddef": "cpp",
+        "cstdint": "cpp",
+        "cstdio": "cpp",
+        "cstdlib": "cpp",
+        "cstring": "cpp",
+        "ctime": "cpp",
+        "cwchar": "cpp",
+        "cwctype": "cpp",
+        "deque": "cpp",
+        "execution": "cpp",
+        "memory": "cpp",
+        "forward_list": "cpp",
+        "fstream": "cpp",
+        "future": "cpp",
+        "initializer_list": "cpp",
+        "iomanip": "cpp",
+        "ios": "cpp",
+        "iosfwd": "cpp",
+        "istream": "cpp",
+        "limits": "cpp",
+        "locale": "cpp",
+        "mutex": "cpp",
+        "new": "cpp",
+        "optional": "cpp",
+        "print": "cpp",
+        "queue": "cpp",
+        "ratio": "cpp",
+        "sstream": "cpp",
+        "stack": "cpp",
+        "stdexcept": "cpp",
+        "streambuf": "cpp",
+        "string": "cpp",
+        "string_view": "cpp",
+        "typeinfo": "cpp",
+        "unordered_map": "cpp",
+        "variant": "cpp",
+        "algorithm": "cpp",
+        "iterator": "cpp",
+        "tuple": "cpp",
+        "utility": "cpp",
+        "type_traits": "cpp"
+    }
+}

+ 444 - 0
src/infer.cu

@@ -0,0 +1,444 @@
+
+#include <NvInfer.h>
+#include <cuda_runtime.h>
+#include <stdarg.h>
+
+#include <fstream>
+#include <numeric>
+#include <sstream>
+#include <unordered_map>
+
+#include "infer.hpp"
+
+namespace trt {
+
+using namespace std;
+using namespace nvinfer1;
+
+#define checkRuntime(call)                                                                 \
+  do {                                                                                     \
+    auto ___call__ret_code__ = (call);                                                     \
+    if (___call__ret_code__ != cudaSuccess) {                                              \
+      INFO("CUDA Runtime error💥 %s # %s, code = %s [ %d ]", #call,                         \
+           cudaGetErrorString(___call__ret_code__), cudaGetErrorName(___call__ret_code__), \
+           ___call__ret_code__);                                                           \
+      abort();                                                                             \
+    }                                                                                      \
+  } while (0)
+
+#define checkKernel(...)                 \
+  do {                                   \
+    { (__VA_ARGS__); }                   \
+    checkRuntime(cudaPeekAtLastError()); \
+  } while (0)
+
+#define Assert(op)                 \
+  do {                             \
+    bool cond = !(!(op));          \
+    if (!cond) {                   \
+      INFO("Assert failed, " #op); \
+      abort();                     \
+    }                              \
+  } while (0)
+
+#define Assertf(op, ...)                             \
+  do {                                               \
+    bool cond = !(!(op));                            \
+    if (!cond) {                                     \
+      INFO("Assert failed, " #op " : " __VA_ARGS__); \
+      abort();                                       \
+    }                                                \
+  } while (0)
+
+static string file_name(const string &path, bool include_suffix) {
+  if (path.empty()) return "";
+
+  int p = path.rfind('/');
+  int e = path.rfind('\\');
+  p = std::max(p, e);
+  p += 1;
+
+  // include suffix
+  if (include_suffix) return path.substr(p);
+
+  int u = path.rfind('.');
+  if (u == -1) return path.substr(p);
+
+  if (u <= p) u = path.size();
+  return path.substr(p, u - p);
+}
+
+void __log_func(const char *file, int line, const char *fmt, ...) {
+  va_list vl;
+  va_start(vl, fmt);
+  char buffer[2048];
+  string filename = file_name(file, true);
+  int n = snprintf(buffer, sizeof(buffer), "[%s:%d]: ", filename.c_str(), line);
+  vsnprintf(buffer + n, sizeof(buffer) - n, fmt, vl);
+  fprintf(stdout, "%s\n", buffer);
+}
+
+static std::string format_shape(const Dims &shape) {
+  stringstream output;
+  char buf[64];
+  const char *fmts[] = {"%d", "x%d"};
+  for (int i = 0; i < shape.nbDims; ++i) {
+    snprintf(buf, sizeof(buf), fmts[i != 0], shape.d[i]);
+    output << buf;
+  }
+  return output.str();
+}
+
+Timer::Timer() {
+  checkRuntime(cudaEventCreate((cudaEvent_t *)&start_));
+  checkRuntime(cudaEventCreate((cudaEvent_t *)&stop_));
+}
+
+Timer::~Timer() {
+  checkRuntime(cudaEventDestroy((cudaEvent_t)start_));
+  checkRuntime(cudaEventDestroy((cudaEvent_t)stop_));
+}
+
+void Timer::start(void *stream) {
+  stream_ = stream;
+  checkRuntime(cudaEventRecord((cudaEvent_t)start_, (cudaStream_t)stream_));
+}
+
+float Timer::stop(const char *prefix, bool print) {
+  checkRuntime(cudaEventRecord((cudaEvent_t)stop_, (cudaStream_t)stream_));
+  checkRuntime(cudaEventSynchronize((cudaEvent_t)stop_));
+
+  float latency = 0;
+  checkRuntime(cudaEventElapsedTime(&latency, (cudaEvent_t)start_, (cudaEvent_t)stop_));
+
+  if (print) {
+    printf("[%s]: %.5f ms\n", prefix, latency);
+  }
+  return latency;
+}
+
+BaseMemory::BaseMemory(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes) {
+  reference(cpu, cpu_bytes, gpu, gpu_bytes);
+}
+
+void BaseMemory::reference(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes) {
+  release();
+
+  if (cpu == nullptr || cpu_bytes == 0) {
+    cpu = nullptr;
+    cpu_bytes = 0;
+  }
+
+  if (gpu == nullptr || gpu_bytes == 0) {
+    gpu = nullptr;
+    gpu_bytes = 0;
+  }
+
+  this->cpu_ = cpu;
+  this->cpu_capacity_ = cpu_bytes;
+  this->cpu_bytes_ = cpu_bytes;
+  this->gpu_ = gpu;
+  this->gpu_capacity_ = gpu_bytes;
+  this->gpu_bytes_ = gpu_bytes;
+
+  this->owner_cpu_ = !(cpu && cpu_bytes > 0);
+  this->owner_gpu_ = !(gpu && gpu_bytes > 0);
+}
+
+BaseMemory::~BaseMemory() { release(); }
+
+void *BaseMemory::gpu_realloc(size_t bytes) {
+  if (gpu_capacity_ < bytes) {
+    release_gpu();
+
+    gpu_capacity_ = bytes;
+    checkRuntime(cudaMalloc(&gpu_, bytes));
+    // checkRuntime(cudaMemset(gpu_, 0, size));
+  }
+  gpu_bytes_ = bytes;
+  return gpu_;
+}
+
+void *BaseMemory::cpu_realloc(size_t bytes) {
+  if (cpu_capacity_ < bytes) {
+    release_cpu();
+
+    cpu_capacity_ = bytes;
+    checkRuntime(cudaMallocHost(&cpu_, bytes));
+    Assert(cpu_ != nullptr);
+    // memset(cpu_, 0, size);
+  }
+  cpu_bytes_ = bytes;
+  return cpu_;
+}
+
+void BaseMemory::release_cpu() {
+  if (cpu_) {
+    if (owner_cpu_) {
+      checkRuntime(cudaFreeHost(cpu_));
+    }
+    cpu_ = nullptr;
+  }
+  cpu_capacity_ = 0;
+  cpu_bytes_ = 0;
+}
+
+void BaseMemory::release_gpu() {
+  if (gpu_) {
+    if (owner_gpu_) {
+      checkRuntime(cudaFree(gpu_));
+    }
+    gpu_ = nullptr;
+  }
+  gpu_capacity_ = 0;
+  gpu_bytes_ = 0;
+}
+
+void BaseMemory::release() {
+  release_cpu();
+  release_gpu();
+}
+
+class __native_nvinfer_logger : public ILogger {
+ public:
+  virtual void log(Severity severity, const char *msg) noexcept override {
+    if (severity == Severity::kINTERNAL_ERROR) {
+      INFO("NVInfer INTERNAL_ERROR: %s", msg);
+      abort();
+    } else if (severity == Severity::kERROR) {
+      INFO("NVInfer: %s", msg);
+    }
+    // else  if (severity == Severity::kWARNING) {
+    //     INFO("NVInfer: %s", msg);
+    // }
+    // else  if (severity == Severity::kINFO) {
+    //     INFO("NVInfer: %s", msg);
+    // }
+    // else {
+    //     INFO("%s", msg);
+    // }
+  }
+};
+static __native_nvinfer_logger gLogger;
+
+template <typename _T>
+static void destroy_nvidia_pointer(_T *ptr) {
+  if (ptr) ptr->destroy();
+}
+
+static std::vector<uint8_t> load_file(const string &file) {
+  ifstream in(file, ios::in | ios::binary);
+  if (!in.is_open()) return {};
+
+  in.seekg(0, ios::end);
+  size_t length = in.tellg();
+
+  std::vector<uint8_t> data;
+  if (length > 0) {
+    in.seekg(0, ios::beg);
+    data.resize(length);
+
+    in.read((char *)&data[0], length);
+  }
+  in.close();
+  return data;
+}
+
+class __native_engine_context {
+ public:
+  virtual ~__native_engine_context() { destroy(); }
+
+  bool construct(const void *pdata, size_t size) {
+    destroy();
+
+    if (pdata == nullptr || size == 0) return false;
+
+    runtime_ = shared_ptr<IRuntime>(createInferRuntime(gLogger), destroy_nvidia_pointer<IRuntime>);
+    if (runtime_ == nullptr) return false;
+
+    engine_ = shared_ptr<ICudaEngine>(runtime_->deserializeCudaEngine(pdata, size, nullptr),
+                                      destroy_nvidia_pointer<ICudaEngine>);
+    if (engine_ == nullptr) return false;
+
+    context_ = shared_ptr<IExecutionContext>(engine_->createExecutionContext(),
+                                             destroy_nvidia_pointer<IExecutionContext>);
+    return context_ != nullptr;
+  }
+
+ private:
+  void destroy() {
+    context_.reset();
+    engine_.reset();
+    runtime_.reset();
+  }
+
+ public:
+  shared_ptr<IExecutionContext> context_;
+  shared_ptr<ICudaEngine> engine_;
+  shared_ptr<IRuntime> runtime_ = nullptr;
+};
+
+class InferImpl : public Infer {
+ public:
+  shared_ptr<__native_engine_context> context_;
+  unordered_map<string, int> binding_name_to_index_;
+
+  virtual ~InferImpl() = default;
+
+  bool construct(const void *data, size_t size) {
+    context_ = make_shared<__native_engine_context>();
+    if (!context_->construct(data, size)) {
+      return false;
+    }
+
+    setup();
+    return true;
+  }
+
+  bool load(const string &file) {
+    auto data = load_file(file);
+    if (data.empty()) {
+      INFO("An empty file has been loaded. Please confirm your file path: %s", file.c_str());
+      return false;
+    }
+    return this->construct(data.data(), data.size());
+  }
+
+  void setup() {
+    auto engine = this->context_->engine_;
+    int nbBindings = engine->getNbBindings();
+
+    binding_name_to_index_.clear();
+    for (int i = 0; i < nbBindings; ++i) {
+      const char *bindingName = engine->getBindingName(i);
+      binding_name_to_index_[bindingName] = i;
+    }
+  }
+
+  virtual int index(const std::string &name) override {
+    auto iter = binding_name_to_index_.find(name);
+    Assertf(iter != binding_name_to_index_.end(), "Can not found the binding name: %s",
+            name.c_str());
+    return iter->second;
+  }
+
+  virtual bool forward(const std::vector<void *> &bindings, void *stream,
+                       void *input_consum_event) override {
+    return this->context_->context_->enqueueV2((void**)bindings.data(), (cudaStream_t)stream,
+                                               (cudaEvent_t *)input_consum_event);
+  }
+
+  virtual std::vector<int> run_dims(const std::string &name) override {
+    return run_dims(index(name));
+  }
+
+  virtual std::vector<int> run_dims(int ibinding) override {
+    auto dim = this->context_->context_->getBindingDimensions(ibinding);
+    return std::vector<int>(dim.d, dim.d + dim.nbDims);
+  }
+
+  virtual std::vector<int> static_dims(const std::string &name) override {
+    return static_dims(index(name));
+  }
+
+  virtual std::vector<int> static_dims(int ibinding) override {
+    auto dim = this->context_->engine_->getBindingDimensions(ibinding);
+    return std::vector<int>(dim.d, dim.d + dim.nbDims);
+  }
+
+  virtual int num_bindings() override { return this->context_->engine_->getNbBindings(); }
+
+  virtual bool is_input(int ibinding) override {
+    return this->context_->engine_->bindingIsInput(ibinding);
+  }
+
+  virtual bool set_run_dims(const std::string &name, const std::vector<int> &dims) override {
+    return this->set_run_dims(index(name), dims);
+  }
+
+  virtual bool set_run_dims(int ibinding, const std::vector<int> &dims) override {
+    Dims d;
+    memcpy(d.d, dims.data(), sizeof(int) * dims.size());
+    d.nbDims = dims.size();
+    return this->context_->context_->setBindingDimensions(ibinding, d);
+  }
+
+  virtual int numel(const std::string &name) override { return numel(index(name)); }
+
+  virtual int numel(int ibinding) override {
+    auto dim = this->context_->context_->getBindingDimensions(ibinding);
+    return std::accumulate(dim.d, dim.d + dim.nbDims, 1, std::multiplies<int>());
+  }
+
+  virtual DType dtype(const std::string &name) override { return dtype(index(name)); }
+
+  virtual DType dtype(int ibinding) override {
+    return (DType)this->context_->engine_->getBindingDataType(ibinding);
+  }
+
+  virtual bool has_dynamic_dim() override {
+    // check if any input or output bindings have dynamic shapes
+    // code from ChatGPT
+    int numBindings = this->context_->engine_->getNbBindings();
+    for (int i = 0; i < numBindings; ++i) {
+      nvinfer1::Dims dims = this->context_->engine_->getBindingDimensions(i);
+      for (int j = 0; j < dims.nbDims; ++j) {
+        if (dims.d[j] == -1) return true;
+      }
+    }
+    return false;
+  }
+
+  virtual void print() override {
+    INFO("Infer %p [%s]", this, has_dynamic_dim() ? "DynamicShape" : "StaticShape");
+
+    int num_input = 0;
+    int num_output = 0;
+    auto engine = this->context_->engine_;
+    for (int i = 0; i < engine->getNbBindings(); ++i) {
+      if (engine->bindingIsInput(i))
+        num_input++;
+      else
+        num_output++;
+    }
+
+    INFO("Inputs: %d", num_input);
+    for (int i = 0; i < num_input; ++i) {
+      auto name = engine->getBindingName(i);
+      auto dim = engine->getBindingDimensions(i);
+      INFO("\t%d.%s : shape {%s}", i, name, format_shape(dim).c_str());
+    }
+
+    INFO("Outputs: %d", num_output);
+    for (int i = 0; i < num_output; ++i) {
+      auto name = engine->getBindingName(i + num_input);
+      auto dim = engine->getBindingDimensions(i + num_input);
+      INFO("\t%d.%s : shape {%s}", i, name, format_shape(dim).c_str());
+    }
+  }
+};
+
+Infer *loadraw(const std::string &file) {
+  InferImpl *impl = new InferImpl();
+  if (!impl->load(file)) {
+    delete impl;
+    impl = nullptr;
+  }
+  return impl;
+}
+
+std::shared_ptr<Infer> load(const std::string &file) {
+  return std::shared_ptr<InferImpl>((InferImpl *)loadraw(file));
+}
+
+std::string format_shape(const std::vector<int> &shape) {
+  stringstream output;
+  char buf[64];
+  const char *fmts[] = {"%d", "x%d"};
+  for (int i = 0; i < (int)shape.size(); ++i) {
+    snprintf(buf, sizeof(buf), fmts[i != 0], shape[i]);
+    output << buf;
+  }
+  return output.str();
+}
+};  // namespace trt

+ 98 - 0
src/infer.hpp

@@ -0,0 +1,98 @@
+#ifndef __INFER_HPP__
+#define __INFER_HPP__
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace trt {
+
+#define INFO(...) trt::__log_func(__FILE__, __LINE__, __VA_ARGS__)
+void __log_func(const char *file, int line, const char *fmt, ...);
+
+enum class DType : int { FLOAT = 0, HALF = 1, INT8 = 2, INT32 = 3, BOOL = 4, UINT8 = 5 };
+
+class Timer {
+ public:
+  Timer();
+  virtual ~Timer();
+  void start(void *stream = nullptr);
+  float stop(const char *prefix = "Timer", bool print = true);
+
+ private:
+  void *start_, *stop_;
+  void *stream_;
+};
+
+class BaseMemory {
+ public:
+  BaseMemory() = default;
+  BaseMemory(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes);
+  virtual ~BaseMemory();
+  virtual void *gpu_realloc(size_t bytes);
+  virtual void *cpu_realloc(size_t bytes);
+  void release_gpu();
+  void release_cpu();
+  void release();
+  inline bool owner_gpu() const { return owner_gpu_; }
+  inline bool owner_cpu() const { return owner_cpu_; }
+  inline size_t cpu_bytes() const { return cpu_bytes_; }
+  inline size_t gpu_bytes() const { return gpu_bytes_; }
+  virtual inline void *get_gpu() const { return gpu_; }
+  virtual inline void *get_cpu() const { return cpu_; }
+  void reference(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes);
+
+ protected:
+  void *cpu_ = nullptr;
+  size_t cpu_bytes_ = 0, cpu_capacity_ = 0;
+  bool owner_cpu_ = true;
+
+  void *gpu_ = nullptr;
+  size_t gpu_bytes_ = 0, gpu_capacity_ = 0;
+  bool owner_gpu_ = true;
+};
+
+template <typename _DT>
+class Memory : public BaseMemory {
+ public:
+  Memory() = default;
+  Memory(const Memory &other) = delete;
+  Memory &operator=(const Memory &other) = delete;
+  virtual _DT *gpu(size_t size) { return (_DT *)BaseMemory::gpu_realloc(size * sizeof(_DT)); }
+  virtual _DT *cpu(size_t size) { return (_DT *)BaseMemory::cpu_realloc(size * sizeof(_DT)); }
+
+  inline size_t cpu_size() const { return cpu_bytes_ / sizeof(_DT); }
+  inline size_t gpu_size() const { return gpu_bytes_ / sizeof(_DT); }
+
+  virtual inline _DT *gpu() const { return (_DT *)gpu_; }
+  virtual inline _DT *cpu() const { return (_DT *)cpu_; }
+};
+
+class Infer {
+ public:
+  virtual bool forward(const std::vector<void *> &bindings, void *stream = nullptr,
+                       void *input_consum_event = nullptr) = 0;
+  virtual int index(const std::string &name) = 0;
+  virtual std::vector<int> run_dims(const std::string &name) = 0;
+  virtual std::vector<int> run_dims(int ibinding) = 0;
+  virtual std::vector<int> static_dims(const std::string &name) = 0;
+  virtual std::vector<int> static_dims(int ibinding) = 0;
+  virtual int numel(const std::string &name) = 0;
+  virtual int numel(int ibinding) = 0;
+  virtual int num_bindings() = 0;
+  virtual bool is_input(int ibinding) = 0;
+  virtual bool set_run_dims(const std::string &name, const std::vector<int> &dims) = 0;
+  virtual bool set_run_dims(int ibinding, const std::vector<int> &dims) = 0;
+  virtual DType dtype(const std::string &name) = 0;
+  virtual DType dtype(int ibinding) = 0;
+  virtual bool has_dynamic_dim() = 0;
+  virtual void print() = 0;
+};
+
+std::shared_ptr<Infer> load(const std::string &file);
+std::string format_shape(const std::vector<int> &shape);
+
+}  // namespace trt
+
+#endif  // __INFER_HPP__

+ 27 - 0
src/main.cpp

@@ -0,0 +1,27 @@
+
+#include <opencv2/opencv.hpp>
+
+#include "infer.hpp"
+#include "resnet.hpp"
+
+using namespace std;
+
+
+resnet::Image cvimg(const cv::Mat &image) { return yolo::Image(image.data, image.cols, image.rows); }
+
+
+
+void single_inference() {
+  cv::Mat image = cv::imread("inference/car.jpg");
+  auto resnet = resnet::load("resnet.engine");
+  if (resnet == nullptr) return;
+
+  auto objs = resnet->forward(cvimg(image));
+}
+
+int main() {
+  perf();
+  batch_inference();
+  single_inference();
+  return 0;
+}

+ 364 - 0
src/resnet.cu

@@ -0,0 +1,364 @@
+#include "infer.hpp"
+#include "resnet.hpp"
+
+namespace resnet
+{
+
+using namespace std;
+
+#define GPU_BLOCK_THREADS 512
+#define checkRuntime(call)                                                                 \
+  do {                                                                                     \
+    auto ___call__ret_code__ = (call);                                                     \
+    if (___call__ret_code__ != cudaSuccess) {                                              \
+      INFO("CUDA Runtime error💥 %s # %s, code = %s [ %d ]", #call,                         \
+           cudaGetErrorString(___call__ret_code__), cudaGetErrorName(___call__ret_code__), \
+           ___call__ret_code__);                                                           \
+      abort();                                                                             \
+    }                                                                                      \
+  } while (0)
+
+#define checkKernel(...)                 \
+  do {                                   \
+    { (__VA_ARGS__); }                   \
+    checkRuntime(cudaPeekAtLastError()); \
+  } while (0)
+
+enum class NormType : int { None = 0, MeanStd = 1, AlphaBeta = 2 };
+
+enum class ChannelType : int { None = 0, SwapRB = 1 };
+
+/* 归一化操作,可以支持均值标准差,alpha beta,和swap RB */
+struct Norm {
+  float mean[3];
+  float std[3];
+  float alpha, beta;
+  NormType type = NormType::None;
+  ChannelType channel_type = ChannelType::None;
+
+  // out = (x * alpha - mean) / std
+  static Norm mean_std(const float mean[3], const float std[3], float alpha = 1 / 255.0f,
+                       ChannelType channel_type = ChannelType::None);
+
+  // out = x * alpha + beta
+  static Norm alpha_beta(float alpha, float beta = 0, ChannelType channel_type = ChannelType::None);
+
+  // None
+  static Norm None();
+};
+
+Norm Norm::mean_std(const float mean[3], const float std[3], float alpha,
+                    ChannelType channel_type) {
+  Norm out;
+  out.type = NormType::MeanStd;
+  out.alpha = alpha;
+  out.channel_type = channel_type;
+  memcpy(out.mean, mean, sizeof(out.mean));
+  memcpy(out.std, std, sizeof(out.std));
+  return out;
+}
+
+Norm Norm::alpha_beta(float alpha, float beta, ChannelType channel_type) {
+  Norm out;
+  out.type = NormType::AlphaBeta;
+  out.alpha = alpha;
+  out.beta = beta;
+  out.channel_type = channel_type;
+  return out;
+}
+
+Norm Norm::None() { return Norm(); }
+
+static dim3 grid_dims(int numJobs) {
+  int numBlockThreads = numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
+  return dim3(((numJobs + numBlockThreads - 1) / (float)numBlockThreads));
+}
+
+static dim3 block_dims(int numJobs) {
+  return numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS;
+}
+
+
+static __global__ void warp_affine_bilinear_and_normalize_plane_kernel(
+    uint8_t *src, int src_line_size, int src_width, int src_height, float *dst, int dst_width,
+    int dst_height, uint8_t const_value_st, float *warp_affine_matrix_2_3, Norm norm) {
+  int dx = blockDim.x * blockIdx.x + threadIdx.x;
+  int dy = blockDim.y * blockIdx.y + threadIdx.y;
+  if (dx >= dst_width || dy >= dst_height) return;
+
+  float m_x1 = warp_affine_matrix_2_3[0];
+  float m_y1 = warp_affine_matrix_2_3[1];
+  float m_z1 = warp_affine_matrix_2_3[2];
+  float m_x2 = warp_affine_matrix_2_3[3];
+  float m_y2 = warp_affine_matrix_2_3[4];
+  float m_z2 = warp_affine_matrix_2_3[5];
+
+  float src_x = m_x1 * dx + m_y1 * dy + m_z1;
+  float src_y = m_x2 * dx + m_y2 * dy + m_z2;
+  float c0, c1, c2;
+
+  if (src_x <= -1 || src_x >= src_width || src_y <= -1 || src_y >= src_height) {
+    // out of range
+    c0 = const_value_st;
+    c1 = const_value_st;
+    c2 = const_value_st;
+  } else {
+    int y_low = floorf(src_y);
+    int x_low = floorf(src_x);
+    int y_high = y_low + 1;
+    int x_high = x_low + 1;
+
+    uint8_t const_value[] = {const_value_st, const_value_st, const_value_st};
+    float ly = src_y - y_low;
+    float lx = src_x - x_low;
+    float hy = 1 - ly;
+    float hx = 1 - lx;
+    float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
+    uint8_t *v1 = const_value;
+    uint8_t *v2 = const_value;
+    uint8_t *v3 = const_value;
+    uint8_t *v4 = const_value;
+    if (y_low >= 0) {
+      if (x_low >= 0) v1 = src + y_low * src_line_size + x_low * 3;
+
+      if (x_high < src_width) v2 = src + y_low * src_line_size + x_high * 3;
+    }
+
+    if (y_high < src_height) {
+      if (x_low >= 0) v3 = src + y_high * src_line_size + x_low * 3;
+
+      if (x_high < src_width) v4 = src + y_high * src_line_size + x_high * 3;
+    }
+
+    // same to opencv
+    c0 = floorf(w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0] + 0.5f);
+    c1 = floorf(w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1] + 0.5f);
+    c2 = floorf(w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2] + 0.5f);
+  }
+
+  if (norm.channel_type == ChannelType::SwapRB) {
+    float t = c2;
+    c2 = c0;
+    c0 = t;
+  }
+
+  if (norm.type == NormType::MeanStd) {
+    c0 = (c0 * norm.alpha - norm.mean[0]) / norm.std[0];
+    c1 = (c1 * norm.alpha - norm.mean[1]) / norm.std[1];
+    c2 = (c2 * norm.alpha - norm.mean[2]) / norm.std[2];
+  } else if (norm.type == NormType::AlphaBeta) {
+    c0 = c0 * norm.alpha + norm.beta;
+    c1 = c1 * norm.alpha + norm.beta;
+    c2 = c2 * norm.alpha + norm.beta;
+  }
+
+  int area = dst_width * dst_height;
+  float *pdst_c0 = dst + dy * dst_width + dx;
+  float *pdst_c1 = pdst_c0 + area;
+  float *pdst_c2 = pdst_c1 + area;
+  *pdst_c0 = c0;
+  *pdst_c1 = c1;
+  *pdst_c2 = c2;
+}
+
+static void warp_affine_bilinear_and_normalize_plane(uint8_t *src, int src_line_size, int src_width,
+                                                     int src_height, float *dst, int dst_width,
+                                                     int dst_height, float *matrix_2_3,
+                                                     uint8_t const_value, const Norm &norm,
+                                                     cudaStream_t stream) {
+  dim3 grid((dst_width + 31) / 32, (dst_height + 31) / 32);
+  dim3 block(32, 32);
+
+  checkKernel(warp_affine_bilinear_and_normalize_plane_kernel<<<grid, block, 0, stream>>>(
+      src, src_line_size, src_width, src_height, dst, dst_width, dst_height, const_value,
+      matrix_2_3, norm));
+}
+
+
+struct AffineMatrix {
+  float i2d[6];  // image to dst(network), 2x3 matrix
+  float d2i[6];  // dst to image, 2x3 matrix
+
+  void compute(const std::tuple<int, int> &from, const std::tuple<int, int> &to) {
+    float scale_x = get<0>(to) / (float)get<0>(from);
+    float scale_y = get<1>(to) / (float)get<1>(from);
+    float scale = std::min(scale_x, scale_y);
+    // letter box
+    // i2d[0] = scale;
+    // i2d[1] = 0;
+    // i2d[2] = -scale * get<0>(from) * 0.5 + get<0>(to) * 0.5 + scale * 0.5 - 0.5;
+    // i2d[3] = 0;
+    // i2d[4] = scale;
+    // i2d[5] = -scale * get<1>(from) * 0.5 + get<1>(to) * 0.5 + scale * 0.5 - 0.5;
+    // resize 
+    i2d[0] = scale;
+    i2d[1] = 0;
+    i2d[2] = 0;
+    i2d[3] = 0;
+    i2d[4] = scale;
+    i2d[5] = 0;
+
+
+    double D = i2d[0] * i2d[4] - i2d[1] * i2d[3];
+    D = D != 0. ? double(1.) / D : double(0.);
+    double A11 = i2d[4] * D, A22 = i2d[0] * D, A12 = -i2d[1] * D, A21 = -i2d[3] * D;
+    double b1 = -A11 * i2d[2] - A12 * i2d[5];
+    double b2 = -A21 * i2d[2] - A22 * i2d[5];
+
+    d2i[0] = A11;
+    d2i[1] = A12;
+    d2i[2] = b1;
+    d2i[3] = A21;
+    d2i[4] = A22;
+    d2i[5] = b2;
+  }
+};
+
+
+class InferImpl : public Infer {
+ public:
+  shared_ptr<trt::Infer> trt_;
+  string engine_file_;
+  vector<shared_ptr<trt::Memory<unsigned char>>> preprocess_buffers_;
+  trt::Memory<float> input_buffer_, output_array_;
+  int network_input_width_, network_input_height_;
+  Norm normalize_;
+  int num_classes_ = 0;
+  bool isdynamic_model_ = false;
+
+  virtual ~InferImpl() = default;
+
+  void adjust_memory(int batch_size) {
+    // the inference batch_size
+    size_t input_numel = network_input_width_ * network_input_height_ * 3;
+    input_buffer_.gpu(batch_size * input_numel);
+    output_boxarray_.gpu(batch_size * num_classes_);
+    output_boxarray_.cpu(batch_size * num_classes_);
+
+
+    if ((int)preprocess_buffers_.size() < batch_size) {
+      for (int i = preprocess_buffers_.size(); i < batch_size; ++i)
+        preprocess_buffers_.push_back(make_shared<trt::Memory<unsigned char>>());
+    }
+  }
+
+  void preprocess(int ibatch, const Image &image,
+                  shared_ptr<trt::Memory<unsigned char>> preprocess_buffer,
+                  void *stream = nullptr) {
+    AffineMatrix affine;
+    affine.compute(make_tuple(image.width, image.height),
+                   make_tuple(network_input_width_, network_input_height_));
+
+    size_t input_numel = network_input_width_ * network_input_height_ * 3;
+    float *input_device = input_buffer_.gpu() + ibatch * input_numel;
+    size_t size_image = image.width * image.height * 3;
+    size_t size_matrix = upbound(sizeof(affine.d2i), 32);
+    uint8_t *gpu_workspace = preprocess_buffer->gpu(size_matrix + size_image);
+    float *affine_matrix_device = (float *)gpu_workspace;
+    uint8_t *image_device = gpu_workspace + size_matrix;
+
+    uint8_t *cpu_workspace = preprocess_buffer->cpu(size_matrix + size_image);
+    float *affine_matrix_host = (float *)cpu_workspace;
+    uint8_t *image_host = cpu_workspace + size_matrix;
+
+    // speed up
+    cudaStream_t stream_ = (cudaStream_t)stream;
+    memcpy(image_host, image.bgrptr, size_image);
+    memcpy(affine_matrix_host, affine.d2i, sizeof(affine.d2i));
+    checkRuntime(
+        cudaMemcpyAsync(image_device, image_host, size_image, cudaMemcpyHostToDevice, stream_));
+    checkRuntime(cudaMemcpyAsync(affine_matrix_device, affine_matrix_host, sizeof(affine.d2i),
+                                 cudaMemcpyHostToDevice, stream_));
+
+    warp_affine_bilinear_and_normalize_plane(image_device, image.width * 3, image.width,
+                                             image.height, input_device, network_input_width_,
+                                             network_input_height_, affine_matrix_device, 114,
+                                             normalize_, stream_);
+  }
+
+  bool load(const string &engine_file) {
+    trt_ = trt::load(engine_file);
+    if (trt_ == nullptr) return false;
+
+    trt_->print();
+
+    auto input_dim = trt_->static_dims(0);
+
+    network_input_width_ = input_dim[3];
+    network_input_height_ = input_dim[2];
+    isdynamic_model_ = trt_->has_dynamic_dim();
+
+    normalize_ = Norm::alpha_beta(1 / 255.0f, 0.0f, ChannelType::SwapRB);
+    num_classes_ = bbox_head_dims_[2] - 5;
+    return true;
+  }
+
+  virtual Attribute forward(const Image &image, void *stream = nullptr) override {
+    auto output = forwards({image}, stream);
+    if (output.empty()) return {};
+    return output[0];
+  }
+
+  virtual vector<Attribute> forwards(const vector<Image> &images, void *stream = nullptr) override {
+    int num_image = images.size();
+    if (num_image == 0) return {};
+
+    auto input_dims = trt_->static_dims(0);
+    int infer_batch_size = input_dims[0];
+    if (infer_batch_size != num_image) {
+      if (isdynamic_model_) {
+        infer_batch_size = num_image;
+        input_dims[0] = num_image;
+        if (!trt_->set_run_dims(0, input_dims)) return {};
+      } else {
+        if (infer_batch_size < num_image) {
+          INFO(
+              "When using static shape model, number of images[%d] must be "
+              "less than or equal to the maximum batch[%d].",
+              num_image, infer_batch_size);
+          return {};
+        }
+      }
+    }
+    adjust_memory(infer_batch_size);
+
+    cudaStream_t stream_ = (cudaStream_t)stream;
+    for (int i = 0; i < num_image; ++i)
+      preprocess(i, images[i], preprocess_buffers_[i], stream);
+
+    float *output_array_device = output_array_.gpu();
+    vector<void *> bindings{input_buffer_.gpu(), output_array_device};
+
+    if (!trt_->forward(bindings, stream)) {
+      INFO("Failed to tensorRT forward.");
+      return {};
+    }
+
+    // for (int ib = 0; ib < num_image; ++ib) {
+    //   float *boxarray_device = output_array_.gpu();
+    // }
+    // checkRuntime(cudaMemcpyAsync(output_boxarray_.cpu(), output_boxarray_.gpu(),
+    //                              output_boxarray_.gpu_bytes(), cudaMemcpyDeviceToHost, stream_));
+    // checkRuntime(cudaStreamSynchronize(stream_));
+
+    vector<Attribute> arrout(num_image);
+
+    return arrout;
+  }
+};
+
+Infer *loadraw(const std::string &engine_file) {
+  InferImpl *impl = new InferImpl();
+  if (!impl->load(engine_file)) {
+    delete impl;
+    impl = nullptr;
+  }
+  return impl;
+}
+
+shared_ptr<Infer> load(const string &engine_file) {
+  return std::shared_ptr<InferImpl>(
+      (InferImpl *)loadraw(engine_file));
+}
+
+}

+ 41 - 0
src/resnet.hpp

@@ -0,0 +1,41 @@
+#ifndef RESNET_HPP__
+#define RESNET_HPP__
+
+#include <future>
+#include <memory>
+#include <string>
+#include <vector>
+
+namespace resnet{
+
+struct Attribute {
+  float confidence;
+  int class_label;
+
+  Attribute() = default;
+  Attribute(float confidence, int class_label)
+      : confidence(confidence),
+        class_label(class_label) {}
+};
+
+struct Image {
+  const void *bgrptr = nullptr;
+  int width = 0, height = 0;
+
+  Image() = default;
+  Image(const void *bgrptr, int width, int height) : bgrptr(bgrptr), width(width), height(height) {}
+};
+
+
+class Infer {
+ public:
+  virtual Attribute forward(const Image &image, void *stream = nullptr) = 0;
+  virtual std::vector<Attribute> forwards(const std::vector<Image> &images,
+                                         void *stream = nullptr) = 0;
+};
+
+std::shared_ptr<Infer> load(const std::string &engine_file);
+
+}
+
+#endif