Spaces:
Running
Running
import os | |
import sys | |
import urllib.request | |
import torch | |
import gradio as gr | |
# ---------- Helper functions ---------- | |
def download_file(url, dest_path): | |
"""Download a file from a URL to a destination path.""" | |
if not os.path.exists(dest_path): | |
print(f"Downloading {url} to {dest_path} ...") | |
urllib.request.urlretrieve(url, dest_path) | |
else: | |
print(f"File {dest_path} already exists.") | |
def clone_repo(repo_url, folder_name): | |
"""Clone the repo if the folder does not exist.""" | |
if not os.path.exists(folder_name): | |
print(f"Cloning repository {repo_url} ...") | |
os.system(f"git clone {repo_url}") | |
else: | |
print(f"Repository {folder_name} already exists.") | |
# ---------- Setup: Clone repository and download models ---------- | |
# Clone the catt repository (which contains the necessary modules) | |
clone_repo("https://github.com/abjadai/catt.git", "catt") | |
# Add the cloned repository to the Python path so we can import modules from it. | |
if "catt" not in sys.path: | |
sys.path.append("catt") | |
# Create models folder if not exists | |
if not os.path.exists("models"): | |
os.makedirs("models") | |
# URLs for model checkpoints | |
url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt" | |
url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt" | |
ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt") | |
ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt") | |
download_file(url_ed, ckpt_path_ed) | |
download_file(url_eo, ckpt_path_eo) | |
# ---------- Import required modules from the cloned repository ---------- | |
from tashkeel_tokenizer import TashkeelTokenizer | |
from utils import remove_non_arabic | |
from ed_pl import TashkeelModel as TashkeelModel_ED | |
from eo_pl import TashkeelModel as TashkeelModel_EO | |
# ---------- Global model initialization ---------- | |
# Prepare tokenizer (used by both models) | |
tokenizer = TashkeelTokenizer() | |
# Determine the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Using device:", device) | |
# Global variables to hold the models | |
model_ed = None | |
model_eo = None | |
def load_models(): | |
global model_ed, model_eo | |
max_seq_len = 1024 | |
# Load Encoder-Decoder model | |
print("Loading Encoder-Decoder model...") | |
model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) | |
model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device)) | |
model_ed.eval().to(device) | |
print("Encoder-Decoder model loaded.") | |
# Load Encoder-Only model | |
print("Loading Encoder-Only model...") | |
model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) | |
model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device)) | |
model_eo.eval().to(device) | |
print("Encoder-Only model loaded.") | |
# Load the models at startup. | |
load_models() | |
# ---------- Inference Function ---------- | |
def diacritize_text(model_type, input_text): | |
""" | |
Process the input Arabic text (removing non-Arabic characters), | |
and run the appropriate model for diacritization. | |
""" | |
if not input_text.strip(): | |
return "Please enter some Arabic text." | |
# Clean the input text | |
text_clean = remove_non_arabic(input_text) | |
x = [text_clean] | |
batch_size = 16 # Fixed batch size. | |
verbose = False | |
try: | |
if model_type == "Encoder-Decoder": | |
output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose) | |
elif model_type == "Encoder-Only": | |
output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose) | |
else: | |
return "Unknown model type selected." | |
except Exception as e: | |
return f"An error occurred during diacritization: {str(e)}" | |
return output_lines[0] if output_lines else "No output produced." | |
# ---------- Gradio Interface ---------- | |
title = "Arabic Diacritization with CATT" | |
description = ( | |
"Enter Arabic text (without diacritics) below and select a model to perform " | |
"automatic diacritization. The Encoder-Decoder model may offer better accuracy, " | |
"while the Encoder-Only model is optimized for faster inference." | |
) | |
iface = gr.Interface( | |
fn=diacritize_text, | |
inputs=[ | |
gr.Dropdown( | |
choices=["Encoder-Decoder", "Encoder-Only"], | |
value="Encoder-Only", | |
label="Model Selection" | |
), | |
gr.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text") | |
], | |
outputs=gr.Textbox(label="Diacritized Output"), | |
title=title, | |
description=description, | |
allow_flagging="never" | |
) | |
# ---------- Launch the Gradio App ---------- | |
if __name__ == "__main__": | |
iface.launch() |