import json import os import random import re import numpy as np import streamlit as st import torch from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer @st.cache_resource(show_spinner=False) def load_model(model_name, tokenizer_name): try: model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) except Exception as e: st.error(f"Error loading model: {e}") st.error(f"Model not found. Use {DEFAULT_MODEL} instead") model_path = DEFAULT_MODEL model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) return model, tokenizer def load_json(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def preprocess(input_text, tokenizer, src_lang, tgt_lang): # task_prefix = f"translate {src_lang} to {tgt_lang}: " # input_text = task_prefix + input_text model_inputs = tokenizer( input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" ) return model_inputs def translate(input_text, model, tokenizer, src_lang, tgt_lang): model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) return prediction[0] def hold_deterministic(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) def postprocess(output_text): output = re.sub(r"]*>", "", output_text) return output def main(): hold_deterministic(SEED) st.set_page_config(page_title="DP-NMT DEMO", layout="wide") st.title("Neural Machine Translation with DP-SGD") st.write( "[![Star](https://img.shields.io/github/stars/trusthlt/dp-nmt.svg?logo=github&style=social)](https://github.com/trusthlt/dp-nmt)" "   " "[![ACL](https://img.shields.io/badge/ACL-Link-red.svg?logo=data:image/svg%2bxml;base64,PHN2ZyBoZWlnaHQ9IjI2MC4wOTA0ODIiIHZpZXdCb3g9IjAgMCA2OCA0NiIgd2lkdGg9IjM4NC40ODE1ODIiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0ibTQxLjk3NzU1MyAwdjMuMDE1OGgtMzQuNDkwNjQ3Ni03LjQ4NjkwNTR2Ny40ODQ5OSAyNy45Nzc4OCA3LjUyMTMzaDcuNDg2OTA1NCA0Mi4wMTM4OTY2IDcuNDg2OTA2IDExLjAxMjI5MnYtMTUuMDA2MzJoLTExLjAxMjI5MnYtMjAuNDkyODktNy40ODQ5OWMwLTEuNTczNjkgMC0xLjI1NDAyIDAtMy4wMTU4em0tMjYuOTY3Mzk4IDE3Ljk4NTc4aDI2Ljk2NzM5OHYxMy4wMDc5aC0yNi45NjczOTh6IiBmaWxsPSIjZWQxYzI0IiBmaWxsLXJ1bGU9ImV2ZW5vZGQiLz48L3N2Zz4=&link=https%3A%2F%2Faclanthology.org%2F2024.eacl-demo.11%2F)](https://aclanthology.org/2024.eacl-demo.11/)" ) st.write("This is a demo for private neural machine translation with DP-SGD.") left, right = st.columns(2) with left: dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys())) language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) language_pair = st.selectbox("Language pair for translation", language_pairs_list) src_lang, tgt_lang = language_pair.split("-") epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True) with right: input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN) btn_translate = st.button("Translate") ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] if "privalingo" in ckpt: # that means the model is loaded from huggingface hub rather checkpoints locally model_path = ckpt else: model_name = DEFAULT_MODEL.split('/')[-1] model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) if not os.path.exists(model_path): with left: st.error(f"Model not found. Use {DEFAULT_MODEL} instead") model_path = DEFAULT_MODEL with left: with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'): model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) sc_box_model_loaded = st.success('Model loaded!') if btn_translate: with right: with st.spinner("Translating..."): prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) st.write("**Translation:**") result_container = st.container(border=True) result_container.write("".join([postprocess(prediction)])) if __name__ == '__main__': DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json") print(DATASETS_MODEL_INFO_PATH) DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) DEFAULT_MODEL = 'google/mt5-small' MAX_SEQ_LEN = 512 NUM_BEAMS = 3 SEED = 2023 MAX_INPUT_LEN = 500 main()