|
thon |
|
|
|
from transformers import AutoTokenizer, AutoFeatureExtractor, SpeechEncoderDecoderModel |
|
from datasets import load_dataset |
|
encoder_id = "facebook/wav2vec2-base-960h" # acoustic model encoder |
|
decoder_id = "google-bert/bert-base-uncased" # text decoder |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id) |
|
tokenizer = AutoTokenizer.from_pretrained(decoder_id) |
|
Combine pre-trained encoder and pre-trained decoder to form a Seq2Seq model |
|
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id) |
|
model.config.decoder_start_token_id = tokenizer.cls_token_id |
|
model.config.pad_token_id = tokenizer.pad_token_id |
|
load an audio input and pre-process (normalise mean/std to 0/1) |
|
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") |
|
input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values |
|
load its corresponding transcription and tokenize to generate labels |
|
labels = tokenizer(ds[0]["text"], return_tensors="pt").input_ids |
|
the forward function automatically creates the correct decoder_input_ids |
|
loss = model(input_values=input_values, labels=labels).loss |
|
loss.backward() |
|
|
|
SpeechEncoderDecoderConfig |
|
[[autodoc]] SpeechEncoderDecoderConfig |
|
SpeechEncoderDecoderModel |
|
[[autodoc]] SpeechEncoderDecoderModel |
|
- forward |
|
- from_encoder_decoder_pretrained |
|
FlaxSpeechEncoderDecoderModel |
|
[[autodoc]] FlaxSpeechEncoderDecoderModel |
|
- call |
|
- from_encoder_decoder_pretrained |