# Copyright (c) OpenMMLab. All rights reserved. import numpy as np from mmengine.logging import print_log from terminaltables import AsciiTable def fast_hist(preds, labels, num_classes): """Compute the confusion matrix for every batch. Args: preds (np.ndarray): Prediction labels of points with shape of (num_points, ). labels (np.ndarray): Ground truth labels of points with shape of (num_points, ). num_classes (int): number of classes Returns: np.ndarray: Calculated confusion matrix. """ k = (labels >= 0) & (labels < num_classes) bin_count = np.bincount( num_classes * labels[k].astype(int) + preds[k], minlength=num_classes**2) return bin_count[:num_classes**2].reshape(num_classes, num_classes) def per_class_iou(hist): """Compute the per class iou. Args: hist(np.ndarray): Overall confusion martix (num_classes, num_classes ). Returns: np.ndarray: Calculated per class iou """ return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) def get_acc(hist): """Compute the overall accuracy. Args: hist(np.ndarray): Overall confusion martix (num_classes, num_classes ). Returns: float: Calculated overall acc """ return np.diag(hist).sum() / hist.sum() def get_acc_cls(hist): """Compute the class average accuracy. Args: hist(np.ndarray): Overall confusion martix (num_classes, num_classes ). Returns: float: Calculated class average acc """ return np.nanmean(np.diag(hist) / hist.sum(axis=1)) def seg_eval(gt_labels, seg_preds, label2cat, ignore_index, logger=None): """Semantic Segmentation Evaluation. Evaluate the result of the Semantic Segmentation. Args: gt_labels (list[torch.Tensor]): Ground truth labels. seg_preds (list[torch.Tensor]): Predictions. label2cat (dict): Map from label to category name. ignore_index (int): Index that will be ignored in evaluation. logger (logging.Logger | str, optional): The way to print the mAP summary. See `mmdet.utils.print_log()` for details. Default: None. Returns: dict[str, float]: Dict of results. """ assert len(seg_preds) == len(gt_labels) num_classes = len(label2cat) hist_list = [] for i in range(len(gt_labels)): gt_seg = gt_labels[i].astype(np.int64) pred_seg = seg_preds[i].astype(np.int64) # filter out ignored points pred_seg[gt_seg == ignore_index] = -1 gt_seg[gt_seg == ignore_index] = -1 # calculate one instance result hist_list.append(fast_hist(pred_seg, gt_seg, num_classes)) iou = per_class_iou(sum(hist_list)) # if ignore_index is in iou, replace it with nan if ignore_index < len(iou): iou[ignore_index] = np.nan miou = np.nanmean(iou) acc = get_acc(sum(hist_list)) acc_cls = get_acc_cls(sum(hist_list)) header = ['classes'] for i in range(len(label2cat)): header.append(label2cat[i]) header.extend(['miou', 'acc', 'acc_cls']) ret_dict = dict() table_columns = [['results']] for i in range(len(label2cat)): ret_dict[label2cat[i]] = float(iou[i]) table_columns.append([f'{iou[i]:.4f}']) ret_dict['miou'] = float(miou) ret_dict['acc'] = float(acc) ret_dict['acc_cls'] = float(acc_cls) table_columns.append([f'{miou:.4f}']) table_columns.append([f'{acc:.4f}']) table_columns.append([f'{acc_cls:.4f}']) table_data = [header] table_rows = list(zip(*table_columns)) table_data += table_rows table = AsciiTable(table_data) table.inner_footing_row_border = True print_log('\n' + table.table, logger=logger) return ret_dict