server.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from pydantic import BaseModel, field_validator, model_validator, Field
  2. from typing import List, Optional, Generic, TypeVar
  3. from fastapi import FastAPI
  4. import uvicorn
  5. from qwenvl import model
  6. from qwenvl import processor
  7. from qwen_vl_utils import process_vision_info
  8. app = FastAPI()
  9. T = TypeVar("T")
  10. class APIRequest(BaseModel):
  11. imageData : str
  12. text : str
  13. class APIResponse(BaseModel, Generic[T]):
  14. success: bool
  15. data: Optional[List[T]]
  16. msg: Optional[List[str]]
  17. @app.post("/llm/detect")
  18. @app.post("/llm/detect/")
  19. async def detect(item: APIRequest):
  20. # illegal 为 0 代表没有违章
  21. # illegal 为 1 代表有违章
  22. response = {
  23. "sucess": "OK",
  24. "data": {"illegal" : 0},
  25. "msg": ""
  26. }
  27. # 提示词
  28. prompt_text = item.text
  29. base64_image = item.imageData
  30. messages = [
  31. {
  32. "role": "user",
  33. "content": [
  34. {
  35. "type": "image",
  36. "image": f"data:image;base64,{base64_image}",
  37. },
  38. {"type": "text", "text": prompt_text},
  39. ],
  40. }
  41. ]
  42. text = processor.apply_chat_template(
  43. messages, tokenize=False, add_generation_prompt=True
  44. )
  45. image_inputs, video_inputs = process_vision_info(messages)
  46. inputs = processor(
  47. text=[text],
  48. images=image_inputs,
  49. videos=video_inputs,
  50. padding=True,
  51. return_tensors="pt",
  52. )
  53. inputs = inputs.to("cuda")
  54. # Inference: Generation of the output
  55. generated_ids = model.generate(**inputs, max_new_tokens=128)
  56. generated_ids_trimmed = [
  57. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  58. ]
  59. output_text = processor.batch_decode(
  60. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  61. )
  62. print(output_text)
  63. # 大模型检测后如果有违章
  64. # response["data"]["illegal"] = 1
  65. return response
  66. if __name__ == "__main__":
  67. uvicorn.run('server:app', host="0.0.0.0", port=18000)