PyTorch
ssl-aasist
custom_code
File size: 2,026 Bytes
d28af7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import numpy as np

from . import metric as metric_path
from . import predictor as predictor_path


class Evaluator(object):
    """
    perform evaluation on a single (downstream) task.
    make this both offline and online.
    TODO(huxu) saving evaluation results.
    """

    def __init__(self, config, eval_dataloader=None):
        if config.metric is None:
            raise ValueError("config.metric is", config.metric)
        metric_cls = getattr(metric_path, config.metric)
        self.metric = metric_cls(config)
        if config.predictor is None:
            raise ValueError("config.predictor is", config.predictor)
        predictor_cls = getattr(predictor_path, config.predictor)
        self.predictor = predictor_cls(config)
        self.eval_dataloader = eval_dataloader

    def __call__(self):
        try:
            print(self.predictor.pred_dir)
            for pred_file in glob.glob(
                    self.predictor.pred_dir + "/*_merged.npy"):
                outputs = np.load(pred_file)
                results = self.metric.compute_metrics(outputs)
                self.metric.print_computed_metrics(results)

            outputs = np.load(os.path.join(
                    self.predictor.pred_dir, "merged.npy"))
            results = self.metric.compute_metrics(outputs)
            return {"results": results, "metric": self.metric}
        except FileNotFoundError:
            print("\n[missing]", self.predictor.pred_dir)
            return {}

    def evaluate(self, model, eval_dataloader=None, output_file="merged"):
        if eval_dataloader is None:
            eval_dataloader = self.eval_dataloader
        outputs = self.predictor.predict_loop(
            model, eval_dataloader, output_file)
        results = self.metric.compute_metrics(**outputs)
        return results