Ver código fonte

天骄条件变量使用,尽可能避免cpu空转

leon 1 mês atrás
pai
commit
7c2669bff2

+ 14 - 0
src/common/queue.hpp

@@ -10,16 +10,28 @@ class SharedQueue {
 public:
     SharedQueue() = default;
 
+    void set_push_callback(std::function<void()> callback) 
+    {
+        std::lock_guard<std::mutex> lock(mutex_);
+        push_callback_ = callback;
+    }
+
     void push(const T& item) {
         std::lock_guard<std::mutex> lock(mutex_);
         queue_.push(item);
         cond_var_.notify_one();
+        if (push_callback_) {  // 触发回调
+            push_callback_();
+        }
     }
 
     void push(T&& item) {
         std::lock_guard<std::mutex> lock(mutex_);
         queue_.push(std::move(item));
         cond_var_.notify_one();
+        if (push_callback_) {  // 触发回调
+            push_callback_();
+        }
     }
 
     bool try_pop(T& item) {
@@ -54,6 +66,8 @@ private:
     mutable std::mutex mutex_;
     std::queue<T> queue_;
     std::condition_variable cond_var_;
+    // 回调函数,用于通知节点的条件变量
+    std::function<void()> push_callback_;
 };
 
 #endif  // QUEUE_HPP__

+ 1 - 1
src/nodes/base/base.cpp

@@ -30,7 +30,7 @@ void BaseNode::start()
 
 void BaseNode::stop()
 {
-    if (running_)
+    if (running_.exchange(false))
     {
         cond_var_->notify_all();
         running_ = false;

+ 3 - 0
src/nodes/base/base.hpp

@@ -43,6 +43,9 @@ public:
     inline void add_output_buffer(const std::string& name, std::shared_ptr<SharedQueue<std::shared_ptr<meta::MetaData>>> buffer)
     {
         std::unique_lock<std::mutex> lock(mutex_);
+        std::unique_lock<std::mutex> lock(mutex_);
+        // 设置回调,当数据push时通知当前节点
+        buffer->set_push_callback([this]() { cond_var_->notify_one(); });
         output_buffers_[name] = buffer;
     }
 

+ 8 - 1
src/nodes/draw/drawNode.cpp

@@ -22,6 +22,7 @@ void DrawNode::work()
     printf("DrawNode %s\n", name_.c_str());
     while (running_)
     {
+        bool has_data = false;
         for (auto& input_buffer : input_buffers_)
         {
             std::shared_ptr<meta::MetaData> metaData;
@@ -29,8 +30,8 @@ void DrawNode::work()
             {
                 continue;
             }
+            has_data = true;
             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
-            // do something
             cv::Mat image = metaData->image.clone();
             int image_width = image.cols;
             int image_height = image.rows;
@@ -52,6 +53,12 @@ void DrawNode::work()
                 output_buffer.second->push(metaData);
             }
         }
+        if (!has_data) {
+            std::unique_lock<std::mutex> lock(mutex_);
+            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
+                return !running_;  // 等待时检查退出条件
+            });
+        }
     }
 }
 

+ 8 - 0
src/nodes/httpPush/httpPush.cpp

@@ -10,6 +10,7 @@ void HttpPushNode::work()
     printf("HttpPush %s\n", name_.c_str());
     while (running_)
     {
+        bool has_data = false;
         for (auto& input_buffer : input_buffers_)
         {
             std::shared_ptr<meta::MetaData> metaData;
@@ -17,12 +18,19 @@ void HttpPushNode::work()
             {
                 continue;
             }
+            has_data = true;
             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
             // do something
             cv::Mat image = metaData->draw_image;
             std::string image_name = "result/" + metaData->from + "_" + getTimeString() + ".jpg";
             cv::imwrite(image_name, image);
         }
+        if (!has_data) {
+            std::unique_lock<std::mutex> lock(mutex_);
+            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
+                return !running_;  // 等待时检查退出条件
+            });
+        }
     }
 };
 

+ 14 - 9
src/nodes/infer/inferNode.cpp

@@ -11,7 +11,7 @@ void InferNode::work()
     printf("InferNode %s\n", name_.c_str());
     while (running_)
     {
-        
+        bool has_data = false;
         for (auto& input_buffer : input_buffers_)
         {
             std::shared_ptr<meta::MetaData> metaData;
@@ -19,20 +19,18 @@ void InferNode::work()
             {
                 continue;
             }
+            has_data = true;
             // printf("Node %s get data from %s\n", name_.c_str(), input_buffer.first.c_str());
-            // do something
             cv::Mat image = metaData->image;
             int width = image.cols;
             int height = image.rows;
 
-            // // cv::imwrite("test.jpg", image);
-            // int x = rand() % width;
-            // int y = rand() % height;
-            // int w = rand() % (width - x);
-            // int h = rand() % (height - y);
-            // metaData->boxes.push_back(data::Box(x, y, x + w, y + h, 0.9, 0));
-
             // auto res = model_->forward(tensor::cvimg(image), image.cols, image.rows, 0.0f, 0.0f);
+            if (!model_)
+            {
+                printf("model is nullptr\n");
+                continue;
+            }
             auto res = model_->forward(tensor::cvimg(image));
             for (auto& r : res)
             {
@@ -45,6 +43,13 @@ void InferNode::work()
                 output_buffer.second->push(metaData);
             }
         }
+        if (!has_data)
+        {
+            std::unique_lock<std::mutex> lock(mutex_);
+            cond_var_->wait_for(lock, std::chrono::milliseconds(100), [this] {
+                return !running_;  // 等待时检查退出条件
+            });
+        }
     }
 };
 

+ 0 - 8
src/nodes/stream/streamNode.cpp

@@ -72,20 +72,12 @@ void StreamNode::work_gpu()
         }
         int ndecoded_frame = decoder_->decode(packet_data, packet_size, pts);
         for(int i = 0; i < ndecoded_frame; ++i){
-
-            /* 因为decoder获取的frame内存,是YUV-NV12格式的。储存内存大小是 [height * 1.5] * width byte
-             因此构造一个height * 1.5,  width 大小的空间
-             然后由opencv函数,把YUV-NV12转换到BGR,转换后的image则是正常的height, width, CV_8UC3
-            */
             cv::Mat frame(decoder_->get_height(), decoder_->get_width(), CV_8UC3, decoder_->get_frame(&pts, &frame_index));
-            //cv::cvtColor(image, image, cv::COLOR_YUV2BGR_NV12);
             frame_index = frame_index + 1;
-            // INFO("write imgs/img_%05d.jpg  %dx%d", frame_index, decoder->get_width(), decoder->get_height());
             
             frame_count_++;
             if (frame_count_ % skip_frame_ != 0)
             {
-                // printf("Skip frame %d\n", frame_count_);
                 continue;
             }
 

+ 3 - 1
src/nodes/stream/streamNode.hpp

@@ -61,7 +61,7 @@ public:
             status_ = StreamStatus::OPENED;
         }
     }
-    virtual ~StreamNode() { running_ = false; };
+    virtual ~StreamNode() { };
 
     void set_stream_url(const std::string& stream_url)
     {
@@ -77,6 +77,8 @@ public:
         skip_frame_ = skip_frame;
     }
 
+    
+
     void work() override;
     void work_cpu();
     void work_gpu();