Bisher commited on
Commit
c119904
·
verified ·
1 Parent(s): 03ed844

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import urllib.request
4
+ import torch
5
+ import gradio as gr
6
+
7
+ # ---------- Helper functions ----------
8
+
9
+ def download_file(url, dest_path):
10
+ """Download a file from a URL to a destination path."""
11
+ if not os.path.exists(dest_path):
12
+ print(f"Downloading {url} to {dest_path} ...")
13
+ urllib.request.urlretrieve(url, dest_path)
14
+ else:
15
+ print(f"File {dest_path} already exists.")
16
+
17
+ def clone_repo(repo_url, folder_name):
18
+ """Clone the repo if the folder does not exist."""
19
+ if not os.path.exists(folder_name):
20
+ print(f"Cloning repository {repo_url} ...")
21
+ os.system(f"git clone {repo_url}")
22
+ else:
23
+ print(f"Repository {folder_name} already exists.")
24
+
25
+ # ---------- Setup: Clone repository and download models ----------
26
+
27
+ # Clone the catt repository (which contains the necessary modules)
28
+ clone_repo("https://github.com/abjadai/catt.git", "catt")
29
+
30
+ # Add the cloned repository to the Python path so we can import modules from it.
31
+ if "catt" not in sys.path:
32
+ sys.path.append("catt")
33
+
34
+ # Create models folder if not exists
35
+ if not os.path.exists("models"):
36
+ os.makedirs("models")
37
+
38
+ # URLs for model checkpoints
39
+ url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
40
+ url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
41
+
42
+ ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt")
43
+ ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt")
44
+
45
+ download_file(url_ed, ckpt_path_ed)
46
+ download_file(url_eo, ckpt_path_eo)
47
+
48
+ # ---------- Import required modules from the cloned repository ----------
49
+
50
+ # Import the tokenizer and utility functions (shared by both models)
51
+ from tashkeel_tokenizer import TashkeelTokenizer
52
+ from utils import remove_non_arabic
53
+
54
+ # Import the two model implementations under different names
55
+ from ed_pl import TashkeelModel as TashkeelModel_ED
56
+ from eo_pl import TashkeelModel as TashkeelModel_EO
57
+
58
+ # ---------- Global model initialization ----------
59
+
60
+ # Prepare tokenizer (used by both models)
61
+ tokenizer = TashkeelTokenizer()
62
+
63
+ # Determine the device
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ print("Using device:", device)
66
+
67
+ # Global variables to hold the models
68
+ model_ed = None
69
+ model_eo = None
70
+
71
+ def load_models():
72
+ global model_ed, model_eo
73
+ max_seq_len = 1024
74
+
75
+ # Load Encoder-Decoder model
76
+ print("Loading Encoder-Decoder model...")
77
+ model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
78
+ model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device))
79
+ model_ed.eval().to(device)
80
+ print("Encoder-Decoder model loaded.")
81
+
82
+ # Load Encoder-Only model
83
+ print("Loading Encoder-Only model...")
84
+ model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
85
+ model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device))
86
+ model_eo.eval().to(device)
87
+ print("Encoder-Only model loaded.")
88
+
89
+ # Load the models at startup.
90
+ load_models()
91
+
92
+ # ---------- Inference Function ----------
93
+
94
+ def diacritize_text(model_type, input_text):
95
+ """
96
+ Process the input Arabic text (removing non-Arabic characters),
97
+ and run the appropriate model for diacritization.
98
+ """
99
+ if not input_text.strip():
100
+ return "Please enter some Arabic text."
101
+
102
+ # Clean the input text
103
+ text_clean = remove_non_arabic(input_text)
104
+ x = [text_clean]
105
+
106
+ batch_size = 16 # Using a fixed batch size; adjust if necessary.
107
+ verbose = False
108
+
109
+ try:
110
+ if model_type == "Encoder-Decoder":
111
+ output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose)
112
+ elif model_type == "Encoder-Only":
113
+ output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose)
114
+ else:
115
+ return "Unknown model type selected."
116
+ except Exception as e:
117
+ return f"An error occurred during diacritization: {str(e)}"
118
+
119
+ # Return the first (and only) line from the output
120
+ return output_lines[0] if output_lines else "No output produced."
121
+
122
+ # ---------- Gradio Interface ----------
123
+
124
+ title = "Arabic Diacritization with CATT"
125
+ description = (
126
+ "Enter Arabic text (without diacritics) below and select a model to perform "
127
+ "automatic diacritization. The Encoder-Decoder model may offer better accuracy, "
128
+ "while the Encoder-Only model is optimized for faster inference."
129
+ )
130
+
131
+ iface = gr.Interface(
132
+ fn=diacritize_text,
133
+ inputs=[
134
+ gr.inputs.Dropdown(
135
+ choices=["Encoder-Decoder", "Encoder-Only"],
136
+ default="Encoder-Only",
137
+ label="Model Selection"
138
+ ),
139
+ gr.inputs.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text")
140
+ ],
141
+ outputs=gr.outputs.Textbox(label="Diacritized Output"),
142
+ title=title,
143
+ description=description,
144
+ allow_flagging="never"
145
+ )
146
+
147
+ # ---------- Launch the Gradio App ----------
148
+
149
+ if __name__ == "__main__":
150
+ iface.launch()