Spaces:
Sleeping
Sleeping
zyu
fix: resolved the issue that the input text disappears while generating translation for the first run.
3cb0c3e
import json | |
import os | |
import random | |
import re | |
import numpy as np | |
import streamlit as st | |
import torch | |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer | |
import logging | |
logger = logging.getLogger(__name__) | |
def load_model(model_name, tokenizer_name): | |
try: | |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
except OSError as e: | |
st.error(f"Error loading model: {e}") | |
st.error(f"Model not found. Use {DEFAULT_MODEL} instead") | |
model_path = DEFAULT_MODEL | |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path) | |
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) | |
except Exception as e: | |
st.error(f"Error loading model: {e}") | |
raise RuntimeError("Error loading model") | |
return model, tokenizer | |
def load_json(file_path): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
return data | |
def preprocess(input_text, tokenizer, src_lang, tgt_lang): | |
# task_prefix = f"translate {src_lang} to {tgt_lang}: " | |
# input_text = task_prefix + input_text | |
model_inputs = tokenizer( | |
input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" | |
) | |
return model_inputs | |
def translate(input_text, model, tokenizer, src_lang, tgt_lang): | |
model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) | |
model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) | |
prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) | |
return prediction[0] | |
def hold_deterministic(seed): | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
random.seed(seed) | |
def postprocess(output_text): | |
output = re.sub(r"<extra_id[^>]*>", "", output_text) | |
return output | |
def display_ui(): | |
st.set_page_config(page_title="DP-NMT DEMO", layout="wide") | |
st.title("Neural Machine Translation with DP-SGD") | |
st.write( | |
"[](https://github.com/trusthlt/dp-nmt)" | |
" " | |
"[](https://aclanthology.org/2024.eacl-demo.11/)" | |
) | |
st.write("This is a demo for private neural machine translation with DP-SGD.") | |
left, right = st.columns(2) | |
return left, right | |
def load_selected_model(config, dataset, language_pair, epsilon): | |
ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] | |
logger.info(f"Loading model from {ckpt}") | |
if "privalingo" in ckpt: | |
model_path = ckpt # load model from huggingface hub | |
else: | |
model_name = DEFAULT_MODEL.split('/')[-1] | |
model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) | |
if not os.path.exists(model_path): | |
st.error(f"Model not found. Using default model: {DEFAULT_MODEL}") | |
model_path = DEFAULT_MODEL | |
return model_path | |
def init_session_state(): | |
if 'model_state' not in st.session_state: | |
st.session_state.model_state = { | |
'loaded': False, | |
'current_config': None | |
} | |
if 'translate_in_progress' not in st.session_state: | |
st.session_state.translate_in_progress = False | |
if "load_model_in_progress" not in st.session_state: | |
st.session_state.load_model_in_progress = False | |
if "select_model_button" in st.session_state and st.session_state.select_model_button == True: | |
st.session_state.load_model_in_progress = True | |
if 'translate_button' in st.session_state and st.session_state.translate_button == True: | |
st.session_state.translate_in_progress = True | |
if 'translation_result' not in st.session_state: | |
st.session_state.translation_result = { | |
'input': None, | |
'output': None | |
} | |
def get_translation_result(): | |
if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None: | |
input_text_content = st.session_state.translation_result['input'] | |
else: | |
input_text_content = "Enter Text Here" | |
if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None: | |
output_text_content = st.session_state.translation_result['output'] | |
else: | |
output_text_content = None | |
return input_text_content, output_text_content | |
def set_input_text_content(): | |
if 'input_text' in st.session_state: | |
st.session_state.translation_result['input'] = st.session_state.input_text | |
def main(): | |
hold_deterministic(SEED) | |
config = load_json(DATASETS_MODEL_INFO_PATH) | |
left, right = display_ui() | |
init_session_state() | |
with right: | |
right_placeholder = st.empty() | |
if st.session_state.load_model_in_progress: | |
# Placeholder for right column, to display the input text area and translation result. If do not overwrite the | |
# right column from previous run, the translate button and input text area will be available for user to interace | |
# during the loading of model. | |
disable = True | |
with right_placeholder.container(): | |
input_text_content, output_text_content = get_translation_result() | |
input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable) | |
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ | |
if st.session_state.model_state['current_config'] is None \ | |
else f"Current Model: {st.session_state.model_state['current_config']}" | |
st.write(msg_model) | |
btn_translate = st.button("Translate", | |
disabled=disable, | |
use_container_width=True, | |
key="translate_button") | |
with left: | |
disable = st.session_state.translate_in_progress or st.session_state.load_model_in_progress | |
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()), disabled=disable) | |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) | |
language_pair = st.selectbox("Language pair for translation", language_pairs_list, disabled=disable) | |
src_lang, tgt_lang = language_pair.split("-") | |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) | |
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True, disabled=disable) | |
btn_select_model = st.button( | |
"Select Model", | |
disabled=disable, | |
use_container_width=True, | |
key="select_model_button") | |
model_status_box = st.empty() | |
# Load model to cache, if the user has selected a model for the first time | |
if btn_select_model: | |
st.session_state.load_model_in_progress = True | |
current_config = f"{dataset}_{language_pair}_{epsilon}" | |
st.session_state.model_state['loaded'] = False | |
model_status_box.write("") | |
with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'): | |
model_path = load_selected_model(config, dataset, language_pair, epsilon) | |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) | |
model_status_box.success('Model loaded!') | |
st.session_state.model_state['current_config'] = current_config | |
st.session_state.load_model_in_progress = False | |
st.rerun() | |
with right_placeholder.container(): | |
disable = st.session_state.load_model_in_progress or st.session_state.translate_in_progress | |
input_text_content, output_text_content = get_translation_result() | |
input_text = st.text_area( | |
"Enter Text", | |
input_text_content, | |
max_chars=MAX_INPUT_LEN, | |
disabled=disable, | |
key="input_text", | |
on_change=set_input_text_content, | |
) | |
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ | |
if st.session_state.model_state['current_config'] is None \ | |
else f"Current Model: {st.session_state.model_state['current_config']}" | |
st.write(msg_model) | |
btn_translate = st.button("Translate", | |
disabled=(disable or st.session_state.translate_in_progress), | |
use_container_width=True, | |
key="translate_button") | |
result_container = st.empty() | |
if output_text_content is not None and not st.session_state.translate_in_progress: | |
with result_container.container(): | |
st.write("**Translation:**") | |
output_container = result_container.container(border=True) | |
output_container.write("".join([postprocess(output_text_content)])) | |
# Load model from cache when click translate button, if the user has selected a model previously | |
if not st.session_state.select_model_button and st.session_state.translate_button: | |
model_config = st.session_state.model_state['current_config'] | |
if model_config is None: | |
# If the user click translate button without selecting a model, set st.session_state.translate_in_progress to False, | |
# to avoid death of program and then refresh the page | |
st.session_state.translate_in_progress = False | |
st.rerun() | |
dataset, language_pair, epsilon = model_config.split("_") | |
model_path = load_selected_model(config, dataset, language_pair, epsilon) | |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) | |
st.session_state.model_state['loaded'] = True | |
if btn_translate: | |
st.session_state.translate_in_progress = True | |
with right: | |
with st.spinner("Translating..."): | |
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) | |
st.session_state.translation_result['input'] = input_text | |
st.session_state.translation_result['output'] = prediction | |
st.session_state.translate_in_progress = False | |
st.rerun() | |
if __name__ == '__main__': | |
DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json") | |
logger.info(DATASETS_MODEL_INFO_PATH) | |
DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) | |
DEFAULT_MODEL = 'google/mt5-small' | |
MAX_SEQ_LEN = 512 | |
NUM_BEAMS = 3 | |
SEED = 2023 | |
MAX_INPUT_LEN = 500 | |
main() | |