Explorar el Código

抽象节点,将work放到基类中实现

leon hace 4 semanas
padre
commit
cad3baee7e

+ 1 - 1
src/main.cpp

@@ -138,5 +138,5 @@ int main()
 // 分析节点         
 // YOLO11 seg 
 // 通过配置文件创建 pipeline
-
+// 置信度阈值修改
 // 设置电子围栏

+ 36 - 0
src/nodes/base/base.cpp

@@ -27,6 +27,11 @@ void BaseNode::stop()
 {
     if (running_.exchange(false))
     {
+        // 删除队列中全部元素
+        std::for_each(input_buffers_.begin(), input_buffers_.end(),
+                  [&](const auto &item) { item.second->clear(); });
+        std::for_each(output_buffers_.begin(), output_buffers_.end(),
+                    [&](const auto &item) { item.second->clear(); });
         cond_var_->notify_all();
         if (worker_thread_.joinable())
         {
@@ -36,4 +41,35 @@ void BaseNode::stop()
     }
 }
 
+void BaseNode::work()
+{
+    while (running_)
+    {
+        // Timer timer("InferNode");
+        bool has_data = false;
+        for (auto& input_buffer : input_buffers_)
+        {
+            std::shared_ptr<meta::MetaData> meta_data;
+            if (!input_buffer.second->try_pop(meta_data))
+            {
+                continue;
+            }
+            has_data = true;
+            handle_data(meta_data);
+            send_output_data(meta_data);          
+        }
+        if (!has_data)
+        {
+            std::unique_lock<std::mutex> lock(mutex_);
+            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
+                return !running_;  // 等待时检查退出条件
+            });
+        }
+}
+
+void BaseNode::handle_data(std::shared_ptr<meta::MetaData>& data)
+{
+    
+}
+
 } // namespace Node

+ 13 - 1
src/nodes/base/base.hpp

@@ -33,7 +33,8 @@ public:
     BaseNode(const std::string& name, NODE_TYPE type);
     virtual ~BaseNode();
 
-    virtual void work() = 0;
+    virtual void work();
+    virtual void handle_data(std::shared_ptr<meta::MetaData>& data);
 
     void start();
     void stop();
@@ -78,6 +79,17 @@ public:
         output_buffers_.erase(name);
     }
 
+    inline void send_output_data(const std::shared_ptr<meta::MetaData>& data) {
+        if (!data) 
+        {
+            return;
+        }
+        for (auto &item : output_buffers_) 
+        {
+            item.second->push(data)
+        }
+    }
+
     inline bool is_running()
     {
         return running_;

+ 1 - 5
src/nodes/draw/drawNode.cpp

@@ -100,11 +100,7 @@ void DrawNode::work()
                 cv::putText(image, text, cv::Point(x, y), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(b, g, r), 2);
             }
             metaData->draw_image = image;
-            for (auto& output_buffer : output_buffers_)
-            {
-                // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
-                output_buffer.second->push(metaData);
-            }
+            send_output_data(metaData);
         }
         if (!has_data) {
             // printf("draw wait data\n");

+ 86 - 59
src/nodes/infer/inferNode.cpp

@@ -22,71 +22,98 @@ void print_mat(const cv::Mat& mat, int max_rows = 10, int max_cols = 10)
     if (mat.rows > max_rows) std::cout << "[...]" << std::endl;
 }
 
