restapi.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Run a Flask REST API exposing one or more YOLOv5s models
  4. """
  5. import argparse
  6. import io
  7. import torch
  8. from flask import Flask, request
  9. from PIL import Image
  10. app = Flask(__name__)
  11. models = {}
  12. DETECTION_URL = '/v1/object-detection/<model>'
  13. @app.route(DETECTION_URL, methods=['POST'])
  14. def predict(model):
  15. if request.method != 'POST':
  16. return
  17. if request.files.get('image'):
  18. # Method 1
  19. # with request.files["image"] as f:
  20. # im = Image.open(io.BytesIO(f.read()))
  21. # Method 2
  22. im_file = request.files['image']
  23. im_bytes = im_file.read()
  24. im = Image.open(io.BytesIO(im_bytes))
  25. if model in models:
  26. results = models[model](im, size=640) # reduce size=320 for faster inference
  27. return results.pandas().xyxy[0].to_json(orient='records')
  28. if __name__ == '__main__':
  29. parser = argparse.ArgumentParser(description='Flask API exposing YOLOv5 model')
  30. parser.add_argument('--port', default=5000, type=int, help='port number')
  31. parser.add_argument('--model', nargs='+', default=['yolov5s'], help='model(s) to run, i.e. --model yolov5n yolov5s')
  32. opt = parser.parse_args()
  33. for m in opt.model:
  34. models[m] = torch.hub.load('ultralytics/yolov5', m, force_reload=True, skip_validation=True)
  35. app.run(host='0.0.0.0', port=opt.port) # debug=True causes Restarting with stat