Jelajahi Sumber

update graph

leon 4 minggu lalu
induk
melakukan
348c9a585e
4 mengubah file dengan 233 tambahan dan 5 penghapusan
  1. 198 2
      src/graph/graph.cpp
  2. 33 1
      src/graph/graph.hpp
  3. 1 1
      src/infer/trt/yolo/yolo.cu
  4. 1 1
      src/infer/trt/yolo/yolo.hpp

+ 198 - 2
src/graph/graph.cpp

@@ -5,16 +5,212 @@ namespace Graph
 
 using json = nlohmann::json;
 
-void read_models()
+static ModelType string_to_model_type(const std::string& type_str) 
 {
+    if (type_str == "YOLOV5")         return ModelType::YOLOV5;
+    if (type_str == "YOLOV5SEG")      return ModelType::YOLOV5SEG;
+    if (type_str == "YOLO11")         return ModelType::YOLO11;
+    if (type_str == "YOLO11POSE")     return ModelType::YOLO11POSE;
+    if (type_str == "YOLO11SEG")      return ModelType::YOLO11SEG;
+    if (type_str == "DEPTH_ANYTHING") return ModelType::DEPTH_ANYTHING;
+    throw std::runtime_error("Unknown model type string: " + type_str);
+}
 
+// Helper to map string to DecodeType
+static GNode::DecodeType string_to_decode_type(const std::string& type_str) {
+    if (type_str == "GPU") return GNode::DecodeType::GPU;
+    if (type_str == "CPU") return GNode::DecodeType::CPU;
+    throw std::runtime_error("Unknown decode type string: " + type_str);
 }
 
 
 void Graph::create_from_json(const std::string& json_path)
 {
+    // 1. Read and parse JSON file
+    std::ifstream json_file(json_path);
+    if (!json_file.is_open()) {
+        throw std::runtime_error("Failed to open JSON file: " + json_path);
+    }
+
     nlohmann::json config;
-    std::ifstream file(json_path);
+    try 
+    {
+        json_file >> config;
+    } catch (const nlohmann::json::parse_error& e) 
+    {
+        throw std::runtime_error("Failed to parse JSON: " + std::string(e.what()));
+    }
+
+    // shared_models_.clear();
+    // configured_pipelines_.clear();
+
+    // 2. Load Models
+    if (config.contains("models")) 
+    {
+        for (auto& [model_id, model_config] : config["models"].items()) 
+        {
+            try {
+                std::string path = model_config.at("model_path").get<std::string>();
+                std::string type_str = model_config.at("model_type").get<std::string>();
+                std::vector<std::string> names = model_config.at("names").get<std::vector<std::string>>();
+                int gpu_id = model_config.at("gpu_id").get<int>();
+                float conf_thresh = model_config.value("confidence_threshold", 0.25f); // Use .value for optional with default
+                float nms_thresh = model_config.value("nms_threshold", 0.45f);
+
+                ModelType model_type_enum = string_to_model_type(type_str);
+
+                // Load the model using your load function
+                std::shared_ptr<Infer> model_instance = load(
+                    path, model_type_enum, names, gpu_id, conf_thresh, nms_thresh);
+
+                if (!model_instance) {
+                     throw std::runtime_error("Failed to load model: " + model_id);
+                }
+
+                shared_models_[model_id] = model_instance;
+                std::cout << "Loaded model: " << model_id << std::endl;
+
+            } 
+            catch (const std::exception& e) 
+            {
+                throw std::runtime_error("Error processing model '" + model_id + "': " + e.what());
+            }
+        }
+    }
+
+    // 3. Create Pipelines
+    if (config.contains("pipelines")) 
+    {
+        for (const auto& pipeline_config : config["pipelines"]) 
+        {
+            try 
+            {
+                PipelineInstance current_pipeline;
+                current_pipeline.pipeline_id = pipeline_config.at("pipeline_id").get<std::string>();
+                current_pipeline.description = pipeline_config.value("description", ""); // Optional description
+
+                std::cout << "Creating pipeline: " << current_pipeline.pipeline_id << std::endl;
+
+                // Temporary map to hold nodes of the current pipeline for linking
+                std::unordered_map<std::string, std::shared_ptr<GNode::BaseNode>> current_pipeline_nodes_map;
+
+                if (pipeline_config.contains("nodes")) 
+                {
+                    for (const auto& node_config : pipeline_config["nodes"]) 
+                    {
+                        std::string node_id = node_config.at("node_id").get<std::string>();
+                        std::string node_type = node_config.at("node_type").get<std::string>();
+                        const auto& params = node_config.at("params");
+
+                        std::shared_ptr<GNode::BaseNode> new_node = nullptr;
+
+                        // --- Instantiate Node based on type ---
+                        if (node_type == "Source") 
+                        {
+                            std::string url = params.at("stream_url").get<std::string>();
+                            int gpu_id = params.at("gpu_id").get<int>();
+                            std::string decode_str = params.at("decode_type").get<std::string>();
+                            int skip = params.value("skip_frame", 0); // Optional skip_frame
+
+                            GNode::DecodeType decode_type = string_to_decode_type(decode_str);
+                            auto stream_node = std::make_shared<GNode::StreamNode>(node_id, url, gpu_id, decode_type);
+                            stream_node->set_skip_frame(skip);
+                            new_node = stream_node;
+                        } 
+                        else if (node_type == "Inference") 
+                        {
+                            std::string model_id_ref = params.at("model_id").get<std::string>();
+                            if (shared_models_.find(model_id_ref) == shared_models_.end()) 
+                            {
+                                throw std::runtime_error("Model ID '" + model_id_ref + "' not found for node '" + node_id + "'");
+                            }
+                            std::shared_ptr<Infer> model_ptr = shared_models_.at(model_id_ref);
+                            auto infer_node = std::make_shared<GNode::InferNode>(node_id);
+                            infer_node->set_model_instance(model_ptr, model_ptr->get_gpu_id());
+                            new_node = infer_node;
+                        } 
+                        else if (node_type == "Tracker") 
+                        {
+                            std::string track_name = params.at("track_name").get<std::string>();
+                            int track_frame = params.value("track_frame", 30);
+                            int track_dist = params.value("track_distance", 30);
+                            new_node = std::make_shared<GNode::TrackNode>(node_id, track_name, track_frame, track_dist);
+
+                        } 
+                        else if (node_type == "Analyzer") 
+                        {
+                            new_node = std::make_shared<GNode::AnalyzeNode>(node_id);
+
+                        } 
+                        else if (node_type == "Drawer") 
+                        {
+                            new_node = std::make_shared<GNode::DrawNode>(node_id);
+                        } 
+                        else if (node_type == "Push")
+                        {
+                            new_node = std::make_shared<GNode::PushNode>(node_id);
+
+                        } 
+                        else if (node_type == "Recorder") 
+                        {
+                            std::string record_path = params.at("record_path").get<std::string>();
+                            auto record_node = std::make_shared<GNode::RecordNode>(node_id);
+                            record_node->set_record_path(record_path);
+                            if (params.contains("fps")) 
+                            {
+                                record_node->set_fps(params["fps"].get<int>());
+                            }
+                            if (params.contains("fourcc")) 
+                            {
+                                std::string fourcc_str = params["fourcc"].get<std::string>();
+                                if (fourcc_str.length() == 4) 
+                                {
+                                    record_node->set_fourcc(cv::VideoWriter::fourcc(fourcc_str[0], fourcc_str[1], fourcc_str[2], fourcc_str[3]));
+                                } 
+                                else 
+                                {
+                                    std::cerr << "Warning: Invalid fourcc string '" << fourcc_str << "' for node " << node_id << ". Using default." << std::endl;
+                                }
+                            }
+                            new_node = record_node;
+                        } 
+                        else 
+                        {
+                            throw std::runtime_error("Unknown node type '" + node_type + "' for node ID '" + node_id + "'");
+                        }
+
+                        if (new_node) 
+                        {
+                            current_pipeline.nodes.push_back(new_node);
+                            current_pipeline_nodes_map[node_id] = new_node;
+                            std::cout << "  Created node: " << node_id << " (" << node_type << ")" << std::endl;
+                        }
+                    }
+
+                    // --- Link nodes within the current pipeline ---
+                    if (current_pipeline.nodes.size() > 1) 
+                    {
+                        int max_queue_size = 100;
+                        OverflowStrategy strategy = OverflowStrategy::BlockTimeout;
+
+                        for (size_t i = 0; i < current_pipeline.nodes.size() - 1; ++i) 
+                        {
+                            GNode::LinkNode(current_pipeline.nodes[i],
+                                            current_pipeline.nodes[i + 1],
+                                            max_queue_size,
+                                            strategy);
+                        }
+                    }
+                }
+                configured_pipelines_.push_back(std::move(current_pipeline));
+            } 
+            catch (const std::exception& e) 
+            {
+                 std::string pipeline_id = pipeline_config.value("pipeline_id", "UNKNOWN");
+                throw std::runtime_error("Error processing pipeline '" + pipeline_id + "': " + e.what());
+            }
+        }
+    }
 }
 
 }

