File size: 454 Bytes
5fa1a76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
thon from evaluate import load import torch wer = load("wer") def compute_metrics(eval_pred): logits, labels = eval_pred predicted = logits.argmax(-1) decoded_labels = processor.batch_decode(labels, skip_special_tokens=True) decoded_predictions = processor.batch_decode(predicted, skip_special_tokens=True) wer_score = wer.compute(predictions=decoded_predictions, references=decoded_labels) return {"wer_score": wer_score} Train! |