leon 5 mēneši atpakaļ
vecāks
revīzija
2507494339
3 mainītis faili ar 180 papildinājumiem un 1 dzēšanām
  1. 2 1
      .vscode/settings.json
  2. 153 0
      src/cpm.hpp
  3. 25 0
      src/main.cpp

+ 2 - 1
.vscode/settings.json

@@ -54,6 +54,7 @@
         "iterator": "cpp",
         "tuple": "cpp",
         "utility": "cpp",
-        "type_traits": "cpp"
+        "type_traits": "cpp",
+        "condition_variable": "cpp"
     }
 }

+ 153 - 0
src/cpm.hpp

@@ -0,0 +1,153 @@
+#ifndef __CPM_HPP__
+#define __CPM_HPP__
+
+// Comsumer Producer Model
+
+#include <algorithm>
+#include <condition_variable>
+#include <future>
+#include <memory>
+#include <queue>
+#include <thread>
+
+namespace cpm {
+
+template <typename Result, typename Input, typename Model>
+class Instance {
+ protected:
+  struct Item {
+    Input input;
+    std::shared_ptr<std::promise<Result>> pro;
+  };
+
+  std::condition_variable cond_;
+  std::queue<Item> input_queue_;
+  std::mutex queue_lock_;
+  std::shared_ptr<std::thread> worker_;
+  volatile bool run_ = false;
+  volatile int max_items_processed_ = 0;
+  void *stream_ = nullptr;
+
+ public:
+  virtual ~Instance() { stop(); }
+
+  void stop() {
+    run_ = false;
+    cond_.notify_one();
+    {
+      std::unique_lock<std::mutex> l(queue_lock_);
+      while (!input_queue_.empty()) {
+        auto &item = input_queue_.front();
+        if (item.pro) item.pro->set_value(Result());
+        input_queue_.pop();
+      }
+    };
+
+    if (worker_) {
+      worker_->join();
+      worker_.reset();
+    }
+  }
+
+  virtual std::shared_future<Result> commit(const Input &input) {
+    Item item;
+    item.input = input;
+    item.pro.reset(new std::promise<Result>());
+    {
+      std::unique_lock<std::mutex> __lock_(queue_lock_);
+      input_queue_.push(item);
+    }
+    cond_.notify_one();
+    return item.pro->get_future();
+  }
+
+  virtual std::vector<std::shared_future<Result>> commits(const std::vector<Input> &inputs) {
+    std::vector<std::shared_future<Result>> output;
+    {
+      std::unique_lock<std::mutex> __lock_(queue_lock_);
+      for (int i = 0; i < (int)inputs.size(); ++i) {
+        Item item;
+        item.input = inputs[i];
+        item.pro.reset(new std::promise<Result>());
+        output.emplace_back(item.pro->get_future());
+        input_queue_.push(item);
+      }
+    }
+    cond_.notify_one();
+    return output;
+  }
+
+  template <typename LoadMethod>
+  bool start(const LoadMethod &loadmethod, int max_items_processed = 1, void *stream = nullptr) {
+    stop();
+
+    this->stream_ = stream;
+    this->max_items_processed_ = max_items_processed;
+    std::promise<bool> status;
+    worker_ = std::make_shared<std::thread>(&Instance::worker<LoadMethod>, this,
+                                            std::ref(loadmethod), std::ref(status));
+    return status.get_future().get();
+  }
+
+ private:
+  template <typename LoadMethod>
+  void worker(const LoadMethod &loadmethod, std::promise<bool> &status) {
+    std::shared_ptr<Model> model = loadmethod();
+    if (model == nullptr) {
+      status.set_value(false);
+      return;
+    }
+
+    run_ = true;
+    status.set_value(true);
+
+    std::vector<Item> fetch_items;
+    std::vector<Input> inputs;
+    while (get_items_and_wait(fetch_items, max_items_processed_)) {
+      inputs.resize(fetch_items.size());
+      std::transform(fetch_items.begin(), fetch_items.end(), inputs.begin(),
+                     [](Item &item) { return item.input; });
+
+      auto ret = model->forwards(inputs, stream_);
+      for (int i = 0; i < (int)fetch_items.size(); ++i) {
+        if (i < (int)ret.size()) {
+          fetch_items[i].pro->set_value(ret[i]);
+        } else {
+          fetch_items[i].pro->set_value(Result());
+        }
+      }
+      inputs.clear();
+      fetch_items.clear();
+    }
+    model.reset();
+    run_ = false;
+  }
+
+  virtual bool get_items_and_wait(std::vector<Item> &fetch_items, int max_size) {
+    std::unique_lock<std::mutex> l(queue_lock_);
+    cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
+
+    if (!run_) return false;
+
+    fetch_items.clear();
+    for (int i = 0; i < max_size && !input_queue_.empty(); ++i) {
+      fetch_items.emplace_back(std::move(input_queue_.front()));
+      input_queue_.pop();
+    }
+    return true;
+  }
+
+  virtual bool get_item_and_wait(Item &fetch_item) {
+    std::unique_lock<std::mutex> l(queue_lock_);
+    cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
+
+    if (!run_) return false;
+
+    fetch_item = std::move(input_queue_.front());
+    input_queue_.pop();
+    return true;
+  }
+};
+};  // namespace cpm
+
+#endif  // __CPM_HPP__

+ 25 - 0
src/main.cpp

@@ -9,7 +9,31 @@ using namespace std;
 
 resnet::Image cvimg(const cv::Mat &image) { return resnet::Image(image.data, image.cols, image.rows); }
 
+void perf() {
+  int max_infer_batch = 1;
+  int batch = 1;
+  std::vector<cv::Mat> images{cv::imread("inference/car.jpg"), cv::imread("inference/car.jpg"),
+                              cv::imread("inference/car.jpg")};
 
+  for (int i = images.size(); i < batch; ++i) images.push_back(images[i % 3]);
+
+  cpm::Instance<resnet::Attribute, resnet::Image, resnet::Infer> cpmi;
+  bool ok = cpmi.start([] { return resnet::load("resnet.engine"); },
+                       max_infer_batch);
+
+  if (!ok) return;
+
+  std::vector<resnet::Image> resnetimages(images.size());
+  std::transform(images.begin(), images.end(), resnetimages.begin(), cvimg);
+
+  trt::Timer timer;
+
+  for (int i = 0; i < 100; ++i) {
+    timer.start();
+    cpmi.commit(resnetimages[0]).get();
+    timer.stop("BATCH1");
+  }
+}
 
 void single_inference() {
   cv::Mat image = cv::imread("inference/car.jpg");
@@ -28,5 +52,6 @@ void single_inference() {
 int main() {
   // [BATCH1]: 0.48650 ms
   single_inference();
+  perf();
   return 0;
 }