|
@@ -0,0 +1,77 @@
|
|
|
+import cv2
|
|
|
+import base64
|
|
|
+import threading
|
|
|
+import requests
|
|
|
+
|
|
|
+prompt_words = {
|
|
|
+ "phone" : "1",
|
|
|
+ "fall" : "2",
|
|
|
+ "fire" : "3"
|
|
|
+}
|
|
|
+
|
|
|
+def cv2_to_base64(image):
|
|
|
+ image1 = cv2.imencode('.jpg', image)[1]
|
|
|
+ image_code = str(base64.b64encode(image1))[2:-1]
|
|
|
+ return image_code
|
|
|
+
|
|
|
+
|
|
|
+def get_cut_position(boxes, w, h):
|
|
|
+ xmin, ymin, xmax, ymax = w, h, 0, 0
|
|
|
+ if len(boxes) == 1:
|
|
|
+ xmin, ymin, xmax, ymax = boxes[0]
|
|
|
+ box_w = xmax - xmin
|
|
|
+ box_h = ymax - ymin
|
|
|
+ xmin = max(0, xmin - box_w // 5)
|
|
|
+ ymin = max(0, ymin - box_h // 5)
|
|
|
+ xmax = min(w, xmax + box_w // 5)
|
|
|
+ ymax = max(h, ymax + box_h // 5)
|
|
|
+ else:
|
|
|
+ for box in boxes:
|
|
|
+ x1, y1, x2, y2 = box
|
|
|
+ xmin = min(x1, xmin)
|
|
|
+ ymin = min(y1, ymin)
|
|
|
+ xmax = max(xmax, x2)
|
|
|
+ ymax = max(ymax, y2)
|
|
|
+ box_w = xmax - xmin
|
|
|
+ box_h = ymax - ymin
|
|
|
+ xmin = max(0, xmin - box_w // 2)
|
|
|
+ ymin = max(0, ymin - box_h // 2)
|
|
|
+ xmax = min(w, xmax + box_w // 2)
|
|
|
+ ymax = max(h, ymax + box_h // 2)
|
|
|
+ return xmin, ymin, xmax, ymax
|
|
|
+
|
|
|
+def post_illegal_data():
|
|
|
+ pass
|
|
|
+
|
|
|
+def task(boxes, image, task_name, llm_url):
|
|
|
+ """
|
|
|
+ boxes : 违章框的列表 [[x1,y1,x2,y2], ...]
|
|
|
+ image : 图片的opencv格式
|
|
|
+ task_name : 任务名称 phone : 打电话, fall : 摔倒, fire : 烟雾火焰
|
|
|
+ """
|
|
|
+ h, w, _ = image.shape
|
|
|
+ xmin, ymin, xmax, ymax = get_cut_position(boxes, w, h)
|
|
|
+
|
|
|
+ crop_image = image[ymin:ymax, xmin:xmax]
|
|
|
+ crop_image_base64 = cv2_to_base64(crop_image)
|
|
|
+ prompt_word = prompt_words[task_name]
|
|
|
+
|
|
|
+ data = {
|
|
|
+ "imageData" : crop_image_base64,
|
|
|
+ "text" : prompt_word
|
|
|
+ }
|
|
|
+ try:
|
|
|
+ response = requests.post(llm_url, json=data)
|
|
|
+ except Exception as error:
|
|
|
+ print(error)
|
|
|
+ return
|
|
|
+ res = response.json()
|
|
|
+
|
|
|
+ if res["data"]["illegal"] == 1:
|
|
|
+ post_illegal_data()
|
|
|
+
|
|
|
+
|
|
|
+def create_llm_task(boxes, image, task_name, llm_url):
|
|
|
+ llm_thread = threading.Thread(target=task, args=(boxes, image, task_name, llm_url), daemon=True)
|
|
|
+ llm_thread.start()
|
|
|
+
|