Yehor's picture
Add a link to the model
57d9237
import sys
import time
try:
import spaces
except ImportError:
print("ZeroGPU is not available, skipping...")
import torch
import torchaudio
import gradio as gr
import torchaudio.transforms as T
import polars as pl
from importlib.metadata import version
from gradio.utils import is_zero_gpu_space
from gradio.themes import Base
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoProcessor,
MoonshineForConditionalGeneration,
)
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
use_zero_gpu = is_zero_gpu_space()
use_cuda = torch.cuda.is_available()
if use_zero_gpu:
spaces_version = version("spaces")
print("ZeroGPU is available, changing inference call.")
else:
spaces_version = "N/A"
print("ZeroGPU is not available, skipping...")
print(f"Spaces version: {spaces_version}")
if use_cuda:
print("CUDA is available, setting correct `device` variable.")
device = "cuda"
torch_dtype = torch.bfloat16
else:
device = "cpu"
torch_dtype = torch.bfloat16
# Config
model_name = "Yehor/kulyk-en-uk"
concurrency_limit = 5
current_theme = Base()
# Load the model
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load ASR
audio_processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-base")
audio_model = MoonshineForConditionalGeneration.from_pretrained(
"UsefulSensors/moonshine-base", attn_implementation="sdpa"
)
audio_model.to(device)
audio_model.to(torch_dtype)
# Load OCR
ocr_model = ocr_predictor(pretrained=True)
ocr_model.to(device)
# Examples
examples_text = [
"WP: F-16s are unlikely to make a significant difference on the battlefield",
"Missile and 7 of 8 Shaheeds shot down over Ukraine",
"Olympic Games 2024. Schedule of competitions for Ukrainian athletes on 28 July",
"Harris' campaign raised more than $200 million in less than a week",
"Over the week, the NBU sold almost $800 million on the interbank market",
"Paris 2024. Day 2: Text broadcast",
]
examples_audio = [
"example_1.wav",
"example_2.wav",
"example_3.wav",
"example_4.wav",
"example_5.wav",
"example_6.wav",
"example_7.wav",
]
examples_image = [
"example_1.jpg",
"example_2.jpg",
"example_3.jpg",
"example_4.jpg",
"example_5.jpg",
"example_6.jpg",
]
title = "EN-UK Translator"
authors_table = """
## Authors
Follow them on social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use egorsmkv@gmail.com |
""".strip()
description_head = f"""
# {title}
This space translates your text, audio, image from English to Ukrainian using [kulyk-en-uk](https://huggingface.co/Yehor/kulyk-en-uk) model. Also, check [UK-EN Translator](https://huggingface.co/spaces/Yehor/uk-en-translator) out.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
#### Models
- [kulyk-en-uk](https://huggingface.co/Yehor/kulyk-en-uk)
- [moonshine-base](https://huggingface.co/UsefulSensors/moonshine-base)
- [doctr](https://github.com/mindee/doctr)
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version("torch")}
- gradio: {version("gradio")}
- transformers: {version("transformers")}
""".strip()
def translate(text: str) -> str:
prompt = "Translate the text to Ukrainian:\n" + text
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
).to(model.device)
output = model.generate(
input_ids,
max_new_tokens=2048,
# Greedy Search
do_sample=False,
repetition_penalty=1.05,
# Sampling
# do_sample=True,
# temperature=0.1,
# # top_k=1,
# min_p=0.9,
# repetition_penalty=1.05,
)
prompt_len = input_ids.shape[1]
generated_tokens = output[:, prompt_len:]
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
return translated_text.strip()
@spaces.GPU
def inference_text(text, progress=gr.Progress()):
if not text:
raise gr.Error("Please paste your text.")
progress(0, desc="Translating...")
results = []
sentences = text.split("\n")
non_empty_sentences = []
for sentence in sentences:
s = sentence.strip()
if len(s) != 0:
non_empty_sentences.append(s)
for sentence in progress.tqdm(
non_empty_sentences, desc="Translating...", unit="sentence"
):
t0 = time.time()
translated_text = translate(sentence)
elapsed_time = round(time.time() - t0, 2)
translated_text = translated_text.strip()
results.append(
{
"sentence": sentence,
"translated_text": translated_text,
"elapsed_time": elapsed_time,
}
)
gr.Info("Finished!", duration=2)
return pl.DataFrame(results)
@spaces.GPU
def inference_audio(audio, progress=gr.Progress()):
if not audio:
raise gr.Error("Please paste your audio file.")
progress(0, desc="Translating...")
if isinstance(audio, str):
audio_array, sr = torchaudio.load(audio)
audio_array = audio_array.squeeze()
else:
audio_array, sr = audio
r_sr = audio_processor.feature_extractor.sampling_rate
print("Audio processor SR:", r_sr)
print("Audio file SR:", sr)
if r_sr != sr:
print("Resampling...")
resampler = T.Resample(orig_freq=sr, new_freq=r_sr)
audio_array = resampler(audio_array)
inputs = audio_processor(audio_array, return_tensors="pt", sampling_rate=r_sr)
inputs = inputs.to(device, dtype=torch_dtype)
# to avoid hallucination loops, we limit the maximum length of the generated text based expected number of tokens per second
token_limit_factor = (
6.5 / audio_processor.feature_extractor.sampling_rate
) # Maximum of 6.5 tokens per second
seq_lens = inputs.attention_mask.sum(dim=-1)
max_length = int((seq_lens * token_limit_factor).max().item())
generated_ids = audio_model.generate(**inputs, max_length=max_length)
predictions = audio_processor.batch_decode(generated_ids, skip_special_tokens=True)
print("Predictions:", predictions)
text = predictions[0]
print("Text:", text)
results = []
sentences = text.split("\n")
non_empty_sentences = []
for sentence in sentences:
s = sentence.strip()
if len(s) != 0:
non_empty_sentences.append(s)
for sentence in progress.tqdm(
non_empty_sentences, desc="Translating...", unit="sentence"
):
t0 = time.time()
translated_text = translate(sentence)
elapsed_time = round(time.time() - t0, 2)
results.append(
{
"sentence": sentence,
"translated_text": translated_text,
"elapsed_time": elapsed_time,
}
)
gr.Info("Finished!", duration=2)
return pl.DataFrame(results)
@spaces.GPU
def inference_image(image, progress=gr.Progress()):
if not image:
raise gr.Error("Please paste your image file.")
progress(0, desc="Translating...")
if isinstance(image, str):
doc = DocumentFile.from_images(image)
else:
raise gr.Error("Please paste your image file.")
result = ocr_model(doc)
text = result.render()
print("Text:", text)
results = []
sentences = [text.replace("\n", " ")]
for sentence in progress.tqdm(sentences, desc="Translating...", unit="sentence"):
t0 = time.time()
translated_text = translate(sentence)
elapsed_time = round(time.time() - t0, 2)
results.append(
{
"sentence": sentence,
"translated_text": translated_text,
"elapsed_time": elapsed_time,
}
)
gr.Info("Finished!", duration=2)
return pl.DataFrame(results)
def create_app():
tab = gr.Blocks(
title=title,
analytics_enabled=False,
theme=current_theme,
)
with tab:
gr.Markdown(description_head)
gr.Markdown("## Usage")
translated_text = gr.DataFrame(
label="Translated text",
)
text = gr.Textbox(label="Text", autofocus=True, lines=5)
gr.Button("Translate").click(
inference_text,
concurrency_limit=concurrency_limit,
inputs=text,
outputs=translated_text,
)
with gr.Row():
gr.Examples(label="Choose an example", inputs=text, examples=examples_text)
return tab
def create_audio_app():
with gr.Blocks(theme=current_theme) as tab:
gr.Markdown(description_head)
gr.Markdown("## Usage")
translated_text = gr.DataFrame(
label="Translated text",
)
audio = gr.Audio(label="Audio file", sources="upload", type="filepath")
gr.Button("Translate").click(
inference_audio,
concurrency_limit=concurrency_limit,
inputs=audio,
outputs=translated_text,
)
with gr.Row():
gr.Examples(
label="Choose an example", inputs=audio, examples=examples_audio
)
return tab
def create_image_app():
with gr.Blocks(theme=current_theme) as tab:
gr.Markdown(description_head)
gr.Markdown("## Usage")
translated_text = gr.DataFrame(
label="Translated text",
)
image = gr.Image(label="Image file", sources="upload", type="filepath")
gr.Button("Translate").click(
inference_image,
concurrency_limit=concurrency_limit,
inputs=image,
outputs=translated_text,
)
with gr.Row():
gr.Examples(
label="Choose an example", inputs=image, examples=examples_image
)
return tab
def create_env():
with gr.Blocks(theme=current_theme) as tab:
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
return tab
def create_authors():
with gr.Blocks(theme=current_theme) as tab:
gr.Markdown(authors_table)
return tab
def create_demo():
app_tab = create_app()
app_audio_tab = create_audio_app()
app_image_tab = create_image_app()
authors_tab = create_authors()
env_tab = create_env()
return gr.TabbedInterface(
[app_tab, app_audio_tab, app_image_tab, authors_tab, env_tab],
tab_names=[
"✍️ Text",
"πŸ”Š Audio",
"πŸ‘€ Image",
"πŸ‘₯ Authors",
"πŸ“¦ Environment, Models, and Libraries",
],
)
if __name__ == "__main__":
demo = create_demo()
demo.queue()
demo.launch()