infer.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. #include <NvInfer.h>
  2. #include <cuda_runtime.h>
  3. #include <stdarg.h>
  4. #include <fstream>
  5. #include <numeric>
  6. #include <sstream>
  7. #include <unordered_map>
  8. #include "infer.hpp"
  9. namespace trt {
  10. using namespace std;
  11. using namespace nvinfer1;
  12. #define checkRuntime(call) \
  13. do { \
  14. auto ___call__ret_code__ = (call); \
  15. if (___call__ret_code__ != cudaSuccess) { \
  16. INFO("CUDA Runtime error💥 %s # %s, code = %s [ %d ]", #call, \
  17. cudaGetErrorString(___call__ret_code__), cudaGetErrorName(___call__ret_code__), \
  18. ___call__ret_code__); \
  19. abort(); \
  20. } \
  21. } while (0)
  22. #define checkKernel(...) \
  23. do { \
  24. { (__VA_ARGS__); } \
  25. checkRuntime(cudaPeekAtLastError()); \
  26. } while (0)
  27. #define Assert(op) \
  28. do { \
  29. bool cond = !(!(op)); \
  30. if (!cond) { \
  31. INFO("Assert failed, " #op); \
  32. abort(); \
  33. } \
  34. } while (0)
  35. #define Assertf(op, ...) \
  36. do { \
  37. bool cond = !(!(op)); \
  38. if (!cond) { \
  39. INFO("Assert failed, " #op " : " __VA_ARGS__); \
  40. abort(); \
  41. } \
  42. } while (0)
  43. static string file_name(const string &path, bool include_suffix) {
  44. if (path.empty()) return "";
  45. int p = path.rfind('/');
  46. int e = path.rfind('\\');
  47. p = std::max(p, e);
  48. p += 1;
  49. // include suffix
  50. if (include_suffix) return path.substr(p);
  51. int u = path.rfind('.');
  52. if (u == -1) return path.substr(p);
  53. if (u <= p) u = path.size();
  54. return path.substr(p, u - p);
  55. }
  56. void __log_func(const char *file, int line, const char *fmt, ...) {
  57. va_list vl;
  58. va_start(vl, fmt);
  59. char buffer[2048];
  60. string filename = file_name(file, true);
  61. int n = snprintf(buffer, sizeof(buffer), "[%s:%d]: ", filename.c_str(), line);
  62. vsnprintf(buffer + n, sizeof(buffer) - n, fmt, vl);
  63. fprintf(stdout, "%s\n", buffer);
  64. }
  65. static std::string format_shape(const Dims &shape) {
  66. stringstream output;
  67. char buf[64];
  68. const char *fmts[] = {"%d", "x%d"};
  69. for (int i = 0; i < shape.nbDims; ++i) {
  70. snprintf(buf, sizeof(buf), fmts[i != 0], shape.d[i]);
  71. output << buf;
  72. }
  73. return output.str();
  74. }
  75. Timer::Timer() {
  76. checkRuntime(cudaEventCreate((cudaEvent_t *)&start_));
  77. checkRuntime(cudaEventCreate((cudaEvent_t *)&stop_));
  78. }
  79. Timer::~Timer() {
  80. checkRuntime(cudaEventDestroy((cudaEvent_t)start_));
  81. checkRuntime(cudaEventDestroy((cudaEvent_t)stop_));
  82. }
  83. void Timer::start(void *stream) {
  84. stream_ = stream;
  85. checkRuntime(cudaEventRecord((cudaEvent_t)start_, (cudaStream_t)stream_));
  86. }
  87. float Timer::stop(const char *prefix, bool print) {
  88. checkRuntime(cudaEventRecord((cudaEvent_t)stop_, (cudaStream_t)stream_));
  89. checkRuntime(cudaEventSynchronize((cudaEvent_t)stop_));
  90. float latency = 0;
  91. checkRuntime(cudaEventElapsedTime(&latency, (cudaEvent_t)start_, (cudaEvent_t)stop_));
  92. if (print) {
  93. printf("[%s]: %.5f ms\n", prefix, latency);
  94. }
  95. return latency;
  96. }
  97. BaseMemory::BaseMemory(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes) {
  98. reference(cpu, cpu_bytes, gpu, gpu_bytes);
  99. }
  100. void BaseMemory::reference(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes) {
  101. release();
  102. if (cpu == nullptr || cpu_bytes == 0) {
  103. cpu = nullptr;
  104. cpu_bytes = 0;
  105. }
  106. if (gpu == nullptr || gpu_bytes == 0) {
  107. gpu = nullptr;
  108. gpu_bytes = 0;
  109. }
  110. this->cpu_ = cpu;
  111. this->cpu_capacity_ = cpu_bytes;
  112. this->cpu_bytes_ = cpu_bytes;
  113. this->gpu_ = gpu;
  114. this->gpu_capacity_ = gpu_bytes;
  115. this->gpu_bytes_ = gpu_bytes;
  116. this->owner_cpu_ = !(cpu && cpu_bytes > 0);
  117. this->owner_gpu_ = !(gpu && gpu_bytes > 0);
  118. }
  119. BaseMemory::~BaseMemory() { release(); }
  120. void *BaseMemory::gpu_realloc(size_t bytes) {
  121. if (gpu_capacity_ < bytes) {
  122. release_gpu();
  123. gpu_capacity_ = bytes;
  124. checkRuntime(cudaMalloc(&gpu_, bytes));
  125. // checkRuntime(cudaMemset(gpu_, 0, size));
  126. }
  127. gpu_bytes_ = bytes;
  128. return gpu_;
  129. }
  130. void *BaseMemory::cpu_realloc(size_t bytes) {
  131. if (cpu_capacity_ < bytes) {
  132. release_cpu();
  133. cpu_capacity_ = bytes;
  134. checkRuntime(cudaMallocHost(&cpu_, bytes));
  135. Assert(cpu_ != nullptr);
  136. // memset(cpu_, 0, size);
  137. }
  138. cpu_bytes_ = bytes;
  139. return cpu_;
  140. }
  141. void BaseMemory::release_cpu() {
  142. if (cpu_) {
  143. if (owner_cpu_) {
  144. checkRuntime(cudaFreeHost(cpu_));
  145. }
  146. cpu_ = nullptr;
  147. }
  148. cpu_capacity_ = 0;
  149. cpu_bytes_ = 0;
  150. }
  151. void BaseMemory::release_gpu() {
  152. if (gpu_) {
  153. if (owner_gpu_) {
  154. checkRuntime(cudaFree(gpu_));
  155. }
  156. gpu_ = nullptr;
  157. }
  158. gpu_capacity_ = 0;
  159. gpu_bytes_ = 0;
  160. }
  161. void BaseMemory::release() {
  162. release_cpu();
  163. release_gpu();
  164. }
  165. class __native_nvinfer_logger : public ILogger {
  166. public:
  167. virtual void log(Severity severity, const char *msg) noexcept override {
  168. if (severity == Severity::kINTERNAL_ERROR) {
  169. INFO("NVInfer INTERNAL_ERROR: %s", msg);
  170. abort();
  171. } else if (severity == Severity::kERROR) {
  172. INFO("NVInfer: %s", msg);
  173. }
  174. // else if (severity == Severity::kWARNING) {
  175. // INFO("NVInfer: %s", msg);
  176. // }
  177. // else if (severity == Severity::kINFO) {
  178. // INFO("NVInfer: %s", msg);
  179. // }
  180. // else {
  181. // INFO("%s", msg);
  182. // }
  183. }
  184. };
  185. static __native_nvinfer_logger gLogger;
  186. template <typename _T>
  187. static void destroy_nvidia_pointer(_T *ptr) {
  188. if (ptr) ptr->destroy();
  189. }
  190. static std::vector<uint8_t> load_file(const string &file) {
  191. ifstream in(file, ios::in | ios::binary);
  192. if (!in.is_open()) return {};
  193. in.seekg(0, ios::end);
  194. size_t length = in.tellg();
  195. std::vector<uint8_t> data;
  196. if (length > 0) {
  197. in.seekg(0, ios::beg);
  198. data.resize(length);
  199. in.read((char *)&data[0], length);
  200. }
  201. in.close();
  202. return data;
  203. }
  204. class __native_engine_context {
  205. public:
  206. virtual ~__native_engine_context() { destroy(); }
  207. bool construct(const void *pdata, size_t size) {
  208. destroy();
  209. if (pdata == nullptr || size == 0) return false;
  210. runtime_ = shared_ptr<IRuntime>(createInferRuntime(gLogger), destroy_nvidia_pointer<IRuntime>);
  211. if (runtime_ == nullptr) return false;
  212. engine_ = shared_ptr<ICudaEngine>(runtime_->deserializeCudaEngine(pdata, size, nullptr),
  213. destroy_nvidia_pointer<ICudaEngine>);
  214. if (engine_ == nullptr) return false;
  215. context_ = shared_ptr<IExecutionContext>(engine_->createExecutionContext(),
  216. destroy_nvidia_pointer<IExecutionContext>);
  217. return context_ != nullptr;
  218. }
  219. private:
  220. void destroy() {
  221. context_.reset();
  222. engine_.reset();
  223. runtime_.reset();
  224. }
  225. public:
  226. shared_ptr<IExecutionContext> context_;
  227. shared_ptr<ICudaEngine> engine_;
  228. shared_ptr<IRuntime> runtime_ = nullptr;
  229. };
  230. class InferImpl : public Infer {
  231. public:
  232. shared_ptr<__native_engine_context> context_;
  233. unordered_map<string, int> binding_name_to_index_;
  234. virtual ~InferImpl() = default;
  235. bool construct(const void *data, size_t size) {
  236. context_ = make_shared<__native_engine_context>();
  237. if (!context_->construct(data, size)) {
  238. return false;
  239. }
  240. setup();
  241. return true;
  242. }
  243. bool load(const string &file) {
  244. auto data = load_file(file);
  245. if (data.empty()) {
  246. INFO("An empty file has been loaded. Please confirm your file path: %s", file.c_str());
  247. return false;
  248. }
  249. return this->construct(data.data(), data.size());
  250. }
  251. void setup() {
  252. auto engine = this->context_->engine_;
  253. int nbBindings = engine->getNbBindings();
  254. binding_name_to_index_.clear();
  255. for (int i = 0; i < nbBindings; ++i) {
  256. const char *bindingName = engine->getBindingName(i);
  257. binding_name_to_index_[bindingName] = i;
  258. }
  259. }
  260. virtual int index(const std::string &name) override {
  261. auto iter = binding_name_to_index_.find(name);
  262. Assertf(iter != binding_name_to_index_.end(), "Can not found the binding name: %s",
  263. name.c_str());
  264. return iter->second;
  265. }
  266. virtual bool forward(const std::vector<void *> &bindings, void *stream,
  267. void *input_consum_event) override {
  268. return this->context_->context_->enqueueV2((void**)bindings.data(), (cudaStream_t)stream,
  269. (cudaEvent_t *)input_consum_event);
  270. }
  271. virtual std::vector<int> run_dims(const std::string &name) override {
  272. return run_dims(index(name));
  273. }
  274. virtual std::vector<int> run_dims(int ibinding) override {
  275. auto dim = this->context_->context_->getBindingDimensions(ibinding);
  276. return std::vector<int>(dim.d, dim.d + dim.nbDims);
  277. }
  278. virtual std::vector<int> static_dims(const std::string &name) override {
  279. return static_dims(index(name));
  280. }
  281. virtual std::vector<int> static_dims(int ibinding) override {
  282. auto dim = this->context_->engine_->getBindingDimensions(ibinding);
  283. return std::vector<int>(dim.d, dim.d + dim.nbDims);
  284. }
  285. virtual int num_bindings() override { return this->context_->engine_->getNbBindings(); }
  286. virtual bool is_input(int ibinding) override {
  287. return this->context_->engine_->bindingIsInput(ibinding);
  288. }
  289. virtual bool set_run_dims(const std::string &name, const std::vector<int> &dims) override {
  290. return this->set_run_dims(index(name), dims);
  291. }
  292. virtual bool set_run_dims(int ibinding, const std::vector<int> &dims) override {
  293. Dims d;
  294. memcpy(d.d, dims.data(), sizeof(int) * dims.size());
  295. d.nbDims = dims.size();
  296. return this->context_->context_->setBindingDimensions(ibinding, d);
  297. }
  298. virtual int numel(const std::string &name) override { return numel(index(name)); }
  299. virtual int numel(int ibinding) override {
  300. auto dim = this->context_->context_->getBindingDimensions(ibinding);
  301. return std::accumulate(dim.d, dim.d + dim.nbDims, 1, std::multiplies<int>());
  302. }
  303. virtual DType dtype(const std::string &name) override { return dtype(index(name)); }
  304. virtual DType dtype(int ibinding) override {
  305. return (DType)this->context_->engine_->getBindingDataType(ibinding);
  306. }
  307. virtual bool has_dynamic_dim() override {
  308. // check if any input or output bindings have dynamic shapes
  309. // code from ChatGPT
  310. int numBindings = this->context_->engine_->getNbBindings();
  311. for (int i = 0; i < numBindings; ++i) {
  312. nvinfer1::Dims dims = this->context_->engine_->getBindingDimensions(i);
  313. for (int j = 0; j < dims.nbDims; ++j) {
  314. if (dims.d[j] == -1) return true;
  315. }
  316. }
  317. return false;
  318. }
  319. virtual void print() override {
  320. INFO("Infer %p [%s]", this, has_dynamic_dim() ? "DynamicShape" : "StaticShape");
  321. int num_input = 0;
  322. int num_output = 0;
  323. auto engine = this->context_->engine_;
  324. for (int i = 0; i < engine->getNbBindings(); ++i) {
  325. if (engine->bindingIsInput(i))
  326. num_input++;
  327. else
  328. num_output++;
  329. }
  330. INFO("Inputs: %d", num_input);
  331. for (int i = 0; i < num_input; ++i) {
  332. auto name = engine->getBindingName(i);
  333. auto dim = engine->getBindingDimensions(i);
  334. INFO("\t%d.%s : shape {%s}", i, name, format_shape(dim).c_str());
  335. }
  336. INFO("Outputs: %d", num_output);
  337. for (int i = 0; i < num_output; ++i) {
  338. auto name = engine->getBindingName(i + num_input);
  339. auto dim = engine->getBindingDimensions(i + num_input);
  340. INFO("\t%d.%s : shape {%s}", i, name, format_shape(dim).c_str());
  341. }
  342. }
  343. };
  344. Infer *loadraw(const std::string &file) {
  345. InferImpl *impl = new InferImpl();
  346. if (!impl->load(file)) {
  347. delete impl;
  348. impl = nullptr;
  349. }
  350. return impl;
  351. }
  352. std::shared_ptr<Infer> load(const std::string &file) {
  353. return std::shared_ptr<InferImpl>((InferImpl *)loadraw(file));
  354. }
  355. std::string format_shape(const std::vector<int> &shape) {
  356. stringstream output;
  357. char buf[64];
  358. const char *fmts[] = {"%d", "x%d"};
  359. for (int i = 0; i < (int)shape.size(); ++i) {
  360. snprintf(buf, sizeof(buf), fmts[i != 0], shape[i]);
  361. output << buf;
  362. }
  363. return output.str();
  364. }
  365. }; // namespace trt