yolo.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. YOLO-specific modules
  4. Usage:
  5. $ python models/yolo.py --cfg yolov5s.yaml
  6. """
  7. import argparse
  8. import contextlib
  9. import os
  10. import platform
  11. import sys
  12. from copy import deepcopy
  13. from pathlib import Path
  14. FILE = Path(__file__).resolve()
  15. ROOT = FILE.parents[1] # YOLOv5 root directory
  16. if str(ROOT) not in sys.path:
  17. sys.path.append(str(ROOT)) # add ROOT to PATH
  18. if platform.system() != 'Windows':
  19. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  20. from models.common import * # noqa
  21. from models.experimental import * # noqa
  22. from utils.autoanchor import check_anchor_order
  23. from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
  24. from utils.plots import feature_visualization
  25. from utils.torch_utils import (fuse_conv_and_bn, initialize_weights, model_info, profile, scale_img, select_device,
  26. time_sync)
  27. try:
  28. import thop # for FLOPs computation
  29. except ImportError:
  30. thop = None
  31. class Detect(nn.Module):
  32. # YOLOv5 Detect head for detection models
  33. stride = None # strides computed during build
  34. dynamic = False # force grid reconstruction
  35. export = False # export mode
  36. def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
  37. super().__init__()
  38. self.nc = nc # number of classes
  39. self.no = nc + 5 # number of outputs per anchor
  40. self.nl = len(anchors) # number of detection layers
  41. self.na = len(anchors[0]) // 2 # number of anchors
  42. self.grid = [torch.empty(0) for _ in range(self.nl)] # init grid
  43. self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] # init anchor grid
  44. self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
  45. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  46. self.inplace = inplace # use inplace ops (e.g. slice assignment)
  47. def forward(self, x):
  48. z = [] # inference output
  49. for i in range(self.nl):
  50. x[i] = self.m[i](x[i]) # conv
  51. bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
  52. x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
  53. if not self.training: # inference
  54. if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
  55. self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
  56. if isinstance(self, Segment): # (boxes + masks)
  57. xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
  58. xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
  59. wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
  60. y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
  61. else: # Detect (boxes only)
  62. xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
  63. xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
  64. wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
  65. y = torch.cat((xy, wh, conf), 4)
  66. z.append(y.view(bs, self.na * nx * ny, self.no))
  67. return x if self.training else (torch.cat(z, 1), ) if self.export else (torch.cat(z, 1), x)
  68. def _make_grid(self, nx=20, ny=20, i=0, torch_1_10=check_version(torch.__version__, '1.10.0')):
  69. d = self.anchors[i].device
  70. t = self.anchors[i].dtype
  71. shape = 1, self.na, ny, nx, 2 # grid shape
  72. y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
  73. yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
  74. grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
  75. anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
  76. return grid, anchor_grid
  77. class Segment(Detect):
  78. # YOLOv5 Segment head for segmentation models
  79. def __init__(self, nc=80, anchors=(), nm=32, npr=256, ch=(), inplace=True):
  80. super().__init__(nc, anchors, ch, inplace)
  81. self.nm = nm # number of masks
  82. self.npr = npr # number of protos
  83. self.no = 5 + nc + self.nm # number of outputs per anchor
  84. self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
  85. self.proto = Proto(ch[0], self.npr, self.nm) # protos
  86. self.detect = Detect.forward
  87. def forward(self, x):
  88. p = self.proto(x[0])
  89. x = self.detect(self, x)
  90. return (x, p) if self.training else (x[0], p) if self.export else (x[0], p, x[1])
  91. class BaseModel(nn.Module):
  92. # YOLOv5 base model
  93. def forward(self, x, profile=False, visualize=False):
  94. return self._forward_once(x, profile, visualize) # single-scale inference, train
  95. def _forward_once(self, x, profile=False, visualize=False):
  96. y, dt = [], [] # outputs
  97. for m in self.model:
  98. if m.f != -1: # if not from previous layer
  99. x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
  100. if profile:
  101. self._profile_one_layer(m, x, dt)
  102. x = m(x) # run
  103. y.append(x if m.i in self.save else None) # save output
  104. if visualize:
  105. feature_visualization(x, m.type, m.i, save_dir=visualize)
  106. return x
  107. def _profile_one_layer(self, m, x, dt):
  108. c = m == self.model[-1] # is final layer, copy input as inplace fix
  109. o = thop.profile(m, inputs=(x.copy() if c else x, ), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
  110. t = time_sync()
  111. for _ in range(10):
  112. m(x.copy() if c else x)
  113. dt.append((time_sync() - t) * 100)
  114. if m == self.model[0]:
  115. LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
  116. LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
  117. if c:
  118. LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
  119. def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
  120. LOGGER.info('Fusing layers... ')
  121. for m in self.model.modules():
  122. if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
  123. m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
  124. delattr(m, 'bn') # remove batchnorm
  125. m.forward = m.forward_fuse # update forward
  126. self.info()
  127. return self
  128. def info(self, verbose=False, img_size=640): # print model information
  129. model_info(self, verbose, img_size)
  130. def _apply(self, fn):
  131. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  132. self = super()._apply(fn)
  133. m = self.model[-1] # Detect()
  134. if isinstance(m, (Detect, Segment)):
  135. m.stride = fn(m.stride)
  136. m.grid = list(map(fn, m.grid))
  137. if isinstance(m.anchor_grid, list):
  138. m.anchor_grid = list(map(fn, m.anchor_grid))
  139. return self
  140. class DetectionModel(BaseModel):
  141. # YOLOv5 detection model
  142. def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
  143. super().__init__()
  144. if isinstance(cfg, dict):
  145. self.yaml = cfg # model dict
  146. else: # is *.yaml
  147. import yaml # for torch hub
  148. self.yaml_file = Path(cfg).name
  149. with open(cfg, encoding='ascii', errors='ignore') as f:
  150. self.yaml = yaml.safe_load(f) # model dict
  151. # Define model
  152. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  153. if nc and nc != self.yaml['nc']:
  154. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  155. self.yaml['nc'] = nc # override yaml value
  156. if anchors:
  157. LOGGER.info(f'Overriding model.yaml anchors with anchors={anchors}')
  158. self.yaml['anchors'] = round(anchors) # override yaml value
  159. self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
  160. self.names = [str(i) for i in range(self.yaml['nc'])] # default names
  161. self.inplace = self.yaml.get('inplace', True)
  162. # Build strides, anchors
  163. m = self.model[-1] # Detect()
  164. if isinstance(m, (Detect, Segment)):
  165. s = 256 # 2x min stride
  166. m.inplace = self.inplace
  167. forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
  168. m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
  169. check_anchor_order(m)
  170. m.anchors /= m.stride.view(-1, 1, 1)
  171. self.stride = m.stride
  172. self._initialize_biases() # only run once
  173. # Init weights, biases
  174. initialize_weights(self)
  175. self.info()
  176. LOGGER.info('')
  177. def forward(self, x, augment=False, profile=False, visualize=False):
  178. if augment:
  179. return self._forward_augment(x) # augmented inference, None
  180. return self._forward_once(x, profile, visualize) # single-scale inference, train
  181. def _forward_augment(self, x):
  182. img_size = x.shape[-2:] # height, width
  183. s = [1, 0.83, 0.67] # scales
  184. f = [None, 3, None] # flips (2-ud, 3-lr)
  185. y = [] # outputs
  186. for si, fi in zip(s, f):
  187. xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max()))
  188. yi = self._forward_once(xi)[0] # forward
  189. # cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
  190. yi = self._descale_pred(yi, fi, si, img_size)
  191. y.append(yi)
  192. y = self._clip_augmented(y) # clip augmented tails
  193. return torch.cat(y, 1), None # augmented inference, train
  194. def _descale_pred(self, p, flips, scale, img_size):
  195. # de-scale predictions following augmented inference (inverse operation)
  196. if self.inplace:
  197. p[..., :4] /= scale # de-scale
  198. if flips == 2:
  199. p[..., 1] = img_size[0] - p[..., 1] # de-flip ud
  200. elif flips == 3:
  201. p[..., 0] = img_size[1] - p[..., 0] # de-flip lr
  202. else:
  203. x, y, wh = p[..., 0:1] / scale, p[..., 1:2] / scale, p[..., 2:4] / scale # de-scale
  204. if flips == 2:
  205. y = img_size[0] - y # de-flip ud
  206. elif flips == 3:
  207. x = img_size[1] - x # de-flip lr
  208. p = torch.cat((x, y, wh, p[..., 4:]), -1)
  209. return p
  210. def _clip_augmented(self, y):
  211. # Clip YOLOv5 augmented inference tails
  212. nl = self.model[-1].nl # number of detection layers (P3-P5)
  213. g = sum(4 ** x for x in range(nl)) # grid points
  214. e = 1 # exclude layer count
  215. i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
  216. y[0] = y[0][:, :-i] # large
  217. i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
  218. y[-1] = y[-1][:, i:] # small
  219. return y
  220. def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
  221. # https://arxiv.org/abs/1708.02002 section 3.3
  222. # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
  223. m = self.model[-1] # Detect() module
  224. for mi, s in zip(m.m, m.stride): # from
  225. b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
  226. b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
  227. b.data[:, 5:5 + m.nc] += math.log(0.6 / (m.nc - 0.99999)) if cf is None else torch.log(cf / cf.sum()) # cls
  228. mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  229. Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
  230. class SegmentationModel(DetectionModel):
  231. # YOLOv5 segmentation model
  232. def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, anchors=None):
  233. super().__init__(cfg, ch, nc, anchors)
  234. class ClassificationModel(BaseModel):
  235. # YOLOv5 classification model
  236. def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
  237. super().__init__()
  238. self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
  239. def _from_detection_model(self, model, nc=1000, cutoff=10):
  240. # Create a YOLOv5 classification model from a YOLOv5 detection model
  241. if isinstance(model, DetectMultiBackend):
  242. model = model.model # unwrap DetectMultiBackend
  243. model.model = model.model[:cutoff] # backbone
  244. m = model.model[-1] # last layer
  245. ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
  246. c = Classify(ch, nc) # Classify()
  247. c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
  248. model.model[-1] = c # replace
  249. self.model = model.model
  250. self.stride = model.stride
  251. self.save = []
  252. self.nc = nc
  253. def _from_yaml(self, cfg):
  254. # Create a YOLOv5 classification model from a *.yaml file
  255. self.model = None
  256. def parse_model(d, ch): # model_dict, input_channels(3)
  257. # Parse a YOLOv5 model.yaml dictionary
  258. LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
  259. anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
  260. if act:
  261. Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
  262. LOGGER.info(f"{colorstr('activation:')} {act}") # print
  263. na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
  264. no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
  265. layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
  266. for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
  267. m = eval(m) if isinstance(m, str) else m # eval strings
  268. for j, a in enumerate(args):
  269. with contextlib.suppress(NameError):
  270. args[j] = eval(a) if isinstance(a, str) else a # eval strings
  271. n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
  272. if m in {
  273. Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
  274. BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
  275. c1, c2 = ch[f], args[0]
  276. if c2 != no: # if not output
  277. c2 = make_divisible(c2 * gw, 8)
  278. args = [c1, c2, *args[1:]]
  279. if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3x}:
  280. args.insert(2, n) # number of repeats
  281. n = 1
  282. elif m is nn.BatchNorm2d:
  283. args = [ch[f]]
  284. elif m is Concat:
  285. c2 = sum(ch[x] for x in f)
  286. # TODO: channel, gw, gd
  287. elif m in {Detect, Segment}:
  288. args.append([ch[x] for x in f])
  289. if isinstance(args[1], int): # number of anchors
  290. args[1] = [list(range(args[1] * 2))] * len(f)
  291. if m is Segment:
  292. args[3] = make_divisible(args[3] * gw, 8)
  293. elif m is Contract:
  294. c2 = ch[f] * args[0] ** 2
  295. elif m is Expand:
  296. c2 = ch[f] // args[0] ** 2
  297. else:
  298. c2 = ch[f]
  299. m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
  300. t = str(m)[8:-2].replace('__main__.', '') # module type
  301. np = sum(x.numel() for x in m_.parameters()) # number params
  302. m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
  303. LOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # print
  304. save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
  305. layers.append(m_)
  306. if i == 0:
  307. ch = []
  308. ch.append(c2)
  309. return nn.Sequential(*layers), sorted(save)
  310. if __name__ == '__main__':
  311. parser = argparse.ArgumentParser()
  312. parser.add_argument('--cfg', type=str, default='yolov5s.yaml', help='model.yaml')
  313. parser.add_argument('--batch-size', type=int, default=1, help='total batch size for all GPUs')
  314. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
  315. parser.add_argument('--profile', action='store_true', help='profile model speed')
  316. parser.add_argument('--line-profile', action='store_true', help='profile model speed layer by layer')
  317. parser.add_argument('--test', action='store_true', help='test all yolo*.yaml')
  318. opt = parser.parse_args()
  319. opt.cfg = check_yaml(opt.cfg) # check YAML
  320. print_args(vars(opt))
  321. device = select_device(opt.device)
  322. # Create model
  323. im = torch.rand(opt.batch_size, 3, 640, 640).to(device)
  324. model = Model(opt.cfg).to(device)
  325. # Options
  326. if opt.line_profile: # profile layer by layer
  327. model(im, profile=True)
  328. elif opt.profile: # profile forward-backward
  329. results = profile(input=im, ops=[model], n=3)
  330. elif opt.test: # test all models
  331. for cfg in Path(ROOT / 'models').rglob('yolo*.yaml'):
  332. try:
  333. _ = Model(cfg)
  334. except Exception as e:
  335. print(f'Error in {cfg}: {e}')
  336. else: # report fused model summary
  337. model.fuse()