Spaces:
Sleeping
Sleeping
File size: 5,158 Bytes
10a832a cb7cdd4 10a832a 700acd6 10a832a 499a604 10a832a cb7cdd4 10a832a cb7cdd4 10a832a cb7cdd4 10a832a cb7cdd4 10a832a cb7cdd4 10a832a cb7cdd4 10a832a ce4c2ae 10a832a ce4c2ae cb7cdd4 ce4c2ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import json
import os
import random
import re
import numpy as np
import streamlit as st
import torch
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
@st.cache_resource(show_spinner=False)
def load_model(model_name, tokenizer_name):
try:
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
except Exception as e:
st.error(f"Error loading model: {e}")
st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
model_path = DEFAULT_MODEL
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
return model, tokenizer
def load_json(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
return data
def preprocess(input_text, tokenizer, src_lang, tgt_lang):
# task_prefix = f"translate {src_lang} to {tgt_lang}: "
# input_text = task_prefix + input_text
model_inputs = tokenizer(
input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np"
)
return model_inputs
def translate(input_text, model, tokenizer, src_lang, tgt_lang):
model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang)
model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS)
prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)
return prediction[0]
def hold_deterministic(seed):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
def postprocess(output_text):
output = re.sub(r"<extra_id[^>]*>", "", output_text)
return output
def main():
hold_deterministic(SEED)
st.set_page_config(page_title="DP-NMT DEMO", layout="wide")
st.title("Neural Machine Translation with DP-SGD")
st.write(
"[](https://github.com/trusthlt/dp-nmt)"
" "
"[](https://aclanthology.org/2024.eacl-demo.11/)"
)
st.write("This is a demo for private neural machine translation with DP-SGD.")
left, right = st.columns(2)
with left:
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
language_pair = st.selectbox("Language pair for translation", language_pairs_list)
src_lang, tgt_lang = language_pair.split("-")
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
with right:
input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN)
btn_translate = st.button("Translate")
ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
if "privalingo" in ckpt:
# that means the model is loaded from huggingface hub rather checkpoints locally
model_path = ckpt
else:
model_name = DEFAULT_MODEL.split('/')[-1]
model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
if not os.path.exists(model_path):
with left:
st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
model_path = DEFAULT_MODEL
with left:
with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
sc_box_model_loaded = st.success('Model loaded!')
if btn_translate:
with right:
with st.spinner("Translating..."):
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
st.write("**Translation:**")
result_container = st.container(border=True)
result_container.write("".join([postprocess(prediction)]))
if __name__ == '__main__':
DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
print(DATASETS_MODEL_INFO_PATH)
DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
DEFAULT_MODEL = 'google/mt5-small'
MAX_SEQ_LEN = 512
NUM_BEAMS = 3
SEED = 2023
MAX_INPUT_LEN = 500
main()
|