gencent / app /utils /speech_to_text.py
hjaved202's picture
Upload folder using huggingface_hub
a350173 verified
raw
history blame
3.8 kB
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
import torch
import librosa
class ASRConfig:
"""Configuration class for ASR transcription."""
def __init__(
self,
model_id="openai/whisper-large-v2",
language="english",
sampling_rate=16000,
device="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
):
self.model_id = model_id
self.language = language
self.sampling_rate = sampling_rate
self.device = device
self.torch_dtype = torch_dtype
class SpeechRecognizer:
def __init__(self, config: ASRConfig = None):
self.config = config if config else ASRConfig()
print(f"Using ASR configuration: {self.config.__dict__}")
self._setup_model()
def _setup_model(self):
"""Initialize Whisper model and processor."""
try:
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.config.model_id,
torch_dtype=self.config.torch_dtype,
use_safetensors=True,
).to(self.config.device)
self.processor = AutoProcessor.from_pretrained(self.config.model_id)
self.pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
tokenizer=self.processor.tokenizer,
feature_extractor=self.processor.feature_extractor,
torch_dtype=self.config.torch_dtype,
device=self.config.device,
)
except Exception as e:
raise RuntimeError(f"Failed to set up Whisper model: {str(e)}")
async def transcribe(self, audio: tuple, prompt: str = None) -> str:
"""
Transcribes the provided audio using the Whisper pipeline.
Args:
audio (tuple): A tuple containing (sample_rate, audio_array).
prompt (str): An optional text prompt to guide transcription.
Returns:
str: Transcription of the audio.
"""
if not audio or len(audio) != 2:
raise ValueError("Invalid audio input. Expected a tuple (sample_rate, audio_array).")
try:
# Extract the raw audio data (audio_array) from the input tuple
sample_rate, audio_array = audio
# Ensure the audio is a numpy array and has the expected format
if not isinstance(audio_array, np.ndarray):
raise TypeError(f"Expected numpy.ndarray for audio data, got {type(audio_array)}")
# Ensure the audio array is in floating-point format
if audio_array.dtype != np.float32:
audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max
# Resample audio if the sample rate differs from the configured rate
if sample_rate != self.config.sampling_rate:
import librosa
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=self.config.sampling_rate)
# Prepare generate_kwargs for the pipeline
generate_kwargs = {}
if self.config.language:
generate_kwargs["language"] = self.config.language
if prompt:
prompt_ids = self.processor.get_prompt_ids(prompt, return_tensors="pt").to(self.config.device)
generate_kwargs["prompt_ids"] = prompt_ids
# Run transcription through the pipeline
result = self.pipe(audio_array, generate_kwargs=generate_kwargs)
return result["text"].strip()
except Exception as e:
raise RuntimeError(f"Transcription failed: {str(e)}")