File size: 2,008 Bytes
5fa1a76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Your predictions need to be converted to
logits first, and then reshaped to match the size of the labels before you can call [~evaluate.EvaluationModule.compute]:

import numpy as np
import torch
from torch import nn
def compute_metrics(eval_pred):
     with torch.no_grad():
         logits, labels = eval_pred
         logits_tensor = torch.from_numpy(logits)
         logits_tensor = nn.functional.interpolate(
             logits_tensor,
             size=labels.shape[-2:],
             mode="bilinear",
             align_corners=False,
         ).argmax(dim=1)

         pred_labels = logits_tensor.detach().cpu().numpy()
         metrics = metric.compute(
             predictions=pred_labels,
             references=labels,
             num_labels=num_labels,
             ignore_index=255,
             reduce_labels=False,
         )
         for key, value in metrics.items():
             if isinstance(value, np.ndarray):
                 metrics[key] = value.tolist()
         return metrics

def compute_metrics(eval_pred):
     logits, labels = eval_pred
     logits = tf.transpose(logits, perm=[0, 2, 3, 1])
     logits_resized = tf.image.resize(
         logits,
         size=tf.shape(labels)[1:],
         method="bilinear",
     )

     pred_labels = tf.argmax(logits_resized, axis=-1)
     metrics = metric.compute(
         predictions=pred_labels,
         references=labels,
         num_labels=num_labels,
         ignore_index=-1,
         reduce_labels=image_processor.do_reduce_labels,
     )
     per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
     per_category_iou = metrics.pop("per_category_iou").tolist()
     metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
     metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
     return {"val_" + k: v for k, v in metrics.items()}

Your compute_metrics function is ready to go now, and you'll return to it when you setup your training.