torch_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. PyTorch utils
  4. """
  5. import math
  6. import os
  7. import platform
  8. import subprocess
  9. import time
  10. import warnings
  11. from contextlib import contextmanager
  12. from copy import deepcopy
  13. from pathlib import Path
  14. import torch
  15. import torch.distributed as dist
  16. import torch.nn as nn
  17. import torch.nn.functional as F
  18. from torch.nn.parallel import DistributedDataParallel as DDP
  19. from utils.general import LOGGER, check_version, colorstr, file_date, git_describe
  20. LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
  21. RANK = int(os.getenv('RANK', -1))
  22. WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
  23. try:
  24. import thop # for FLOPs computation
  25. except ImportError:
  26. thop = None
  27. # Suppress PyTorch warnings
  28. warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
  29. warnings.filterwarnings('ignore', category=UserWarning)
  30. def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
  31. # Applies torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator
  32. def decorate(fn):
  33. return (torch.inference_mode if torch_1_9 else torch.no_grad)()(fn)
  34. return decorate
  35. def smartCrossEntropyLoss(label_smoothing=0.0):
  36. # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
  37. if check_version(torch.__version__, '1.10.0'):
  38. return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
  39. if label_smoothing > 0:
  40. LOGGER.warning(f'WARNING ⚠️ label smoothing {label_smoothing} requires torch>=1.10.0')
  41. return nn.CrossEntropyLoss()
  42. def smart_DDP(model):
  43. # Model DDP creation with checks
  44. assert not check_version(torch.__version__, '1.12.0', pinned=True), \
  45. 'torch==1.12.0 torchvision==0.13.0 DDP training is not supported due to a known issue. ' \
  46. 'Please upgrade or downgrade torch to use DDP. See https://github.com/ultralytics/yolov5/issues/8395'
  47. if check_version(torch.__version__, '1.11.0'):
  48. return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK, static_graph=True)
  49. else:
  50. return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
  51. def reshape_classifier_output(model, n=1000):
  52. # Update a TorchVision classification model to class count 'n' if required
  53. from models.common import Classify
  54. name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
  55. if isinstance(m, Classify): # YOLOv5 Classify() head
  56. if m.linear.out_features != n:
  57. m.linear = nn.Linear(m.linear.in_features, n)
  58. elif isinstance(m, nn.Linear): # ResNet, EfficientNet
  59. if m.out_features != n:
  60. setattr(model, name, nn.Linear(m.in_features, n))
  61. elif isinstance(m, nn.Sequential):
  62. types = [type(x) for x in m]
  63. if nn.Linear in types:
  64. i = types.index(nn.Linear) # nn.Linear index
  65. if m[i].out_features != n:
  66. m[i] = nn.Linear(m[i].in_features, n)
  67. elif nn.Conv2d in types:
  68. i = types.index(nn.Conv2d) # nn.Conv2d index
  69. if m[i].out_channels != n:
  70. m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
  71. @contextmanager
  72. def torch_distributed_zero_first(local_rank: int):
  73. # Decorator to make all processes in distributed training wait for each local_master to do something
  74. if local_rank not in [-1, 0]:
  75. dist.barrier(device_ids=[local_rank])
  76. yield
  77. if local_rank == 0:
  78. dist.barrier(device_ids=[0])
  79. def device_count():
  80. # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
  81. assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
  82. try:
  83. cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
  84. return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
  85. except Exception:
  86. return 0
  87. def select_device(device='', batch_size=0, newline=True):
  88. # device = None or 'cpu' or 0 or '0' or '0,1,2,3'
  89. s = f'YOLOv5 🚀 {git_describe() or file_date()} Python-{platform.python_version()} torch-{torch.__version__} '
  90. device = str(device).strip().lower().replace('cuda:', '').replace('none', '') # to string, 'cuda:0' to '0'
  91. cpu = device == 'cpu'
  92. mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
  93. if cpu or mps:
  94. os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
  95. elif device: # non-cpu device requested
  96. os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable - must be before assert is_available()
  97. assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
  98. f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
  99. if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
  100. devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
  101. n = len(devices) # device count
  102. if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
  103. assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
  104. space = ' ' * (len(s) + 1)
  105. for i, d in enumerate(devices):
  106. p = torch.cuda.get_device_properties(i)
  107. s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n" # bytes to MB
  108. arg = 'cuda:0'
  109. elif mps and getattr(torch, 'has_mps', False) and torch.backends.mps.is_available(): # prefer MPS if available
  110. s += 'MPS\n'
  111. arg = 'mps'
  112. else: # revert to CPU
  113. s += 'CPU\n'
  114. arg = 'cpu'
  115. if not newline:
  116. s = s.rstrip()
  117. LOGGER.info(s)
  118. return torch.device(arg)
  119. def time_sync():
  120. # PyTorch-accurate time
  121. if torch.cuda.is_available():
  122. torch.cuda.synchronize()
  123. return time.time()
  124. def profile(input, ops, n=10, device=None):
  125. """ YOLOv5 speed/memory/FLOPs profiler
  126. Usage:
  127. input = torch.randn(16, 3, 640, 640)
  128. m1 = lambda x: x * torch.sigmoid(x)
  129. m2 = nn.SiLU()
  130. profile(input, [m1, m2], n=100) # profile over 100 iterations
  131. """
  132. results = []
  133. if not isinstance(device, torch.device):
  134. device = select_device(device)
  135. print(f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
  136. f"{'input':>24s}{'output':>24s}")
  137. for x in input if isinstance(input, list) else [input]:
  138. x = x.to(device)
  139. x.requires_grad = True
  140. for m in ops if isinstance(ops, list) else [ops]:
  141. m = m.to(device) if hasattr(m, 'to') else m # device
  142. m = m.half() if hasattr(m, 'half') and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
  143. tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
  144. try:
  145. flops = thop.profile(m, inputs=(x, ), verbose=False)[0] / 1E9 * 2 # GFLOPs
  146. except Exception:
  147. flops = 0
  148. try:
  149. for _ in range(n):
  150. t[0] = time_sync()
  151. y = m(x)
  152. t[1] = time_sync()
  153. try:
  154. _ = (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
  155. t[2] = time_sync()
  156. except Exception: # no backward method
  157. # print(e) # for debug
  158. t[2] = float('nan')
  159. tf += (t[1] - t[0]) * 1000 / n # ms per op forward
  160. tb += (t[2] - t[1]) * 1000 / n # ms per op backward
  161. mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0 # (GB)
  162. s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else 'list' for x in (x, y)) # shapes
  163. p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
  164. print(f'{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}')
  165. results.append([p, flops, mem, tf, tb, s_in, s_out])
  166. except Exception as e:
  167. print(e)
  168. results.append(None)
  169. torch.cuda.empty_cache()
  170. return results
  171. def is_parallel(model):
  172. # Returns True if model is of type DP or DDP
  173. return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
  174. def de_parallel(model):
  175. # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
  176. return model.module if is_parallel(model) else model
  177. def initialize_weights(model):
  178. for m in model.modules():
  179. t = type(m)
  180. if t is nn.Conv2d:
  181. pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  182. elif t is nn.BatchNorm2d:
  183. m.eps = 1e-3
  184. m.momentum = 0.03
  185. elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
  186. m.inplace = True
  187. def find_modules(model, mclass=nn.Conv2d):
  188. # Finds layer indices matching module class 'mclass'
  189. return [i for i, m in enumerate(model.module_list) if isinstance(m, mclass)]
  190. def sparsity(model):
  191. # Return global model sparsity
  192. a, b = 0, 0
  193. for p in model.parameters():
  194. a += p.numel()
  195. b += (p == 0).sum()
  196. return b / a
  197. def prune(model, amount=0.3):
  198. # Prune model to requested global sparsity
  199. import torch.nn.utils.prune as prune
  200. for name, m in model.named_modules():
  201. if isinstance(m, nn.Conv2d):
  202. prune.l1_unstructured(m, name='weight', amount=amount) # prune
  203. prune.remove(m, 'weight') # make permanent
  204. LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')
  205. def fuse_conv_and_bn(conv, bn):
  206. # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
  207. fusedconv = nn.Conv2d(conv.in_channels,
  208. conv.out_channels,
  209. kernel_size=conv.kernel_size,
  210. stride=conv.stride,
  211. padding=conv.padding,
  212. dilation=conv.dilation,
  213. groups=conv.groups,
  214. bias=True).requires_grad_(False).to(conv.weight.device)
  215. # Prepare filters
  216. w_conv = conv.weight.clone().view(conv.out_channels, -1)
  217. w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
  218. fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
  219. # Prepare spatial bias
  220. b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
  221. b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
  222. fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
  223. return fusedconv
  224. def model_info(model, verbose=False, imgsz=640):
  225. # Model information. img_size may be int or list, i.e. img_size=640 or img_size=[640, 320]
  226. n_p = sum(x.numel() for x in model.parameters()) # number parameters
  227. n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) # number gradients
  228. if verbose:
  229. print(f"{'layer':>5} {'name':>40} {'gradient':>9} {'parameters':>12} {'shape':>20} {'mu':>10} {'sigma':>10}")
  230. for i, (name, p) in enumerate(model.named_parameters()):
  231. name = name.replace('module_list.', '')
  232. print('%5g %40s %9s %12g %20s %10.3g %10.3g' %
  233. (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std()))
  234. try: # FLOPs
  235. p = next(model.parameters())
  236. stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
  237. im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
  238. flops = thop.profile(deepcopy(model), inputs=(im, ), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
  239. imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
  240. fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
  241. except Exception:
  242. fs = ''
  243. name = Path(model.yaml_file).stem.replace('yolov5', 'YOLOv5') if hasattr(model, 'yaml_file') else 'Model'
  244. LOGGER.info(f'{name} summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}')
  245. def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
  246. # Scales img(bs,3,y,x) by ratio constrained to gs-multiple
  247. if ratio == 1.0:
  248. return img
  249. h, w = img.shape[2:]
  250. s = (int(h * ratio), int(w * ratio)) # new size
  251. img = F.interpolate(img, size=s, mode='bilinear', align_corners=False) # resize
  252. if not same_shape: # pad/crop img
  253. h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
  254. return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
  255. def copy_attr(a, b, include=(), exclude=()):
  256. # Copy attributes from b to a, options to only include [...] and to exclude [...]
  257. for k, v in b.__dict__.items():
  258. if (len(include) and k not in include) or k.startswith('_') or k in exclude:
  259. continue
  260. else:
  261. setattr(a, k, v)
  262. def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
  263. # YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay
  264. g = [], [], [] # optimizer parameter groups
  265. bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
  266. for v in model.modules():
  267. for p_name, p in v.named_parameters(recurse=0):
  268. if p_name == 'bias': # bias (no decay)
  269. g[2].append(p)
  270. elif p_name == 'weight' and isinstance(v, bn): # weight (no decay)
  271. g[1].append(p)
  272. else:
  273. g[0].append(p) # weight (with decay)
  274. if name == 'Adam':
  275. optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum
  276. elif name == 'AdamW':
  277. optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
  278. elif name == 'RMSProp':
  279. optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum)
  280. elif name == 'SGD':
  281. optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
  282. else:
  283. raise NotImplementedError(f'Optimizer {name} not implemented.')
  284. optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay
  285. optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights)
  286. LOGGER.info(f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}) with parameter groups "
  287. f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias')
  288. return optimizer
  289. def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
  290. # YOLOv5 torch.hub.load() wrapper with smart error/issue handling
  291. if check_version(torch.__version__, '1.9.1'):
  292. kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
  293. if check_version(torch.__version__, '1.12.0'):
  294. kwargs['trust_repo'] = True # argument required starting in torch 0.12
  295. try:
  296. return torch.hub.load(repo, model, **kwargs)
  297. except Exception:
  298. return torch.hub.load(repo, model, force_reload=True, **kwargs)
  299. def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
  300. # Resume training from a partially trained checkpoint
  301. best_fitness = 0.0
  302. start_epoch = ckpt['epoch'] + 1
  303. if ckpt['optimizer'] is not None:
  304. optimizer.load_state_dict(ckpt['optimizer']) # optimizer
  305. best_fitness = ckpt['best_fitness']
  306. if ema and ckpt.get('ema'):
  307. ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA
  308. ema.updates = ckpt['updates']
  309. if resume:
  310. assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.\n' \
  311. f"Start a new training without --resume, i.e. 'python train.py --weights {weights}'"
  312. LOGGER.info(f'Resuming training from {weights} from epoch {start_epoch} to {epochs} total epochs')
  313. if epochs < start_epoch:
  314. LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.")
  315. epochs += ckpt['epoch'] # finetune additional epochs
  316. return best_fitness, start_epoch, epochs
  317. class EarlyStopping:
  318. # YOLOv5 simple early stopper
  319. def __init__(self, patience=30):
  320. self.best_fitness = 0.0 # i.e. mAP
  321. self.best_epoch = 0
  322. self.patience = patience or float('inf') # epochs to wait after fitness stops improving to stop
  323. self.possible_stop = False # possible stop may occur next epoch
  324. def __call__(self, epoch, fitness):
  325. if fitness >= self.best_fitness: # >= 0 to allow for early zero-fitness stage of training
  326. self.best_epoch = epoch
  327. self.best_fitness = fitness
  328. delta = epoch - self.best_epoch # epochs without improvement
  329. self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
  330. stop = delta >= self.patience # stop training if patience exceeded
  331. if stop:
  332. LOGGER.info(f'Stopping training early as no improvement observed in last {self.patience} epochs. '
  333. f'Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n'
  334. f'To update EarlyStopping(patience={self.patience}) pass a new patience value, '
  335. f'i.e. `python train.py --patience 300` or use `--patience 0` to disable EarlyStopping.')
  336. return stop
  337. class ModelEMA:
  338. """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
  339. Keeps a moving average of everything in the model state_dict (parameters and buffers)
  340. For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
  341. """
  342. def __init__(self, model, decay=0.9999, tau=2000, updates=0):
  343. # Create EMA
  344. self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
  345. self.updates = updates # number of EMA updates
  346. self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
  347. for p in self.ema.parameters():
  348. p.requires_grad_(False)
  349. def update(self, model):
  350. # Update EMA parameters
  351. self.updates += 1
  352. d = self.decay(self.updates)
  353. msd = de_parallel(model).state_dict() # model state_dict
  354. for k, v in self.ema.state_dict().items():
  355. if v.dtype.is_floating_point: # true for FP16 and FP32
  356. v *= d
  357. v += (1 - d) * msd[k].detach()
  358. # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
  359. def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
  360. # Update EMA attributes
  361. copy_attr(self.ema, model, include, exclude)