File size: 6,099 Bytes
8dde699
 
 
 
 
46c7e88
8dde699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46c7e88
 
 
2263e75
46c7e88
 
 
8dde699
 
 
 
 
 
 
 
46c7e88
 
8dde699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2263e75
8dde699
 
 
 
 
 
 
 
 
 
 
2263e75
8dde699
 
 
 
 
2263e75
8dde699
 
 
 
 
2263e75
8dde699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683aeb8
8dde699
 
 
 
 
 
 
 
 
 
 
 
 
2263e75
8dde699
 
 
 
 
 
 
e93c3f9
8dde699
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import logging
import warnings

import gradio as gr
import pytube as pt
import psutil
import torch
import whisper
from huggingface_hub import hf_hub_download, model_info
from transformers.utils.logging import disable_progress_bar

warnings.filterwarnings("ignore")
disable_progress_bar()

DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
CHECKPOINT_FILENAME = "checkpoint_openai.pt"

GEN_KWARGS = {
    "task": "transcribe",
    "language": "fr",
    # "without_timestamps": True,
    # decode options
    # "beam_size": 5,
    # "patience": 2,
    # disable fallback
    # "compression_ratio_threshold": None,
    # "logprob_threshold": None,
    # vad threshold
    # "no_speech_threshold": None,
}

logging.basicConfig(
    format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%SZ",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# device = 0 if torch.cuda.is_available() else "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Model will be loaded on device `{device}`")

cached_models = {}


def _print_memory_info():
    memory = psutil.virtual_memory()
    logger.info(
        f"Memory info - Free: {memory.available / (1024 ** 3):.2f} Gb, used: {memory.percent}%, total: {memory.total / (1024 ** 3):.2f} Gb"
    )


def print_cuda_memory_info():
    used_mem, tot_mem = torch.cuda.mem_get_info()
    logger.info(
        f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
    )


def print_memory_info():
    _print_memory_info()
    print_cuda_memory_info()


def maybe_load_cached_pipeline(model_name):
    model = cached_models.get(model_name)
    if model is None:
        downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)

        model = whisper.load_model(downloaded_model_path, device=device)
        logger.info(f"`{model_name}` has been loaded on device `{device}`")

        print_memory_info()

        cached_models[model_name] = model
    return model


def infer(model, filename, with_timestamps):
    if with_timestamps:
        model_outputs = model.transcribe(filename, **GEN_KWARGS)
        return "\n\n".join(
            [
                f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
                for segment in model_outputs["segments"]
            ]
        )
    else:
        return model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]


def download_from_youtube(yt_url, downloaded_filename="audio.wav"):
    yt = pt.YouTube(yt_url)
    stream = yt.streams.filter(only_audio=True)[0]
    # stream.download(filename="audio.mp3")
    stream.download(filename=downloaded_filename)
    return downloaded_filename


def transcribe(microphone, file_upload, yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
    warn_output = ""
    if (microphone is not None) and (file_upload is not None) and yt_url:
        warn_output = (
            "WARNING: You've uploaded an audio file, used the microphone, and pasted a YouTube URL. "
            "The recorded file from the microphone will be used, the uploaded audio and the YouTube URL will be discarded.\n"
        )

    if (microphone is not None) and (file_upload is not None):
        warn_output = (
            "WARNING: You've uploaded an audio file and used the microphone. "
            "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
        )

    if (microphone is not None) and yt_url:
        warn_output = (
            "WARNING: You've used the microphone and pasted a YouTube URL. "
            "The recorded file from the microphone will be used and the YouTube URL will be discarded.\n"
        )

    if (file_upload is not None) and yt_url:
        warn_output = (
            "WARNING: You've uploaded an audio file and pasted a YouTube URL. "
            "The uploaded audio will be used and the YouTube URL will be discarded.\n"
        )

    elif (microphone is None) and (file_upload is None) and (not yt_url):
        return "ERROR: You have to either use the microphone, upload an audio file or paste a YouTube URL"

    if microphone is not None:
        file = microphone
        logging_prefix = f"Transcription by `{model_name}` of microphone:"
    elif file_upload is not None:
        file = file_upload
        logging_prefix = f"Transcription by `{model_name}` of uploaded file:"
    else:
        file = download_from_youtube(yt_url)
        logging_prefix = f'Transcription by `{model_name}` of "{yt_url}":'

    model = maybe_load_cached_pipeline(model_name)
    # text = model.transcribe(file, **GEN_KWARGS)["text"]
    text = infer(model, file, with_timestamps)

    logger.info(logging_prefix + "\n" + text + "\n")

    return warn_output + text


# load default model
maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)

demo = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True),
        gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True),
        gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL", optional=True),
        gr.Checkbox(label="With timestamps?"),
    ],
    # outputs="text",
    outputs=gr.outputs.Textbox(label="Transcription"),
    layout="horizontal",
    theme="huggingface",
    title="Whisper French Demo 🇫🇷 : Transcribe Audio",
    description=(
        "**Transcribe long-form microphone, audio inputs or YouTube videos with the click of a button!** \n\nDemo uses the the fine-tuned"
        f" checkpoint [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)


# demo.launch(server_name="0.0.0.0", debug=True, share=True)
demo.launch(enable_queue=True)