import json import os import random import re import numpy as np import streamlit as st import torch from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer import logging logger = logging.getLogger(__name__) @st.cache_resource(show_spinner=False) def load_model(model_name, tokenizer_name): try: model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) except OSError as e: st.error(f"Error loading model: {e}") st.error(f"Model not found. Use {DEFAULT_MODEL} instead") model_path = DEFAULT_MODEL model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL) except Exception as e: st.error(f"Error loading model: {e}") raise RuntimeError("Error loading model") return model, tokenizer def load_json(file_path): with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) return data def preprocess(input_text, tokenizer, src_lang, tgt_lang): # task_prefix = f"translate {src_lang} to {tgt_lang}: " # input_text = task_prefix + input_text model_inputs = tokenizer( input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np" ) return model_inputs def translate(input_text, model, tokenizer, src_lang, tgt_lang): model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang) model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS) prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True) return prediction[0] def hold_deterministic(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) def postprocess(output_text): output = re.sub(r"]*>", "", output_text) return output def display_ui(): st.set_page_config(page_title="DP-NMT DEMO", layout="wide") st.title("Neural Machine Translation with DP-SGD") st.write( "[![Star](https://img.shields.io/github/stars/trusthlt/dp-nmt.svg?logo=github&style=social)](https://github.com/trusthlt/dp-nmt)" "   " "[![ACL](https://img.shields.io/badge/ACL-Link-red.svg?logo=&link=https%3A%2F%2Faclanthology.org%2F2024.eacl-demo.11%2F)](https://aclanthology.org/2024.eacl-demo.11/)" ) st.write("This is a demo for private neural machine translation with DP-SGD.") left, right = st.columns(2) return left, right def load_selected_model(config, dataset, language_pair, epsilon): ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)] logger.info(f"Loading model from {ckpt}") if "privalingo" in ckpt: model_path = ckpt # load model from huggingface hub else: model_name = DEFAULT_MODEL.split('/')[-1] model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name) if not os.path.exists(model_path): st.error(f"Model not found. Using default model: {DEFAULT_MODEL}") model_path = DEFAULT_MODEL return model_path def init_session_state(): if 'model_state' not in st.session_state: st.session_state.model_state = { 'loaded': False, 'current_config': None } if 'translate_in_progress' not in st.session_state: st.session_state.translate_in_progress = False if "load_model_in_progress" not in st.session_state: st.session_state.load_model_in_progress = False if "select_model_button" in st.session_state and st.session_state.select_model_button == True: st.session_state.load_model_in_progress = True if 'translate_button' in st.session_state and st.session_state.translate_button == True: st.session_state.translate_in_progress = True if 'translation_result' not in st.session_state: st.session_state.translation_result = { 'input': None, 'output': None } def get_translation_result(): if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None: input_text_content = st.session_state.translation_result['input'] else: input_text_content = "Enter Text Here" if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None: output_text_content = st.session_state.translation_result['output'] else: output_text_content = None return input_text_content, output_text_content def set_input_text_content(): if 'input_text' in st.session_state: st.session_state.translation_result['input'] = st.session_state.input_text def main(): hold_deterministic(SEED) config = load_json(DATASETS_MODEL_INFO_PATH) left, right = display_ui() init_session_state() with right: right_placeholder = st.empty() if st.session_state.load_model_in_progress: # Placeholder for right column, to display the input text area and translation result. If do not overwrite the # right column from previous run, the translate button and input text area will be available for user to interace # during the loading of model. disable = True with right_placeholder.container(): input_text_content, output_text_content = get_translation_result() input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable) msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ if st.session_state.model_state['current_config'] is None \ else f"Current Model: {st.session_state.model_state['current_config']}" st.write(msg_model) btn_translate = st.button("Translate", disabled=disable, use_container_width=True, key="translate_button") with left: disable = st.session_state.translate_in_progress or st.session_state.load_model_in_progress dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()), disabled=disable) language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys()) language_pair = st.selectbox("Language pair for translation", language_pairs_list, disabled=disable) src_lang, tgt_lang = language_pair.split("-") epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys()) epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True, disabled=disable) btn_select_model = st.button( "Select Model", disabled=disable, use_container_width=True, key="select_model_button") model_status_box = st.empty() # Load model to cache, if the user has selected a model for the first time if btn_select_model: st.session_state.load_model_in_progress = True current_config = f"{dataset}_{language_pair}_{epsilon}" st.session_state.model_state['loaded'] = False model_status_box.write("") with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'): model_path = load_selected_model(config, dataset, language_pair, epsilon) model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) model_status_box.success('Model loaded!') st.session_state.model_state['current_config'] = current_config st.session_state.load_model_in_progress = False st.rerun() with right_placeholder.container(): disable = st.session_state.load_model_in_progress or st.session_state.translate_in_progress input_text_content, output_text_content = get_translation_result() input_text = st.text_area( "Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable, key="input_text", on_change=set_input_text_content, ) msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \ if st.session_state.model_state['current_config'] is None \ else f"Current Model: {st.session_state.model_state['current_config']}" st.write(msg_model) btn_translate = st.button("Translate", disabled=(disable or st.session_state.translate_in_progress), use_container_width=True, key="translate_button") result_container = st.empty() if output_text_content is not None and not st.session_state.translate_in_progress: with result_container.container(): st.write("**Translation:**") output_container = result_container.container(border=True) output_container.write("".join([postprocess(output_text_content)])) # Load model from cache when click translate button, if the user has selected a model previously if not st.session_state.select_model_button and st.session_state.translate_button: model_config = st.session_state.model_state['current_config'] if model_config is None: # If the user click translate button without selecting a model, set st.session_state.translate_in_progress to False, # to avoid death of program and then refresh the page st.session_state.translate_in_progress = False st.rerun() dataset, language_pair, epsilon = model_config.split("_") model_path = load_selected_model(config, dataset, language_pair, epsilon) model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL) st.session_state.model_state['loaded'] = True if btn_translate: st.session_state.translate_in_progress = True with right: with st.spinner("Translating..."): prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang) st.session_state.translation_result['input'] = input_text st.session_state.translation_result['output'] = prediction st.session_state.translate_in_progress = False st.rerun() if __name__ == '__main__': DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json") logger.info(DATASETS_MODEL_INFO_PATH) DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH) DEFAULT_MODEL = 'google/mt5-small' MAX_SEQ_LEN = 512 NUM_BEAMS = 3 SEED = 2023 MAX_INPUT_LEN = 500 main()