memory.hpp 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #ifndef __MEMORY_HPP__
  2. #define __MEMORY_HPP__
  3. #include <initializer_list>
  4. #include <memory>
  5. #include <string>
  6. #include <vector>
  7. namespace tensor
  8. {
  9. class BaseMemory
  10. {
  11. public:
  12. BaseMemory() = default;
  13. BaseMemory(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes);
  14. virtual ~BaseMemory();
  15. virtual void *gpu_realloc(size_t bytes);
  16. virtual void *cpu_realloc(size_t bytes);
  17. void release_gpu();
  18. void release_cpu();
  19. void release();
  20. inline bool owner_gpu() const { return owner_gpu_; }
  21. inline bool owner_cpu() const { return owner_cpu_; }
  22. inline size_t cpu_bytes() const { return cpu_bytes_; }
  23. inline size_t gpu_bytes() const { return gpu_bytes_; }
  24. virtual inline void *get_gpu() const { return gpu_; }
  25. virtual inline void *get_cpu() const { return cpu_; }
  26. void reference(void *cpu, size_t cpu_bytes, void *gpu, size_t gpu_bytes);
  27. protected:
  28. void *cpu_ = nullptr;
  29. size_t cpu_bytes_ = 0;
  30. size_t cpu_capacity_ = 0;
  31. bool owner_cpu_ = true;
  32. void *gpu_ = nullptr;
  33. size_t gpu_bytes_ = 0;
  34. size_t gpu_capacity_ = 0;
  35. bool owner_gpu_ = true;
  36. };
  37. template <typename _DT> class Memory : public BaseMemory
  38. {
  39. public:
  40. Memory() = default;
  41. Memory(const Memory &other) = delete;
  42. Memory &operator=(const Memory &other) = delete;
  43. virtual _DT *gpu(size_t size) { return (_DT *)BaseMemory::gpu_realloc(size * sizeof(_DT)); }
  44. virtual _DT *cpu(size_t size) { return (_DT *)BaseMemory::cpu_realloc(size * sizeof(_DT)); }
  45. inline size_t cpu_size() const { return cpu_bytes_ / sizeof(_DT); }
  46. inline size_t gpu_size() const { return gpu_bytes_ / sizeof(_DT); }
  47. virtual inline _DT *gpu() const { return (_DT *)gpu_; }
  48. virtual inline _DT *cpu() const { return (_DT *)cpu_; }
  49. };
  50. } // namespace tensor
  51. #endif