plots.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import contextlib
  2. import math
  3. from pathlib import Path
  4. import cv2
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import pandas as pd
  8. import torch
  9. from .. import threaded
  10. from ..general import xywh2xyxy
  11. from ..plots import Annotator, colors
  12. @threaded
  13. def plot_images_and_masks(images, targets, masks, paths=None, fname='images.jpg', names=None):
  14. # Plot image grid with labels
  15. if isinstance(images, torch.Tensor):
  16. images = images.cpu().float().numpy()
  17. if isinstance(targets, torch.Tensor):
  18. targets = targets.cpu().numpy()
  19. if isinstance(masks, torch.Tensor):
  20. masks = masks.cpu().numpy().astype(int)
  21. max_size = 1920 # max image size
  22. max_subplots = 16 # max image subplots, i.e. 4x4
  23. bs, _, h, w = images.shape # batch size, _, height, width
  24. bs = min(bs, max_subplots) # limit plot images
  25. ns = np.ceil(bs ** 0.5) # number of subplots (square)
  26. if np.max(images[0]) <= 1:
  27. images *= 255 # de-normalise (optional)
  28. # Build Image
  29. mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
  30. for i, im in enumerate(images):
  31. if i == max_subplots: # if last batch has fewer images than we expect
  32. break
  33. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  34. im = im.transpose(1, 2, 0)
  35. mosaic[y:y + h, x:x + w, :] = im
  36. # Resize (optional)
  37. scale = max_size / ns / max(h, w)
  38. if scale < 1:
  39. h = math.ceil(scale * h)
  40. w = math.ceil(scale * w)
  41. mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
  42. # Annotate
  43. fs = int((h + w) * ns * 0.01) # font size
  44. annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
  45. for i in range(i + 1):
  46. x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
  47. annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
  48. if paths:
  49. annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
  50. if len(targets) > 0:
  51. idx = targets[:, 0] == i
  52. ti = targets[idx] # image targets
  53. boxes = xywh2xyxy(ti[:, 2:6]).T
  54. classes = ti[:, 1].astype('int')
  55. labels = ti.shape[1] == 6 # labels if no conf column
  56. conf = None if labels else ti[:, 6] # check for confidence presence (label vs pred)
  57. if boxes.shape[1]:
  58. if boxes.max() <= 1.01: # if normalized with tolerance 0.01
  59. boxes[[0, 2]] *= w # scale to pixels
  60. boxes[[1, 3]] *= h
  61. elif scale < 1: # absolute coords need scale if image scales
  62. boxes *= scale
  63. boxes[[0, 2]] += x
  64. boxes[[1, 3]] += y
  65. for j, box in enumerate(boxes.T.tolist()):
  66. cls = classes[j]
  67. color = colors(cls)
  68. cls = names[cls] if names else cls
  69. if labels or conf[j] > 0.25: # 0.25 conf thresh
  70. label = f'{cls}' if labels else f'{cls} {conf[j]:.1f}'
  71. annotator.box_label(box, label, color=color)
  72. # Plot masks
  73. if len(masks):
  74. if masks.max() > 1.0: # mean that masks are overlap
  75. image_masks = masks[[i]] # (1, 640, 640)
  76. nl = len(ti)
  77. index = np.arange(nl).reshape(nl, 1, 1) + 1
  78. image_masks = np.repeat(image_masks, nl, axis=0)
  79. image_masks = np.where(image_masks == index, 1.0, 0.0)
  80. else:
  81. image_masks = masks[idx]
  82. im = np.asarray(annotator.im).copy()
  83. for j, box in enumerate(boxes.T.tolist()):
  84. if labels or conf[j] > 0.25: # 0.25 conf thresh
  85. color = colors(classes[j])
  86. mh, mw = image_masks[j].shape
  87. if mh != h or mw != w:
  88. mask = image_masks[j].astype(np.uint8)
  89. mask = cv2.resize(mask, (w, h))
  90. mask = mask.astype(bool)
  91. else:
  92. mask = image_masks[j].astype(bool)
  93. with contextlib.suppress(Exception):
  94. im[y:y + h, x:x + w, :][mask] = im[y:y + h, x:x + w, :][mask] * 0.4 + np.array(color) * 0.6
  95. annotator.fromarray(im)
  96. annotator.im.save(fname) # save
  97. def plot_results_with_masks(file='path/to/results.csv', dir='', best=True):
  98. # Plot training results.csv. Usage: from utils.plots import *; plot_results('path/to/results.csv')
  99. save_dir = Path(file).parent if file else Path(dir)
  100. fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True)
  101. ax = ax.ravel()
  102. files = list(save_dir.glob('results*.csv'))
  103. assert len(files), f'No results.csv files found in {save_dir.resolve()}, nothing to plot.'
  104. for f in files:
  105. try:
  106. data = pd.read_csv(f)
  107. index = np.argmax(0.9 * data.values[:, 8] + 0.1 * data.values[:, 7] + 0.9 * data.values[:, 12] +
  108. 0.1 * data.values[:, 11])
  109. s = [x.strip() for x in data.columns]
  110. x = data.values[:, 0]
  111. for i, j in enumerate([1, 2, 3, 4, 5, 6, 9, 10, 13, 14, 15, 16, 7, 8, 11, 12]):
  112. y = data.values[:, j]
  113. # y[y == 0] = np.nan # don't show zero values
  114. ax[i].plot(x, y, marker='.', label=f.stem, linewidth=2, markersize=2)
  115. if best:
  116. # best
  117. ax[i].scatter(index, y[index], color='r', label=f'best:{index}', marker='*', linewidth=3)
  118. ax[i].set_title(s[j] + f'\n{round(y[index], 5)}')
  119. else:
  120. # last
  121. ax[i].scatter(x[-1], y[-1], color='r', label='last', marker='*', linewidth=3)
  122. ax[i].set_title(s[j] + f'\n{round(y[-1], 5)}')
  123. # if j in [8, 9, 10]: # share train and val loss y axes
  124. # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
  125. except Exception as e:
  126. print(f'Warning: Plotting error for {f}: {e}')
  127. ax[1].legend()
  128. fig.savefig(save_dir / 'results.png', dpi=200)
  129. plt.close()