graph.cpp 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. #include "graph/graph.hpp"
  2. #include <fstream>
  3. namespace Graph
  4. {
  5. using json = nlohmann::json;
  6. static ModelType string_to_model_type(const std::string& type_str)
  7. {
  8. if (type_str == "YOLOV5") return ModelType::YOLOV5;
  9. if (type_str == "YOLOV5SEG") return ModelType::YOLOV5SEG;
  10. if (type_str == "YOLO11") return ModelType::YOLO11;
  11. if (type_str == "YOLO11POSE") return ModelType::YOLO11POSE;
  12. if (type_str == "YOLO11SEG") return ModelType::YOLO11SEG;
  13. if (type_str == "DEPTH_ANYTHING") return ModelType::DEPTH_ANYTHING;
  14. throw std::runtime_error("Unknown model type string: " + type_str);
  15. }
  16. // Helper to map string to DecodeType
  17. static GNode::DecodeType string_to_decode_type(const std::string& type_str) {
  18. if (type_str == "GPU") return GNode::DecodeType::GPU;
  19. if (type_str == "CPU") return GNode::DecodeType::CPU;
  20. if (type_str == "FOLDER") return GNode::DecodeType::FOLDER;
  21. throw std::runtime_error("Unknown decode type string: " + type_str);
  22. }
  23. void Graph::create_from_json(const std::string& json_path)
  24. {
  25. std::ifstream json_file(json_path);
  26. if (!json_file.is_open())
  27. {
  28. throw std::runtime_error("Failed to open JSON file: " + json_path);
  29. }
  30. nlohmann::json config;
  31. try
  32. {
  33. json_file >> config;
  34. } catch (const nlohmann::json::parse_error& e)
  35. {
  36. throw std::runtime_error("Failed to parse JSON: " + std::string(e.what()));
  37. }
  38. // shared_models_.clear();
  39. // configured_pipelines_.clear();
  40. // 2. Load Models
  41. if (config.contains("models"))
  42. {
  43. for (auto& [model_id, model_config] : config["models"].items())
  44. {
  45. try {
  46. std::string path = model_config.at("model_path").get<std::string>();
  47. std::string type_str = model_config.at("model_type").get<std::string>();
  48. std::vector<std::string> names = model_config.at("names").get<std::vector<std::string>>();
  49. int gpu_id = model_config.at("gpu_id").get<int>();
  50. float conf_thresh = model_config.value("confidence_threshold", 0.25f); // Use .value for optional with default
  51. float nms_thresh = model_config.value("nms_threshold", 0.45f);
  52. ModelType model_type_enum = string_to_model_type(type_str);
  53. // Load the model using your load function
  54. std::shared_ptr<Infer> model_instance = load(
  55. path, model_type_enum, names, gpu_id, conf_thresh, nms_thresh);
  56. if (!model_instance) {
  57. throw std::runtime_error("Failed to load model: " + model_id);
  58. }
  59. shared_models_[model_id] = model_instance;
  60. std::cout << "Loaded model: " << model_id << std::endl;
  61. }
  62. catch (const std::exception& e)
  63. {
  64. throw std::runtime_error("Error processing model '" + model_id + "': " + e.what());
  65. }
  66. }
  67. }
  68. // 3. Create Pipelines
  69. if (config.contains("pipelines"))
  70. {
  71. for (const auto& pipeline_config : config["pipelines"])
  72. {
  73. try
  74. {
  75. PipelineInstance current_pipeline;
  76. current_pipeline.pipeline_id = pipeline_config.at("pipeline_id").get<std::string>();
  77. current_pipeline.description = pipeline_config.value("description", ""); // Optional description
  78. std::cout << "Creating pipeline: " << current_pipeline.pipeline_id << std::endl;
  79. // Temporary map to hold nodes of the current pipeline for linking
  80. std::unordered_map<std::string, std::shared_ptr<GNode::BaseNode>> current_pipeline_nodes_map;
  81. if (pipeline_config.contains("nodes"))
  82. {
  83. for (const auto& node_config : pipeline_config["nodes"])
  84. {
  85. std::string node_id = node_config.at("node_id").get<std::string>();
  86. std::string node_type = node_config.at("node_type").get<std::string>();
  87. const auto& params = node_config.at("params");
  88. std::shared_ptr<GNode::BaseNode> new_node = nullptr;
  89. // --- Instantiate Node based on type ---
  90. if (node_type == "Source")
  91. {
  92. std::string url = params.at("stream_url").get<std::string>();
  93. int gpu_id = params.at("gpu_id").get<int>();
  94. std::string decode_str = params.at("decode_type").get<std::string>();
  95. int skip = params.value("skip_frame", 0); // Optional skip_frame
  96. GNode::DecodeType decode_type = string_to_decode_type(decode_str);
  97. auto stream_node = std::make_shared<GNode::StreamNode>(node_id, url, gpu_id, decode_type);
  98. stream_node->set_skip_frame(skip);
  99. new_node = stream_node;
  100. }
  101. else if (node_type == "Inference")
  102. {
  103. std::string model_id_ref = params.at("model_id").get<std::string>();
  104. if (shared_models_.find(model_id_ref) == shared_models_.end())
  105. {
  106. throw std::runtime_error("Model ID '" + model_id_ref + "' not found for node '" + node_id + "'");
  107. }
  108. std::shared_ptr<Infer> model_ptr = shared_models_.at(model_id_ref);
  109. auto infer_node = std::make_shared<GNode::InferNode>(node_id);
  110. infer_node->set_model_instance(model_ptr, model_ptr->get_gpu_id());
  111. new_node = infer_node;
  112. }
  113. else if (node_type == "Tracker")
  114. {
  115. std::string track_name = params.at("track_name").get<std::string>();
  116. int track_frame = params.value("track_frame", 30);
  117. int track_dist = params.value("track_distance", 30);
  118. new_node = std::make_shared<GNode::TrackNode>(node_id, track_name, track_frame, track_dist);
  119. }
  120. else if (node_type == "Analyzer")
  121. {
  122. new_node = std::make_shared<GNode::AnalyzeNode>(node_id);
  123. }
  124. else if (node_type == "Drawer")
  125. {
  126. new_node = std::make_shared<GNode::DrawNode>(node_id);
  127. }
  128. else if (node_type == "Push")
  129. {
  130. new_node = std::make_shared<GNode::HttpPushNode>(node_id);
  131. }
  132. else if (node_type == "Recorder")
  133. {
  134. std::string gst_pipeline = params.at("gst_pipeline").get<std::string>();
  135. int fps = params.value("fps", 25);
  136. new_node = std::make_shared<GNode::RtmpNode>(node_id, gst_pipeline, fps);
  137. }
  138. else
  139. {
  140. throw std::runtime_error("Unknown node type '" + node_type + "' for node ID '" + node_id + "'");
  141. }
  142. if (new_node)
  143. {
  144. current_pipeline.nodes.push_back(new_node);
  145. current_pipeline_nodes_map[node_id] = new_node;
  146. std::cout << " Created node: " << node_id << " (" << node_type << ")" << std::endl;
  147. }
  148. }
  149. // --- Link nodes within the current pipeline ---
  150. if (current_pipeline.nodes.size() > 1)
  151. {
  152. int max_queue_size = 100;
  153. OverflowStrategy strategy = OverflowStrategy::BlockTimeout;
  154. for (size_t i = 0; i < current_pipeline.nodes.size() - 1; ++i)
  155. {
  156. GNode::LinkNode(current_pipeline.nodes[i],
  157. current_pipeline.nodes[i + 1],
  158. max_queue_size,
  159. strategy);
  160. }
  161. }
  162. }
  163. configured_pipelines_.push_back(std::move(current_pipeline));
  164. }
  165. catch (const std::exception& e)
  166. {
  167. std::string pipeline_id = pipeline_config.value("pipeline_id", "UNKNOWN");
  168. throw std::runtime_error("Error processing pipeline '" + pipeline_id + "': " + e.what());
  169. }
  170. }
  171. }
  172. }
  173. }