|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline |
|
import numpy as np |
|
import torch |
|
import librosa |
|
|
|
class ASRConfig: |
|
"""Configuration class for ASR transcription.""" |
|
def __init__( |
|
self, |
|
model_id="openai/whisper-large-v2", |
|
language="english", |
|
sampling_rate=16000, |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
): |
|
self.model_id = model_id |
|
self.language = language |
|
self.sampling_rate = sampling_rate |
|
self.device = device |
|
self.torch_dtype = torch_dtype |
|
|
|
class SpeechRecognizer: |
|
def __init__(self, config: ASRConfig = None): |
|
self.config = config if config else ASRConfig() |
|
print(f"Using ASR configuration: {self.config.__dict__}") |
|
self._setup_model() |
|
|
|
def _setup_model(self): |
|
"""Initialize Whisper model and processor.""" |
|
try: |
|
self.model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
self.config.model_id, |
|
torch_dtype=self.config.torch_dtype, |
|
use_safetensors=True, |
|
).to(self.config.device) |
|
|
|
self.processor = AutoProcessor.from_pretrained(self.config.model_id) |
|
self.pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model=self.model, |
|
tokenizer=self.processor.tokenizer, |
|
feature_extractor=self.processor.feature_extractor, |
|
torch_dtype=self.config.torch_dtype, |
|
device=self.config.device, |
|
) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to set up Whisper model: {str(e)}") |
|
|
|
async def transcribe(self, audio: tuple, prompt: str = None) -> str: |
|
""" |
|
Transcribes the provided audio using the Whisper pipeline. |
|
|
|
Args: |
|
audio (tuple): A tuple containing (sample_rate, audio_array). |
|
prompt (str): An optional text prompt to guide transcription. |
|
|
|
Returns: |
|
str: Transcription of the audio. |
|
""" |
|
if not audio or len(audio) != 2: |
|
raise ValueError("Invalid audio input. Expected a tuple (sample_rate, audio_array).") |
|
|
|
try: |
|
|
|
sample_rate, audio_array = audio |
|
|
|
|
|
if not isinstance(audio_array, np.ndarray): |
|
raise TypeError(f"Expected numpy.ndarray for audio data, got {type(audio_array)}") |
|
|
|
|
|
if audio_array.dtype != np.float32: |
|
audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max |
|
|
|
|
|
if sample_rate != self.config.sampling_rate: |
|
import librosa |
|
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=self.config.sampling_rate) |
|
|
|
|
|
generate_kwargs = {} |
|
if self.config.language: |
|
generate_kwargs["language"] = self.config.language |
|
if prompt: |
|
prompt_ids = self.processor.get_prompt_ids(prompt, return_tensors="pt").to(self.config.device) |
|
generate_kwargs["prompt_ids"] = prompt_ids |
|
|
|
|
|
result = self.pipe(audio_array, generate_kwargs=generate_kwargs) |
|
return result["text"].strip() |
|
except Exception as e: |
|
raise RuntimeError(f"Transcription failed: {str(e)}") |
|
|