Gokulavelan commited on
Commit
ea5aa75
·
1 Parent(s): 33d067f

files push

Browse files
Files changed (7) hide show
  1. AudioHandler.py +100 -0
  2. app copy.py +7 -0
  3. app.py +13 -4
  4. config.py +42 -0
  5. diarization_util.py +141 -0
  6. requirements.txt +10 -0
  7. util.py +41 -0
AudioHandler.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import os
4
+ import base64
5
+
6
+ from pyannote.audio import Pipeline
7
+ from transformers import pipeline, AutoModelForCausalLM
8
+ from huggingface_hub import HfApi
9
+ from pydantic import ValidationError
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class AudioHandler:
15
+ def __init__(self, model_settings):
16
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
17
+ logger.info(f"Using device: {device.type}")
18
+ torch_dtype = torch.float32 if device.type == "cpu" else torch.float16
19
+
20
+ self.device = device
21
+ self.torch_dtype = torch_dtype
22
+
23
+ # Load assistant model
24
+ self.assistant_model = (
25
+ AutoModelForCausalLM.from_pretrained(
26
+ model_settings.assistant_model,
27
+ torch_dtype=torch_dtype,
28
+ low_cpu_mem_usage=True,
29
+ use_safetensors=True
30
+ ).to(device)
31
+ if model_settings.assistant_model
32
+ else None
33
+ )
34
+
35
+ # Load ASR pipeline
36
+ self.asr_pipeline = pipeline(
37
+ "automatic-speech-recognition",
38
+ model=model_settings.asr_model,
39
+ torch_dtype=torch_dtype,
40
+ device=device
41
+ )
42
+
43
+ # Load diarization pipeline if available
44
+ if model_settings.diarization_model:
45
+ HfApi().whoami(model_settings.hf_token)
46
+ self.diarization_pipeline = Pipeline.from_pretrained(
47
+ checkpoint_path=model_settings.diarization_model,
48
+ use_auth_token=model_settings.hf_token,
49
+ ).to(device)
50
+ else:
51
+ self.diarization_pipeline = None
52
+
53
+ def run_asr(self, file, parameters):
54
+ """Run Automatic Speech Recognition (ASR)"""
55
+ generate_kwargs = {
56
+ "task": parameters.task,
57
+ "language": parameters.language,
58
+ "assistant_model": self.assistant_model if parameters.assisted else None
59
+ }
60
+ return self.asr_pipeline(
61
+ file,
62
+ chunk_length_s=parameters.chunk_length_s,
63
+ batch_size=parameters.batch_size,
64
+ generate_kwargs=generate_kwargs,
65
+ return_timestamps=True,
66
+ )
67
+
68
+ def run_diarization(self, file, parameters, asr_outputs):
69
+ """Run Diarization if available"""
70
+ if not self.diarization_pipeline:
71
+ return []
72
+ # Replace with actual diarization logic if required
73
+ return diarize(self.diarization_pipeline, file, parameters, asr_outputs)
74
+
75
+ def run_inference(self, file: bytes, parameters):
76
+ """Run the complete inference process"""
77
+ try:
78
+ logger.info(f"Inference parameters: {parameters}")
79
+ asr_outputs = self.run_asr(file, parameters)
80
+ except RuntimeError as e:
81
+ logger.error(f"ASR inference error: {str(e)}")
82
+ raise RuntimeError(f"ASR inference error: {str(e)}")
83
+ except Exception as e:
84
+ logger.error(f"Unknown error during ASR inference: {str(e)}")
85
+ raise RuntimeError(f"Unknown error during ASR inference: {str(e)}")
86
+
87
+ try:
88
+ transcript = self.run_diarization(file, parameters, asr_outputs)
89
+ except RuntimeError as e:
90
+ logger.error(f"Diarization inference error: {str(e)}")
91
+ raise RuntimeError(f"Diarization inference error: {str(e)}")
92
+ except Exception as e:
93
+ logger.error(f"Unknown error during diarization: {str(e)}")
94
+ raise RuntimeError(f"Unknown error during diarization: {str(e)}")
95
+
96
+ return {
97
+ "speakers": transcript,
98
+ "chunks": asr_outputs["chunks"],
99
+ "text": asr_outputs["text"],
100
+ }
app copy.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ demo.launch()
app.py CHANGED
@@ -1,7 +1,16 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from util import *
3
 
