'''
@File    :   infer.py
@Time    :   2024/10/17 11:18:00
@Author  :   leon 
@Version :   1.0
@Desc    :   None
'''
import cv2
import torch
import numpy as np
from ultralytics import YOLO
from shapely.geometry.polygon import Polygon

from collections import namedtuple
from typing import List, Tuple

Box = namedtuple("Box", ["left", "top", "right", "bottom", "confidence", "label"])

class PerspectiveMatrix:
    def __init__(self, matrix: np.ndarray, target: Tuple[int, int]) -> None:
        self.matrix = matrix
        self.target = target
    
    def __repr__(self):
        matrix_str = np.array2string(self.matrix, formatter={'float_kind': lambda x: f"{x:.2f}"})
        return f"PerspectiveMatrix(matrix={matrix_str}, target={self.target})"
    
    @staticmethod
    def perspective_matrix(src: List[Tuple[int, int]], dst: List[Tuple[int, int]], target : Tuple[int, int]) -> "PerspectiveMatrix":
        """
        计算透视变换矩阵。
        
        参数:
        - src: 源图像上的 4 个点 [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
        - dst: 目标图像上的 4 个点 [(u1, v1), (u2, v2), (u3, v3), (u4, v4)]
        
        返回:
        - PerspectiveMatrix: 透视变换矩阵
        """
        src_pts = np.array(src, dtype=np.float32)
        dst_pts = np.array(dst, dtype=np.float32)

        # 计算透视变换矩阵
        # matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
        # 单应性矩阵
        matrix, _ = cv2.findHomography(src_pts, dst_pts)

        return PerspectiveMatrix(matrix=matrix, target=target)

class ImageTransformer:
    @staticmethod
    def PerspectiveTransform(
        image: np.ndarray, 
        perspective_matrix: PerspectiveMatrix,
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(114, 114, 114)
    ) -> np.ndarray:
        if image is None or perspective_matrix is None:
            raise ValueError("Input image and warpaffine_matrix cannot be None.")

        transformed_image = cv2.warpPerspective(
            image, 
            perspective_matrix.matrix, 
            perspective_matrix.target, 
            flags=flags,
            borderMode=borderMode,
            borderValue=borderValue)
        return transformed_image

class AreaManage(object):
    def __init__(self):
        self.door_areas = [
            # [(934, 143), (958, 130), (961, 165), (936, 181)],
            [(564, 399), (682, 432), (684, 528), (574, 493)]]
        
        self.door_perspective_matrix = []

        self.dst_door_points = [(0,0), (224, 0), (224, 224), (0, 224)]

        for p in self.door_areas:
            perspective_matrix = PerspectiveMatrix.perspective_matrix(p, self.dst_door_points, (224, 224))
            self.door_perspective_matrix.append(perspective_matrix)

        # 预设值的三个门对应的人的活动区域
        self.person_areas = [
            # [(900, 152),  (914, 84), (637, 20), (604, 73)],
            [(860, 200), (958, 226), (687, 432), (586, 395)]]
        
        self.person_areas_polygon = [Polygon(area) for area in self.person_areas]
        

    def update(self, person_area):
        areas = []
        person_area_polygon = Polygon(person_area)
        for pap in self.person_areas_polygon:
            areas.append(person_area_polygon.intersection(pap).area)
        max_area = max(areas)
        if max_area > 0:
            max_idx = areas.index(max_area)
            self.person_areas_polygon[max_idx] = person_area_polygon
        
    def inner_area(self, person_polygon):
        for i in range(len(self.person_areas_polygon)):
            overlap_ratop = person_polygon.intersection(self.person_areas_polygon[i]).area / person_polygon.area
            if overlap_ratop > 0.4:
                return i
        return -1


        

"""
1. 人员检测
2. 检查区域内是否有人
3. 如果区域内有人,分类该区域的门是否关闭
4. 如果关闭,上报违章
"""
class DoorInference(object):
    """
    human_model_path : 检测人的模型地址
    door_model_path  : 门分类的模型地址
    person_areas     : 电子围栏区域 [[(x,y),(x,y),(x,y),(x,y),...],[(x,y),(x,y),(x,y),(x,y),...],...]
    device_id        : 显卡id, -1表示使用cpu
    confidence_threshold : 检测人的模型的阈值
    """
    def __init__(self, human_model_path, door_model_path, person_areas, device_id=0, confidence_threshold= 0.5) -> "DoorInference":
        self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() and device_id !=-1 else "cpu")
        self.confidence_threshold = confidence_threshold
        self.human_model = YOLO(human_model_path).to(self.device)
        self.door_model  = YOLO(door_model_path).to(self.device)

        self.door_names  = {0: 'block', 1: 'close', 2: 'open'}

        self.am = AreaManage()

        if person_areas:
            for person_area in person_areas:
                self.am.update(person_area)

    def __repr__(self):
        infer_str = f"DoorInference(device = {self.device})"
        return infer_str
    
    def rect(self, point_list):
        """
        根据给定的点计算矩形框
        :param point_list: [(x1, y1), (x2, y2), ..., (xn, yn)] 多个点
        :return: 左上角和右下角的坐标,表示矩形的框
        """
        # 获取所有点的 x 和 y 坐标的最小和最大值
        min_x = min(point[0] for point in point_list)
        max_x = max(point[0] for point in point_list)
        min_y = min(point[1] for point in point_list)
        max_y = max(point[1] for point in point_list)

        # 计算矩形框的左上角和右下角
        left, top = min_x, min_y
        right, bottom = max_x, max_y

        return left, top ,right, bottom

    """
    返回所有检测到的人
    """
    def person_detect(self, image):
        objs  = []
        confs = []
        results = self.human_model(image, stream=False, classes=[0], conf=self.confidence_threshold, iou=0.3, imgsz=640)
        for result in results:
            boxes = result.boxes.cpu()
            for box in boxes:
                conf = box.conf.cpu().numpy().tolist()[0]
                left, top, right, bottom = box.xyxy.tolist()[0]
                objs.append([(left, top), (right, top), (right, bottom), (left, bottom)])
                confs.append(conf)
        return objs, confs

    """
    判断门是否关闭
    """
    def is_door_close(self, image, index) -> bool:
        # 图片透视变换, 只需要门的区域
        # 比只使用矩形框更大程度的只保留门的区域
        door_image  = ImageTransformer.PerspectiveTransform(image, self.am.door_perspective_matrix[index])
        # cv2.imshow('0',door_image)
        # cv2.waitKey(30000)
        res = self.door_model(door_image)
        class_idx = res[0].probs.top1
        class_conf = res[0].probs.top1conf.cpu().numpy().item()
        return class_idx == 1, class_conf

    """ 
    image 输入待识别图片
    返回需要画框的门和人的坐标
    """
    def __call__(self, image):
        inner_person = {0:[], 1:[], 2:[]}
        person_boxes, person_confs = self.person_detect(image)
        for person_box, person_conf in zip(person_boxes, person_confs):
            person_polygon = Polygon(person_box)
            idx = self.am.inner_area(person_polygon)
            if idx == -1: continue
            inner_person[idx].append({"box" : person_box, "conf" : person_conf})

        result = []
        for i, persons in inner_person.items():
            if len(persons) == 0:
                continue
            close, class_conf = self.is_door_close(image, i)
            if close:
                left, top ,right, bottom = self.rect(self.am.door_areas[i])
                result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=class_conf, label="door_close"))
                for person in persons:
                    left, top ,right, bottom = self.rect(person["box"])
                    conf = person["conf"]
                    result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=conf, label="person"))
        return result