comet_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import logging
  2. import os
  3. from urllib.parse import urlparse
  4. try:
  5. import comet_ml
  6. except (ModuleNotFoundError, ImportError):
  7. comet_ml = None
  8. import yaml
  9. logger = logging.getLogger(__name__)
  10. COMET_PREFIX = 'comet://'
  11. COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'yolov5')
  12. COMET_DEFAULT_CHECKPOINT_FILENAME = os.getenv('COMET_DEFAULT_CHECKPOINT_FILENAME', 'last.pt')
  13. def download_model_checkpoint(opt, experiment):
  14. model_dir = f'{opt.project}/{experiment.name}'
  15. os.makedirs(model_dir, exist_ok=True)
  16. model_name = COMET_MODEL_NAME
  17. model_asset_list = experiment.get_model_asset_list(model_name)
  18. if len(model_asset_list) == 0:
  19. logger.error(f'COMET ERROR: No checkpoints found for model name : {model_name}')
  20. return
  21. model_asset_list = sorted(
  22. model_asset_list,
  23. key=lambda x: x['step'],
  24. reverse=True,
  25. )
  26. logged_checkpoint_map = {asset['fileName']: asset['assetId'] for asset in model_asset_list}
  27. resource_url = urlparse(opt.weights)
  28. checkpoint_filename = resource_url.query
  29. if checkpoint_filename:
  30. asset_id = logged_checkpoint_map.get(checkpoint_filename)
  31. else:
  32. asset_id = logged_checkpoint_map.get(COMET_DEFAULT_CHECKPOINT_FILENAME)
  33. checkpoint_filename = COMET_DEFAULT_CHECKPOINT_FILENAME
  34. if asset_id is None:
  35. logger.error(f'COMET ERROR: Checkpoint {checkpoint_filename} not found in the given Experiment')
  36. return
  37. try:
  38. logger.info(f'COMET INFO: Downloading checkpoint {checkpoint_filename}')
  39. asset_filename = checkpoint_filename
  40. model_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
  41. model_download_path = f'{model_dir}/{asset_filename}'
  42. with open(model_download_path, 'wb') as f:
  43. f.write(model_binary)
  44. opt.weights = model_download_path
  45. except Exception as e:
  46. logger.warning('COMET WARNING: Unable to download checkpoint from Comet')
  47. logger.exception(e)
  48. def set_opt_parameters(opt, experiment):
  49. """Update the opts Namespace with parameters
  50. from Comet's ExistingExperiment when resuming a run
  51. Args:
  52. opt (argparse.Namespace): Namespace of command line options
  53. experiment (comet_ml.APIExperiment): Comet API Experiment object
  54. """
  55. asset_list = experiment.get_asset_list()
  56. resume_string = opt.resume
  57. for asset in asset_list:
  58. if asset['fileName'] == 'opt.yaml':
  59. asset_id = asset['assetId']
  60. asset_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
  61. opt_dict = yaml.safe_load(asset_binary)
  62. for key, value in opt_dict.items():
  63. setattr(opt, key, value)
  64. opt.resume = resume_string
  65. # Save hyperparameters to YAML file
  66. # Necessary to pass checks in training script
  67. save_dir = f'{opt.project}/{experiment.name}'
  68. os.makedirs(save_dir, exist_ok=True)
  69. hyp_yaml_path = f'{save_dir}/hyp.yaml'
  70. with open(hyp_yaml_path, 'w') as f:
  71. yaml.dump(opt.hyp, f)
  72. opt.hyp = hyp_yaml_path
  73. def check_comet_weights(opt):
  74. """Downloads model weights from Comet and updates the
  75. weights path to point to saved weights location
  76. Args:
  77. opt (argparse.Namespace): Command Line arguments passed
  78. to YOLOv5 training script
  79. Returns:
  80. None/bool: Return True if weights are successfully downloaded
  81. else return None
  82. """
  83. if comet_ml is None:
  84. return
  85. if isinstance(opt.weights, str):
  86. if opt.weights.startswith(COMET_PREFIX):
  87. api = comet_ml.API()
  88. resource = urlparse(opt.weights)
  89. experiment_path = f'{resource.netloc}{resource.path}'
  90. experiment = api.get(experiment_path)
  91. download_model_checkpoint(opt, experiment)
  92. return True
  93. return None
  94. def check_comet_resume(opt):
  95. """Restores run parameters to its original state based on the model checkpoint
  96. and logged Experiment parameters.
  97. Args:
  98. opt (argparse.Namespace): Command Line arguments passed
  99. to YOLOv5 training script
  100. Returns:
  101. None/bool: Return True if the run is restored successfully
  102. else return None
  103. """
  104. if comet_ml is None:
  105. return
  106. if isinstance(opt.resume, str):
  107. if opt.resume.startswith(COMET_PREFIX):
  108. api = comet_ml.API()
  109. resource = urlparse(opt.resume)
  110. experiment_path = f'{resource.netloc}{resource.path}'
  111. experiment = api.get(experiment_path)
  112. set_opt_parameters(opt, experiment)
  113. download_model_checkpoint(opt, experiment)
  114. return True
  115. return None