4
+ # Create Gradio Blocks interface
5
+ with gr.Blocks() as demo:
6
+ audio_input = gr.Audio(type="filepath", label="Upload Audio")
7
+ textbox = gr.Textbox(label="Transcription Output", lines=15, interactive=False)
8
 
9
+ # Set up the audio file processing and display transcription
10
+ audio_input.change(
11
+ fn=process_audio,
12
+ inputs=audio_input,
13
+ outputs=textbox
14
+ )
15
+ # Launch Gradio app
16
+ demo.launch()
config.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pydantic import BaseModel,Field
4
+ from pydantic_settings import BaseSettings
5
+ from typing import Optional, Literal
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class ModelSettings(BaseSettings):
11
+ asr_model: str = Field(alias='ASR_MODEL')
12
+ assistant_model: str = Field(alias='ASSISTANT_MODEL')
13
+ diarization_model: str = Field(alias='DIARIZATION_MODEL')
14
+ hf_token: str = Field(alias='HF_TOKEN')
15
+
16
+
17
+ class InferenceConfig(BaseModel):
18
+ task: Literal["transcribe", "translate"] = "transcribe"
19
+ batch_size: int = 24
20
+ assisted: bool = False
21
+ chunk_length_s: int = 30
22
+ sampling_rate: int = 16000
23
+ language: Optional[str] = None
24
+ num_speakers: Optional[int] = None
25
+ min_speakers: Optional[int] = None
26
+ max_speakers: Optional[int] = None
27
+
28
+ # Instead of model_dump, create a dictionary with the settings and
29
+ # pass it to the ModelSettings constructor
30
+ model_settings_data = {
31
+ "DIARIZATION_MODEL": "pyannote/speaker-diarization-3.1",
32
+ "HF_TOKEN": os.environ.get("HF_TOKEN"),
33
+ "ASR_MODEL": "openai/whisper-large-v3",
34
+ "ASSISTANT_MODEL": "distil-whisper/distil-large-v3"
35
+ }
36
+
37
+ # Initialize ModelSettings with the dictionary data
38
+ model_settings = ModelSettings(**model_settings_data)
39
+
40
+ logger.info(f"asr model: {model_settings.asr_model}")
41
+ logger.info(f"assist model: {model_settings.assistant_model}")
42
+ logger.info(f"diar model: {model_settings.diarization_model}")
diarization_util.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchaudio import functional as F
4
+ from transformers.pipelines.audio_utils import ffmpeg_read
5
+ from starlette.exceptions import HTTPException
6
+ import sys
7
+
8
+ # Code from insanely-fast-whisper:
9
+ # https://github.com/Vaibhavs10/insanely-fast-whisper
10
+
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def preprocess_inputs(inputs, sampling_rate):
15
+ inputs = ffmpeg_read(inputs, sampling_rate)
16
+
17
+ if sampling_rate != 16000:
18
+ inputs = F.resample(
19
+ torch.from_numpy(inputs), sampling_rate, 16000
20
+ ).numpy()
21
+
22
+ if len(inputs.shape) != 1:
23
+ logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
+ raise HTTPException(
25
+ status_code=400,
26
+ detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
+ )
28
+
29
+ # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
30
+ diarizer_inputs = torch.from_numpy(inputs).float()
31
+ diarizer_inputs = diarizer_inputs.unsqueeze(0)
32
+
33
+ return inputs, diarizer_inputs
34
+
35
+
36
+ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
37
+ diarization = diarization_pipeline(
38
+ {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
39
+ num_speakers=parameters.num_speakers,
40
+ min_speakers=parameters.min_speakers,
41
+ max_speakers=parameters.max_speakers,
42
+ )
43
+
44
+ segments = []
45
+ for segment, track, label in diarization.itertracks(yield_label=True):
46
+ segments.append(
47
+ {
48
+ "segment": {"start": segment.start, "end": segment.end},
49
+ "track": track,
50
+ "label": label,
51
+ }
52
+ )
53
+
54
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
+ new_segments = []
57
+ prev_segment = cur_segment = segments[0]
58
+
59
+ for i in range(1, len(segments)):
60
+ cur_segment = segments[i]
61
+
62
+ # check if we have changed speaker ("label")
63
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
+ # add the start/end times for the super-segment to the new list
65
+ new_segments.append(
66
+ {
67
+ "segment": {
68
+ "start": prev_segment["segment"]["start"],
69
+ "end": cur_segment["segment"]["start"],
70
+ },
71
+ "speaker": prev_segment["label"],
72
+ }
73
+ )
74
+ prev_segment = segments[i]
75
+
76
+ # add the last segment(s) if there was no speaker change
77
+ new_segments.append(
78
+ {
79
+ "segment": {
80
+ "start": prev_segment["segment"]["start"],
81
+ "end": cur_segment["segment"]["end"],
82
+ },
83
+ "speaker": prev_segment["label"],
84
+ }
85
+ )
86
+
87
+ return new_segments
88
+
89
+
90
+ def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
+ # get the end timestamps for each chunk from the ASR output
92
+ end_timestamps = np.array(
93
+ [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
+ segmented_preds = []
95
+
96
+ # align the diarizer timestamps and the ASR timestamps
97
+ for segment in new_segments:
98
+ # get the diarizer end timestamp
99
+ end_time = segment["segment"]["end"]
100
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
+
103
+ if group_by_speaker:
104
+ segmented_preds.append(
105
+ {
106
+ "speaker": segment["speaker"],
107
+ "text": "".join(
108
+ [chunk["text"] for chunk in transcript[: upto_idx + 1]]
109
+ ),
110
+ "timestamp": (
111
+ transcript[0]["timestamp"][0],
112
+ transcript[upto_idx]["timestamp"][1],
113
+ ),
114
+ }
115
+ )
116
+ else:
117
+ for i in range(upto_idx + 1):
118
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
+
120
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
+ transcript = transcript[upto_idx + 1:]
122
+ end_timestamps = end_timestamps[upto_idx + 1:]
123
+
124
+ if len(end_timestamps) == 0:
125
+ break
126
+
127
+ return segmented_preds
128
+
129
+
130
+ def diarize(diarization_pipeline, file, parameters, asr_outputs):
131
+ _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
+
133
+ segments = diarize_audio(
134
+ diarizer_inputs,
135
+ diarization_pipeline,
136
+ parameters
137
+ )
138
+
139
+ return post_process_segments_and_transcripts(
140
+ segments, asr_outputs["chunks"], group_by_speaker=False
141
+ )
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.2
2
+ torch==2.2.1
3
+ pyannote-audio==3.1.1
4
+ transformers==4.38.2
5
+ numpy==1.26.4
6
+ torchaudio==2.2.1
7
+ pydantic==2.6.3
8
+ pydantic-settings==2.2.1
9
+ starlette
10
+ gradio
util.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from AudioHandler import AudioHandler
2
+ from config import *
3
+ handler = AudioHandler(model_settings)
4
+
5
+ def format_as_markdown(transcript, chunks):
6
+ # Combine all transcript entries into a markdown string
7
+ if transcript:
8
+ return "\n".join(
9
+ f"**{segment.get('speaker', 'Speaker')}**: {segment.get('text', '')}" for segment in transcript
10
+ )
11
+ else:
12
+ return "\n".join(
13
+ f"**[Chunk {i + 1}]**: {chunk.get('text', '')}" for i, chunk in enumerate(chunks)
14
+ )
15
+
16
+
17
+ def process_audio(audio):
18
+ try:
19
+ # Load audio file
20
+ with open(audio, "rb") as f:
21
+ audio_data = f.read()
22
+
23
+ parameters = InferenceConfig(
24
+ task="transcribe",
25
+ language="en",
26
+ chunk_length_s=30,
27
+ batch_size=4,
28
+ assisted=False
29
+ )
30
+
31
+ # Run inference
32
+ result = handler.run_inference(audio_data, parameters)
33
+ transcript = result["speakers"]
34
+ chunks = result["chunks"]
35
+
36
+ # Format as markdown for the output
37
+ output = format_as_markdown(transcript, chunks)
38
+ except Exception as e:
39
+ output = f"**Error**: {str(e)}"
40
+
41
+ return output