Spaces:
Running
Running
File size: 4,752 Bytes
c119904 9058e43 c119904 9058e43 c119904 9058e43 c119904 9058e43 c119904 9058e43 c119904 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import os
import sys
import urllib.request
import torch
import gradio as gr
# ---------- Helper functions ----------
def download_file(url, dest_path):
"""Download a file from a URL to a destination path."""
if not os.path.exists(dest_path):
print(f"Downloading {url} to {dest_path} ...")
urllib.request.urlretrieve(url, dest_path)
else:
print(f"File {dest_path} already exists.")
def clone_repo(repo_url, folder_name):
"""Clone the repo if the folder does not exist."""
if not os.path.exists(folder_name):
print(f"Cloning repository {repo_url} ...")
os.system(f"git clone {repo_url}")
else:
print(f"Repository {folder_name} already exists.")
# ---------- Setup: Clone repository and download models ----------
# Clone the catt repository (which contains the necessary modules)
clone_repo("https://github.com/abjadai/catt.git", "catt")
# Add the cloned repository to the Python path so we can import modules from it.
if "catt" not in sys.path:
sys.path.append("catt")
# Create models folder if not exists
if not os.path.exists("models"):
os.makedirs("models")
# URLs for model checkpoints
url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt")
ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt")
download_file(url_ed, ckpt_path_ed)
download_file(url_eo, ckpt_path_eo)
# ---------- Import required modules from the cloned repository ----------
from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
from ed_pl import TashkeelModel as TashkeelModel_ED
from eo_pl import TashkeelModel as TashkeelModel_EO
# ---------- Global model initialization ----------
# Prepare tokenizer (used by both models)
tokenizer = TashkeelTokenizer()
# Determine the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
# Global variables to hold the models
model_ed = None
model_eo = None
def load_models():
global model_ed, model_eo
max_seq_len = 1024
# Load Encoder-Decoder model
print("Loading Encoder-Decoder model...")
model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device))
model_ed.eval().to(device)
print("Encoder-Decoder model loaded.")
# Load Encoder-Only model
print("Loading Encoder-Only model...")
model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device))
model_eo.eval().to(device)
print("Encoder-Only model loaded.")
# Load the models at startup.
load_models()
# ---------- Inference Function ----------
def diacritize_text(model_type, input_text):
"""
Process the input Arabic text (removing non-Arabic characters),
and run the appropriate model for diacritization.
"""
if not input_text.strip():
return "Please enter some Arabic text."
# Clean the input text
text_clean = remove_non_arabic(input_text)
x = [text_clean]
batch_size = 16 # Fixed batch size.
verbose = False
try:
if model_type == "Encoder-Decoder":
output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose)
elif model_type == "Encoder-Only":
output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose)
else:
return "Unknown model type selected."
except Exception as e:
return f"An error occurred during diacritization: {str(e)}"
return output_lines[0] if output_lines else "No output produced."
# ---------- Gradio Interface ----------
title = "Arabic Diacritization with CATT"
description = (
"Enter Arabic text (without diacritics) below and select a model to perform "
"automatic diacritization. The Encoder-Decoder model may offer better accuracy, "
"while the Encoder-Only model is optimized for faster inference."
)
iface = gr.Interface(
fn=diacritize_text,
inputs=[
gr.Dropdown(
choices=["Encoder-Decoder", "Encoder-Only"],
value="Encoder-Only",
label="Model Selection"
),
gr.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text")
],
outputs=gr.Textbox(label="Diacritized Output"),
title=title,
description=description,
allow_flagging="never"
)
# ---------- Launch the Gradio App ----------
if __name__ == "__main__":
iface.launch() |