Explorar o código

创建请求大模型服务代码

leon hai 5 días
pai
achega
f752e7ad01
Modificáronse 1 ficheiros con 77 adicións e 0 borrados
  1. 77 0
      reqllm/reqllm.py

+ 77 - 0
reqllm/reqllm.py

@@ -0,0 +1,77 @@
+import cv2
+import base64
+import threading
+import requests
+
+prompt_words = {
+    "phone" : "1",
+    "fall"  : "2",
+    "fire"  : "3"
+}
+
+def cv2_to_base64(image):
+    image1 = cv2.imencode('.jpg', image)[1]
+    image_code = str(base64.b64encode(image1))[2:-1]
+    return image_code
+
+
+def get_cut_position(boxes, w, h):
+    xmin, ymin, xmax, ymax = w, h, 0, 0
+    if len(boxes) == 1:
+        xmin, ymin, xmax, ymax = boxes[0]
+        box_w = xmax - xmin
+        box_h = ymax - ymin
+        xmin = max(0, xmin - box_w // 5)
+        ymin = max(0, ymin - box_h // 5)
+        xmax = min(w, xmax + box_w // 5)
+        ymax = max(h, ymax + box_h // 5)
+    else:
+        for box in boxes:
+            x1, y1, x2, y2 = box
+            xmin = min(x1, xmin)
+            ymin = min(y1, ymin)
+            xmax = max(xmax, x2)
+            ymax = max(ymax, y2)
+        box_w = xmax - xmin
+        box_h = ymax - ymin
+        xmin = max(0, xmin - box_w // 2)
+        ymin = max(0, ymin - box_h // 2)
+        xmax = min(w, xmax + box_w // 2)
+        ymax = max(h, ymax + box_h // 2)
+    return xmin, ymin, xmax, ymax
+
+def post_illegal_data():
+    pass
+
+def task(boxes, image, task_name, llm_url):
+    """
+    boxes : 违章框的列表 [[x1,y1,x2,y2], ...]
+    image : 图片的opencv格式
+    task_name : 任务名称 phone : 打电话, fall : 摔倒, fire : 烟雾火焰
+    """
+    h, w, _ = image.shape
+    xmin, ymin, xmax, ymax = get_cut_position(boxes, w, h)
+
+    crop_image = image[ymin:ymax, xmin:xmax]
+    crop_image_base64 = cv2_to_base64(crop_image)
+    prompt_word = prompt_words[task_name]
+    
+    data = {
+        "imageData" : crop_image_base64,
+        "text" : prompt_word
+    }
+    try:
+        response = requests.post(llm_url, json=data)
+    except Exception as error:
+        print(error)
+        return
+    res = response.json()
+
+    if res["data"]["illegal"] == 1:
+        post_illegal_data()
+
+
+def create_llm_task(boxes, image, task_name, llm_url):
+    llm_thread = threading.Thread(target=task, args=(boxes, image, task_name, llm_url), daemon=True)
+    llm_thread.start()
+