浏览代码

random box

leon 1 月之前
父节点
当前提交
1802e4b705
共有 4 个文件被更改,包括 179 次插入6 次删除
  1. 2 1
      .vscode/settings.json
  2. 154 0
      src/infer/cpm.hpp
  3. 14 1
      src/nodes/draw/drawNode.cpp
  4. 9 4
      src/nodes/infer/inferNode.cpp

+ 2 - 1
.vscode/settings.json

@@ -53,6 +53,7 @@
         "variant": "cpp",
         "vector": "cpp",
         "algorithm": "cpp",
-        "tuple": "cpp"
+        "tuple": "cpp",
+        "future": "cpp"
     }
 }

+ 154 - 0
src/infer/cpm.hpp

@@ -0,0 +1,154 @@
+#ifndef __CPM_HPP__
+#define __CPM_HPP__
+
+// Comsumer Producer Model
+
+#include <algorithm>
+#include <condition_variable>
+#include <future>
+#include <memory>
+#include <queue>
+#include <thread>
+
+namespace cpm {
+
+template <typename Result, typename Input, typename Model>
+class Instance {
+ protected:
+  struct Item {
+    Input input;
+    std::shared_ptr<std::promise<Result>> pro;
+  };
+
+  std::condition_variable cond_;
+  std::queue<Item> input_queue_;
+  std::mutex queue_lock_;
+  std::shared_ptr<std::thread> worker_;
+  volatile bool run_ = false;
+  volatile int max_items_processed_ = 0;
+  void *stream_ = nullptr;
+
+ public:
+  virtual ~Instance() { stop(); }
+
+  void stop() {
+    run_ = false;
+    cond_.notify_one();
+    {
+      std::unique_lock<std::mutex> l(queue_lock_);
+      while (!input_queue_.empty()) {
+        auto &item = input_queue_.front();
+        if (item.pro) item.pro->set_value(Result());
+        input_queue_.pop();
+      }
+    };
+
+    if (worker_) {
+      worker_->join();
+      worker_.reset();
+    }
+  }
+
+  virtual std::shared_future<Result> commit(const Input &input) {
+    Item item;
+    item.input = input;
+    item.pro.reset(new std::promise<Result>());
+    {
+      std::unique_lock<std::mutex> __lock_(queue_lock_);
+      input_queue_.push(item);
+    }
+    cond_.notify_one();
+    return item.pro->get_future();
+  }
+
+  virtual std::vector<std::shared_future<Result>> commits(const std::vector<Input> &inputs) {
+    std::vector<std::shared_future<Result>> output;
+    {
+      std::unique_lock<std::mutex> __lock_(queue_lock_);
+      for (int i = 0; i < (int)inputs.size(); ++i) {
+        Item item;
+        item.input = inputs[i];
+        item.pro.reset(new std::promise<Result>());
+        output.emplace_back(item.pro->get_future());
+        input_queue_.push(item);
+      }
+    }
+    cond_.notify_one();
+    return output;
+  }
+
+  template <typename LoadMethod>
+  bool start(const LoadMethod &loadmethod, int max_items_processed = 1, void *stream = nullptr) {
+    stop();
+
+    this->stream_ = stream;
+    this->max_items_processed_ = max_items_processed;
+    std::promise<bool> status;
+    worker_ = std::make_shared<std::thread>(&Instance::worker<LoadMethod>, this,
+                                            std::ref(loadmethod), std::ref(status));
+    return status.get_future().get();
+  }
+
+ private:
+  template <typename LoadMethod>
+  void worker(const LoadMethod &loadmethod, std::promise<bool> &status) {
+    std::shared_ptr<Model> model = loadmethod();
+    if (model == nullptr) {
+      status.set_value(false);
+      return;
+    }
+
+    run_ = true;
+    status.set_value(true);
+
+    std::vector<Item> fetch_items;
+    std::vector<Input> inputs;
+    while (get_items_and_wait(fetch_items, max_items_processed_)) {
+      inputs.resize(fetch_items.size());
+      std::transform(fetch_items.begin(), fetch_items.end(), inputs.begin(),
+                     [](Item &item) { return item.input; });
+
+      auto ret = model->forwards(inputs, stream_);
+      for (int i = 0; i < (int)fetch_items.size(); ++i) {
+        if (i < (int)ret.size()) {
+          fetch_items[i].pro->set_value(ret[i]);
+        } else {
+          fetch_items[i].pro->set_value(Result());
+        }
+      }
+      inputs.clear();
+      fetch_items.clear();
+    }
+    model.reset();
+    run_ = false;
+  }
+
+  virtual bool get_items_and_wait(std::vector<Item> &fetch_items, int max_size) {
+    std::unique_lock<std::mutex> l(queue_lock_);
+    cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
+
+    if (!run_) return false;
+
+    fetch_items.clear();
+    for (int i = 0; i < max_size && !input_queue_.empty(); ++i) {
+      fetch_items.emplace_back(std::move(input_queue_.front()));
+      input_queue_.pop();
+    }
+    return true;
+  }
+
+  virtual bool get_item_and_wait(Item &fetch_item) {
+    std::unique_lock<std::mutex> l(queue_lock_);
+    cond_.wait(l, [&]() { return !run_ || !input_queue_.empty(); });
+
+    if (!run_) return false;
+
+    fetch_item = std::move(input_queue_.front());
+    input_queue_.pop();
+    return true;
+  }
+};
+
+}  // namespace cpm
+
+#endif  // __CPM_HPP__

