leon 4 semanas atrás
pai
commit
0bb0c65fa3
3 arquivos alterados com 16 adições e 7 exclusões
  1. 6 6
      src/main.cpp
  2. 7 0
      src/nodes/infer/inferNode.cpp
  3. 3 1
      src/nodes/infer/inferNode.hpp

+ 6 - 6
src/main.cpp

@@ -47,9 +47,9 @@ void test_yolo()
     std::shared_ptr<GNode::StreamNode> src_node0   = std::make_shared<GNode::StreamNode>("src0", "carperson.mp4", 0, GNode::DecodeType::GPU);
     src_node0->set_skip_frame(1);
 
-    std::shared_ptr<Infer> yolo_model = load("model/carperson.engine", ModelType::YOLOV5, names, 0, 0.25, 0.45);
+    std::shared_ptr<Infer> yolo_model = load("model/carperson.engine", ModelType::YOLOV5, names, 1, 0.25, 0.45);
     std::shared_ptr<GNode::InferNode> infer_node   = std::make_shared<GNode::InferNode>("yolov5");
-    infer_node->set_model_instance(yolo_model, ModelType::YOLO11);
+    infer_node->set_model_instance(yolo_model, ModelType::YOLO11, 1);
 
     std::shared_ptr<GNode::TrackNode> track_node     = std::make_shared<GNode::TrackNode>("tracker", "person", 30, 30);
 
@@ -88,13 +88,13 @@ void test_multi()
     std::shared_ptr<GNode::StreamNode> src_node2   = std::make_shared<GNode::StreamNode>("src2", "rtsp://admin:lww123456@172.16.22.16:554/Streaming/Channels/301", 0, GNode::DecodeType::GPU);
     src_node2->set_skip_frame(10);
 
-    std::shared_ptr<Infer> yolo_model = load("model/yolo11s.engine", ModelType::YOLO11, names, 0, 0.25, 0.45);
+    std::shared_ptr<Infer> yolo_model = load("model/yolo11s.engine", ModelType::YOLO11, names, 1, 0.25, 0.45);
     std::shared_ptr<GNode::InferNode> infer_node1   = std::make_shared<GNode::InferNode>("yolo11");
-    infer_node1->set_model_instance(yolo_model, ModelType::YOLO11);
+    infer_node1->set_model_instance(yolo_model, ModelType::YOLO11, 1);
 
-    std::shared_ptr<Infer> depth_model = load("model/depth.engine", ModelType::DEPTH_ANYTHING, {}, 0, 0.25, 0.45);
+    std::shared_ptr<Infer> depth_model = load("model/depth.engine", ModelType::DEPTH_ANYTHING, {}, 1, 0.25, 0.45);
     std::shared_ptr<GNode::InferNode> infer_node2   = std::make_shared<GNode::InferNode>("depth");
-    infer_node2->set_model_instance(depth_model, ModelType::DEPTH_ANYTHING);
+    infer_node2->set_model_instance(depth_model, ModelType::DEPTH_ANYTHING, 1);
 
     std::shared_ptr<GNode::DrawNode> draw_node     = std::make_shared<GNode::DrawNode>("draw");
     std::shared_ptr<GNode::HttpPushNode> push_node = std::make_shared<GNode::HttpPushNode>("push", "172.16.20.168", 8080, "/push");

+ 7 - 0
src/nodes/infer/inferNode.cpp

@@ -2,6 +2,7 @@
 #include "nodes/infer/inferNode.hpp"
 #include "common/image.hpp"
 #include "common/utils.hpp"
+#include "common/check.hpp"
 #include <unordered_map>
 #include <random>
 #include <algorithm>
@@ -24,6 +25,12 @@ void print_mat(const cv::Mat& mat, int max_rows = 10, int max_cols = 10)
 void InferNode::work()
 {
     printf("InferNode %s\n", name_.c_str());
+    if (!model_)
+    {
+        printf("model is nullptr\n");
+        return;
+    }
+    checkRuntime(cudaSetDevice(device_id_));
     while (running_)
     {
         // Timer timer("InferNode");

+ 3 - 1
src/nodes/infer/inferNode.hpp

@@ -15,9 +15,10 @@ public:
     InferNode(const std::string& name) : BaseNode(name, NODE_TYPE::MID_NODE) {}
     virtual ~InferNode()  { stop(); };
 
-    void set_model_instance(std::shared_ptr<Infer> model, ModelType model_type)
+    void set_model_instance(std::shared_ptr<Infer> model, ModelType model_type, int device_id=0)
     {
         model_ = model;
+        device_id_ = device_id;
         model_type_ = model_type;
     }
 
@@ -26,6 +27,7 @@ public:
 private:
     std::shared_ptr<Infer> model_ = nullptr;
     ModelType model_type_;
+    int device_id_;
 };
 
 }