'''
@File    :   infer.py
@Time    :   2024/10/17 11:18:00
@Author  :   leon 
@Version :   1.0
@Desc    :   None
'''
import cv2
import torch
import time
import numpy as np
from ultralytics import YOLO
from shapely.geometry.polygon import Polygon

from collections import namedtuple
from typing import List, Tuple
import requests

from logger import logger
from stream import StreamCapture 

Box = namedtuple("Box", ["left", "top", "right", "bottom", "confidence", "label"])

class PerspectiveMatrix:
    def __init__(self, matrix: np.ndarray, target: Tuple[int, int]) -> None:
        self.matrix = matrix
        self.target = target
    
    def __repr__(self):
        matrix_str = np.array2string(self.matrix, formatter={'float_kind': lambda x: f"{x:.2f}"})
        return f"PerspectiveMatrix(matrix={matrix_str}, target={self.target})"
    
    @staticmethod
    def perspective_matrix(src: List[Tuple[int, int]], dst: List[Tuple[int, int]], target : Tuple[int, int]) -> "PerspectiveMatrix":
        """
        计算透视变换矩阵。
        
        参数:
        - src: 源图像上的 4 个点 [(x1, y1), (x2, y2), (x3, y3), (x4, y4)]
        - dst: 目标图像上的 4 个点 [(u1, v1), (u2, v2), (u3, v3), (u4, v4)]
        
        返回:
        - PerspectiveMatrix: 透视变换矩阵
        """
        src_pts = np.array(src, dtype=np.float32)
        dst_pts = np.array(dst, dtype=np.float32)

        # 计算透视变换矩阵
        # matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
        # 单应性矩阵
        matrix, _ = cv2.findHomography(src_pts, dst_pts)

        return PerspectiveMatrix(matrix=matrix, target=target)

class ImageTransformer:
    @staticmethod
    def PerspectiveTransform(
        image: np.ndarray, 
        perspective_matrix: PerspectiveMatrix,
        flags=cv2.INTER_LINEAR,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=(114, 114, 114)
    ) -> np.ndarray:
        if image is None or perspective_matrix is None:
            raise ValueError("Input image and warpaffine_matrix cannot be None.")

        transformed_image = cv2.warpPerspective(
            image, 
            perspective_matrix.matrix, 
            perspective_matrix.target, 
            flags=flags,
            borderMode=borderMode,
            borderValue=borderValue)
        return transformed_image

class AreaManage(object):
    def __init__(self):
        self.door_areas = [
            [(934, 143), (958, 130), (961, 165), (936, 181)],
            [(564, 399), (682, 432), (684, 528), (574, 493)]]
        
        self.door_perspective_matrix = []

        self.dst_door_points = [(0,0), (224, 0), (224, 224), (0, 224)]

        for p in self.door_areas:
            perspective_matrix = PerspectiveMatrix.perspective_matrix(p, self.dst_door_points, (224, 224))
            self.door_perspective_matrix.append(perspective_matrix)

        # 预设值的三个门对应的人的活动区域
        self.person_areas = [
            [(900, 152),  (914, 84), (637, 20), (604, 73)],
            [(860, 200), (958, 226), (687, 432), (586, 395)]]
        
        self.person_areas_polygon = [Polygon(area) for area in self.person_areas]
        

    def update(self, person_area):
        areas = []
        person_area_polygon = Polygon(person_area)
        for pap in self.person_areas_polygon:
            areas.append(person_area_polygon.intersection(pap).area)
        max_area = max(areas)
        if max_area > 0:
            max_idx = areas.index(max_area)
            self.person_areas_polygon[max_idx] = person_area_polygon
        
    def inner_area(self, person_polygon):
        for i in range(len(self.person_areas_polygon)):
            if person_polygon.intersection(self.person_areas_polygon[i]).area / person_polygon.area > 0.5:
                return i
        return -1


        

