File size: 4,752 Bytes
c119904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9058e43
c119904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9058e43
c119904
9058e43
c119904
 
9058e43
c119904
9058e43
c119904
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import sys
import urllib.request
import torch
import gradio as gr

# ---------- Helper functions ----------

def download_file(url, dest_path):
    """Download a file from a URL to a destination path."""
    if not os.path.exists(dest_path):
        print(f"Downloading {url} to {dest_path} ...")
        urllib.request.urlretrieve(url, dest_path)
    else:
        print(f"File {dest_path} already exists.")

def clone_repo(repo_url, folder_name):
    """Clone the repo if the folder does not exist."""
    if not os.path.exists(folder_name):
        print(f"Cloning repository {repo_url} ...")
        os.system(f"git clone {repo_url}")
    else:
        print(f"Repository {folder_name} already exists.")

# ---------- Setup: Clone repository and download models ----------

# Clone the catt repository (which contains the necessary modules)
clone_repo("https://github.com/abjadai/catt.git", "catt")

# Add the cloned repository to the Python path so we can import modules from it.
if "catt" not in sys.path:
    sys.path.append("catt")

# Create models folder if not exists
if not os.path.exists("models"):
    os.makedirs("models")

# URLs for model checkpoints
url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"

ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt")
ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt")

download_file(url_ed, ckpt_path_ed)
download_file(url_eo, ckpt_path_eo)

# ---------- Import required modules from the cloned repository ----------

from tashkeel_tokenizer import TashkeelTokenizer
from utils import remove_non_arabic
from ed_pl import TashkeelModel as TashkeelModel_ED
from eo_pl import TashkeelModel as TashkeelModel_EO

# ---------- Global model initialization ----------

# Prepare tokenizer (used by both models)
tokenizer = TashkeelTokenizer()

# Determine the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Global variables to hold the models
model_ed = None
model_eo = None

def load_models():
    global model_ed, model_eo
    max_seq_len = 1024

    # Load Encoder-Decoder model
    print("Loading Encoder-Decoder model...")
    model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
    model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device))
    model_ed.eval().to(device)
    print("Encoder-Decoder model loaded.")

    # Load Encoder-Only model
    print("Loading Encoder-Only model...")
    model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
    model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device))
    model_eo.eval().to(device)
    print("Encoder-Only model loaded.")

# Load the models at startup.
load_models()

# ---------- Inference Function ----------

def diacritize_text(model_type, input_text):
    """
    Process the input Arabic text (removing non-Arabic characters),
    and run the appropriate model for diacritization.
    """
    if not input_text.strip():
        return "Please enter some Arabic text."

    # Clean the input text
    text_clean = remove_non_arabic(input_text)
    x = [text_clean]

    batch_size = 16  # Fixed batch size.
    verbose = False

    try:
        if model_type == "Encoder-Decoder":
            output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose)
        elif model_type == "Encoder-Only":
            output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose)
        else:
            return "Unknown model type selected."
    except Exception as e:
        return f"An error occurred during diacritization: {str(e)}"

    return output_lines[0] if output_lines else "No output produced."

# ---------- Gradio Interface ----------

title = "Arabic Diacritization with CATT"
description = (
    "Enter Arabic text (without diacritics) below and select a model to perform "
    "automatic diacritization. The Encoder-Decoder model may offer better accuracy, "
    "while the Encoder-Only model is optimized for faster inference."
)

iface = gr.Interface(
    fn=diacritize_text,
    inputs=[
        gr.Dropdown(
            choices=["Encoder-Decoder", "Encoder-Only"],
            value="Encoder-Only",
            label="Model Selection"
        ),
        gr.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text")
    ],
    outputs=gr.Textbox(label="Diacritized Output"),
    title=title,
    description=description,
    allow_flagging="never"
)

# ---------- Launch the Gradio App ----------

if __name__ == "__main__":
    iface.launch()