plots.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Plotting utils
  4. """
  5. import contextlib
  6. import math
  7. import os
  8. from copy import copy
  9. from pathlib import Path
  10. from urllib.error import URLError
  11. import cv2
  12. import matplotlib
  13. import matplotlib.pyplot as plt
  14. import numpy as np
  15. import pandas as pd
  16. import seaborn as sn
  17. import torch
  18. from PIL import Image, ImageDraw, ImageFont
  19. from scipy.ndimage.filters import gaussian_filter1d
  20. from utils import TryExcept, threaded
  21. from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_boxes, increment_path,
  22. is_ascii, xywh2xyxy, xyxy2xywh)
  23. from utils.metrics import fitness
  24. from utils.segment.general import scale_image
  25. # Settings
  26. RANK = int(os.getenv('RANK', -1))
  27. matplotlib.rc('font', **{'size': 11})
  28. matplotlib.use('Agg') # for writing to files only
  29. class Colors:
  30. # Ultralytics color palette https://ultralytics.com/
  31. def __init__(self):
  32. # hex = matplotlib.colors.TABLEAU_COLORS.values()
  33. hexs = ('FF3838', 'FF9D97', 'FF701F', 'FFB21D', 'CFD231', '48F90A', '92CC17', '3DDB86', '1A9334', '00D4BB',
  34. '2C99A8', '00C2FF', '344593', '6473FF', '0018EC', '8438FF', '520085', 'CB38FF', 'FF95C8', 'FF37C7')
  35. self.palette = [self.hex2rgb(f'#{c}') for c in hexs]
  36. self.n = len(self.palette)
  37. def __call__(self, i, bgr=False):
  38. c = self.palette[int(i) % self.n]
  39. return (c[2], c[1], c[0]) if bgr else c
  40. @staticmethod
  41. def hex2rgb(h): # rgb order (PIL)
  42. return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4))
  43. colors = Colors() # create instance for 'from utils.plots import colors'
  44. def check_pil_font(font=FONT, size=10):
  45. # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary
  46. font = Path(font)
  47. font = font if font.exists() else (CONFIG_DIR / font.name)
  48. try:
  49. return ImageFont.truetype(str(font) if font.exists() else font.name, size)
  50. except Exception: # download if missing
  51. try:
  52. check_font(font)
  53. return ImageFont.truetype(str(font), size)
  54. except TypeError:
  55. check_requirements('Pillow>=8.4.0') # known issue https://github.com/ultralytics/yolov5/issues/5374
  56. except URLError: # not online
  57. return ImageFont.load_default()
  58. class Annotator:
  59. # YOLOv5 Annotator for train/val mosaics and jpgs and detect/hub inference annotations
  60. def __init__(self, im, line_width=None, font_size=None, font='Arial.ttf', pil=False, example='abc'):
  61. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to Annotator() input images.'
  62. non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
  63. self.pil = pil or non_ascii
  64. if self.pil: # use PIL
  65. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  66. self.draw = ImageDraw.Draw(self.im)
  67. self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font,
  68. size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12))
  69. else: # use cv2
  70. self.im = im
  71. self.lw = line_width or max(round(sum(im.shape) / 2 * 0.003), 2) # line width
  72. def box_label(self, box, label='', color=(128, 128, 128), txt_color=(255, 255, 255)):
  73. # Add one xyxy box to image with label
  74. if self.pil or not is_ascii(label):
  75. self.draw.rectangle(box, width=self.lw, outline=color) # box
  76. if label:
  77. w, h = self.font.getsize(label) # text width, height (WARNING: deprecated) in 9.2.0
  78. # _, _, w, h = self.font.getbbox(label) # text width, height (New)
  79. outside = box[1] - h >= 0 # label fits outside box
  80. self.draw.rectangle(
  81. (box[0], box[1] - h if outside else box[1], box[0] + w + 1,
  82. box[1] + 1 if outside else box[1] + h + 1),
  83. fill=color,
  84. )
  85. # self.draw.text((box[0], box[1]), label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
  86. self.draw.text((box[0], box[1] - h if outside else box[1]), label, fill=txt_color, font=self.font)
  87. else: # cv2
  88. p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  89. cv2.rectangle(self.im, p1, p2, color, thickness=self.lw, lineType=cv2.LINE_AA)
  90. if label:
  91. tf = max(self.lw - 1, 1) # font thickness
  92. w, h = cv2.getTextSize(label, 0, fontScale=self.lw / 3, thickness=tf)[0] # text width, height
  93. outside = p1[1] - h >= 3
  94. p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
  95. cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
  96. cv2.putText(self.im,
  97. label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
  98. 0,
  99. self.lw / 3,
  100. txt_color,
  101. thickness=tf,
  102. lineType=cv2.LINE_AA)
  103. def masks(self, masks, colors, im_gpu, alpha=0.5, retina_masks=False):
  104. """Plot masks at once.
  105. Args:
  106. masks (tensor): predicted masks on cuda, shape: [n, h, w]
  107. colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
  108. im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
  109. alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
  110. """
  111. if self.pil:
  112. # convert to numpy first
  113. self.im = np.asarray(self.im).copy()
  114. if len(masks) == 0:
  115. self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
  116. colors = torch.tensor(colors, device=im_gpu.device, dtype=torch.float32) / 255.0
  117. colors = colors[:, None, None] # shape(n,1,1,3)
  118. masks = masks.unsqueeze(3) # shape(n,h,w,1)
  119. masks_color = masks * (colors * alpha) # shape(n,h,w,3)
  120. inv_alph_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
  121. mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
  122. im_gpu = im_gpu.flip(dims=[0]) # flip channel
  123. im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
  124. im_gpu = im_gpu * inv_alph_masks[-1] + mcs
  125. im_mask = (im_gpu * 255).byte().cpu().numpy()
  126. self.im[:] = im_mask if retina_masks else scale_image(im_gpu.shape, im_mask, self.im.shape)
  127. if self.pil:
  128. # convert im back to PIL and update draw
  129. self.fromarray(self.im)
  130. def rectangle(self, xy, fill=None, outline=None, width=1):
  131. # Add rectangle to image (PIL-only)
  132. self.draw.rectangle(xy, fill, outline, width)
  133. def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
  134. # Add text to image (PIL-only)
  135. if anchor == 'bottom': # start y from font bottom
  136. w, h = self.font.getsize(text) # text width, height
  137. xy[1] += 1 - h
  138. self.draw.text(xy, text, fill=txt_color, font=self.font)
  139. def fromarray(self, im):
  140. # Update self.im from a numpy array
  141. self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
  142. self.draw = ImageDraw.Draw(self.im)
  143. def result(self):
  144. # Return annotated image as array
  145. return np.asarray(self.im)
  146. def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')):
  147. """
  148. x: Features to be visualized
  149. module_type: Module type
  150. stage: Module stage within model
  151. n: Maximum number of feature maps to plot
  152. save_dir: Directory to save results
  153. """
  154. if 'Detect' not in module_type:
  155. batch, channels, height, width = x.shape # batch, channels, height, width
  156. if height > 1 and width > 1:
  157. f = save_dir / f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename
  158. blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
  159. n = min(n, channels) # number of plots
  160. fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
  161. ax = ax.ravel()
  162. plt.subplots_adjust(wspace=0.05, hspace=0.05)
  163. for i in range(n):
  164. ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
  165. ax[i].axis('off')
  166. LOGGER.info(f'Saving {f}... ({n}/{channels})')
  167. plt.savefig(f, dpi=300, bbox_inches='tight')
  168. plt.close()
  169. np.save(str(f.with_suffix('.npy')), x[0].cpu().numpy()) # npy save
  170. def hist2d(x, y, n=100):
  171. # 2d histogram used in labels.png and evolve.png
  172. xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
  173. hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
  174. xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
  175. yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
  176. return np.log(hist[xidx, yidx])
  177. def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
  178. from scipy.signal import butter, filtfilt
  179. # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
  180. def butter_lowpass(cutoff, fs, order):
  181. nyq = 0.5 * fs
  182. normal_cutoff = cutoff / nyq
  183. return butter(order, normal_cutoff, btype='low', analog=False)
  184. b, a = butter_lowpass(cutoff, fs, order=order)
  185. return filtfilt(b, a, data) # forward-backward filter
  186. def output_to_target(output, max_det=300):
  187. # Convert model output to target format [batch_id, class_id, x, y, w, h, conf] for plotting
  188. targets = []
  189. for i, o in enumerate(output):
  190. box, conf, cls = o[:max_det, :6].cpu().split((4, 1, 1), 1)
  191. j = torch.full((conf.shape[0], 1), i)
  192. targets.append(torch.cat((j, cls, xyxy2xywh(box), conf), 1))
  193. return torch.cat(targets, 0).numpy()
  194. @threaded
  195. def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
  196. # Plot image grid with labels
  197. if isinstance(images, torch.Tensor):
  198. images = images.cpu().float().numpy()
  199. if isinstance(targets, torch.Tensor):
  200. targets = targets.cpu().numpy()
  201. max_size = 1920 # max image size
  202. max_subplots = 16 # max image subplots, i.e. 4x4
  203. bs, _, h, w = images.shape # batch size, _, height, width
  204. bs = min(bs, max_subplots) # limit plot images
  205. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  206. if np.max(images[0]) <= 1:
  207. images *= 255 # de-normalise (optional)
  208. # Build Image
  209. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  210. for i, im in enumerate(images):
  211. if i == max_subplots: # if last batch has fewer images than we expect
  212. break
  213. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  214. im = im.transpose(1, 2, 0)
  215. mosaic[y:y + h, x:x + w, :] = im
  216. # Resize (optional)
  217. scale = max_size / ns / max(h, w)
  218. if scale < 1:
  219. h = math.ceil(scale * h)
  220. w = math.ceil(scale * w)
  221. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  222. # Annotate
  223. fs = int((h + w) * ns * 0.01) # font size
  224. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
  225. for i in range(i + 1):
  226. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  227. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  228. if paths:
  229. annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  230. if len(targets) > 0:
  231. ti = targets[targets[:, 0] == i] # image targets
  232. boxes = xywh2xyxy(ti[:, 2:6]).T
  233. classes = ti[:, 1].astype('int')
  234. labels = ti.shape[1] == 6 # labels if no conf column
  235. conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
  236. if boxes.shape[1]:
  237. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  238. boxes[[0, 2]] *= w # scale to pixels
  239. boxes[[1, 3]] *= h
  240. elif scale < 1: # absolute coords need scale if image scales
  241. boxes *= scale
  242. boxes[[0, 2]] += x
  243. boxes[[1, 3]] += y
  244. for j, box in enumerate(boxes.T.tolist()):
  245. cls = classes[j]
  246. color = colors(cls)
  247. cls = names[cls] if names else cls
  248. if labels or conf[j] > 0.25: # 0.25 conf thresh
  249. label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
  250. annotator.box_label(box, label, color=color)
  251. annotator.im.save(fname) # save
  252. def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
  253. # Plot LR simulating training for full epochs
  254. optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
  255. y = []
  256. for _ in range(epochs):
  257. scheduler.step()
  258. y.append(optimizer.param_groups[0]['lr'])
  259. plt.plot(y, '.-', label='LR')
  260. plt.xlabel('epoch')
  261. plt.ylabel('LR')
  262. plt.grid()
  263. plt.xlim(0, epochs)
  264. plt.ylim(0)
  265. plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
  266. plt.close()
  267. def plot_val_txt(): # from utils.plots import *; plot_val()
  268. # Plot val.txt histograms
  269. x = np.loadtxt('val.txt', dtype=np.float32)
  270. box = xyxy2xywh(x[:, :4])
  271. cx, cy = box[:, 0], box[:, 1]
  272. fig, ax = plt.subplots(1, 1, figsize=(6, 6), tight_layout=True)
  273. ax.hist2d(cx, cy, bins=600, cmax=10, cmin=0)
  274. ax.set_aspect('equal')
  275. plt.savefig('hist2d.png', dpi=300)
  276. fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True)
  277. ax[0].hist(cx, bins=600)
  278. ax[1].hist(cy, bins=600)
  279. plt.savefig('hist1d.png', dpi=200)
  280. def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()
  281. # Plot targets.txt histograms
  282. x = np.loadtxt('targets.txt', dtype=np.float32).T
  283. s = ['x targets', 'y targets', 'width targets', 'height targets']
  284. fig, ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)
  285. ax = ax.ravel()
  286. for i in range(4):
  287. ax[i].hist(x[i], bins=100, label=f'{x[i].mean():.3g} +/- {x[i].std():.3g}')
  288. ax[i].legend()
  289. ax[i].set_title(s[i])
  290. plt.savefig('targets.jpg', dpi=200)
  291. def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_val_study()
  292. # Plot file=study.txt generated by val.py (or plot all study*.txt in dir)
  293. save_dir = Path(file).parent if file else Path(dir)
  294. plot2 = False # plot additional results
  295. if plot2:
  296. ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()
  297. fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
  298. # for f in [save_dir / f'study_coco_{x}.txt' for x in ['yolov5n6', 'yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
  299. for f in sorted(save_dir.glob('study*.txt')):
  300. y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
  301. x = np.arange(y.shape[1]) if x is None else np.array(x)
  302. if plot2:
  303. s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
  304. for i in range(7):
  305. ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
  306. ax[i].set_title(s[i])
  307. j = y[3].argmax() + 1
  308. ax2.plot(y[5, 1:j],
  309. y[3, 1:j] * 1E2,
  310. '.-',
  311. linewidth=2,
  312. markersize=8,
  313. label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))
  314. ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
  315. 'k.-',
  316. linewidth=2,
  317. markersize=8,
  318. alpha=.25,
  319. label='EfficientDet')
  320. ax2.grid(alpha=0.2)
  321. ax2.set_yticks(np.arange(20, 60, 5))
  322. ax2.set_xlim(0, 57)
  323. ax2.set_ylim(25, 55)
  324. ax2.set_xlabel('GPU Speed (ms/img)')
  325. ax2.set_ylabel('COCO AP val')
  326. ax2.legend(loc='lower right')
  327. f = save_dir / 'study.png'
  328. print(f'Saving {f}...')
  329. plt.savefig(f, dpi=300)
  330. @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
  331. def plot_labels(labels, names=(), save_dir=Path('')):
  332. # plot dataset labels
  333. LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
  334. c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
  335. nc = int(c.max() + 1) # number of classes
  336. x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
  337. # seaborn correlogram
  338. sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
  339. plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
  340. plt.close()
  341. # matplotlib labels
  342. matplotlib.use('svg') # faster
  343. ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
  344. y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
  345. with contextlib.suppress(Exception): # color histogram bars by class
  346. [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
  347. ax[0].set_ylabel('instances')
  348. if 0 < len(names) < 30:
  349. ax[0].set_xticks(range(len(names)))
  350. ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
  351. else:
  352. ax[0].set_xlabel('classes')
  353. sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
  354. sn.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)
  355. # rectangles
  356. labels[:, 1:3] = 0.5 # center
  357. labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
  358. img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
  359. for cls, *box in labels[:1000]:
  360. ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
  361. ax[1].imshow(img)
  362. ax[1].axis('off')
  363. for a in [0, 1, 2, 3]:
  364. for s in ['top', 'right', 'left', 'bottom']:
  365. ax[a].spines[s].set_visible(False)
  366. plt.savefig(save_dir / 'labels.jpg', dpi=200)
  367. matplotlib.use('Agg')
  368. plt.close()
  369. def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
  370. # Show classification image grid with labels (optional) and predictions (optional)
  371. from utils.augmentations import denormalize
  372. names = names or [f'class{i}' for i in range(1000)]
  373. blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
  374. dim=0) # select batch index 0, block by channels
  375. n = min(len(blocks), nmax) # number of plots
  376. m = min(8, round(n ** 0.5)) # 8 x 8 default
  377. fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
  378. ax = ax.ravel() if m > 1 else [ax]
  379. # plt.subplots_adjust(wspace=0.05, hspace=0.05)
  380. for i in range(n):
  381. ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
  382. ax[i].axis('off')
  383. if labels is not None:
  384. s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
  385. ax[i].set_title(s, fontsize=8, verticalalignment='top')
  386. plt.savefig(f, dpi=300, bbox_inches='tight')
  387. plt.close()
  388. if verbose:
  389. LOGGER.info(f'Saving {f}')
  390. if labels is not None:
  391. LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
  392. if pred is not None:
  393. LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
  394. return f
  395. def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
  396. # Plot evolve.csv hyp evolution results
  397. evolve_csv = Path(evolve_csv)
  398. data = pd.read_csv(evolve_csv)
  399. keys = [x.strip() for x in data.columns]
  400. x = data.values
  401. f = fitness(x)
  402. j = np.argmax(f) # max fitness index
  403. plt.figure(figsize=(10, 12), tight_layout=True)
  404. matplotlib.rc('font', **{'size': 8})
  405. print(f'Best results from row {j} of {evolve_csv}:')
  406. for i, k in enumerate(keys[7:]):
  407. v = x[:, 7 + i]
  408. mu = v[j] # best single result
  409. plt.subplot(6, 5, i + 1)
  410. plt.scatter(v, f, c=hist2d(v, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
  411. plt.plot(mu, f.max(), 'k+', markersize=15)
  412. plt.title(f'{k} = {mu:.3g}', fontdict={'size': 9}) # limit to 40 characters
  413. if i % 5 != 0:
  414. plt.yticks([])
  415. print(f'{k:>15}: {mu:.3g}')
  416. f = evolve_csv.with_suffix('.png') # filename
  417. plt.savefig(f, dpi=200)
  418. plt.close()
  419. print(f'Saved {f}')
  420. def plot_results(file='path/to/results.csv', dir=''):
  421. # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
  422. save_dir = Path(file).parent if file else Path(dir)
  423. fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True)
  424. ax = ax.ravel()
  425. files = list(save_dir.glob('results*.csv'))
  426. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  427. for f in files:
  428. try:
  429. data = pd.read_csv(f)
  430. s = [x.strip() for x in data.columns]
  431. x = data.values[:, 0]
  432. for i, j in enumerate([1, 2, 3, 4, 5, 8, 9, 10, 6, 7]):
  433. y = data.values[:, j].astype('float')
  434. # y[y == 0] = np.nan # don't show zero values
  435. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=8) # actual results
  436. ax[i].plot(x, gaussian_filter1d(y, sigma=3), ':', label='smooth', linewidth=2) # smoothing line
  437. ax[i].set_title(s[j], fontsize=12)
  438. # if j in [8, 9, 10]: # share train and val loss y axes
  439. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  440. except Exception as e:
  441. LOGGER.info(f'Warning: Plotting error for {f}: {e}')
  442. ax[1].legend()
  443. fig.savefig(save_dir / 'results.png', dpi=200)
  444. plt.close()
  445. def profile_idetection(start=0, stop=0, labels=(), save_dir=''):
  446. # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection()
  447. ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel()
  448. s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS']
  449. files = list(Path(save_dir).glob('frames*.txt'))
  450. for fi, f in enumerate(files):
  451. try:
  452. results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows
  453. n = results.shape[1] # number of rows
  454. x = np.arange(start, min(stop, n) if stop else n)
  455. results = results[:, x]
  456. t = (results[0] - results[0].min()) # set t0=0s
  457. results[0] = x
  458. for i, a in enumerate(ax):
  459. if i < len(results):
  460. label = labels[fi] if len(labels) else f.stem.replace('frames_', '')
  461. a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5)
  462. a.set_title(s[i])
  463. a.set_xlabel('time (s)')
  464. # if fi == len(files) - 1:
  465. # a.set_ylim(bottom=0)
  466. for side in ['top', 'right']:
  467. a.spines[side].set_visible(False)
  468. else:
  469. a.remove()
  470. except Exception as e:
  471. print(f'Warning: Plotting error for {f}; {e}')
  472. ax[1].legend()
  473. plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200)
  474. def save_one_box(xyxy, im, file=Path('im.jpg'), gain=1.02, pad=10, square=False, BGR=False, save=True):
  475. # Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop
  476. xyxy = torch.tensor(xyxy).view(-1, 4)
  477. b = xyxy2xywh(xyxy) # boxes
  478. if square:
  479. b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
  480. b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
  481. xyxy = xywh2xyxy(b).long()
  482. clip_boxes(xyxy, im.shape)
  483. crop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]
  484. if save:
  485. file.parent.mkdir(parents=True, exist_ok=True) # make directory
  486. f = str(increment_path(file).with_suffix('.jpg'))
  487. # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
  488. Image.fromarray(crop[..., ::-1]).save(f, quality=95, subsampling=0) # save RGB
  489. return crop