File size: 2,026 Bytes
d28af7f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import numpy as np
from . import metric as metric_path
from . import predictor as predictor_path
class Evaluator(object):
"""
perform evaluation on a single (downstream) task.
make this both offline and online.
TODO(huxu) saving evaluation results.
"""
def __init__(self, config, eval_dataloader=None):
if config.metric is None:
raise ValueError("config.metric is", config.metric)
metric_cls = getattr(metric_path, config.metric)
self.metric = metric_cls(config)
if config.predictor is None:
raise ValueError("config.predictor is", config.predictor)
predictor_cls = getattr(predictor_path, config.predictor)
self.predictor = predictor_cls(config)
self.eval_dataloader = eval_dataloader
def __call__(self):
try:
print(self.predictor.pred_dir)
for pred_file in glob.glob(
self.predictor.pred_dir + "/*_merged.npy"):
outputs = np.load(pred_file)
results = self.metric.compute_metrics(outputs)
self.metric.print_computed_metrics(results)
outputs = np.load(os.path.join(
self.predictor.pred_dir, "merged.npy"))
results = self.metric.compute_metrics(outputs)
return {"results": results, "metric": self.metric}
except FileNotFoundError:
print("\n[missing]", self.predictor.pred_dir)
return {}
def evaluate(self, model, eval_dataloader=None, output_file="merged"):
if eval_dataloader is None:
eval_dataloader = self.eval_dataloader
outputs = self.predictor.predict_loop(
model, eval_dataloader, output_file)
results = self.metric.compute_metrics(**outputs)
return results
|