+ 33 - 1
src/graph/graph.hpp

@@ -6,6 +6,7 @@
 #include <string>
 #include <unordered_map>
 #include "common/json.hpp"
+#include "infer/infer.hpp"
 #include "nodes/base/base.hpp"
 #include "nodes/stream/streamNode.hpp"
 #include "nodes/infer/inferNode.hpp"
@@ -39,10 +40,41 @@ public:
     }
 
     // 获取加载的共享模型 (只读)
-    const std::map<std::string, std::shared_ptr<GNode::Infer>>& getSharedModels() const {
+    const std::unordered_map<std::string, std::shared_ptr<Infer>>& getSharedModels() const {
         return shared_models_;
     }
 
+    void start_pipelines() 
+    {
+      for (auto& instance : configured_pipelines_) 
+      {
+        if (!instance.nodes.empty()) 
+        {
+            std::cout << "Starting pipeline: " << instance.pipeline_id << std::endl;
+            for (auto it = instance.nodes.rbegin(); it != instance.nodes.rend(); ++it) 
+            {
+              (*it)->start(); // Assuming a start() method exists
+            }
+        }
+      }
+    }
+
+    void stop_pipelines() 
+    {
+      for (auto& instance : configured_pipelines_) 
+      {
+        if (!instance.nodes.empty()) 
+        {
+            std::cout << "Stopping pipeline: " << instance.pipeline_id << std::endl;
+            // Stop nodes (e.g., source first) or as required
+            for (const auto& node : instance.nodes) 
+            {
+                node->stop(); // Assuming a stop() method exists
+            }
+        }
+      }
+    }
+
 
     void create_from_json(const std::string& json_path);
 

+ 1 - 1
src/infer/trt/yolo/yolo.cu

@@ -778,7 +778,7 @@ Infer *loadraw(const std::string &engine_file, ModelType model_type, const std::
                float nms_threshold, int gpu_id) 
 {
     YoloModelImpl *impl = new YoloModelImpl();
-    if (!impl->load(engine_file, model_type, names, confidence_threshold, nms_threshold)) 
+    if (!impl->load(engine_file, model_type, names, confidence_threshold, nms_threshold, gpu_id)) 
     {
         delete impl;
         impl = nullptr;

+ 1 - 1
src/infer/trt/yolo/yolo.hpp

@@ -87,7 +87,7 @@ namespace yolo
         
         std::shared_ptr<data::InstanceSegmentMap> decode_segment(int imemory, float* pbox, void *stream);
     
-        bool load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold);
+        bool load(const std::string &engine_file, ModelType model_type, const std::vector<std::string>& names, float confidence_threshold, float nms_threshold, int gpu_id);
     
         virtual Result forward(const tensor::Image &image, int slice_width, int slice_height, float overlap_width_ratio, float overlap_height_ratio, void *stream = nullptr);