tensorrt8.hpp 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #ifndef TENSORRT8_HPP__
  2. #define TENSORRT8_HPP__
  3. #include <vector>
  4. #include "common/memory.hpp"
  5. namespace TensorRT8
  6. {
  7. enum class DType : int { FLOAT = 0, HALF = 1, INT8 = 2, INT32 = 3, BOOL = 4, UINT8 = 5 };
  8. class Engine
  9. {
  10. public:
  11. virtual bool forward(const std::vector<void *> &bindings, void *stream = nullptr,
  12. void *input_consum_event = nullptr) = 0;
  13. virtual int index(const std::string &name) = 0;
  14. virtual std::vector<int> run_dims(const std::string &name) = 0;
  15. virtual std::vector<int> run_dims(int ibinding) = 0;
  16. virtual std::vector<int> static_dims(const std::string &name) = 0;
  17. virtual std::vector<int> static_dims(int ibinding) = 0;
  18. virtual int numel(const std::string &name) = 0;
  19. virtual int numel(int ibinding) = 0;
  20. virtual int num_bindings() = 0;
  21. virtual bool is_input(int ibinding) = 0;
  22. virtual bool set_run_dims(const std::string &name, const std::vector<int> &dims) = 0;
  23. virtual bool set_run_dims(int ibinding, const std::vector<int> &dims) = 0;
  24. virtual DType dtype(const std::string &name) = 0;
  25. virtual DType dtype(int ibinding) = 0;
  26. virtual bool has_dynamic_dim() = 0;
  27. virtual void print() = 0;
  28. };
  29. std::shared_ptr<Engine> load(const std::string &file);
  30. std::string format_shape(const std::vector<int> &shape);
  31. } // namespace trt
  32. #endif