server.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from pydantic import BaseModel, field_validator, model_validator, Field
  2. from typing import List, Optional, Generic, TypeVar
  3. from fastapi import FastAPI
  4. import uvicorn
  5. import logging
  6. from qwenvl import model
  7. from qwenvl import processor
  8. from qwen_vl_utils import process_vision_info
  9. logger = logging.getLogger()
  10. logger.setLevel(logging.INFO)
  11. handler1 = logging.StreamHandler()
  12. handler2 = logging.FileHandler(filename='../log/llmserver.log')
  13. formatter = logging.Formatter(
  14. "%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s"
  15. )
  16. handler1.setFormatter(formatter)
  17. handler2.setFormatter(formatter)
  18. logger.addHandler(handler1) # 将日志输出至屏幕
  19. logger.addHandler(handler2) # 将日志输出至文件
  20. app = FastAPI()
  21. T = TypeVar("T")
  22. class APIRequest(BaseModel):
  23. imageData : str
  24. text : str
  25. class APIResponse(BaseModel, Generic[T]):
  26. success: bool
  27. data: Optional[List[T]]
  28. msg: Optional[List[str]]
  29. @app.post("/llm/detect")
  30. @app.post("/llm/detect/")
  31. async def detect(item: APIRequest):
  32. # illegal 为 0 代表没有违章
  33. # illegal 为 1 代表有违章
  34. response = {
  35. "sucess": "OK",
  36. "data": {"illegal" : 0},
  37. "msg": ""
  38. }
  39. # 提示词
  40. prompt_text = item.text
  41. base64_image = item.imageData
  42. messages = [
  43. {
  44. "role": "user",
  45. "content": [
  46. {
  47. "type": "image",
  48. "image": f"data:image;base64,{base64_image}",
  49. },
  50. {"type": "text", "text": prompt_text},
  51. ],
  52. }
  53. ]
  54. text = processor.apply_chat_template(
  55. messages, tokenize=False, add_generation_prompt=True
  56. )
  57. image_inputs, video_inputs = process_vision_info(messages)
  58. inputs = processor(
  59. text=[text],
  60. images=image_inputs,
  61. videos=video_inputs,
  62. padding=True,
  63. return_tensors="pt",
  64. )
  65. inputs = inputs.to("cuda")
  66. # Inference: Generation of the output
  67. generated_ids = model.generate(**inputs, max_new_tokens=128)
  68. generated_ids_trimmed = [
  69. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  70. ]
  71. output_text = processor.batch_decode(
  72. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  73. )
  74. print(output_text)
  75. # 大模型检测后如果有违章
  76. # response["data"]["illegal"] = 1
  77. if "yes" in output_text[0].lower():
  78. response["data"]["illegal"] = 1
  79. return response
  80. if __name__ == "__main__":
  81. uvicorn.run('server:app', host="0.0.0.0", port=18000)