autoanchor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. AutoAnchor utils
  4. """
  5. import random
  6. import numpy as np
  7. import torch
  8. import yaml
  9. from tqdm import tqdm
  10. from utils import TryExcept
  11. from utils.general import LOGGER, TQDM_BAR_FORMAT, colorstr
  12. PREFIX = colorstr('AutoAnchor: ')
  13. def check_anchor_order(m):
  14. # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
  15. a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer
  16. da = a[-1] - a[0] # delta a
  17. ds = m.stride[-1] - m.stride[0] # delta s
  18. if da and (da.sign() != ds.sign()): # same order
  19. LOGGER.info(f'{PREFIX}Reversing anchor order')
  20. m.anchors[:] = m.anchors.flip(0)
  21. @TryExcept(f'{PREFIX}ERROR')
  22. def check_anchors(dataset, model, thr=4.0, imgsz=640):
  23. # Check anchor fit to data, recompute if necessary
  24. m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
  25. shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  26. scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale
  27. wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh
  28. def metric(k): # compute metric
  29. r = wh[:, None] / k[None]
  30. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  31. best = x.max(1)[0] # best_x
  32. aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold
  33. bpr = (best > 1 / thr).float().mean() # best possible recall
  34. return bpr, aat
  35. stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides
  36. anchors = m.anchors.clone() * stride # current anchors
  37. bpr, aat = metric(anchors.cpu().view(-1, 2))
  38. s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). '
  39. if bpr > 0.98: # threshold to recompute
  40. LOGGER.info(f'{s}Current anchors are a good fit to dataset ✅')
  41. else:
  42. LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
  43. na = m.anchors.numel() // 2 # number of anchors
  44. anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
  45. new_bpr = metric(anchors)[0]
  46. if new_bpr > bpr: # replace anchors
  47. anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
  48. m.anchors[:] = anchors.clone().view_as(m.anchors)
  49. check_anchor_order(m) # must be in pixel-space (not grid-space)
  50. m.anchors /= stride
  51. s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)'
  52. else:
  53. s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)'
  54. LOGGER.info(s)
  55. def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True):
  56. """ Creates kmeans-evolved anchors from training dataset
  57. Arguments:
  58. dataset: path to data.yaml, or a loaded dataset
  59. n: number of anchors
  60. img_size: image size used for training
  61. thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0
  62. gen: generations to evolve anchors using genetic algorithm
  63. verbose: print all results
  64. Return:
  65. k: kmeans evolved anchors
  66. Usage:
  67. from utils.autoanchor import *; _ = kmean_anchors()
  68. """
  69. from scipy.cluster.vq import kmeans
  70. npr = np.random
  71. thr = 1 / thr
  72. def metric(k, wh): # compute metrics
  73. r = wh[:, None] / k[None]
  74. x = torch.min(r, 1 / r).min(2)[0] # ratio metric
  75. # x = wh_iou(wh, torch.tensor(k)) # iou metric
  76. return x, x.max(1)[0] # x, best_x
  77. def anchor_fitness(k): # mutation fitness
  78. _, best = metric(torch.tensor(k, dtype=torch.float32), wh)
  79. return (best * (best > thr).float()).mean() # fitness
  80. def print_results(k, verbose=True):
  81. k = k[np.argsort(k.prod(1))] # sort small to large
  82. x, best = metric(k, wh0)
  83. bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr
  84. s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \
  85. f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \
  86. f'past_thr={x[x > thr].mean():.3f}-mean: '
  87. for x in k:
  88. s += '%i,%i, ' % (round(x[0]), round(x[1]))
  89. if verbose:
  90. LOGGER.info(s[:-2])
  91. return k
  92. if isinstance(dataset, str): # *.yaml file
  93. with open(dataset, errors='ignore') as f:
  94. data_dict = yaml.safe_load(f) # model dict
  95. from utils.dataloaders import LoadImagesAndLabels
  96. dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True)
  97. # Get label wh
  98. shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
  99. wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
  100. # Filter
  101. i = (wh0 < 3.0).any(1).sum()
  102. if i:
  103. LOGGER.info(f'{PREFIX}WARNING ⚠️ Extremely small objects found: {i} of {len(wh0)} labels are <3 pixels in size')
  104. wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
  105. # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
  106. # Kmeans init
  107. try:
  108. LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...')
  109. assert n <= len(wh) # apply overdetermined constraint
  110. s = wh.std(0) # sigmas for whitening
  111. k = kmeans(wh / s, n, iter=30)[0] * s # points
  112. assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar
  113. except Exception:
  114. LOGGER.warning(f'{PREFIX}WARNING ⚠️ switching strategies from kmeans to random init')
  115. k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init
  116. wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0))
  117. k = print_results(k, verbose=False)
  118. # Plot
  119. # k, d = [None] * 20, [None] * 20
  120. # for i in tqdm(range(1, 21)):
  121. # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance
  122. # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True)
  123. # ax = ax.ravel()
  124. # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.')
  125. # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh
  126. # ax[0].hist(wh[wh[:, 0]<100, 0],400)
  127. # ax[1].hist(wh[wh[:, 1]<100, 1],400)
  128. # fig.savefig('wh.png', dpi=200)
  129. # Evolve
  130. f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
  131. pbar = tqdm(range(gen), bar_format=TQDM_BAR_FORMAT) # progress bar
  132. for _ in pbar:
  133. v = np.ones(sh)
  134. while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
  135. v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
  136. kg = (k.copy() * v).clip(min=2.0)
  137. fg = anchor_fitness(kg)
  138. if fg > f:
  139. f, k = fg, kg.copy()
  140. pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}'
  141. if verbose:
  142. print_results(k, verbose)
  143. return print_results(k).astype(np.float32)