-void InferNode::work()
+
+void InferNode::handle_data(std::shared_ptr<meta::MetaData>& meta_data)
 {
-    PLOGI.printf("InferNode : [%s] start", name_.c_str());
-    if (!model_)
+    cv::Mat image = meta_data->image;
+    int width = image.cols;
+    int height = image.rows;
+
+    auto det_result = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
+
+    if (std::holds_alternative<data::BoxArray>(det_result)) 
+    {
+        auto result = std::get<data::BoxArray>(det_result);
+        for (auto& r : result)
+        {
+            meta_data->boxes.push_back(r);
+        }
+    } 
+    else if(std::holds_alternative<cv::Mat>(det_result))
     {
-        PLOGE.printf("InferNode : [%s] model is nullptr", name_.c_str());
-        return;
+        auto depth_mat = std::get<cv::Mat>(det_result);
+        // print_mat(depth_mat);
+        meta_data->depth = depth_mat;
+
     }
-    // 不同线程都需要指定显卡id,默认为0号显卡
-    // 在主线程中创建的模型,创建时指定了显卡id
-    // 推理的时候创建了另一个线程,需要指定同样的显卡id
-    checkRuntime(cudaSetDevice(device_id_));
-    while (running_)
+    else
     {
-        // Timer timer("InferNode");
-        bool has_data = false;
-        for (auto& input_buffer : input_buffers_)
-        {
-            std::shared_ptr<meta::MetaData> metaData;
-            if (!input_buffer.second->try_pop(metaData))
-            {
-                continue;
-            }
-            has_data = true;
-            // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
-            cv::Mat image = metaData->image;
-            int width = image.cols;
-            int height = image.rows;
+        PLOGE.printf("InferNode : [%s] Unexpected result type from model", name_.c_str());
+        throw std::runtime_error("Unexpected result type from model");
+    }
+}
 
-            auto det_result = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
+// void InferNode::work()
+// {
+//     PLOGI.printf("InferNode : [%s] start", name_.c_str());
+//     if (!model_)
+//     {
+//         PLOGE.printf("InferNode : [%s] model is nullptr", name_.c_str());
+//         return;
+//     }
+//     // 不同线程都需要指定显卡id,默认为0号显卡
+//     // 在主线程中创建的模型,创建时指定了显卡id
+//     // 推理的时候创建了另一个线程,需要指定同样的显卡id
+//     checkRuntime(cudaSetDevice(device_id_));
+//     while (running_)
+//     {
+//         // Timer timer("InferNode");
+//         bool has_data = false;
+//         for (auto& input_buffer : input_buffers_)
+//         {
+//             std::shared_ptr<meta::MetaData> metaData;
+//             if (!input_buffer.second->try_pop(metaData))
+//             {
+//                 continue;
+//             }
+//             has_data = true;
+//             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
+//             cv::Mat image = metaData->image;
+//             int width = image.cols;
+//             int height = image.rows;
 
-            if (std::holds_alternative<data::BoxArray>(det_result)) 
-            {
-                auto result = std::get<data::BoxArray>(det_result);
-                for (auto& r : result)
-                {
-                    metaData->boxes.push_back(r);
-                }
-            } 
-            else if(std::holds_alternative<cv::Mat>(det_result))
-            {
-                auto depth_mat = std::get<cv::Mat>(det_result);
-                // print_mat(depth_mat);
-                metaData->depth = depth_mat;
+//             auto det_result = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
 
-            }
-            else
-            {
-                PLOGE.printf("InferNode : [%s] Unexpected result type from model", name_.c_str());
-                throw std::runtime_error("Unexpected result type from model");
-            }
-            for (auto& output_buffer : output_buffers_)
-            {
-                // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
-                output_buffer.second->push(metaData);
-            }           
-        }
-        if (!has_data)
-        {
-            std::unique_lock<std::mutex> lock(mutex_);
-            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
-                return !running_;  // 等待时检查退出条件
-            });
-        }
-    }
-};
+//             if (std::holds_alternative<data::BoxArray>(det_result)) 
+//             {
+//                 auto result = std::get<data::BoxArray>(det_result);
+//                 for (auto& r : result)
+//                 {
+//                     metaData->boxes.push_back(r);
+//                 }
+//             } 
+//             else if(std::holds_alternative<cv::Mat>(det_result))
+//             {
+//                 auto depth_mat = std::get<cv::Mat>(det_result);
+//                 // print_mat(depth_mat);
+//                 metaData->depth = depth_mat;
+
+//             }
+//             else
+//             {
+//                 PLOGE.printf("InferNode : [%s] Unexpected result type from model", name_.c_str());
+//                 throw std::runtime_error("Unexpected result type from model");
+//             }
+//             send_output_data(metaData);          
+//         }
+//         if (!has_data)
+//         {
+//             std::unique_lock<std::mutex> lock(mutex_);
+//             cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
+//                 return !running_;  // 等待时检查退出条件
+//             });
+//         }
+//     }
+// }
 
 }   // namespace StreamNode

+ 2 - 18
src/nodes/stream/streamNode.cpp

@@ -172,14 +172,7 @@ void StreamNode::process_stream_cpu()
         auto metaData = std::make_shared<meta::MetaData>();
         metaData->image = frame.clone();
         metaData->from = name_;
-
-        for (auto& output_buffer : output_buffers_)
-        {
-            if (output_buffer.second) 
-            { 
-                 output_buffer.second->push(metaData);
-            }
-        }
+        send_output_data(metaData);
     }
     PLOGI.printf("StreamNode [%s]: Exiting CPU processing loop (Running: %s, Status: %d).",
             name_.c_str(), running_ ? "true" : "false", static_cast<int>(status_));
@@ -293,16 +286,7 @@ void StreamNode::process_stream_gpu()
             auto metaData = std::make_shared<meta::MetaData>();
             metaData->image = frame_gpu.clone(); // CLONE is crucial here!
             metaData->from = name_;
-
-            bool pushed = false;
-            for (auto& output_buffer : output_buffers_)
-            {
-                if (output_buffer.second) 
-                {
-                    output_buffer.second->push(metaData);
-                    pushed = true;
-                }
-            }
+            send_output_data(metaData);
         }
         if (status_ == StreamStatus::ERROR) 
         {

+ 1 - 6
src/nodes/track/trackNode.cpp

@@ -53,12 +53,7 @@ void TrackNode::work()
                 const std::vector<float>& tlwh = track.tlwh;
                 metaData->track_boxes.emplace_back(tlwh[0], tlwh[1], tlwh[0] + tlwh[2], tlwh[1] + tlwh[3], track.score, track.track_id, track_label_);
             }
-
-            for (auto& output_buffer : output_buffers_)
-            {
-                // printf("Node %s push data to %s\n", name_.c_str(), output_buffer.first.c_str());
-                output_buffer.second->push(metaData);
-            }    
+            send_output_data(metaData);    
         }
         if (!has_data)
         {