thon | |
from evaluate import load | |
import torch | |
wer = load("wer") | |
def compute_metrics(eval_pred): | |
logits, labels = eval_pred | |
predicted = logits.argmax(-1) | |
decoded_labels = processor.batch_decode(labels, skip_special_tokens=True) | |
decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True) | |
wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels) | |
return {"wer_score": wer_score} | |
Train! |