train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Train a YOLOv5 classifier model on a classification dataset
  4. Usage - Single-GPU training:
  5. $ python classify/train.py --model yolov5s-cls.pt --data imagenette160 --epochs 5 --img 224
  6. Usage - Multi-GPU DDP training:
  7. $ python -m torch.distributed.run --nproc_per_node 4 --master_port 2022 classify/train.py --model yolov5s-cls.pt --data imagenet --epochs 5 --img 224 --device 0,1,2,3
  8. Datasets: --data mnist, fashion-mnist, cifar10, cifar100, imagenette, imagewoof, imagenet, or 'path/to/data'
  9. YOLOv5-cls models: --model yolov5n-cls.pt, yolov5s-cls.pt, yolov5m-cls.pt, yolov5l-cls.pt, yolov5x-cls.pt
  10. Torchvision models: --model resnet50, efficientnet_b0, etc. See https://pytorch.org/vision/stable/models.html
  11. """
  12. import argparse
  13. import os
  14. import subprocess
  15. import sys
  16. import time
  17. from copy import deepcopy
  18. from datetime import datetime
  19. from pathlib import Path
  20. import torch
  21. import torch.distributed as dist
  22. import torch.hub as hub
  23. import torch.optim.lr_scheduler as lr_scheduler
  24. import torchvision
  25. from torch.cuda import amp
  26. from tqdm import tqdm
  27. FILE = Path(__file__).resolve()
  28. ROOT = FILE.parents[1] # YOLOv5 root directory
  29. if str(ROOT) not in sys.path:
  30. sys.path.append(str(ROOT)) # add ROOT to PATH
  31. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  32. from classify import val as validate
  33. from models.experimental import attempt_load
  34. from models.yolo import ClassificationModel, DetectionModel
  35. from utils.dataloaders import create_classification_dataloader
  36. from utils.general import (DATASETS_DIR, LOGGER, TQDM_BAR_FORMAT, WorkingDirectory, check_git_info, check_git_status,
  37. check_requirements, colorstr, download, increment_path, init_seeds, print_args, yaml_save)
  38. from utils.loggers import GenericLogger
  39. from utils.plots import imshow_cls
  40. from utils.torch_utils import (ModelEMA, de_parallel, model_info, reshape_classifier_output, select_device, smart_DDP,
  41. smart_optimizer, smartCrossEntropyLoss, torch_distributed_zero_first)
  42. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  43. RANK = int(os.getenv('RANK', -1))
  44. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  45. GIT_INFO = check_git_info()
  46. def train(opt, device):
  47. init_seeds(opt.seed + 1 + RANK, deterministic=True)
  48. save_dir, data, bs, epochs, nw, imgsz, pretrained = \
  49. opt.save_dir, Path(opt.data), opt.batch_size, opt.epochs, min(os.cpu_count() - 1, opt.workers), \
  50. opt.imgsz, str(opt.pretrained).lower() == 'true'
  51. cuda = device.type != 'cpu'
  52. # Directories
  53. wdir = save_dir / 'weights'
  54. wdir.mkdir(parents=True, exist_ok=True) # make dir
  55. last, best = wdir / 'last.pt', wdir / 'best.pt'
  56. # Save run settings
  57. yaml_save(save_dir / 'opt.yaml', vars(opt))
  58. # Logger
  59. logger = GenericLogger(opt=opt, console_logger=LOGGER) if RANK in {-1, 0} else None
  60. # Download Dataset
  61. with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
  62. data_dir = data if data.is_dir() else (DATASETS_DIR / data)
  63. if not data_dir.is_dir():
  64. LOGGER.info(f'\nDataset not found ⚠️, missing path {data_dir}, attempting download...')
  65. t = time.time()
  66. if str(data) == 'imagenet':
  67. subprocess.run(['bash', str(ROOT / 'data/scripts/get_imagenet.sh')], shell=True, check=True)
  68. else:
  69. url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{data}.zip'
  70. download(url, dir=data_dir.parent)
  71. s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
  72. LOGGER.info(s)
  73. # Dataloaders
  74. nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
  75. trainloader = create_classification_dataloader(path=data_dir / 'train',
  76. imgsz=imgsz,
  77. batch_size=bs // WORLD_SIZE,
  78. augment=True,
  79. cache=opt.cache,
  80. rank=LOCAL_RANK,
  81. workers=nw)
  82. test_dir = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
  83. if RANK in {-1, 0}:
  84. testloader = create_classification_dataloader(path=test_dir,
  85. imgsz=imgsz,
  86. batch_size=bs // WORLD_SIZE * 2,
  87. augment=False,
  88. cache=opt.cache,
  89. rank=-1,
  90. workers=nw)
  91. # Model
  92. with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(ROOT):
  93. if Path(opt.model).is_file() or opt.model.endswith('.pt'):
  94. model = attempt_load(opt.model, device='cpu', fuse=False)
  95. elif opt.model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
  96. model = torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None)
  97. else:
  98. m = hub.list('ultralytics/yolov5') # + hub.list('pytorch/vision') # models
  99. raise ModuleNotFoundError(f'--model {opt.model} not found. Available models are: \n' + '\n'.join(m))
  100. if isinstance(model, DetectionModel):
  101. LOGGER.warning("WARNING ⚠️ pass YOLOv5 classifier model with '-cls' suffix, i.e. '--model yolov5s-cls.pt'")
  102. model = ClassificationModel(model=model, nc=nc, cutoff=opt.cutoff or 10) # convert to classification model
  103. reshape_classifier_output(model, nc) # update class count
  104. for m in model.modules():
  105. if not pretrained and hasattr(m, 'reset_parameters'):
  106. m.reset_parameters()
  107. if isinstance(m, torch.nn.Dropout) and opt.dropout is not None:
  108. m.p = opt.dropout # set dropout
  109. for p in model.parameters():
  110. p.requires_grad = True # for training
  111. model = model.to(device)
  112. # Info
  113. if RANK in {-1, 0}:
  114. model.names = trainloader.dataset.classes # attach class names
  115. model.transforms = testloader.dataset.torch_transforms # attach inference transforms
  116. model_info(model)
  117. if opt.verbose:
  118. LOGGER.info(model)
  119. images, labels = next(iter(trainloader))
  120. file = imshow_cls(images[:25], labels[:25], names=model.names, f=save_dir / 'train_images.jpg')
  121. logger.log_images(file, name='Train Examples')
  122. logger.log_graph(model, imgsz) # log model
  123. # Optimizer
  124. optimizer = smart_optimizer(model, opt.optimizer, opt.lr0, momentum=0.9, decay=opt.decay)
  125. # Scheduler
  126. lrf = 0.01 # final lr (fraction of lr0)
  127. # lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - lrf) + lrf # cosine
  128. lf = lambda x: (1 - x / epochs) * (1 - lrf) + lrf # linear
  129. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
  130. # scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=lr0, total_steps=epochs, pct_start=0.1,
  131. # final_div_factor=1 / 25 / lrf)
  132. # EMA
  133. ema = ModelEMA(model) if RANK in {-1, 0} else None
  134. # DDP mode
  135. if cuda and RANK != -1:
  136. model = smart_DDP(model)
  137. # Train
  138. t0 = time.time()
  139. criterion = smartCrossEntropyLoss(label_smoothing=opt.label_smoothing) # loss function
  140. best_fitness = 0.0
  141. scaler = amp.GradScaler(enabled=cuda)
  142. val = test_dir.stem # 'val' or 'test'
  143. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} test\n'
  144. f'Using {nw * WORLD_SIZE} dataloader workers\n'
  145. f"Logging results to {colorstr('bold', save_dir)}\n"
  146. f'Starting {opt.model} training on {data} dataset with {nc} classes for {epochs} epochs...\n\n'
  147. f"{'Epoch':>10}{'GPU_mem':>10}{'train_loss':>12}{f'{val}_loss':>12}{'top1_acc':>12}{'top5_acc':>12}")
  148. for epoch in range(epochs): # loop over the dataset multiple times
  149. tloss, vloss, fitness = 0.0, 0.0, 0.0 # train loss, val loss, fitness
  150. model.train()
  151. if RANK != -1:
  152. trainloader.sampler.set_epoch(epoch)
  153. pbar = enumerate(trainloader)
  154. if RANK in {-1, 0}:
  155. pbar = tqdm(enumerate(trainloader), total=len(trainloader), bar_format=TQDM_BAR_FORMAT)
  156. for i, (images, labels) in pbar: # progress bar
  157. images, labels = images.to(device, non_blocking=True), labels.to(device)
  158. # Forward
  159. with amp.autocast(enabled=cuda): # stability issues when enabled
  160. loss = criterion(model(images), labels)
  161. # Backward
  162. scaler.scale(loss).backward()
  163. # Optimize
  164. scaler.unscale_(optimizer) # unscale gradients
  165. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
  166. scaler.step(optimizer)
  167. scaler.update()
  168. optimizer.zero_grad()
  169. if ema:
  170. ema.update(model)
  171. if RANK in {-1, 0}:
  172. # Print
  173. tloss = (tloss * i + loss.item()) / (i + 1) # update mean losses
  174. mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
  175. pbar.desc = f"{f'{epoch + 1}/{epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
  176. # Test
  177. if i == len(pbar) - 1: # last batch
  178. top1, top5, vloss = validate.run(model=ema.ema,
  179. dataloader=testloader,
  180. criterion=criterion,
  181. pbar=pbar) # test accuracy, loss
  182. fitness = top1 # define fitness as top1 accuracy
  183. # Scheduler
  184. scheduler.step()
  185. # Log metrics
  186. if RANK in {-1, 0}:
  187. # Best fitness
  188. if fitness > best_fitness:
  189. best_fitness = fitness
  190. # Log
  191. metrics = {
  192. 'train/loss': tloss,
  193. f'{val}/loss': vloss,
  194. 'metrics/accuracy_top1': top1,
  195. 'metrics/accuracy_top5': top5,
  196. 'lr/0': optimizer.param_groups[0]['lr']} # learning rate
  197. logger.log_metrics(metrics, epoch)
  198. # Save model
  199. final_epoch = epoch + 1 == epochs
  200. if (not opt.nosave) or final_epoch:
  201. ckpt = {
  202. 'epoch': epoch,
  203. 'best_fitness': best_fitness,
  204. 'model': deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(),
  205. 'ema': None, # deepcopy(ema.ema).half(),
  206. 'updates': ema.updates,
  207. 'optimizer': None, # optimizer.state_dict(),
  208. 'opt': vars(opt),
  209. 'git': GIT_INFO, # {remote, branch, commit} if a git repo
  210. 'date': datetime.now().isoformat()}
  211. # Save last, best and delete
  212. torch.save(ckpt, last)
  213. if best_fitness == fitness:
  214. torch.save(ckpt, best)
  215. del ckpt
  216. # Train complete
  217. if RANK in {-1, 0} and final_epoch:
  218. LOGGER.info(f'\nTraining complete ({(time.time() - t0) / 3600:.3f} hours)'
  219. f"\nResults saved to {colorstr('bold', save_dir)}"
  220. f'\nPredict: python classify/predict.py --weights {best} --source im.jpg'
  221. f'\nValidate: python classify/val.py --weights {best} --data {data_dir}'
  222. f'\nExport: python export.py --weights {best} --include onnx'
  223. f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{best}')"
  224. f'\nVisualize: https://netron.app\n')
  225. # Plot examples
  226. images, labels = (x[:25] for x in next(iter(testloader))) # first 25 images and labels
  227. pred = torch.max(ema.ema(images.to(device)), 1)[1]
  228. file = imshow_cls(images, labels, pred, de_parallel(model).names, verbose=False, f=save_dir / 'test_images.jpg')
  229. # Log results
  230. meta = {'epochs': epochs, 'top1_acc': best_fitness, 'date': datetime.now().isoformat()}
  231. logger.log_images(file, name='Test Examples (true-predicted)', epoch=epoch)
  232. logger.log_model(best, epochs, metadata=meta)
  233. def parse_opt(known=False):
  234. parser = argparse.ArgumentParser()
  235. parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
  236. parser.add_argument('--data', type=str, default='imagenette160', help='cifar10, cifar100, mnist, imagenet, ...')
  237. parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
  238. parser.add_argument('--batch-size', type=int, default=64, help='total batch size for all GPUs')
  239. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=224, help='train, val image size (pixels)')
  240. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  241. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
  242. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  243. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  244. parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
  245. parser.add_argument('--name', default='exp', help='save to project/name')
  246. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  247. parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
  248. parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
  249. parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
  250. parser.add_argument('--decay', type=float, default=5e-5, help='weight decay')
  251. parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
  252. parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
  253. parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
  254. parser.add_argument('--verbose', action='store_true', help='Verbose mode')
  255. parser.add_argument('--seed', type=int, default=0, help='Global training seed')
  256. parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
  257. return parser.parse_known_args()[0] if known else parser.parse_args()
  258. def main(opt):
  259. # Checks
  260. if RANK in {-1, 0}:
  261. print_args(vars(opt))
  262. check_git_status()
  263. check_requirements(ROOT / 'requirements.txt')
  264. # DDP mode
  265. device = select_device(opt.device, batch_size=opt.batch_size)
  266. if LOCAL_RANK != -1:
  267. assert opt.batch_size != -1, 'AutoBatch is coming soon for classification, please pass a valid --batch-size'
  268. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  269. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  270. torch.cuda.set_device(LOCAL_RANK)
  271. device = torch.device('cuda', LOCAL_RANK)
  272. dist.init_process_group(backend='nccl' if dist.is_nccl_available() else 'gloo')
  273. # Parameters
  274. opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok) # increment run
  275. # Train
  276. train(opt, device)
  277. def run(**kwargs):
  278. # Usage: from yolov5 import classify; classify.train.run(data=mnist, imgsz=320, model='yolov5m')
  279. opt = parse_opt(True)
  280. for k, v in kwargs.items():
  281. setattr(opt, k, v)
  282. main(opt)
  283. return opt
  284. if __name__ == '__main__':
  285. opt = parse_opt()
  286. main(opt)