Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import numpy as np | |
import streamlit as st | |
import torch | |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer | |
def load_model(model_name, tokenizer_name): | |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
return model, tokenizer | |
def load_json(file_path): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
return data | |
def preprocess(input_text, tokenizer, src_lang, tgt_lang): | |
# task_prefix = f"translate {src_lang} to {tgt_lang}: " | |
# input_text = task_prefix + input_text | |
model_inputs = tokenizer( | |
input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" | |
) | |
return model_inputs | |
def translate(input_text, model, tokenizer, src_lang, tgt_lang): | |
model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) | |
model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) | |
prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) | |
return prediction[0] | |
def hold_deterministic(seed): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
def main(): | |
# st.title("Privalingo Playground Demo") | |
# html_temp = """ | |
# <div style="background:#025246 ;padding:10px"> | |
# <h2 style="color:white;text-align:center;">Playground Demo </h2> | |
# </div> | |
# """ | |
# st.markdown(html_temp, unsafe_allow_html=True) | |
hold_deterministic(SEED) | |
st.title("Neural Machine Translation with DP-SGD") | |
st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)") | |
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys())) | |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) | |
language_pair = st.selectbox("Language pair for translation", language_pairs_list) | |
src_lang, tgt_lang = language_pair.split("-") | |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) | |
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options) | |
st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...') | |
ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] | |
model_name = MODEL.split('/')[-1] | |
model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) | |
if not os.path.exists(model_path): | |
st.error(f"Model not found. Use {MODEL} instead") | |
model_path = MODEL | |
model, tokenizer = load_model(model_path, tokenizer_name=MODEL) | |
st.success('Model loaded!') | |
st_model_load.text("") | |
input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200) | |
if st.button("Translate"): | |
st.write("Translation") | |
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) | |
st.success("".join([prediction])) | |
if __name__ == '__main__': | |
DATASETS_MODEL_INFO_PATH = os.getcwd() + "/app/dataset_and_model_info.json" | |
CHECKPOINTS_DIR = os.getcwd() + "/checkpoints" | |
DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) | |
MODEL = 'google/mt5-small' | |
MAX_SEQ_LEN = 512 | |
NUM_BEAMS = 3 | |
SEED = 2023 | |
main() | |