Bisher's picture
Update app.py
9058e43 verified
raw
history blame
4.75 kB
import os
import sys
import urllib.request
import torch
import gradio as gr
# ---------- Helper functions ----------
def download_file(url, dest_path):
"""Download a file from a URL to a destination path."""
if not os.path.exists(dest_path):
print(f"Downloading {url} to {dest_path} ...")
urllib.request.urlretrieve(url, dest_path)
else:
print(f"File {dest_path} already exists.")
def clone_repo(repo_url, folder_name):
"""Clone the repo if the folder does not exist."""
if not os.path.exists(folder_name):
print(f"Cloning repository {repo_url} ...")
os.system(f"git clone {repo_url}")
else:
print(f"Repository {folder_name} already exists.")
# ---------- Setup: Clone repository and download models ----------
# Clone the catt repository (which contains the necessary modules)
clone_repo("https://github.com/abjadai/catt.git", "catt")
# Add the cloned repository to the Python path so we can import modules from it.
if "catt" not in sys.path:
sys.path.append("catt")
# Create models folder if not exists
if not os.path.exists("models"):
os.makedirs("models")
# URLs for model checkpoints
url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt")
ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt")
download_file(url_ed, ckpt_path_ed)
download_file(url_eo, ckpt_path_eo)
# ---------- Import required modules from the cloned repository ----------
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
from ed_pl import TashkeelModel as TashkeelModel_ED
from eo_pl import TashkeelModel as TashkeelModel_EO
# ---------- Global model initialization ----------
# Prepare tokenizer (used by both models)
tokenizer = TashkeelTokenizer()
# Determine the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# Global variables to hold the models
model_ed = None
model_eo = None
def load_models():
global model_ed, model_eo
max_seq_len = 1024
# Load Encoder-Decoder model
print("Loading Encoder-Decoder model...")
model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device))
model_ed.eval().to(device)
print("Encoder-Decoder model loaded.")
# Load Encoder-Only model
print("Loading Encoder-Only model...")
model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device))
model_eo.eval().to(device)
print("Encoder-Only model loaded.")
# Load the models at startup.
load_models()
# ---------- Inference Function ----------
def diacritize_text(model_type, input_text):
"""
Process the input Arabic text (removing non-Arabic characters),
and run the appropriate model for diacritization.
"""
if not input_text.strip():
return "Please enter some Arabic text."
# Clean the input text
text_clean = remove_non_arabic(input_text)
x = [text_clean]
batch_size = 16 # Fixed batch size.
verbose = False
try:
if model_type == "Encoder-Decoder":
output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose)
elif model_type == "Encoder-Only":
output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose)
else:
return "Unknown model type selected."
except Exception as e:
return f"An error occurred during diacritization: {str(e)}"
return output_lines[0] if output_lines else "No output produced."
# ---------- Gradio Interface ----------
title = "Arabic Diacritization with CATT"
description = (
"Enter Arabic text (without diacritics) below and select a model to perform "
"automatic diacritization. The Encoder-Decoder model may offer better accuracy, "
"while the Encoder-Only model is optimized for faster inference."
)
iface = gr.Interface(
fn=diacritize_text,
inputs=[
gr.Dropdown(
choices=["Encoder-Decoder", "Encoder-Only"],
value="Encoder-Only",
label="Model Selection"
),
gr.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text")
],
outputs=gr.Textbox(label="Diacritized Output"),
title=title,
description=description,
allow_flagging="never"
)
# ---------- Launch the Gradio App ----------
if __name__ == "__main__":
iface.launch()