callbacks.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Callback utils
  4. """
  5. import threading
  6. class Callbacks:
  7. """"
  8. Handles all registered callbacks for YOLOv5 Hooks
  9. """
  10. def __init__(self):
  11. # Define the available callbacks
  12. self._callbacks = {
  13. 'on_pretrain_routine_start': [],
  14. 'on_pretrain_routine_end': [],
  15. 'on_train_start': [],
  16. 'on_train_epoch_start': [],
  17. 'on_train_batch_start': [],
  18. 'optimizer_step': [],
  19. 'on_before_zero_grad': [],
  20. 'on_train_batch_end': [],
  21. 'on_train_epoch_end': [],
  22. 'on_val_start': [],
  23. 'on_val_batch_start': [],
  24. 'on_val_image_end': [],
  25. 'on_val_batch_end': [],
  26. 'on_val_end': [],
  27. 'on_fit_epoch_end': [], # fit = train + val
  28. 'on_model_save': [],
  29. 'on_train_end': [],
  30. 'on_params_update': [],
  31. 'teardown': [], }
  32. self.stop_training = False # set True to interrupt training
  33. def register_action(self, hook, name='', callback=None):
  34. """
  35. Register a new action to a callback hook
  36. Args:
  37. hook: The callback hook name to register the action to
  38. name: The name of the action for later reference
  39. callback: The callback to fire
  40. """
  41. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  42. assert callable(callback), f"callback '{callback}' is not callable"
  43. self._callbacks[hook].append({'name': name, 'callback': callback})
  44. def get_registered_actions(self, hook=None):
  45. """"
  46. Returns all the registered actions by callback hook
  47. Args:
  48. hook: The name of the hook to check, defaults to all
  49. """
  50. return self._callbacks[hook] if hook else self._callbacks
  51. def run(self, hook, *args, thread=False, **kwargs):
  52. """
  53. Loop through the registered actions and fire all callbacks on main thread
  54. Args:
  55. hook: The name of the hook to check, defaults to all
  56. args: Arguments to receive from YOLOv5
  57. thread: (boolean) Run callbacks in daemon thread
  58. kwargs: Keyword Arguments to receive from YOLOv5
  59. """
  60. assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
  61. for logger in self._callbacks[hook]:
  62. if thread:
  63. threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
  64. else:
  65. logger['callback'](*args, **kwargs)