"""
1. 人员检测
2. 检查区域内是否有人
3. 如果区域内有人,分类该区域的门是否关闭
4. 如果关闭,上报违章
"""
class DoorInference(object):
    """
    human_model_path : 检测人的模型地址
    door_model_path  : 门分类的模型地址
    person_areas     : 电子围栏区域 [[(x,y),(x,y),(x,y),(x,y),...],[(x,y),(x,y),(x,y),(x,y),...],...]
    device_id        : 显卡id, -1表示使用cpu
    confidence_threshold : 检测人的模型的阈值
    """
    def __init__(self, human_model_path, door_model_path, person_areas, device_id=0, confidence_threshold= 0.5) -> "DoorInference":
        self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() and device_id !=-1 else "cpu")
        self.confidence_threshold = confidence_threshold
        self.human_model = YOLO(human_model_path).to(self.device)
        self.door_model  = YOLO(door_model_path).to(self.device)

        self.door_names  = {0: 'block', 1: 'close', 2: 'open'}

        self.am = AreaManage()

        if person_areas:
            for person_area in person_areas:
                self.am.update(person_area)

    def __repr__(self):
        infer_str = f"DoorInference(device = {self.device})"
        return infer_str
    
    def rect(self, point_list):
        """
        根据给定的点计算矩形框
        :param point_list: [(x1, y1), (x2, y2), ..., (xn, yn)] 多个点
        :return: 左上角和右下角的坐标,表示矩形的框
        """
        # 获取所有点的 x 和 y 坐标的最小和最大值
        min_x = min(point[0] for point in point_list)
        max_x = max(point[0] for point in point_list)
        min_y = min(point[1] for point in point_list)
        max_y = max(point[1] for point in point_list)

        # 计算矩形框的左上角和右下角
        left, top = min_x, min_y
        right, bottom = max_x, max_y

        return left, top ,right, bottom

    """
    返回所有检测到的人
    """
    def person_detect(self, image):
        objs  = []
        confs = []
        results = self.human_model(image, stream=False, classes=[4], conf=self.confidence_threshold, iou=0.3, imgsz=640)
        for result in results:
            boxes = result.boxes.cpu()
            for box in boxes:
                conf = box.conf.cpu().numpy().tolist()[0]
                left, top, right, bottom = box.xyxy.tolist()[0]
                objs.append([(left, top), (right, top), (right, bottom), (left, bottom)])
                confs.append(conf)
        return objs, confs

    """
    判断门是否关闭
    """
    def is_door_close(self, image, index) -> bool:
        # 图片透视变换, 只需要门的区域
        # 比只使用矩形框更大程度的只保留门的区域
        door_image  = ImageTransformer.PerspectiveTransform(image, self.am.door_perspective_matrix[index])
        # cv2.imshow('0',door_image)
        # cv2.waitKey(30000)
        res = self.door_model(door_image)
        class_idx = res[0].probs.top1
        class_conf = res[0].probs.top1conf.cpu().numpy()
        return class_idx == 1, class_conf

    """ 
    image 输入待识别图片
    返回需要画框的门和人的坐标
    """
    def __call__(self, image):
        inner_person = {0:[], 1:[], 2:[]}
        person_boxes, person_confs = self.person_detect(image)
        
        for person_box, person_conf in zip(person_boxes, person_confs):
            person_polygon = Polygon(person_box)
            idx = self.am.inner_area(person_polygon)
            if idx == -1: continue
            inner_person[idx].append({"box" : person_box, "conf" : person_conf})

        result = []
        for i, persons in inner_person.items():
            if len(persons) == 0:
                continue
            close, class_conf = self.is_door_close(image, i)
            if close:
                left, top ,right, bottom = self.rect(self.am.door_areas[i])
                result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=class_conf, label="door_close"))
                for person in persons:
                    left, top ,right, bottom = self.rect(person["box"])
                    conf = person["conf"]
                    result.append(Box(left=left, top=top, right=right, bottom=bottom, confidence=conf, label="person"))
        return result
    
    
if __name__ == "__main__":
    logger.info("====== Start Server =======")
    human_model_path = "models/work_clo_person_head_hat.pt"
    door_model_path  = "models/door_classify.pt"
    #image_path = "images/camera1_9ce1aca78db14c5807a271cb154fed5b.jpg"
    #image = cv2.imread(image_path)
    # logger.info(video)
    # (222, 59), (432, 3), (528, 96), (318, 198) 这个区域是为了测试,画大了的
    test_area = [[(222, 59), (432, 3), (528, 96), (318, 198)]]
    instance = DoorInference(human_model_path, door_model_path, person_areas=None)
    # 返回需要门和人
    ip = '172.19.152.231'
    channel = '45'
    stream = StreamCapture(ip, channel)
    posttime = time.time()
    for frame, ret in stream():
        if not ret:continue
        image = frame.copy()
        result = instance(image)
        if len(result) > 0 and time.time() - posttime > 30:
            try:
                posttime  = time.time()
                videoTime = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime())
                fileTime  = time.strftime('%Y-%m-%d-%H:%M:%S',time.localtime())
                filename  = fileTime + ".jpg"
                filenameori = fileTime + "det.jpg"
                logger.info(videoTime)
                logger.info(result)
                for res in result:
                    cv2.rectangle(image, tuple(map(int, (res.left, res.top))), tuple(map(int, (res.right, res.bottom))), (255,0, 0), 4)
                success, encoded_image = cv2.imencode('.jpg', image)
                content = encoded_image.tobytes()
                successori, encoded_imageori = cv2.imencode('.jpg',frame)
                contentori = encoded_imageori.tobytes()
                payload = {'channel': '45',
                            'classIndex': '8',
                            'ip': '172.19.152.231',
                            'videoTime': videoTime,
                            'videoUrl': stream.stream_url}
                files = [
                                ('file', (filename, content, 'image/jpeg')),
                                ('oldFile', (filenameori, contentori, 'image/jpeg')),
                            ]
            
                result = requests.post('http://172.19.145.197/open/api/operate/upload', data=payload, files=files)
                logger.info(result)
            except Exception as error:
                logger.error('Error : ', str(error))
    logger.info("=======  EXIT  =======")
    #cv2.imwrite("result/result.jpg", image)