leon 1 mese fa
parent
commit
ff1e344aa1
3 ha cambiato i file con 22 aggiunte e 9 eliminazioni
  1. 1 1
      src/main.cpp
  2. 18 7
      src/nodes/infer/inferNode.cpp
  3. 3 1
      src/nodes/infer/inferNode.hpp

+ 1 - 1
src/main.cpp

@@ -20,7 +20,7 @@ int main()
 
     std::shared_ptr<Infer> model = load("model/yolo11s.engine", ModelType::YOLO11, names, 0, 0.25, 0.45);
     std::shared_ptr<Node::InferNode> infer_node   = std::make_shared<Node::InferNode>("infer");
-    infer_node->set_model_instance(model);
+    infer_node->set_model_instance(model, ModelType::YOLO11);
 
     std::shared_ptr<Node::DrawNode> draw_node     = std::make_shared<Node::DrawNode>("draw");
     std::shared_ptr<Node::HttpPushNode> push_node = std::make_shared<Node::HttpPushNode>("push", "172.16.20.168", 8080, "/push");

+ 18 - 7
src/nodes/infer/inferNode.cpp

@@ -31,17 +31,28 @@ void InferNode::work()
                 printf("model is nullptr\n");
                 continue;
             }
-            auto res = model_->forward(tensor::cvimg(image));
-            for (auto& r : res)
+            auto det_result = model_->forward(tensor::cvimg(image));
+            if (model_type_ == ModelType::DEPTH_ANYTHING)
             {
-                metaData->boxes.push_back(r);
+                auto result = std::get<cv::Mat>(det_result);
+                // do something
+                continue;
             }
-            metaData->boxes = res;
-            for (auto& output_buffer : output_buffers_)
+            else if (model_type_ == ModelType::YOLO11 || model_type_ == ModelType::YOLOV8 || model_type_ == ModelType::YOLOV5)
             {
-                // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
-                output_buffer.second->push(metaData);
+                auto result = std::get<data::BoxArray>(det_result);
+                for (auto& r : result)
+                {
+                    metaData->boxes.push_back(r);
+                }
+                metaData->boxes = result;
+                for (auto& output_buffer : output_buffers_)
+                {
+                    // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
+                    output_buffer.second->push(metaData);
+                }
             }
+            
         }
         if (!has_data)
         {

+ 3 - 1
src/nodes/infer/inferNode.hpp

@@ -15,15 +15,17 @@ public:
     InferNode(const std::string& name) : BaseNode(name, NODE_TYPE::MID_NODE) {}
     virtual ~InferNode()  { stop(); };
 
-    void set_model_instance(std::shared_ptr<Infer> model)
+    void set_model_instance(std::shared_ptr<Infer> model, ModelType model_type)
     {
         model_ = model;
+        model_type_ = model_type;
     }
 
     void work() override;
 
 private:
     std::shared_ptr<Infer> model_ = nullptr;
+    ModelType model_type_;
 };
 
 }