File size: 3,797 Bytes
a350173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import numpy as np
import torch
import librosa

class ASRConfig:
    """Configuration class for ASR transcription."""
    def __init__(
        self,
        model_id="openai/whisper-large-v2",
        language="english",
        sampling_rate=16000,
        device="cuda" if torch.cuda.is_available() else "cpu",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    ):
        self.model_id = model_id
        self.language = language
        self.sampling_rate = sampling_rate
        self.device = device
        self.torch_dtype = torch_dtype

class SpeechRecognizer:
    def __init__(self, config: ASRConfig = None):
        self.config = config if config else ASRConfig()
        print(f"Using ASR configuration: {self.config.__dict__}")
        self._setup_model()

    def _setup_model(self):
        """Initialize Whisper model and processor."""
        try:
            self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
                self.config.model_id,
                torch_dtype=self.config.torch_dtype,
                use_safetensors=True,
            ).to(self.config.device)

            self.processor = AutoProcessor.from_pretrained(self.config.model_id)
            self.pipe = pipeline(
                "automatic-speech-recognition",
                model=self.model,
                tokenizer=self.processor.tokenizer,
                feature_extractor=self.processor.feature_extractor,
                torch_dtype=self.config.torch_dtype,
                device=self.config.device,
            )
        except Exception as e:
            raise RuntimeError(f"Failed to set up Whisper model: {str(e)}")

    async def transcribe(self, audio: tuple, prompt: str = None) -> str:
        """
        Transcribes the provided audio using the Whisper pipeline.

        Args:
            audio (tuple): A tuple containing (sample_rate, audio_array).
            prompt (str): An optional text prompt to guide transcription.

        Returns:
            str: Transcription of the audio.
        """
        if not audio or len(audio) != 2:
            raise ValueError("Invalid audio input. Expected a tuple (sample_rate, audio_array).")

        try:
            # Extract the raw audio data (audio_array) from the input tuple
            sample_rate, audio_array = audio

            # Ensure the audio is a numpy array and has the expected format
            if not isinstance(audio_array, np.ndarray):
                raise TypeError(f"Expected numpy.ndarray for audio data, got {type(audio_array)}")

            # Ensure the audio array is in floating-point format
            if audio_array.dtype != np.float32:
                audio_array = audio_array.astype(np.float32) / np.iinfo(audio_array.dtype).max

            # Resample audio if the sample rate differs from the configured rate
            if sample_rate != self.config.sampling_rate:
                import librosa
                audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=self.config.sampling_rate)

            # Prepare generate_kwargs for the pipeline
            generate_kwargs = {}
            if self.config.language:
                generate_kwargs["language"] = self.config.language
            if prompt:
                prompt_ids = self.processor.get_prompt_ids(prompt, return_tensors="pt").to(self.config.device)
                generate_kwargs["prompt_ids"] = prompt_ids

            # Run transcription through the pipeline
            result = self.pipe(audio_array, generate_kwargs=generate_kwargs)
            return result["text"].strip()
        except Exception as e:
            raise RuntimeError(f"Transcription failed: {str(e)}")