leon vor 1 Tag
Ursprung
Commit
78e52c8b90
1 geänderte Dateien mit 26 neuen und 26 gelöschten Zeilen
  1. 26 26
      server/server.py

+ 26 - 26
server/server.py

@@ -7,7 +7,7 @@ import logging
 
 from qwenvl import model
 from qwenvl import processor
-# from qwen_vl_utils import process_vision_info
+from qwen_vl_utils import process_vision_info
 
 logger = logging.getLogger()
 logger.setLevel(logging.INFO)
@@ -61,32 +61,32 @@ async def detect(item: APIRequest):
             ],
         }
     ]
-    # text = processor.apply_chat_template(
-    #     messages, tokenize=False, add_generation_prompt=True
-    # )
-    # image_inputs, video_inputs = process_vision_info(messages)
-    # inputs = processor(
-    #     text=[text],
-    #     images=image_inputs,
-    #     videos=video_inputs,
-    #     padding=True,
-    #     return_tensors="pt",
-    # )
-    # inputs = inputs.to("cuda")
+    text = processor.apply_chat_template(
+        messages, tokenize=False, add_generation_prompt=True
+    )
+    image_inputs, video_inputs = process_vision_info(messages)
+    inputs = processor(
+        text=[text],
+        images=image_inputs,
+        videos=video_inputs,
+        padding=True,
+        return_tensors="pt",
+    )
+    inputs = inputs.to("cuda")
 
-    # # Inference: Generation of the output
-    # generated_ids = model.generate(**inputs, max_new_tokens=128)
-    # generated_ids_trimmed = [
-    #     out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
-    # ]
-    # output_text = processor.batch_decode(
-    #     generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
-    # )
-    # print(output_text)
-    # # 大模型检测后如果有违章
-    # # response["data"]["illegal"] = 1
-    # if "yes" in output_text[0].lower():
-    #     response["data"]["illegal"] = 1
+    # Inference: Generation of the output
+    generated_ids = model.generate(**inputs, max_new_tokens=128)
+    generated_ids_trimmed = [
+        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+    ]
+    output_text = processor.batch_decode(
+        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+    )
+    print(output_text)
+    # 大模型检测后如果有违章
+    # response["data"]["illegal"] = 1
+    if "yes" in output_text[0].lower():
+        response["data"]["illegal"] = 1
     return response
 
 if __name__ == "__main__":