import os import sys import urllib.request import torch import gradio as gr # ---------- Helper functions ---------- def download_file(url, dest_path): """Download a file from a URL to a destination path.""" if not os.path.exists(dest_path): print(f"Downloading {url} to {dest_path} ...") urllib.request.urlretrieve(url, dest_path) else: print(f"File {dest_path} already exists.") def clone_repo(repo_url, folder_name): """Clone the repo if the folder does not exist.""" if not os.path.exists(folder_name): print(f"Cloning repository {repo_url} ...") os.system(f"git clone {repo_url}") else: print(f"Repository {folder_name} already exists.") # ---------- Setup: Clone repository and download models ---------- # Clone the catt repository (which contains the necessary modules) clone_repo("https://github.com/abjadai/catt.git", "catt") # Add the cloned repository to the Python path so we can import modules from it. if "catt" not in sys.path: sys.path.append("catt") # Create models folder if not exists if not os.path.exists("models"): os.makedirs("models") # URLs for model checkpoints url_ed = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt" url_eo = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt" ckpt_path_ed = os.path.join("models", "best_ed_mlm_ns_epoch_178.pt") ckpt_path_eo = os.path.join("models", "best_eo_mlm_ns_epoch_193.pt") download_file(url_ed, ckpt_path_ed) download_file(url_eo, ckpt_path_eo) # ---------- Import required modules from the cloned repository ---------- # Import the tokenizer and utility functions (shared by both models) from tashkeel_tokenizer import TashkeelTokenizer from utils import remove_non_arabic # Import the two model implementations under different names from ed_pl import TashkeelModel as TashkeelModel_ED from eo_pl import TashkeelModel as TashkeelModel_EO # ---------- Global model initialization ---------- # Prepare tokenizer (used by both models) tokenizer = TashkeelTokenizer() # Determine the device device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) # Global variables to hold the models model_ed = None model_eo = None def load_models(): global model_ed, model_eo max_seq_len = 1024 # Load Encoder-Decoder model print("Loading Encoder-Decoder model...") model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) model_ed.load_state_dict(torch.load(ckpt_path_ed, map_location=device)) model_ed.eval().to(device) print("Encoder-Decoder model loaded.") # Load Encoder-Only model print("Loading Encoder-Only model...") model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) model_eo.load_state_dict(torch.load(ckpt_path_eo, map_location=device)) model_eo.eval().to(device) print("Encoder-Only model loaded.") # Load the models at startup. load_models() # ---------- Inference Function ---------- def diacritize_text(model_type, input_text): """ Process the input Arabic text (removing non-Arabic characters), and run the appropriate model for diacritization. """ if not input_text.strip(): return "Please enter some Arabic text." # Clean the input text text_clean = remove_non_arabic(input_text) x = [text_clean] batch_size = 16 # Using a fixed batch size; adjust if necessary. verbose = False try: if model_type == "Encoder-Decoder": output_lines = model_ed.do_tashkeel_batch(x, batch_size, verbose) elif model_type == "Encoder-Only": output_lines = model_eo.do_tashkeel_batch(x, batch_size, verbose) else: return "Unknown model type selected." except Exception as e: return f"An error occurred during diacritization: {str(e)}" # Return the first (and only) line from the output return output_lines[0] if output_lines else "No output produced." # ---------- Gradio Interface ---------- title = "Arabic Diacritization with CATT" description = ( "Enter Arabic text (without diacritics) below and select a model to perform " "automatic diacritization. The Encoder-Decoder model may offer better accuracy, " "while the Encoder-Only model is optimized for faster inference." ) iface = gr.Interface( fn=diacritize_text, inputs=[ gr.inputs.Dropdown( choices=["Encoder-Decoder", "Encoder-Only"], default="Encoder-Only", label="Model Selection" ), gr.inputs.Textbox(lines=4, placeholder="Enter Arabic text here...", label="Input Text") ], outputs=gr.outputs.Textbox(label="Diacritized Output"), title=title, description=description, allow_flagging="never" ) # ---------- Launch the Gradio App ---------- if __name__ == "__main__": iface.launch()