Spaces:
Sleeping
Sleeping
File size: 3,109 Bytes
10a832a 40e17e4 10a832a 40e17e4 10a832a 40e17e4 10a832a 40e17e4 10a832a 40e17e4 10a832a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import json
import os
import random
import numpy as np
import streamlit as st
import torch
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
@st.cache_resource
def load_model(model_name, tokenizer_name):
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
return model, tokenizer
def load_json(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def preprocess(input_text, tokenizer, src_lang, tgt_lang):
# task_prefix = f"translate {src_lang} to {tgt_lang}: "
# input_text = task_prefix + input_text
model_inputs = tokenizer(
input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np"
)
return model_inputs
def translate(input_text, model, tokenizer, src_lang, tgt_lang):
model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang)
model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS)
prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)
return prediction[0]
def hold_deterministic(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
def main():
hold_deterministic(SEED)
st.title("Neural Machine Translation with DP-SGD")
st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
language_pair = st.selectbox("Language pair for translation", language_pairs_list)
src_lang, tgt_lang = language_pair.split("-")
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options)
st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...')
model_path = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
if not os.path.exists(model_path):
st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
model_path = DEFAULT_MODEL
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
st.success('Model loaded!')
st_model_load.text("")
input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200)
if st.button("Translate"):
st.write("Translation")
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
st.success("".join([prediction]))
if __name__ == '__main__':
DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
DEFAULT_MODEL = 'google/mt5-small'
MAX_SEQ_LEN = 512
NUM_BEAMS = 3
SEED = 2023
main()
|