infer.py 8.0 KB


  1. '''
  2. @File : infer.py
  3. @Time : 2024/10/17 11:18:00
  4. @Author : leon
  5. @Version : 1.0
  6. @Desc : None
  7. '''
  8. import cv2
  9. import torch
  10. import numpy as np
  11. from ultralytics import YOLO
  12. from shapely.geometry.polygon import Polygon
  13. from collections import namedtuple
  14. from typing import List, Tuple
  15. Box = namedtuple("Box", ["left", "top", "right", "bottom", "confidence", "label"])
  16. class PerspectiveMatrix:
  17. def __init__(self, matrix: np.ndarray, target: Tuple[int, int]) -> None:
  18. self.matrix = matrix
  19. self.target = target
  20. def __repr__(self):
  21. matrix_str = np.array2string(self.matrix, formatter={'float_kind': lambda x: f"{x:.2f}"})
  22. return f"PerspectiveMatrix(matrix={matrix_str}, target={self.target})"
  23. @staticmethod
  24. def perspective_matrix(src: List[Tuple[int, int]], dst: List[Tuple[int, int]], target : Tuple[int, int]) -> "PerspectiveMatrix":
  25. """
  26. 计算透视变换矩阵。
  27. 参数:
  28. - src: 源图像上的 4 个点 [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
  29. - dst: 目标图像上的 4 个点 [(u1, v1), (u2, v2), (u3, v3), (u4, v4)]
  30. 返回:
  31. - PerspectiveMatrix: 透视变换矩阵
  32. """
  33. src_pts = np.array(src, dtype=np.float32)
  34. dst_pts = np.array(dst, dtype=np.float32)
  35. # 计算透视变换矩阵
  36. # matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
  37. # 单应性矩阵
  38. matrix, _ = cv2.findHomography(src_pts, dst_pts)
  39. return PerspectiveMatrix(matrix=matrix, target=target)
  40. class ImageTransformer:
  41. @staticmethod
  42. def PerspectiveTransform(
  43. image: np.ndarray,
  44. perspective_matrix: PerspectiveMatrix,
  45. flags=cv2.INTER_LINEAR,
  46. borderMode=cv2.BORDER_CONSTANT,
  47. borderValue=(114, 114, 114)
  48. ) -> np.ndarray:
  49. if image is None or perspective_matrix is None:
  50. raise ValueError("Input image and warpaffine_matrix cannot be None.")
  51. transformed_image = cv2.warpPerspective(
  52. image,
  53. perspective_matrix.matrix,
  54. perspective_matrix.target,
  55. flags=flags,
  56. borderMode=borderMode,
  57. borderValue=borderValue)
  58. return transformed_image
  59. class AreaManage(object):
  60. def __init__(self):
  61. self.door_areas = [
  62. # [(934, 143), (958, 130), (961, 165), (936, 181)],
  63. [(564, 399), (682, 432), (684, 528), (574, 493)]]
  64. self.door_perspective_matrix = []
  65. self.dst_door_points = [(0,0), (224, 0), (224, 224), (0, 224)]
  66. for p in self.door_areas:
  67. perspective_matrix = PerspectiveMatrix.perspective_matrix(p, self.dst_door_points, (224, 224))
  68. self.door_perspective_matrix.append(perspective_matrix)
  69. # 预设值的三个门对应的人的活动区域
  70. self.person_areas = [
  71. # [(900, 152), (914, 84), (637, 20), (604, 73)],
  72. [(860, 200), (958, 226), (687, 432), (586, 395)]]
  73. self.person_areas_polygon = [Polygon(area) for area in self.person_areas]
  74. def update(self, person_area):
  75. areas = []
  76. person_area_polygon = Polygon(person_area)
  77. for pap in self.person_areas_polygon:
  78. areas.append(person_area_polygon.intersection(pap).area)
  79. max_area = max(areas)
  80. if max_area > 0:
  81. max_idx = areas.index(max_area)
  82. self.person_areas_polygon[max_idx] = person_area_polygon
  83. def inner_area(self, person_polygon):
  84. for i in range(len(self.person_areas_polygon)):
  85. overlap_ratop = person_polygon.intersection(self.person_areas_polygon[i]).area / person_polygon.area
  86. if overlap_ratop > 0.4:
  87. return i
  88. return -1
  89. """
  90. 1. 人员检测
  91. 2. 检查区域内是否有人
  92. 3. 如果区域内有人,分类该区域的门是否关闭
  93. 4. 如果关闭,上报违章
  94. """
  95. class DoorInference(object):
  96. """
  97. human_model_path : 检测人的模型地址
  98. door_model_path : 门分类的模型地址
  99. person_areas : 电子围栏区域 [[(x,y),(x,y),(x,y),(x,y),...],[(x,y),(x,y),(x,y),(x,y),...],...]
  100. device_id : 显卡id, -1表示使用cpu
  101. confidence_threshold : 检测人的模型的阈值
  102. """
  103. def __init__(self, human_model_path, door_model_path, person_areas, device_id=0, confidence_threshold= 0.5) -> "DoorInference":
  104. self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() and device_id !=-1 else "cpu")
  105. self.confidence_threshold = confidence_threshold
  106. self.human_model = YOLO(human_model_path).to(self.device)
  107. self.door_model = YOLO(door_model_path).to(self.device)
  108. self.door_names = {0: 'block', 1: 'close', 2: 'open'}
  109. self.am = AreaManage()
  110. if person_areas:
  111. for person_area in person_areas:
  112. self.am.update(person_area)
  113. def __repr__(self):
  114. infer_str = f"DoorInference(device = {self.device})"
  115. return infer_str
  116. def rect(self, point_list):
  117. """
  118. 根据给定的点计算矩形框
  119. :param point_list: [(x1, y1), (x2, y2), ..., (xn, yn)] 多个点
  120. :return: 左上角和右下角的坐标,表示矩形的框
  121. """
  122. # 获取所有点的 x 和 y 坐标的最小和最大值
  123. min_x = min(point[0] for point in point_list)
  124. max_x = max(point[0] for point in point_list)
  125. min_y = min(point[1] for point in point_list)
  126. max_y = max(point[1] for point in point_list)
  127. # 计算矩形框的左上角和右下角
  128. left, top = min_x, min_y
  129. right, bottom = max_x, max_y
  130. return left, top ,right, bottom
  131. """
  132. 返回所有检测到的人
  133. """
  134. def person_detect(self, image):
  135. objs = []
  136. confs = []
  137. results = self.human_model(image, stream=False, classes=[0], conf=self.confidence_threshold, iou=0.3, imgsz=640)
  138. for result in results:
  139. boxes = result.boxes.cpu()
  140. for box in boxes:
  141. conf = box.conf.cpu().numpy().tolist()[0]
  142. left, top, right, bottom = box.xyxy.tolist()[0]
  143. objs.append([(left, top), (right, top), (right, bottom), (left, bottom)])
  144. confs.append(conf)
  145. return objs, confs
  146. """
  147. 判断门是否关闭
  148. """
  149. def is_door_close(self, image, index) -> bool:
  150. # 图片透视变换, 只需要门的区域
  151. # 比只使用矩形框更大程度的只保留门的区域
  152. door_image = ImageTransformer.PerspectiveTransform(image, self.am.door_perspective_matrix[index])
  153. # cv2.imshow('0',door_image)
  154. # cv2.waitKey(30000)
  155. res = self.door_model(door_image)
  156. class_idx = res[0].probs.top1
  157. class_conf = res[0].probs.top1conf.cpu().numpy().item()
  158. return class_idx == 1, class_conf
  159. """
  160. image 输入待识别图片
  161. 返回需要画框的门和人的坐标
  162. """
  163. def __call__(self, image):
  164. inner_person = {0:[], 1:[], 2:[]}
  165. person_boxes, person_confs = self.person_detect(image)
  166. for person_box, person_conf in zip(person_boxes, person_confs):
  167. person_polygon = Polygon(person_box)
  168. idx = self.am.inner_area(person_polygon)
  169. if idx == -1: continue
  170. inner_person[idx].append({"box" : person_box, "conf" : person_conf})
  171. result = []
  172. for i, persons in inner_person.items():
  173. if len(persons) == 0:
  174. continue
  175. close, class_conf = self.is_door_close(image, i)
  176. if close:
  177. left, top ,right, bottom = self.rect(self.am.door_areas[i])
  178. result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=class_conf, label="door_close"))
  179. for person in persons:
  180. left, top ,right, bottom = self.rect(person["box"])
  181. conf = person["conf"]
  182. result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=conf, label="person"))
  183. return result