+ 14 - 1
src/nodes/draw/drawNode.cpp

@@ -2,10 +2,22 @@
 #include "nodes/draw/drawNode.hpp"
 #include "nodes/draw/position.hpp"
 #include <opencv2/opencv.hpp>
+#include <chrono>
+
 
 namespace Node
 {
 
+static std::string getTimeString() {
+    auto now = std::chrono::system_clock::now();
+    auto t = std::chrono::system_clock::to_time_t(now);
+    std::tm tm = *std::localtime(&t);
+
+    std::ostringstream oss;
+    oss << std::put_time(&tm, "%Y_%m_%d_%H_%M_%S");
+    return oss.str();
+}
+
 static std::tuple<int, int, int> getFontSize(const std::string& text)
 {
     int baseline = 0;
@@ -40,7 +52,8 @@ void DrawNode::work()
                 cv::putText(image, box.label, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 255, 0), 2);
             }
             metaData->draw_image = image;
-            cv::imwrite("dtest.jpg", image);
+            std::string image_name = getTimeString() + ".jpg";
+            cv::imwrite(image_name, image);
             for (auto& output_buffer : output_buffers_)
             {
                 printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());

+ 9 - 4
src/nodes/infer/inferNode.cpp

@@ -1,7 +1,7 @@
 #include "nodes/base/base.hpp"
 #include "nodes/infer/inferNode.hpp"
 #include <unordered_map>
-
+#include <random>
 namespace Node
 {
 
@@ -20,11 +20,16 @@ void InferNode::work()
             }
             printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
             // do something
-            // cv::Mat image = metaData->image;
+            cv::Mat image = metaData->image;
+            int width = image.cols;
+            int height = image.rows;
 
             // cv::imwrite("test.jpg", image);
-            metaData->boxes.push_back(data::Box(0, 0, 100, 100, 0.9, "test"));
-            metaData->boxes.push_back(data::Box(0, 0, 100, 100, 0.9, "test"));
+            float x = rand() % width;
+            float y = rand() % height;
+            float w = rand() % (width - x);
+            float h = rand() % (height - y);
+            metaData->boxes.push_back(data::Box(x, y, x + w, y + h, 0.9, "test"));
             for (auto& output_buffer : output_buffers_)
             {
                 printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());