PrivalingoDemo / app.py
Zhuo
init app
10a832a
raw
history blame
3.48 kB
import json
import os
import random
import numpy as np
import streamlit as st
import torch
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
@st.cache_resource
def load_model(model_name, tokenizer_name):
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return model, tokenizer
def load_json(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def preprocess(input_text, tokenizer, src_lang, tgt_lang):
# task_prefix = f"translate {src_lang} to {tgt_lang}: "
# input_text = task_prefix + input_text
model_inputs = tokenizer(
input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np"
)
return model_inputs
def translate(input_text, model, tokenizer, src_lang, tgt_lang):
model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang)
model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS)
prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)
return prediction[0]
def hold_deterministic(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
def main():
# st.title("Privalingo Playground Demo")
# html_temp = """
# <div style="background:#025246 ;padding:10px">
# <h2 style="color:white;text-align:center;">Playground Demo </h2>
# </div>
# """
# st.markdown(html_temp, unsafe_allow_html=True)
hold_deterministic(SEED)
st.title("Neural Machine Translation with DP-SGD")
st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
language_pair = st.selectbox("Language pair for translation", language_pairs_list)
src_lang, tgt_lang = language_pair.split("-")
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options)
st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...')
ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
model_name = MODEL.split('/')[-1]
model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
if not os.path.exists(model_path):
st.error(f"Model not found. Use {MODEL} instead")
model_path = MODEL
model, tokenizer = load_model(model_path, tokenizer_name=MODEL)
st.success('Model loaded!')
st_model_load.text("")
input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200)
if st.button("Translate"):
st.write("Translation")
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
st.success("".join([prediction]))
if __name__ == '__main__':
DATASETS_MODEL_INFO_PATH = os.getcwd() + "/app/dataset_and_model_info.json"
CHECKPOINTS_DIR = os.getcwd() + "/checkpoints"
DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
MODEL = 'google/mt5-small'
MAX_SEQ_LEN = 512
NUM_BEAMS = 3
SEED = 2023
main()