Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import urllib.request
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
# ---------- Helper functions ----------
|
8 |
+
|
9 |
+
def download_file(url, dest_path):
|
10 |
+
"""Download a file from a URL to a destination path."""
|
11 |
+
if not os.path.exists(dest_path):
|
12 |
+
print(f"Downloading {url} to {dest_path} ...")
|
13 |
+
urllib.request.urlretrieve(url, dest_path)
|
14 |
+
else:
|
15 |
+
print(f"File {dest_path} already exists.")
|
16 |
+
|
17 |
+
def clone_repo(repo_url, folder_name):
|
18 |
+
"""Clone the repo if the folder does not exist."""
|
19 |
+
if not os.path.exists(folder_name):
|
20 |
+
print(f"Cloning repository {repo_url} ...")
|
21 |
+
os.system(f"git clone {repo_url}")
|
22 |
+
else:
|
23 |
+
print(f"Repository {folder_name} already exists.")
|
24 |
+
|
25 |
+
# ---------- Setup: Clone repository and download models ----------
|
26 |
+
|
27 |
+
# Clone the catt repository (which contains the necessary modules)
|
28 |
+
clone_repo("https://github.com/abjadai/catt.git", "catt")
|
29 |
+
|
30 |
+
# Add the cloned repository to the Python path so we can import modules from it.
|
31 |
+
if "catt" not in sys.path:
|
32 |
+
sys.path.append("catt")
|
33 |
+
|
34 |
+
# Create models folder if not exists
|
35 |
+
if not os.path.exists("models"):
|
36 |
+
os.makedirs("models")
|
37 |
+
|
38 |
+
# URLs for model checkpoints
|
39 |
+
url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
|
40 |
+
url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
|
41 |
+
|
42 |
+
ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt")
|
43 |
+
ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt")
|
44 |
+
|
45 |
+
download_file(url_ed, ckpt_path_ed)
|
46 |
+
download_file(url_eo, ckpt_path_eo)
|
47 |
+
|
48 |
+
# ---------- Import required modules from the cloned repository ----------
|
49 |
+
|
50 |
+
# Import the tokenizer and utility functions (shared by both models)
|
51 |
+
from tashkeel_tokenizer import TashkeelTokenizer
|
52 |
+
from utils import remove_non_arabic
|
53 |
+
|
54 |
+
# Import the two model implementations under different names
|
55 |
+
from ed_pl import TashkeelModel as TashkeelModel_ED
|
56 |
+
from eo_pl import TashkeelModel as TashkeelModel_EO
|
57 |
+
|
58 |
+
# ---------- Global model initialization ----------
|
59 |
+
|
60 |
+
# Prepare tokenizer (used by both models)
|
61 |
+
tokenizer = TashkeelTokenizer()
|
62 |
+
|
63 |
+
# Determine the device
|
64 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
65 |
+
print("Using device:", device)
|
66 |
+
|
67 |
+
# Global variables to hold the models
|
68 |
+
model_ed = None
|
69 |
+
model_eo = None
|
70 |
+
|
71 |
+
def load_models():
|
72 |
+
global model_ed, model_eo
|
73 |
+
max_seq_len = 1024
|
74 |
+
|
75 |
+
# Load Encoder-Decoder model
|
76 |
+
print("Loading Encoder-Decoder model...")
|
77 |
+
model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
|
78 |
+
model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device))
|
79 |
+
model_ed.eval().to(device)
|
80 |
+
print("Encoder-Decoder model loaded.")
|
81 |
+
|
82 |
+
# Load Encoder-Only model
|
83 |
+
print("Loading Encoder-Only model...")
|
84 |
+
model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
|
85 |
+
model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device))
|
86 |
+
model_eo.eval().to(device)
|
87 |
+
print("Encoder-Only model loaded.")
|
88 |
+
|
89 |
+
# Load the models at startup.
|
90 |
+
load_models()
|
91 |
+
|
92 |
+
# ---------- Inference Function ----------
|
93 |
+
|
94 |
+
def diacritize_text(model_type, input_text):
|
95 |
+
"""
|
96 |
+
Process the input Arabic text (removing non-Arabic characters),
|
97 |
+
and run the appropriate model for diacritization.
|
98 |
+
"""
|
99 |
+
if not input_text.strip():
|
100 |
+
return "Please enter some Arabic text."
|
101 |
+
|
102 |
+
# Clean the input text
|
103 |
+
text_clean = remove_non_arabic(input_text)
|
104 |
+
x = [text_clean]
|
105 |
+
|
106 |
+
batch_size = 16 # Using a fixed batch size; adjust if necessary.
|
107 |
+
verbose = False
|
108 |
+
|
109 |
+
try:
|
110 |
+
if model_type == "Encoder-Decoder":
|
111 |
+
output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose)
|
112 |
+
elif model_type == "Encoder-Only":
|
113 |
+
output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose)
|
114 |
+
else:
|
115 |
+
return "Unknown model type selected."
|
116 |
+
except Exception as e:
|
117 |
+
return f"An error occurred during diacritization: {str(e)}"
|
118 |
+
|
119 |
+
# Return the first (and only) line from the output
|
120 |
+
return output_lines[0] if output_lines else "No output produced."
|
121 |
+
|
122 |
+
# ---------- Gradio Interface ----------
|
123 |
+
|
124 |
+
title = "Arabic Diacritization with CATT"
|
125 |
+
description = (
|
126 |
+
"Enter Arabic text (without diacritics) below and select a model to perform "
|
127 |
+
"automatic diacritization. The Encoder-Decoder model may offer better accuracy, "
|
128 |
+
"while the Encoder-Only model is optimized for faster inference."
|
129 |
+
)
|
130 |
+
|
131 |
+
iface = gr.Interface(
|
132 |
+
fn=diacritize_text,
|
133 |
+
inputs=[
|
134 |
+
gr.inputs.Dropdown(
|
135 |
+
choices=["Encoder-Decoder", "Encoder-Only"],
|
136 |
+
default="Encoder-Only",
|
137 |
+
label="Model Selection"
|
138 |
+
),
|
139 |
+
gr.inputs.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text")
|
140 |
+
],
|
141 |
+
outputs=gr.outputs.Textbox(label="Diacritized Output"),
|
142 |
+
title=title,
|
143 |
+
description=description,
|
144 |
+
allow_flagging="never"
|
145 |
+
)
|
146 |
+
|
147 |
+
# ---------- Launch the Gradio App ----------
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
iface.launch()
|