dataloaders.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Dataloaders
  4. """
  5. import os
  6. import random
  7. import cv2
  8. import numpy as np
  9. import torch
  10. from torch.utils.data import DataLoader, distributed
  11. from ..augmentations import augment_hsv, copy_paste, letterbox
  12. from ..dataloaders import InfiniteDataLoader, LoadImagesAndLabels, seed_worker
  13. from ..general import LOGGER, xyn2xy, xywhn2xyxy, xyxy2xywhn
  14. from ..torch_utils import torch_distributed_zero_first
  15. from .augmentations import mixup, random_perspective
  16. RANK = int(os.getenv('RANK', -1))
  17. def create_dataloader(path,
  18. imgsz,
  19. batch_size,
  20. stride,
  21. single_cls=False,
  22. hyp=None,
  23. augment=False,
  24. cache=False,
  25. pad=0.0,
  26. rect=False,
  27. rank=-1,
  28. workers=8,
  29. image_weights=False,
  30. quad=False,
  31. prefix='',
  32. shuffle=False,
  33. mask_downsample_ratio=1,
  34. overlap_mask=False,
  35. seed=0):
  36. if rect and shuffle:
  37. LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
  38. shuffle = False
  39. with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
  40. dataset = LoadImagesAndLabelsAndMasks(
  41. path,
  42. imgsz,
  43. batch_size,
  44. augment=augment, # augmentation
  45. hyp=hyp, # hyperparameters
  46. rect=rect, # rectangular batches
  47. cache_images=cache,
  48. single_cls=single_cls,
  49. stride=int(stride),
  50. pad=pad,
  51. image_weights=image_weights,
  52. prefix=prefix,
  53. downsample_ratio=mask_downsample_ratio,
  54. overlap=overlap_mask)
  55. batch_size = min(batch_size, len(dataset))
  56. nd = torch.cuda.device_count() # number of CUDA devices
  57. nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
  58. sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
  59. loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
  60. generator = torch.Generator()
  61. generator.manual_seed(6148914691236517205 + seed + RANK)
  62. return loader(
  63. dataset,
  64. batch_size=batch_size,
  65. shuffle=shuffle and sampler is None,
  66. num_workers=nw,
  67. sampler=sampler,
  68. pin_memory=True,
  69. collate_fn=LoadImagesAndLabelsAndMasks.collate_fn4 if quad else LoadImagesAndLabelsAndMasks.collate_fn,
  70. worker_init_fn=seed_worker,
  71. generator=generator,
  72. ), dataset
  73. class LoadImagesAndLabelsAndMasks(LoadImagesAndLabels): # for training/testing
  74. def __init__(
  75. self,
  76. path,
  77. img_size=640,
  78. batch_size=16,
  79. augment=False,
  80. hyp=None,
  81. rect=False,
  82. image_weights=False,
  83. cache_images=False,
  84. single_cls=False,
  85. stride=32,
  86. pad=0,
  87. min_items=0,
  88. prefix='',
  89. downsample_ratio=1,
  90. overlap=False,
  91. ):
  92. super().__init__(path, img_size, batch_size, augment, hyp, rect, image_weights, cache_images, single_cls,
  93. stride, pad, min_items, prefix)
  94. self.downsample_ratio = downsample_ratio
  95. self.overlap = overlap
  96. def __getitem__(self, index):
  97. index = self.indices[index] # linear, shuffled, or image_weights
  98. hyp = self.hyp
  99. mosaic = self.mosaic and random.random() < hyp['mosaic']
  100. masks = []
  101. if mosaic:
  102. # Load mosaic
  103. img, labels, segments = self.load_mosaic(index)
  104. shapes = None
  105. # MixUp augmentation
  106. if random.random() < hyp['mixup']:
  107. img, labels, segments = mixup(img, labels, segments, *self.load_mosaic(random.randint(0, self.n - 1)))
  108. else:
  109. # Load image
  110. img, (h0, w0), (h, w) = self.load_image(index)
  111. # Letterbox
  112. shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size # final letterboxed shape
  113. img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment)
  114. shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling
  115. labels = self.labels[index].copy()
  116. # [array, array, ....], array.shape=(num_points, 2), xyxyxyxy
  117. segments = self.segments[index].copy()
  118. if len(segments):
  119. for i_s in range(len(segments)):
  120. segments[i_s] = xyn2xy(
  121. segments[i_s],
  122. ratio[0] * w,
  123. ratio[1] * h,
  124. padw=pad[0],
  125. padh=pad[1],
  126. )
  127. if labels.size: # normalized xywh to pixel xyxy format
  128. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], ratio[0] * w, ratio[1] * h, padw=pad[0], padh=pad[1])
  129. if self.augment:
  130. img, labels, segments = random_perspective(img,
  131. labels,
  132. segments=segments,
  133. degrees=hyp['degrees'],
  134. translate=hyp['translate'],
  135. scale=hyp['scale'],
  136. shear=hyp['shear'],
  137. perspective=hyp['perspective'])
  138. nl = len(labels) # number of labels
  139. if nl:
  140. labels[:, 1:5] = xyxy2xywhn(labels[:, 1:5], w=img.shape[1], h=img.shape[0], clip=True, eps=1e-3)
  141. if self.overlap:
  142. masks, sorted_idx = polygons2masks_overlap(img.shape[:2],
  143. segments,
  144. downsample_ratio=self.downsample_ratio)
  145. masks = masks[None] # (640, 640) -> (1, 640, 640)
  146. labels = labels[sorted_idx]
  147. else:
  148. masks = polygons2masks(img.shape[:2], segments, color=1, downsample_ratio=self.downsample_ratio)
  149. masks = (torch.from_numpy(masks) if len(masks) else torch.zeros(1 if self.overlap else nl, img.shape[0] //
  150. self.downsample_ratio, img.shape[1] //
  151. self.downsample_ratio))
  152. # TODO: albumentations support
  153. if self.augment:
  154. # Albumentations
  155. # there are some augmentation that won't change boxes and masks,
  156. # so just be it for now.
  157. img, labels = self.albumentations(img, labels)
  158. nl = len(labels) # update after albumentations
  159. # HSV color-space
  160. augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
  161. # Flip up-down
  162. if random.random() < hyp['flipud']:
  163. img = np.flipud(img)
  164. if nl:
  165. labels[:, 2] = 1 - labels[:, 2]
  166. masks = torch.flip(masks, dims=[1])
  167. # Flip left-right
  168. if random.random() < hyp['fliplr']:
  169. img = np.fliplr(img)
  170. if nl:
  171. labels[:, 1] = 1 - labels[:, 1]
  172. masks = torch.flip(masks, dims=[2])
  173. # Cutouts # labels = cutout(img, labels, p=0.5)
  174. labels_out = torch.zeros((nl, 6))
  175. if nl:
  176. labels_out[:, 1:] = torch.from_numpy(labels)
  177. # Convert
  178. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  179. img = np.ascontiguousarray(img)
  180. return (torch.from_numpy(img), labels_out, self.im_files[index], shapes, masks)
  181. def load_mosaic(self, index):
  182. # YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
  183. labels4, segments4 = [], []
  184. s = self.img_size
  185. yc, xc = (int(random.uniform(-x, 2 * s + x)) for x in self.mosaic_border) # mosaic center x, y
  186. # 3 additional image indices
  187. indices = [index] + random.choices(self.indices, k=3) # 3 additional image indices
  188. for i, index in enumerate(indices):
  189. # Load image
  190. img, _, (h, w) = self.load_image(index)
  191. # place img in img4
  192. if i == 0: # top left
  193. img4 = np.full((s * 2, s * 2, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles
  194. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  195. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  196. elif i == 1: # top right
  197. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
  198. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  199. elif i == 2: # bottom left
  200. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
  201. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  202. elif i == 3: # bottom right
  203. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
  204. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  205. img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
  206. padw = x1a - x1b
  207. padh = y1a - y1b
  208. labels, segments = self.labels[index].copy(), self.segments[index].copy()
  209. if labels.size:
  210. labels[:, 1:] = xywhn2xyxy(labels[:, 1:], w, h, padw, padh) # normalized xywh to pixel xyxy format
  211. segments = [xyn2xy(x, w, h, padw, padh) for x in segments]
  212. labels4.append(labels)
  213. segments4.extend(segments)
  214. # Concat/clip labels
  215. labels4 = np.concatenate(labels4, 0)
  216. for x in (labels4[:, 1:], *segments4):
  217. np.clip(x, 0, 2 * s, out=x) # clip when using random_perspective()
  218. # img4, labels4 = replicate(img4, labels4) # replicate
  219. # Augment
  220. img4, labels4, segments4 = copy_paste(img4, labels4, segments4, p=self.hyp['copy_paste'])
  221. img4, labels4, segments4 = random_perspective(img4,
  222. labels4,
  223. segments4,
  224. degrees=self.hyp['degrees'],
  225. translate=self.hyp['translate'],
  226. scale=self.hyp['scale'],
  227. shear=self.hyp['shear'],
  228. perspective=self.hyp['perspective'],
  229. border=self.mosaic_border) # border to remove
  230. return img4, labels4, segments4
  231. @staticmethod
  232. def collate_fn(batch):
  233. img, label, path, shapes, masks = zip(*batch) # transposed
  234. batched_masks = torch.cat(masks, 0)
  235. for i, l in enumerate(label):
  236. l[:, 0] = i # add target image index for build_targets()
  237. return torch.stack(img, 0), torch.cat(label, 0), path, shapes, batched_masks
  238. def polygon2mask(img_size, polygons, color=1, downsample_ratio=1):
  239. """
  240. Args:
  241. img_size (tuple): The image size.
  242. polygons (np.ndarray): [N, M], N is the number of polygons,
  243. M is the number of points(Be divided by 2).
  244. """
  245. mask = np.zeros(img_size, dtype=np.uint8)
  246. polygons = np.asarray(polygons)
  247. polygons = polygons.astype(np.int32)
  248. shape = polygons.shape
  249. polygons = polygons.reshape(shape[0], -1, 2)
  250. cv2.fillPoly(mask, polygons, color=color)
  251. nh, nw = (img_size[0] // downsample_ratio, img_size[1] // downsample_ratio)
  252. # NOTE: fillPoly firstly then resize is trying the keep the same way
  253. # of loss calculation when mask-ratio=1.
  254. mask = cv2.resize(mask, (nw, nh))
  255. return mask
  256. def polygons2masks(img_size, polygons, color, downsample_ratio=1):
  257. """
  258. Args:
  259. img_size (tuple): The image size.
  260. polygons (list[np.ndarray]): each polygon is [N, M],
  261. N is the number of polygons,
  262. M is the number of points(Be divided by 2).
  263. """
  264. masks = []
  265. for si in range(len(polygons)):
  266. mask = polygon2mask(img_size, [polygons[si].reshape(-1)], color, downsample_ratio)
  267. masks.append(mask)
  268. return np.array(masks)
  269. def polygons2masks_overlap(img_size, segments, downsample_ratio=1):
  270. """Return a (640, 640) overlap mask."""
  271. masks = np.zeros((img_size[0] // downsample_ratio, img_size[1] // downsample_ratio),
  272. dtype=np.int32 if len(segments) > 255 else np.uint8)
  273. areas = []
  274. ms = []
  275. for si in range(len(segments)):
  276. mask = polygon2mask(
  277. img_size,
  278. [segments[si].reshape(-1)],
  279. downsample_ratio=downsample_ratio,
  280. color=1,
  281. )
  282. ms.append(mask)
  283. areas.append(mask.sum())
  284. areas = np.asarray(areas)
  285. index = np.argsort(-areas)
  286. ms = np.array(ms)[index]
  287. for i in range(len(segments)):
  288. mask = ms[i] * (i + 1)
  289. masks = masks + mask
  290. masks = np.clip(masks, a_min=0, a_max=i + 1)
  291. return masks, index