metrics.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
  2. """
  3. Model validation metrics
  4. """
  5. import math
  6. import warnings
  7. from pathlib import Path
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. import torch
  11. from utils import TryExcept, threaded
  12. def fitness(x):
  13. # Model fitness as a weighted combination of metrics
  14. w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
  15. return (x[:, :4] * w).sum(1)
  16. def smooth(y, f=0.05):
  17. # Box filter of fraction f
  18. nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
  19. p = np.ones(nf // 2) # ones padding
  20. yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
  21. return np.convolve(yp, np.ones(nf) / nf, mode='valid') # y-smoothed
  22. def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16, prefix=''):
  23. """ Compute the average precision, given the recall and precision curves.
  24. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
  25. # Arguments
  26. tp: True positives (nparray, nx1 or nx10).
  27. conf: Objectness value from 0-1 (nparray).
  28. pred_cls: Predicted object classes (nparray).
  29. target_cls: True object classes (nparray).
  30. plot: Plot precision-recall curve at mAP@0.5
  31. save_dir: Plot save directory
  32. # Returns
  33. The average precision as computed in py-faster-rcnn.
  34. """
  35. # Sort by objectness
  36. i = np.argsort(-conf)
  37. tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
  38. # Find unique classes
  39. unique_classes, nt = np.unique(target_cls, return_counts=True)
  40. nc = unique_classes.shape[0] # number of classes, number of detections
  41. # Create Precision-Recall curve and compute AP for each class
  42. px, py = np.linspace(0, 1, 1000), [] # for plotting
  43. ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
  44. for ci, c in enumerate(unique_classes):
  45. i = pred_cls == c
  46. n_l = nt[ci] # number of labels
  47. n_p = i.sum() # number of predictions
  48. if n_p == 0 or n_l == 0:
  49. continue
  50. # Accumulate FPs and TPs
  51. fpc = (1 - tp[i]).cumsum(0)
  52. tpc = tp[i].cumsum(0)
  53. # Recall
  54. recall = tpc / (n_l + eps) # recall curve
  55. r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
  56. # Precision
  57. precision = tpc / (tpc + fpc) # precision curve
  58. p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
  59. # AP from recall-precision curve
  60. for j in range(tp.shape[1]):
  61. ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
  62. if plot and j == 0:
  63. py.append(np.interp(px, mrec, mpre)) # precision at mAP@0.5
  64. # Compute F1 (harmonic mean of precision and recall)
  65. f1 = 2 * p * r / (p + r + eps)
  66. names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
  67. names = dict(enumerate(names)) # to dict
  68. if plot:
  69. plot_pr_curve(px, py, ap, Path(save_dir) / f'{prefix}PR_curve.png', names)
  70. plot_mc_curve(px, f1, Path(save_dir) / f'{prefix}F1_curve.png', names, ylabel='F1')
  71. plot_mc_curve(px, p, Path(save_dir) / f'{prefix}P_curve.png', names, ylabel='Precision')
  72. plot_mc_curve(px, r, Path(save_dir) / f'{prefix}R_curve.png', names, ylabel='Recall')
  73. i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
  74. p, r, f1 = p[:, i], r[:, i], f1[:, i]
  75. tp = (r * nt).round() # true positives
  76. fp = (tp / (p + eps) - tp).round() # false positives
  77. return tp, fp, p, r, f1, ap, unique_classes.astype(int)
  78. def compute_ap(recall, precision):
  79. """ Compute the average precision, given the recall and precision curves
  80. # Arguments
  81. recall: The recall curve (list)
  82. precision: The precision curve (list)
  83. # Returns
  84. Average precision, precision curve, recall curve
  85. """
  86. # Append sentinel values to beginning and end
  87. mrec = np.concatenate(([0.0], recall, [1.0]))
  88. mpre = np.concatenate(([1.0], precision, [0.0]))
  89. # Compute the precision envelope
  90. mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
  91. # Integrate area under curve
  92. method = 'interp' # methods: 'continuous', 'interp'
  93. if method == 'interp':
  94. x = np.linspace(0, 1, 101) # 101-point interp (COCO)
  95. ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate
  96. else: # 'continuous'
  97. i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes
  98. ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
  99. return ap, mpre, mrec
  100. class ConfusionMatrix:
  101. # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
  102. def __init__(self, nc, conf=0.25, iou_thres=0.45):
  103. self.matrix = np.zeros((nc + 1, nc + 1))
  104. self.nc = nc # number of classes
  105. self.conf = conf
  106. self.iou_thres = iou_thres
  107. def process_batch(self, detections, labels):
  108. """
  109. Return intersection-over-union (Jaccard index) of boxes.
  110. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  111. Arguments:
  112. detections (Array[N, 6]), x1, y1, x2, y2, conf, class
  113. labels (Array[M, 5]), class, x1, y1, x2, y2
  114. Returns:
  115. None, updates confusion matrix accordingly
  116. """
  117. if detections is None:
  118. gt_classes = labels.int()
  119. for gc in gt_classes:
  120. self.matrix[self.nc, gc] += 1 # background FN
  121. return
  122. detections = detections[detections[:, 4] > self.conf]
  123. gt_classes = labels[:, 0].int()
  124. detection_classes = detections[:, 5].int()
  125. iou = box_iou(labels[:, 1:], detections[:, :4])
  126. x = torch.where(iou > self.iou_thres)
  127. if x[0].shape[0]:
  128. matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
  129. if x[0].shape[0] > 1:
  130. matches = matches[matches[:, 2].argsort()[::-1]]
  131. matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
  132. matches = matches[matches[:, 2].argsort()[::-1]]
  133. matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
  134. else:
  135. matches = np.zeros((0, 3))
  136. n = matches.shape[0] > 0
  137. m0, m1, _ = matches.transpose().astype(int)
  138. for i, gc in enumerate(gt_classes):
  139. j = m0 == i
  140. if n and sum(j) == 1:
  141. self.matrix[detection_classes[m1[j]], gc] += 1 # correct
  142. else:
  143. self.matrix[self.nc, gc] += 1 # true background
  144. if n:
  145. for i, dc in enumerate(detection_classes):
  146. if not any(m1 == i):
  147. self.matrix[dc, self.nc] += 1 # predicted background
  148. def tp_fp(self):
  149. tp = self.matrix.diagonal() # true positives
  150. fp = self.matrix.sum(1) - tp # false positives
  151. # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
  152. return tp[:-1], fp[:-1] # remove background class
  153. @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure')
  154. def plot(self, normalize=True, save_dir='', names=()):
  155. import seaborn as sn
  156. array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
  157. array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
  158. fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
  159. nc, nn = self.nc, len(names) # number of classes, names
  160. sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
  161. labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
  162. ticklabels = (names + ['background']) if labels else 'auto'
  163. with warnings.catch_warnings():
  164. warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
  165. sn.heatmap(array,
  166. ax=ax,
  167. annot=nc < 30,
  168. annot_kws={
  169. 'size': 8},
  170. cmap='Blues',
  171. fmt='.2f',
  172. square=True,
  173. vmin=0.0,
  174. xticklabels=ticklabels,
  175. yticklabels=ticklabels).set_facecolor((1, 1, 1))
  176. ax.set_xlabel('True')
  177. ax.set_ylabel('Predicted')
  178. ax.set_title('Confusion Matrix')
  179. fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
  180. plt.close(fig)
  181. def print(self):
  182. for i in range(self.nc + 1):
  183. print(' '.join(map(str, self.matrix[i])))
  184. def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
  185. # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
  186. # Get the coordinates of bounding boxes
  187. if xywh: # transform from xywh to xyxy
  188. (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
  189. w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
  190. b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
  191. b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
  192. else: # x1, y1, x2, y2 = box1
  193. b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
  194. b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
  195. w1, h1 = b1_x2 - b1_x1, (b1_y2 - b1_y1).clamp(eps)
  196. w2, h2 = b2_x2 - b2_x1, (b2_y2 - b2_y1).clamp(eps)
  197. # Intersection area
  198. inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
  199. (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
  200. # Union Area
  201. union = w1 * h1 + w2 * h2 - inter + eps
  202. # IoU
  203. iou = inter / union
  204. if CIoU or DIoU or GIoU:
  205. cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
  206. ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
  207. if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
  208. c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared
  209. rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
  210. if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
  211. v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
  212. with torch.no_grad():
  213. alpha = v / (v - iou + (1 + eps))
  214. return iou - (rho2 / c2 + v * alpha) # CIoU
  215. return iou - rho2 / c2 # DIoU
  216. c_area = cw * ch + eps # convex area
  217. return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
  218. return iou # IoU
  219. def box_iou(box1, box2, eps=1e-7):
  220. # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  221. """
  222. Return intersection-over-union (Jaccard index) of boxes.
  223. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  224. Arguments:
  225. box1 (Tensor[N, 4])
  226. box2 (Tensor[M, 4])
  227. Returns:
  228. iou (Tensor[N, M]): the NxM matrix containing the pairwise
  229. IoU values for every element in boxes1 and boxes2
  230. """
  231. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  232. (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  233. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
  234. # IoU = inter / (area1 + area2 - inter)
  235. return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  236. def bbox_ioa(box1, box2, eps=1e-7):
  237. """ Returns the intersection over box2 area given box1, box2. Boxes are x1y1x2y2
  238. box1: np.array of shape(4)
  239. box2: np.array of shape(nx4)
  240. returns: np.array of shape(n)
  241. """
  242. # Get the coordinates of bounding boxes
  243. b1_x1, b1_y1, b1_x2, b1_y2 = box1
  244. b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
  245. # Intersection area
  246. inter_area = (np.minimum(b1_x2, b2_x2) - np.maximum(b1_x1, b2_x1)).clip(0) * \
  247. (np.minimum(b1_y2, b2_y2) - np.maximum(b1_y1, b2_y1)).clip(0)
  248. # box2 area
  249. box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) + eps
  250. # Intersection over box2 area
  251. return inter_area / box2_area
  252. def wh_iou(wh1, wh2, eps=1e-7):
  253. # Returns the nxm IoU matrix. wh1 is nx2, wh2 is mx2
  254. wh1 = wh1[:, None] # [N,1,2]
  255. wh2 = wh2[None] # [1,M,2]
  256. inter = torch.min(wh1, wh2).prod(2) # [N,M]
  257. return inter / (wh1.prod(2) + wh2.prod(2) - inter + eps) # iou = inter / (area1 + area2 - inter)
  258. # Plots ----------------------------------------------------------------------------------------------------------------
  259. @threaded
  260. def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
  261. # Precision-recall curve
  262. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  263. py = np.stack(py, axis=1)
  264. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  265. for i, y in enumerate(py.T):
  266. ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision)
  267. else:
  268. ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision)
  269. ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f mAP@0.5' % ap[:, 0].mean())
  270. ax.set_xlabel('Recall')
  271. ax.set_ylabel('Precision')
  272. ax.set_xlim(0, 1)
  273. ax.set_ylim(0, 1)
  274. ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
  275. ax.set_title('Precision-Recall Curve')
  276. fig.savefig(save_dir, dpi=250)
  277. plt.close(fig)
  278. @threaded
  279. def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
  280. # Metric-confidence curve
  281. fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
  282. if 0 < len(names) < 21: # display per-class legend if < 21 classes
  283. for i, y in enumerate(py):
  284. ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric)
  285. else:
  286. ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric)
  287. y = smooth(py.mean(0), 0.05)
  288. ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}')
  289. ax.set_xlabel(xlabel)
  290. ax.set_ylabel(ylabel)
  291. ax.set_xlim(0, 1)
  292. ax.set_ylim(0, 1)
  293. ax.legend(bbox_to_anchor=(1.04, 1), loc='upper left')
  294. ax.set_title(f'{ylabel}-Confidence Curve')
  295. fig.savefig(save_dir, dpi=250)
  296. plt.close(fig)