import json import os import random import numpy as np import streamlit as st import torch from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer @st.cache_resource def load_model(model_name, tokenizer_name): model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) return model, tokenizer def load_json(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def preprocess(input_text, tokenizer, src_lang, tgt_lang): # task_prefix = f"translate {src_lang} to {tgt_lang}: " # input_text = task_prefix + input_text model_inputs = tokenizer( input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" ) return model_inputs def translate(input_text, model, tokenizer, src_lang, tgt_lang): model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) return prediction[0] def hold_deterministic(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) def main(): # st.title("Privalingo Playground Demo") # html_temp = """ #
#

Playground Demo

#
# """ # st.markdown(html_temp, unsafe_allow_html=True) hold_deterministic(SEED) st.title("Neural Machine Translation with DP-SGD") st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)") dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys())) language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) language_pair = st.selectbox("Language pair for translation", language_pairs_list) src_lang, tgt_lang = language_pair.split("-") epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) epsilon = st.radio("Select a privacy budget epsilon", epsilon_options) st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...') ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] model_name = MODEL.split('/')[-1] model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) if not os.path.exists(model_path): st.error(f"Model not found. Use {MODEL} instead") model_path = MODEL model, tokenizer = load_model(model_path, tokenizer_name=MODEL) st.success('Model loaded!') st_model_load.text("") input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200) if st.button("Translate"): st.write("Translation") prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) st.success("".join([prediction])) if __name__ == '__main__': DATASETS_MODEL_INFO_PATH = os.getcwd() + "/app/dataset_and_model_info.json" CHECKPOINTS_DIR = os.getcwd() + "/checkpoints" DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) MODEL = 'google/mt5-small' MAX_SEQ_LEN = 512 NUM_BEAMS = 3 SEED = 2023 main()