train.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Train a YOLOv5 segment model on a segment dataset
  4. Models and datasets download automatically from the latest YOLOv5 release.
  5. Usage - Single-GPU training:
  6. $ python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 # from pretrained (recommended)
  7. $ python segment/train.py --data coco128-seg.yaml --weights '' --cfg yolov5s-seg.yaml --img 640 # from scratch
  8. Usage - Multi-GPU DDP training:
  9. $ python -m torch.distributed.run --nproc_per_node 4 --master_port 1 segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --img 640 --device 0,1,2,3
  10. Models: https://github.com/ultralytics/yolov5/tree/master/models
  11. Datasets: https://github.com/ultralytics/yolov5/tree/master/data
  12. Tutorial: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data
  13. """
  14. import argparse
  15. import math
  16. import os
  17. import random
  18. import subprocess
  19. import sys
  20. import time
  21. from copy import deepcopy
  22. from datetime import datetime
  23. from pathlib import Path
  24. import numpy as np
  25. import torch
  26. import torch.distributed as dist
  27. import torch.nn as nn
  28. import yaml
  29. from torch.optim import lr_scheduler
  30. from tqdm import tqdm
  31. FILE = Path(__file__).resolve()
  32. ROOT = FILE.parents[1] # YOLOv5 root directory
  33. if str(ROOT) not in sys.path:
  34. sys.path.append(str(ROOT)) # add ROOT to PATH
  35. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  36. import segment.val as validate # for end-of-epoch mAP
  37. from models.experimental import attempt_load
  38. from models.yolo import SegmentationModel
  39. from utils.autoanchor import check_anchors
  40. from utils.autobatch import check_train_batch_size
  41. from utils.callbacks import Callbacks
  42. from utils.downloads import attempt_download, is_url
  43. from utils.general import (LOGGER, TQDM_BAR_FORMAT, check_amp, check_dataset, check_file, check_git_info,
  44. check_git_status, check_img_size, check_requirements, check_suffix, check_yaml, colorstr,
  45. get_latest_run, increment_path, init_seeds, intersect_dicts, labels_to_class_weights,
  46. labels_to_image_weights, one_cycle, print_args, print_mutation, strip_optimizer, yaml_save)
  47. from utils.loggers import GenericLogger
  48. from utils.plots import plot_evolve, plot_labels
  49. from utils.segment.dataloaders import create_dataloader
  50. from utils.segment.loss import ComputeLoss
  51. from utils.segment.metrics import KEYS, fitness
  52. from utils.segment.plots import plot_images_and_masks, plot_results_with_masks
  53. from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
  54. smart_resume, torch_distributed_zero_first)
  55. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  56. RANK = int(os.getenv('RANK', -1))
  57. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  58. GIT_INFO = check_git_info()
  59. def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictionary
  60. save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, mask_ratio = \
  61. Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
  62. opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze, opt.mask_ratio
  63. # callbacks.run('on_pretrain_routine_start')
  64. # Directories
  65. w = save_dir / 'weights' # weights dir
  66. (w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
  67. last, best = w / 'last.pt', w / 'best.pt'
  68. # Hyperparameters
  69. if isinstance(hyp, str):
  70. with open(hyp, errors='ignore') as f:
  71. hyp = yaml.safe_load(f) # load hyps dict
  72. LOGGER.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
  73. opt.hyp = hyp.copy() # for saving hyps to checkpoints
  74. # Save run settings
  75. if not evolve:
  76. yaml_save(save_dir / 'hyp.yaml', hyp)
  77. yaml_save(save_dir / 'opt.yaml', vars(opt))
  78. # Loggers
  79. data_dict = None
  80. if RANK in {-1, 0}:
  81. logger = GenericLogger(opt=opt, console_logger=LOGGER)
  82. # Config
  83. plots = not evolve and not opt.noplots # create plots
  84. overlap = not opt.no_overlap
  85. cuda = device.type != 'cpu'
  86. init_seeds(opt.seed + 1 + RANK, deterministic=True)
  87. with torch_distributed_zero_first(LOCAL_RANK):
  88. data_dict = data_dict or check_dataset(data) # check if None
  89. train_path, val_path = data_dict['train'], data_dict['val']
  90. nc = 1 if single_cls else int(data_dict['nc']) # number of classes
  91. names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
  92. is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
  93. # Model
  94. check_suffix(weights, '.pt') # check weights
  95. pretrained = weights.endswith('.pt')
  96. if pretrained:
  97. with torch_distributed_zero_first(LOCAL_RANK):
  98. weights = attempt_download(weights) # download if not found locally
  99. ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
  100. model = SegmentationModel(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)
  101. exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
  102. csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
  103. csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
  104. model.load_state_dict(csd, strict=False) # load
  105. LOGGER.info(f'Transferred {len(csd)}/{len(model.state_dict())} items from {weights}') # report
  106. else:
  107. model = SegmentationModel(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
  108. amp = check_amp(model) # check AMP
  109. # Freeze
  110. freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
  111. for k, v in model.named_parameters():
  112. v.requires_grad = True # train all layers
  113. # v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
  114. if any(x in k for x in freeze):
  115. LOGGER.info(f'freezing {k}')
  116. v.requires_grad = False
  117. # Image size
  118. gs = max(int(model.stride.max()), 32) # grid size (max stride)
  119. imgsz = check_img_size(opt.imgsz, gs, floor=gs * 2) # verify imgsz is gs-multiple
  120. # Batch size
  121. if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
  122. batch_size = check_train_batch_size(model, imgsz, amp)
  123. logger.update_params({'batch_size': batch_size})
  124. # loggers.on_params_update({"batch_size": batch_size})
  125. # Optimizer
  126. nbs = 64 # nominal batch size
  127. accumulate = max(round(nbs / batch_size), 1) # accumulate loss before optimizing
  128. hyp['weight_decay'] *= batch_size * accumulate / nbs # scale weight_decay
  129. optimizer = smart_optimizer(model, opt.optimizer, hyp['lr0'], hyp['momentum'], hyp['weight_decay'])
  130. # Scheduler
  131. if opt.cos_lr:
  132. lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf']
  133. else:
  134. lf = lambda x: (1 - x / epochs) * (1.0 - hyp['lrf']) + hyp['lrf'] # linear
  135. scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs)
  136. # EMA
  137. ema = ModelEMA(model) if RANK in {-1, 0} else None
  138. # Resume
  139. best_fitness, start_epoch = 0.0, 0
  140. if pretrained:
  141. if resume:
  142. best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume)
  143. del ckpt, csd
  144. # DP mode
  145. if cuda and RANK == -1 and torch.cuda.device_count() > 1:
  146. LOGGER.warning(
  147. 'WARNING ⚠️ DP not recommended, use torch.distributed.run for best DDP Multi-GPU results.\n'
  148. 'See Multi-GPU Tutorial at https://docs.ultralytics.com/yolov5/tutorials/multi_gpu_training to get started.'
  149. )
  150. model = torch.nn.DataParallel(model)
  151. # SyncBatchNorm
  152. if opt.sync_bn and cuda and RANK != -1:
  153. model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
  154. LOGGER.info('Using SyncBatchNorm()')
  155. # Trainloader
  156. train_loader, dataset = create_dataloader(
  157. train_path,
  158. imgsz,
  159. batch_size // WORLD_SIZE,
  160. gs,
  161. single_cls,
  162. hyp=hyp,
  163. augment=True,
  164. cache=None if opt.cache == 'val' else opt.cache,
  165. rect=opt.rect,
  166. rank=LOCAL_RANK,
  167. workers=workers,
  168. image_weights=opt.image_weights,
  169. quad=opt.quad,
  170. prefix=colorstr('train: '),
  171. shuffle=True,
  172. mask_downsample_ratio=mask_ratio,
  173. overlap_mask=overlap,
  174. )
  175. labels = np.concatenate(dataset.labels, 0)
  176. mlc = int(labels[:, 0].max()) # max label class
  177. assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
  178. # Process 0
  179. if RANK in {-1, 0}:
  180. val_loader = create_dataloader(val_path,
  181. imgsz,
  182. batch_size // WORLD_SIZE * 2,
  183. gs,
  184. single_cls,
  185. hyp=hyp,
  186. cache=None if noval else opt.cache,
  187. rect=True,
  188. rank=-1,
  189. workers=workers * 2,
  190. pad=0.5,
  191. mask_downsample_ratio=mask_ratio,
  192. overlap_mask=overlap,
  193. prefix=colorstr('val: '))[0]
  194. if not resume:
  195. if not opt.noautoanchor:
  196. check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
  197. model.half().float() # pre-reduce anchor precision
  198. if plots:
  199. plot_labels(labels, names, save_dir)
  200. # callbacks.run('on_pretrain_routine_end', labels, names)
  201. # DDP mode
  202. if cuda and RANK != -1:
  203. model = smart_DDP(model)
  204. # Model attributes
  205. nl = de_parallel(model).model[-1].nl # number of detection layers (to scale hyps)
  206. hyp['box'] *= 3 / nl # scale to layers
  207. hyp['cls'] *= nc / 80 * 3 / nl # scale to classes and layers
  208. hyp['obj'] *= (imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
  209. hyp['label_smoothing'] = opt.label_smoothing
  210. model.nc = nc # attach number of classes to model
  211. model.hyp = hyp # attach hyperparameters to model
  212. model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
  213. model.names = names
  214. # Start training
  215. t0 = time.time()
  216. nb = len(train_loader) # number of batches
  217. nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations)
  218. # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
  219. last_opt_step = -1
  220. maps = np.zeros(nc) # mAP per class
  221. results = (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls)
  222. scheduler.last_epoch = start_epoch - 1 # do not move
  223. scaler = torch.cuda.amp.GradScaler(enabled=amp)
  224. stopper, stop = EarlyStopping(patience=opt.patience), False
  225. compute_loss = ComputeLoss(model, overlap=overlap) # init loss class
  226. # callbacks.run('on_train_start')
  227. LOGGER.info(f'Image sizes {imgsz} train, {imgsz} val\n'
  228. f'Using {train_loader.num_workers * WORLD_SIZE} dataloader workers\n'
  229. f"Logging results to {colorstr('bold', save_dir)}\n"
  230. f'Starting training for {epochs} epochs...')
  231. for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------
  232. # callbacks.run('on_train_epoch_start')
  233. model.train()
  234. # Update image weights (optional, single-GPU only)
  235. if opt.image_weights:
  236. cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
  237. iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
  238. dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
  239. # Update mosaic border (optional)
  240. # b = int(random.uniform(0.25 * imgsz, 0.75 * imgsz + gs) // gs * gs)
  241. # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
  242. mloss = torch.zeros(4, device=device) # mean losses
  243. if RANK != -1:
  244. train_loader.sampler.set_epoch(epoch)
  245. pbar = enumerate(train_loader)
  246. LOGGER.info(('\n' + '%11s' * 8) %
  247. ('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Instances', 'Size'))
  248. if RANK in {-1, 0}:
  249. pbar = tqdm(pbar, total=nb, bar_format=TQDM_BAR_FORMAT) # progress bar
  250. optimizer.zero_grad()
  251. for i, (imgs, targets, paths, _, masks) in pbar: # batch ------------------------------------------------------
  252. # callbacks.run('on_train_batch_start')
  253. ni = i + nb * epoch # number integrated batches (since train start)
  254. imgs = imgs.to(device, non_blocking=True).float() / 255 # uint8 to float32, 0-255 to 0.0-1.0
  255. # Warmup
  256. if ni <= nw:
  257. xi = [0, nw] # x interp
  258. # compute_loss.gr = np.interp(ni, xi, [0.0, 1.0]) # iou loss ratio (obj_loss = 1.0 or iou)
  259. accumulate = max(1, np.interp(ni, xi, [1, nbs / batch_size]).round())
  260. for j, x in enumerate(optimizer.param_groups):
  261. # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
  262. x['lr'] = np.interp(ni, xi, [hyp['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
  263. if 'momentum' in x:
  264. x['momentum'] = np.interp(ni, xi, [hyp['warmup_momentum'], hyp['momentum']])
  265. # Multi-scale
  266. if opt.multi_scale:
  267. sz = random.randrange(int(imgsz * 0.5), int(imgsz * 1.5) + gs) // gs * gs # size
  268. sf = sz / max(imgs.shape[2:]) # scale factor
  269. if sf != 1:
  270. ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
  271. imgs = nn.functional.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
  272. # Forward
  273. with torch.cuda.amp.autocast(amp):
  274. pred = model(imgs) # forward
  275. loss, loss_items = compute_loss(pred, targets.to(device), masks=masks.to(device).float())
  276. if RANK != -1:
  277. loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
  278. if opt.quad:
  279. loss *= 4.
  280. # Backward
  281. scaler.scale(loss).backward()
  282. # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
  283. if ni - last_opt_step >= accumulate:
  284. scaler.unscale_(optimizer) # unscale gradients
  285. torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients
  286. scaler.step(optimizer) # optimizer.step
  287. scaler.update()
  288. optimizer.zero_grad()
  289. if ema:
  290. ema.update(model)
  291. last_opt_step = ni
  292. # Log
  293. if RANK in {-1, 0}:
  294. mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
  295. mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
  296. pbar.set_description(('%11s' * 2 + '%11.4g' * 6) %
  297. (f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
  298. # callbacks.run('on_train_batch_end', model, ni, imgs, targets, paths)
  299. # if callbacks.stop_training:
  300. # return
  301. # Mosaic plots
  302. if plots:
  303. if ni < 3:
  304. plot_images_and_masks(imgs, targets, masks, paths, save_dir / f'train_batch{ni}.jpg')
  305. if ni == 10:
  306. files = sorted(save_dir.glob('train*.jpg'))
  307. logger.log_images(files, 'Mosaics', epoch)
  308. # end batch ------------------------------------------------------------------------------------------------
  309. # Scheduler
  310. lr = [x['lr'] for x in optimizer.param_groups] # for loggers
  311. scheduler.step()
  312. if RANK in {-1, 0}:
  313. # mAP
  314. # callbacks.run('on_train_epoch_end', epoch=epoch)
  315. ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
  316. final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
  317. if not noval or final_epoch: # Calculate mAP
  318. results, maps, _ = validate.run(data_dict,
  319. batch_size=batch_size // WORLD_SIZE * 2,
  320. imgsz=imgsz,
  321. half=amp,
  322. model=ema.ema,
  323. single_cls=single_cls,
  324. dataloader=val_loader,
  325. save_dir=save_dir,
  326. plots=False,
  327. callbacks=callbacks,
  328. compute_loss=compute_loss,
  329. mask_downsample_ratio=mask_ratio,
  330. overlap=overlap)
  331. # Update best mAP
  332. fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
  333. stop = stopper(epoch=epoch, fitness=fi) # early stop check
  334. if fi > best_fitness:
  335. best_fitness = fi
  336. log_vals = list(mloss) + list(results) + lr
  337. # callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
  338. # Log val metrics and media
  339. metrics_dict = dict(zip(KEYS, log_vals))
  340. logger.log_metrics(metrics_dict, epoch)
  341. # Save model
  342. if (not nosave) or (final_epoch and not evolve): # if save
  343. ckpt = {
  344. 'epoch': epoch,
  345. 'best_fitness': best_fitness,
  346. 'model': deepcopy(de_parallel(model)).half(),
  347. 'ema': deepcopy(ema.ema).half(),
  348. 'updates': ema.updates,
  349. 'optimizer': optimizer.state_dict(),
  350. 'opt': vars(opt),
  351. 'git': GIT_INFO, # {remote, branch, commit} if a git repo
  352. 'date': datetime.now().isoformat()}
  353. # Save last, best and delete
  354. torch.save(ckpt, last)
  355. if best_fitness == fi:
  356. torch.save(ckpt, best)
  357. if opt.save_period > 0 and epoch % opt.save_period == 0:
  358. torch.save(ckpt, w / f'epoch{epoch}.pt')
  359. logger.log_model(w / f'epoch{epoch}.pt')
  360. del ckpt
  361. # callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
  362. # EarlyStopping
  363. if RANK != -1: # if DDP training
  364. broadcast_list = [stop if RANK == 0 else None]
  365. dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
  366. if RANK != 0:
  367. stop = broadcast_list[0]
  368. if stop:
  369. break # must break all DDP ranks
  370. # end epoch ----------------------------------------------------------------------------------------------------
  371. # end training -----------------------------------------------------------------------------------------------------
  372. if RANK in {-1, 0}:
  373. LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
  374. for f in last, best:
  375. if f.exists():
  376. strip_optimizer(f) # strip optimizers
  377. if f is best:
  378. LOGGER.info(f'\nValidating {f}...')
  379. results, _, _ = validate.run(
  380. data_dict,
  381. batch_size=batch_size // WORLD_SIZE * 2,
  382. imgsz=imgsz,
  383. model=attempt_load(f, device).half(),
  384. iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
  385. single_cls=single_cls,
  386. dataloader=val_loader,
  387. save_dir=save_dir,
  388. save_json=is_coco,
  389. verbose=True,
  390. plots=plots,
  391. callbacks=callbacks,
  392. compute_loss=compute_loss,
  393. mask_downsample_ratio=mask_ratio,
  394. overlap=overlap) # val best model with plots
  395. if is_coco:
  396. # callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
  397. metrics_dict = dict(zip(KEYS, list(mloss) + list(results) + lr))
  398. logger.log_metrics(metrics_dict, epoch)
  399. # callbacks.run('on_train_end', last, best, epoch, results)
  400. # on train end callback using genericLogger
  401. logger.log_metrics(dict(zip(KEYS[4:16], results)), epochs)
  402. if not opt.evolve:
  403. logger.log_model(best, epoch)
  404. if plots:
  405. plot_results_with_masks(file=save_dir / 'results.csv') # save results.png
  406. files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))]
  407. files = [(save_dir / f) for f in files if (save_dir / f).exists()] # filter
  408. LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
  409. logger.log_images(files, 'Results', epoch + 1)
  410. logger.log_images(sorted(save_dir.glob('val*.jpg')), 'Validation', epoch + 1)
  411. torch.cuda.empty_cache()
  412. return results
  413. def parse_opt(known=False):
  414. parser = argparse.ArgumentParser()
  415. parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s-seg.pt', help='initial weights path')
  416. parser.add_argument('--cfg', type=str, default='', help='model.yaml path')
  417. parser.add_argument('--data', type=str, default=ROOT / 'data/coco128-seg.yaml', help='dataset.yaml path')
  418. parser.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-low.yaml', help='hyperparameters path')
  419. parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
  420. parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
  421. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
  422. parser.add_argument('--rect', action='store_true', help='rectangular training')
  423. parser.add_argument('--resume', nargs='?', const=True, default=False, help='resume most recent training')
  424. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
  425. parser.add_argument('--noval', action='store_true', help='only validate final epoch')
  426. parser.add_argument('--noautoanchor', action='store_true', help='disable AutoAnchor')
  427. parser.add_argument('--noplots', action='store_true', help='save no plot files')
  428. parser.add_argument('--evolve', type=int, nargs='?', const=300, help='evolve hyperparameters for x generations')
  429. parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
  430. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='image --cache ram/disk')
  431. parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training')
  432. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  433. parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%')
  434. parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
  435. parser.add_argument('--optimizer', type=str, choices=['SGD', 'Adam', 'AdamW'], default='SGD', help='optimizer')
  436. parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
  437. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
  438. parser.add_argument('--project', default=ROOT / 'runs/train-seg', help='save to project/name')
  439. parser.add_argument('--name', default='exp', help='save to project/name')
  440. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
  441. parser.add_argument('--quad', action='store_true', help='quad dataloader')
  442. parser.add_argument('--cos-lr', action='store_true', help='cosine LR scheduler')
  443. parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
  444. parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
  445. parser.add_argument('--freeze', nargs='+', type=int, default=[0], help='Freeze layers: backbone=10, first3=0 1 2')
  446. parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
  447. parser.add_argument('--seed', type=int, default=0, help='Global training seed')
  448. parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
  449. # Instance Segmentation Args
  450. parser.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
  451. parser.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
  452. return parser.parse_known_args()[0] if known else parser.parse_args()
  453. def main(opt, callbacks=Callbacks()):
  454. # Checks
  455. if RANK in {-1, 0}:
  456. print_args(vars(opt))
  457. check_git_status()
  458. check_requirements(ROOT / 'requirements.txt')
  459. # Resume
  460. if opt.resume and not opt.evolve: # resume from specified or most recent last.pt
  461. last = Path(check_file(opt.resume) if isinstance(opt.resume, str) else get_latest_run())
  462. opt_yaml = last.parent.parent / 'opt.yaml' # train options yaml
  463. opt_data = opt.data # original dataset
  464. if opt_yaml.is_file():
  465. with open(opt_yaml, errors='ignore') as f:
  466. d = yaml.safe_load(f)
  467. else:
  468. d = torch.load(last, map_location='cpu')['opt']
  469. opt = argparse.Namespace(**d) # replace
  470. opt.cfg, opt.weights, opt.resume = '', str(last), True # reinstate
  471. if is_url(opt_data):
  472. opt.data = check_file(opt_data) # avoid HUB resume auth timeout
  473. else:
  474. opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \
  475. check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks
  476. assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified'
  477. if opt.evolve:
  478. if opt.project == str(ROOT / 'runs/train-seg'): # if default project name, rename to runs/evolve-seg
  479. opt.project = str(ROOT / 'runs/evolve-seg')
  480. opt.exist_ok, opt.resume = opt.resume, False # pass resume to exist_ok and disable resume
  481. if opt.name == 'cfg':
  482. opt.name = Path(opt.cfg).stem # use model.yaml as name
  483. opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
  484. # DDP mode
  485. device = select_device(opt.device, batch_size=opt.batch_size)
  486. if LOCAL_RANK != -1:
  487. msg = 'is not compatible with YOLOv5 Multi-GPU DDP training'
  488. assert not opt.image_weights, f'--image-weights {msg}'
  489. assert not opt.evolve, f'--evolve {msg}'
  490. assert opt.batch_size != -1, f'AutoBatch with --batch-size -1 {msg}, please pass a valid --batch-size'
  491. assert opt.batch_size % WORLD_SIZE == 0, f'--batch-size {opt.batch_size} must be multiple of WORLD_SIZE'
  492. assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command'
  493. torch.cuda.set_device(LOCAL_RANK)
  494. device = torch.device('cuda', LOCAL_RANK)
  495. dist.init_process_group(backend='nccl' if dist.is_nccl_available() else 'gloo')
  496. # Train
  497. if not opt.evolve:
  498. train(opt.hyp, opt, device, callbacks)
  499. # Evolve hyperparameters (optional)
  500. else:
  501. # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
  502. meta = {
  503. 'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
  504. 'lrf': (1, 0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
  505. 'momentum': (0.3, 0.6, 0.98), # SGD momentum/Adam beta1
  506. 'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
  507. 'warmup_epochs': (1, 0.0, 5.0), # warmup epochs (fractions ok)
  508. 'warmup_momentum': (1, 0.0, 0.95), # warmup initial momentum
  509. 'warmup_bias_lr': (1, 0.0, 0.2), # warmup initial bias lr
  510. 'box': (1, 0.02, 0.2), # box loss gain
  511. 'cls': (1, 0.2, 4.0), # cls loss gain
  512. 'cls_pw': (1, 0.5, 2.0), # cls BCELoss positive_weight
  513. 'obj': (1, 0.2, 4.0), # obj loss gain (scale with pixels)
  514. 'obj_pw': (1, 0.5, 2.0), # obj BCELoss positive_weight
  515. 'iou_t': (0, 0.1, 0.7), # IoU training threshold
  516. 'anchor_t': (1, 2.0, 8.0), # anchor-multiple threshold
  517. 'anchors': (2, 2.0, 10.0), # anchors per output grid (0 to ignore)
  518. 'fl_gamma': (0, 0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5)
  519. 'hsv_h': (1, 0.0, 0.1), # image HSV-Hue augmentation (fraction)
  520. 'hsv_s': (1, 0.0, 0.9), # image HSV-Saturation augmentation (fraction)
  521. 'hsv_v': (1, 0.0, 0.9), # image HSV-Value augmentation (fraction)
  522. 'degrees': (1, 0.0, 45.0), # image rotation (+/- deg)
  523. 'translate': (1, 0.0, 0.9), # image translation (+/- fraction)
  524. 'scale': (1, 0.0, 0.9), # image scale (+/- gain)
  525. 'shear': (1, 0.0, 10.0), # image shear (+/- deg)
  526. 'perspective': (0, 0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
  527. 'flipud': (1, 0.0, 1.0), # image flip up-down (probability)
  528. 'fliplr': (0, 0.0, 1.0), # image flip left-right (probability)
  529. 'mosaic': (1, 0.0, 1.0), # image mixup (probability)
  530. 'mixup': (1, 0.0, 1.0), # image mixup (probability)
  531. 'copy_paste': (1, 0.0, 1.0)} # segment copy-paste (probability)
  532. with open(opt.hyp, errors='ignore') as f:
  533. hyp = yaml.safe_load(f) # load hyps dict
  534. if 'anchors' not in hyp: # anchors commented in hyp.yaml
  535. hyp['anchors'] = 3
  536. if opt.noautoanchor:
  537. del hyp['anchors'], meta['anchors']
  538. opt.noval, opt.nosave, save_dir = True, True, Path(opt.save_dir) # only val/save final epoch
  539. # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
  540. evolve_yaml, evolve_csv = save_dir / 'hyp_evolve.yaml', save_dir / 'evolve.csv'
  541. if opt.bucket:
  542. # download evolve.csv if exists
  543. subprocess.run([
  544. 'gsutil',
  545. 'cp',
  546. f'gs://{opt.bucket}/evolve.csv',
  547. str(evolve_csv), ])
  548. for _ in range(opt.evolve): # generations to evolve
  549. if evolve_csv.exists(): # if evolve.csv exists: select best hyps and mutate
  550. # Select parent(s)
  551. parent = 'single' # parent selection method: 'single' or 'weighted'
  552. x = np.loadtxt(evolve_csv, ndmin=2, delimiter=',', skiprows=1)
  553. n = min(5, len(x)) # number of previous results to consider
  554. x = x[np.argsort(-fitness(x))][:n] # top n mutations
  555. w = fitness(x) - fitness(x).min() + 1E-6 # weights (sum > 0)
  556. if parent == 'single' or len(x) == 1:
  557. # x = x[random.randint(0, n - 1)] # random selection
  558. x = x[random.choices(range(n), weights=w)[0]] # weighted selection
  559. elif parent == 'weighted':
  560. x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
  561. # Mutate
  562. mp, s = 0.8, 0.2 # mutation probability, sigma
  563. npr = np.random
  564. npr.seed(int(time.time()))
  565. g = np.array([meta[k][0] for k in hyp.keys()]) # gains 0-1
  566. ng = len(meta)
  567. v = np.ones(ng)
  568. while all(v == 1): # mutate until a change occurs (prevent duplicates)
  569. v = (g * (npr.random(ng) < mp) * npr.randn(ng) * npr.random() * s + 1).clip(0.3, 3.0)
  570. for i, k in enumerate(hyp.keys()): # plt.hist(v.ravel(), 300)
  571. hyp[k] = float(x[i + 12] * v[i]) # mutate
  572. # Constrain to limits
  573. for k, v in meta.items():
  574. hyp[k] = max(hyp[k], v[1]) # lower limit
  575. hyp[k] = min(hyp[k], v[2]) # upper limit
  576. hyp[k] = round(hyp[k], 5) # significant digits
  577. # Train mutation
  578. results = train(hyp.copy(), opt, device, callbacks)
  579. callbacks = Callbacks()
  580. # Write mutation results
  581. print_mutation(KEYS[4:16], results, hyp.copy(), save_dir, opt.bucket)
  582. # Plot results
  583. plot_evolve(evolve_csv)
  584. LOGGER.info(f'Hyperparameter evolution finished {opt.evolve} generations\n'
  585. f"Results saved to {colorstr('bold', save_dir)}\n"
  586. f'Usage example: $ python train.py --hyp {evolve_yaml}')
  587. def run(**kwargs):
  588. # Usage: import train; train.run(data='coco128.yaml', imgsz=320, weights='yolov5m.pt')
  589. opt = parse_opt(True)
  590. for k, v in kwargs.items():
  591. setattr(opt, k, v)
  592. main(opt)
  593. return opt
  594. if __name__ == '__main__':
  595. opt = parse_opt()
  596. main(opt)