platform_wearing_detect.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import time
  2. import torch
  3. import cv2
  4. import threading
  5. from datetime import datetime
  6. from ultralytics import YOLO
  7. from globals import stop_event,redis_client
  8. from config import SAVE_IMG_PATH,POST_IMG_PATH3,PLATFORM_WEARING_MODEL,PLATFORM_WEARING_VIDEO_SOURCES
  9. def init_wearing_detection():
  10. redis_client.set("platform_wearing_human_in_postion",'False')
  11. redis_client.delete("platform_wearing_items_nums")
  12. redis_client.delete("platform_wearing_detection_img")
  13. redis_client.set("platform_wearing_detection_img_flag",'False')
  14. def start_wearing_detection(start_events):
  15. # Create threads for each video stream and model
  16. threads = []
  17. for model_path in PLATFORM_WEARING_MODEL:
  18. event = threading.Event()
  19. start_events.append(event)
  20. thread = threading.Thread(target=process_video, args=(model_path,PLATFORM_WEARING_VIDEO_SOURCES,event))
  21. threads.append(thread)
  22. thread.daemon=True
  23. thread.start()
  24. # Wait for all threads to complete
  25. for thread in threads:
  26. thread.join()
  27. def process_video(model_path, video_source, start_event):
  28. model = YOLO(model_path)
  29. cap = cv2.VideoCapture(video_source)
  30. while cap.isOpened():
  31. # Read a frame from the video
  32. success, frame = cap.read()
  33. if stop_event.is_set():#控制停止推理
  34. break
  35. if success:
  36. if cap.get(cv2.CAP_PROP_POS_FRAMES) % 10 != 0:#跳帧检测,
  37. continue
  38. x, y, w, h = 650, 0, 980, 1440#剪裁画面的中心区域
  39. # Crop the frame to the ROI
  40. frame = frame[y:y+h, x:x+w]
  41. # Run YOLOv8 inference on the frame
  42. if model_path==PLATFORM_WEARING_MODEL[0]:#yolov8n,专门用来检测人
  43. #model.classes = [0]#设置只检测人一个类别
  44. results = model.predict(frame,conf=0.6,verbose=False,classes=[0])#这里的results是一个生成器
  45. for r in results:
  46. ##下面这些都是tensor类型
  47. boxes = r.boxes.xyxy # 提取所有检测到的边界框坐标
  48. confidences = r.boxes.conf # 提取所有检测到的置信度
  49. classes = r.boxes.cls # 提取所有检测到的类别索引
  50. for i in range(len(boxes)):
  51. confidence = confidences[i].item()
  52. cls = int(classes[i].item())
  53. label = model.names[cls]
  54. if label=="person" and redis_client.get("platform_wearing_human_in_postion")=='False':
  55. redis_client.set("platform_wearing_human_in_postion",'True')
  56. start_event.set()
  57. if model_path==PLATFORM_WEARING_MODEL[1]:
  58. results = model.predict(frame,conf=0.6,verbose=False)
  59. for r in results:
  60. boxes=r.obb.xyxyxyxy
  61. confidences=r.obb.conf
  62. classes=r.obb.cls
  63. wearing_items={"belt" :0,
  64. 'helmet': 0,
  65. 'shoes': 0
  66. }
  67. for i in range(len(boxes)):
  68. confidence = confidences[i].item()
  69. cls = int(classes[i].item())
  70. label = model.names[cls]
  71. wearing_items[label] += 1
  72. #因为安全带检测有四个标签,当检测到两个及以上的时候,就认为有安全带
  73. wearing_items["belt"] = 1 if wearing_items["belt"] > 2 else 0
  74. wearing_items_nums = [wearing_items["belt"], wearing_items["helmet"], wearing_items["shoes"]]
  75. if redis_client.exists("platform_wearing_items_nums"):
  76. redis_client.delete("platform_wearing_items_nums")
  77. redis_client.rpush("platform_wearing_items_nums", *wearing_items_nums)
  78. if redis_client.get("platform_wearing_detection_img_flag")=='True' and not redis_client.exists("platform_wearing_detection_img"):
  79. save_time=datetime.now().strftime('%Y%m%d_%H%M')
  80. imgpath = f"{SAVE_IMG_PATH}/platform_wearing_detection_{save_time}.jpg"
  81. post_path= f"{POST_IMG_PATH3}/platform_wearing_detection_{save_time}.jpg"
  82. annotated_frame = results[0].plot()
  83. cv2.imwrite(imgpath, annotated_frame)
  84. redis_client.set("platform_wearing_detection_img",post_path)
  85. start_event.set()
  86. else:
  87. # Break the loop if the end of the video is reached
  88. break
  89. # Release the video capture object and close the display window
  90. cap.release()
  91. if torch.cuda.is_available():
  92. torch.cuda.empty_cache()
  93. del model