File size: 3,844 Bytes
34d1f8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmengine.logging import print_log
from terminaltables import AsciiTable
def fast_hist(preds, labels, num_classes):
"""Compute the confusion matrix for every batch.
Args:
preds (np.ndarray): Prediction labels of points with shape of
(num_points, ).
labels (np.ndarray): Ground truth labels of points with shape of
(num_points, ).
num_classes (int): number of classes
Returns:
np.ndarray: Calculated confusion matrix.
"""
k = (labels >= 0) & (labels < num_classes)
bin_count = np.bincount(
num_classes * labels[k].astype(int) + preds[k],
minlength=num_classes**2)
return bin_count[:num_classes**2].reshape(num_classes, num_classes)
def per_class_iou(hist):
"""Compute the per class iou.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
np.ndarray: Calculated per class iou
"""
return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
def get_acc(hist):
"""Compute the overall accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated overall acc
"""
return np.diag(hist).sum() / hist.sum()
def get_acc_cls(hist):
"""Compute the class average accuracy.
Args:
hist(np.ndarray): Overall confusion martix
(num_classes, num_classes ).
Returns:
float: Calculated class average acc
"""
return np.nanmean(np.diag(hist) / hist.sum(axis=1))
def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None):
"""Semantic Segmentation Evaluation.
Evaluate the result of the Semantic Segmentation.
Args:
gt_labels (list[torch.Tensor]): Ground truth labels.
seg_preds (list[torch.Tensor]): Predictions.
label2cat (dict): Map from label to category name.
ignore_index (int): Index that will be ignored in evaluation.
logger (logging.Logger | str, optional): The way to print the mAP
summary. See `mmdet.utils.print_log()` for details. Default: None.
Returns:
dict[str, float]: Dict of results.
"""
assert len(seg_preds) == len(gt_labels)
num_classes = len(label2cat)
hist_list = []
for i in range(len(gt_labels)):
gt_seg = gt_labels[i].astype(np.int64)
pred_seg = seg_preds[i].astype(np.int64)
# filter out ignored points
pred_seg[gt_seg == ignore_index] = -1
gt_seg[gt_seg == ignore_index] = -1
# calculate one instance result
hist_list.append(fast_hist(pred_seg, gt_seg, num_classes))
iou = per_class_iou(sum(hist_list))
# if ignore_index is in iou, replace it with nan
if ignore_index < len(iou):
iou[ignore_index] = np.nan
miou = np.nanmean(iou)
acc = get_acc(sum(hist_list))
acc_cls = get_acc_cls(sum(hist_list))
header = ['classes']
for i in range(len(label2cat)):
header.append(label2cat[i])
header.extend(['miou', 'acc', 'acc_cls'])
ret_dict = dict()
table_columns = [['results']]
for i in range(len(label2cat)):
ret_dict[label2cat[i]] = float(iou[i])
table_columns.append([f'{iou[i]:.4f}'])
ret_dict['miou'] = float(miou)
ret_dict['acc'] = float(acc)
ret_dict['acc_cls'] = float(acc_cls)
table_columns.append([f'{miou:.4f}'])
table_columns.append([f'{acc:.4f}'])
table_columns.append([f'{acc_cls:.4f}'])
table_data = [header]
table_rows = list(zip(*table_columns))
table_data += table_rows
table = AsciiTable(table_data)
table.inner_footing_row_border = True
print_log('\n' + table.table, logger=logger)
return ret_dict
|