import json import os import random import numpy as np import streamlit as st import torch from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer @st.cache_resource def load_model(model_name, tokenizer_name): model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) return model, tokenizer def load_json(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def preprocess(input_text, tokenizer, src_lang, tgt_lang): # task_prefix = f"translate {src_lang} to {tgt_lang}: " # input_text = task_prefix + input_text model_inputs = tokenizer( input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" ) return model_inputs def translate(input_text, model, tokenizer, src_lang, tgt_lang): model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) return prediction[0] def hold_deterministic(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) def main(): hold_deterministic(SEED) st.title("Neural Machine Translation with DP-SGD") st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)") dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys())) language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) language_pair = st.selectbox("Language pair for translation", language_pairs_list) src_lang, tgt_lang = language_pair.split("-") epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) epsilon = st.radio("Select a privacy budget epsilon", epsilon_options) st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...') model_path = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] if not os.path.exists(model_path): st.error(f"Model not found. Use {DEFAULT_MODEL} instead") model_path = DEFAULT_MODEL model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) st.success('Model loaded!') st_model_load.text("") input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200) if st.button("Translate"): st.write("Translation") prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) st.success("".join([prediction])) if __name__ == '__main__': DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json") DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) DEFAULT_MODEL = 'google/mt5-small' MAX_SEQ_LEN = 512 NUM_BEAMS = 3 SEED = 2023 main()