Przeglądaj źródła

增加千问推理代码

leon 1 tydzień temu
rodzic
commit
70b68c5b2d

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+models/

+ 3 - 0
.vscode/settings.json

@@ -0,0 +1,3 @@
+{
+    "python.analysis.autoImportCompletions": true
+}

+ 2 - 0
Makefile

@@ -1,4 +1,5 @@
 start:
+	cd server && \
 	gunicorn -w 2 -b 0.0.0.0:18000 -k uvicorn.workers.UvicornWorker server:app --daemon
 	@sleep 1s
 	@ps -ef | grep gunicorn
@@ -16,6 +17,7 @@ status:
 	@ps -ef | grep gunicorn 
 
 debug:
+	cd server && \
 	python3 server.py
 
 

+ 0 - 0
asserts/.gitkeep


+ 0 - 0
models/.gitkeep


+ 0 - 44
server.py

@@ -1,44 +0,0 @@
-from pydantic import BaseModel, field_validator, model_validator, Field
-from typing import List, Optional, Generic, TypeVar
-from utils import base64_to_cv2
-from fastapi import FastAPI
-import uvicorn
-
-app = FastAPI()
-
-T = TypeVar("T")
-
-class APIRequest(BaseModel):
-    imageData : str
-    text  : str
-
-class APIResponse(BaseModel, Generic[T]):
-    success: bool
-    data: Optional[List[T]]
-    msg: Optional[List[str]]
-
-@app.post("/llm/detect")
-@app.post("/llm/detect/")
-async def detect(item: APIRequest):
-    # illegal 为 0 代表没有违章
-    # illegal 为 1 代表有违章
-    response = {
-        "sucess": "OK", 
-        "data": {"illegal" : 0}, 
-        "msg": ""
-    }
-    image = base64_to_cv2(item.imageData)
-    if not image.size:
-        response["sucess"] = "FAILED"
-        response["msg"] = "Decode Image Error"
-        return response
-    
-    # 提示词
-    text = item.text
-    # 大模型检测后如果有违章
-    # response["data"]["illegal"] = 1
-    
-    return response
-
-if __name__ == "__main__":
-    uvicorn.run('server:app', host="0.0.0.0", port=18000)

BIN
server/__pycache__/qwenvl.cpython-312.pyc


BIN
server/__pycache__/server.cpython-312.pyc


+ 62 - 0
server/qwenvl.py

@@ -0,0 +1,62 @@
+from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
+from qwen_vl_utils import process_vision_info
+
+
+model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+    "../models/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
+)
+
+# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
+# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
+#     "Qwen/Qwen2.5-VL-3B-Instruct",
+#     torch_dtype=torch.bfloat16,
+#     attn_implementation="flash_attention_2",
+#     device_map="auto",
+# )
+
+# default processer
+processor = AutoProcessor.from_pretrained("../models/Qwen2.5-VL-3B-Instruct")
+
+# The default range for the number of visual tokens per image in the model is 4-16384.
+# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
+# min_pixels = 256*28*28
+# max_pixels = 1280*28*28
+# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
+
+if __name__ == "__main__":
+    messages = [
+        {
+            "role": "user",
+            "content": [
+                {
+                    "type": "image",
+                    "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
+                },
+                {"type": "text", "text": "Describe this image."},
+            ],
+        }
+    ]
+
+    # Preparation for inference
+    text = processor.apply_chat_template(
+        messages, tokenize=False, add_generation_prompt=True
+    )
+    image_inputs, video_inputs = process_vision_info(messages)
+    inputs = processor(
+        text=[text],
+        images=image_inputs,
+        videos=video_inputs,
+        padding=True,
+        return_tensors="pt",
+    )
+    inputs = inputs.to("cuda")
+
+    # Inference: Generation of the output
+    generated_ids = model.generate(**inputs, max_new_tokens=128)
+    generated_ids_trimmed = [
+        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+    ]
+    output_text = processor.batch_decode(
+        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+    )
+    print(output_text)

+ 77 - 0
server/server.py

@@ -0,0 +1,77 @@
+from pydantic import BaseModel, field_validator, model_validator, Field
+from typing import List, Optional, Generic, TypeVar
+from fastapi import FastAPI
+import uvicorn
+
+from qwenvl import model
+from qwenvl import processor
+from qwen_vl_utils import process_vision_info
+
+app = FastAPI()
+
+T = TypeVar("T")
+
+class APIRequest(BaseModel):
+    imageData : str
+    text  : str
+
+class APIResponse(BaseModel, Generic[T]):
+    success: bool
+    data: Optional[List[T]]
+    msg: Optional[List[str]]
+
+@app.post("/llm/detect")
+@app.post("/llm/detect/")
+async def detect(item: APIRequest):
+    # illegal 为 0 代表没有违章
+    # illegal 为 1 代表有违章
+    response = {
+        "sucess": "OK", 
+        "data": {"illegal" : 0}, 
+        "msg": ""
+    }
+    
+    # 提示词
+    prompt_text  = item.text
+    base64_image = item.imageData
+    messages = [
+        {
+            "role": "user",
+            "content": [
+                {
+                    "type": "image",
+                    "image": f"data:image;base64,{base64_image}",
+                },
+                {"type": "text", "text": prompt_text},
+            ],
+        }
+    ]
+    text = processor.apply_chat_template(
+        messages, tokenize=False, add_generation_prompt=True
+    )
+    image_inputs, video_inputs = process_vision_info(messages)
+    inputs = processor(
+        text=[text],
+        images=image_inputs,
+        videos=video_inputs,
+        padding=True,
+        return_tensors="pt",
+    )
+    inputs = inputs.to("cuda")
+
+    # Inference: Generation of the output
+    generated_ids = model.generate(**inputs, max_new_tokens=128)
+    generated_ids_trimmed = [
+        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+    ]
+    output_text = processor.batch_decode(
+        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+    )
+    print(output_text)
+    # 大模型检测后如果有违章
+    # response["data"]["illegal"] = 1
+    
+    return response
+
+if __name__ == "__main__":
+    uvicorn.run('server:app', host="0.0.0.0", port=18000)

+ 0 - 0
utils.py → server/utils.py