|
|
|
import numpy as np |
|
from mmengine.logging import print_log |
|
from terminaltables import AsciiTable |
|
|
|
from .scannet_utils.evaluate_semantic_instance import scannet_eval |
|
|
|
|
|
def aggregate_predictions(masks, labels, scores, valid_class_ids): |
|
"""Maps predictions to ScanNet evaluator format. |
|
|
|
Args: |
|
masks (list[torch.Tensor]): Per scene predicted instance masks. |
|
labels (list[torch.Tensor]): Per scene predicted instance labels. |
|
scores (list[torch.Tensor]): Per scene predicted instance scores. |
|
valid_class_ids (tuple[int]): Ids of valid categories. |
|
|
|
Returns: |
|
list[dict]: Per scene aggregated predictions. |
|
""" |
|
infos = [] |
|
for id, (mask, label, score) in enumerate(zip(masks, labels, scores)): |
|
mask = mask.clone().numpy() |
|
label = label.clone().numpy() |
|
score = score.clone().numpy() |
|
info = dict() |
|
n_instances = mask.max() + 1 |
|
for i in range(n_instances): |
|
|
|
file_name = f'{id}_{i}' |
|
info[file_name] = dict() |
|
info[file_name]['mask'] = (mask == i).astype(np.int64) |
|
info[file_name]['label_id'] = valid_class_ids[label[i]] |
|
info[file_name]['conf'] = score[i] |
|
infos.append(info) |
|
return infos |
|
|
|
|
|
def rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids): |
|
"""Maps gt instance and semantic masks to instance masks for ScanNet |
|
evaluator. |
|
|
|
Args: |
|
gt_semantic_masks (list[torch.Tensor]): Per scene gt semantic masks. |
|
gt_instance_masks (list[torch.Tensor]): Per scene gt instance masks. |
|
valid_class_ids (tuple[int]): Ids of valid categories. |
|
|
|
Returns: |
|
list[np.array]: Per scene instance masks. |
|
""" |
|
renamed_instance_masks = [] |
|
for semantic_mask, instance_mask in zip(gt_semantic_masks, |
|
gt_instance_masks): |
|
semantic_mask = semantic_mask.clone().numpy() |
|
instance_mask = instance_mask.clone().numpy() |
|
unique = np.unique(instance_mask) |
|
assert len(unique) < 1000 |
|
for i in unique: |
|
semantic_instance = semantic_mask[instance_mask == i] |
|
semantic_unique = np.unique(semantic_instance) |
|
assert len(semantic_unique) == 1 |
|
if semantic_unique[0] < len(valid_class_ids): |
|
instance_mask[ |
|
instance_mask == |
|
i] = 1000 * valid_class_ids[semantic_unique[0]] + i |
|
renamed_instance_masks.append(instance_mask) |
|
return renamed_instance_masks |
|
|
|
|
|
def instance_seg_eval(gt_semantic_masks, |
|
gt_instance_masks, |
|
pred_instance_masks, |
|
pred_instance_labels, |
|
pred_instance_scores, |
|
valid_class_ids, |
|
class_labels, |
|
options=None, |
|
logger=None): |
|
"""Instance Segmentation Evaluation. |
|
|
|
Evaluate the result of the instance segmentation. |
|
|
|
Args: |
|
gt_semantic_masks (list[torch.Tensor]): Ground truth semantic masks. |
|
gt_instance_masks (list[torch.Tensor]): Ground truth instance masks. |
|
pred_instance_masks (list[torch.Tensor]): Predicted instance masks. |
|
pred_instance_labels (list[torch.Tensor]): Predicted instance labels. |
|
pred_instance_scores (list[torch.Tensor]): Predicted instance labels. |
|
valid_class_ids (tuple[int]): Ids of valid categories. |
|
class_labels (tuple[str]): Names of valid categories. |
|
options (dict, optional): Additional options. Keys may contain: |
|
`overlaps`, `min_region_sizes`, `distance_threshes`, |
|
`distance_confs`. Default: None. |
|
logger (logging.Logger | str, optional): The way to print the mAP |
|
summary. See `mmdet.utils.print_log()` for details. Default: None. |
|
|
|
Returns: |
|
dict[str, float]: Dict of results. |
|
""" |
|
assert len(valid_class_ids) == len(class_labels) |
|
id_to_label = { |
|
valid_class_ids[i]: class_labels[i] |
|
for i in range(len(valid_class_ids)) |
|
} |
|
preds = aggregate_predictions( |
|
masks=pred_instance_masks, |
|
labels=pred_instance_labels, |
|
scores=pred_instance_scores, |
|
valid_class_ids=valid_class_ids) |
|
gts = rename_gt(gt_semantic_masks, gt_instance_masks, valid_class_ids) |
|
metrics = scannet_eval( |
|
preds=preds, |
|
gts=gts, |
|
options=options, |
|
valid_class_ids=valid_class_ids, |
|
class_labels=class_labels, |
|
id_to_label=id_to_label) |
|
header = ['classes', 'AP_0.25', 'AP_0.50', 'AP'] |
|
rows = [] |
|
for label, data in metrics['classes'].items(): |
|
aps = [data['ap25%'], data['ap50%'], data['ap']] |
|
rows.append([label] + [f'{ap:.4f}' for ap in aps]) |
|
aps = metrics['all_ap_25%'], metrics['all_ap_50%'], metrics['all_ap'] |
|
footer = ['Overall'] + [f'{ap:.4f}' for ap in aps] |
|
table = AsciiTable([header] + rows + [footer]) |
|
table.inner_footing_row_border = True |
|
print_log('\n' + table.table, logger=logger) |
|
return metrics |
|
|