Spaces:
Sleeping
Sleeping
zyu
commited on
Commit
·
03fbdf1
1
Parent(s):
c7c367c
bugs fixed:
Browse files- added more instructions and slightly modified the layout
- disable translate while loading model and vice versa.
- input text would be disappeared while clicking either translate button or model select button.
app.py
CHANGED
@@ -7,6 +7,9 @@ import numpy as np
|
|
7 |
import streamlit as st
|
8 |
import torch
|
9 |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
@st.cache_resource(show_spinner=False)
|
@@ -14,12 +17,15 @@ def load_model(model_name, tokenizer_name):
|
|
14 |
try:
|
15 |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
16 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
17 |
-
except
|
18 |
st.error(f"Error loading model: {e}")
|
19 |
st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
|
20 |
model_path = DEFAULT_MODEL
|
21 |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
|
22 |
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
|
|
|
|
|
|
|
23 |
return model, tokenizer
|
24 |
|
25 |
|
@@ -75,6 +81,7 @@ def display_ui():
|
|
75 |
|
76 |
def load_selected_model(config, dataset, language_pair, epsilon):
|
77 |
ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
|
|
|
78 |
if "privalingo" in ckpt:
|
79 |
model_path = ckpt # load model from huggingface hub
|
80 |
else:
|
@@ -86,12 +93,36 @@ def load_selected_model(config, dataset, language_pair, epsilon):
|
|
86 |
return model_path
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def main():
|
90 |
hold_deterministic(SEED)
|
91 |
config = load_json(DATASETS_MODEL_INFO_PATH)
|
92 |
|
93 |
left, right = display_ui()
|
94 |
|
|
|
|
|
95 |
with left:
|
96 |
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
|
97 |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
|
@@ -99,28 +130,74 @@ def main():
|
|
99 |
src_lang, tgt_lang = language_pair.split("-")
|
100 |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
|
101 |
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
|
|
|
|
|
102 |
model_status_box = st.empty()
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
result_container = st.empty()
|
108 |
-
|
109 |
-
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
110 |
|
111 |
-
|
112 |
model_status_box.write("")
|
113 |
with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
|
|
|
114 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
115 |
model_status_box.success('Model loaded!')
|
|
|
|
|
|
|
116 |
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
with right:
|
119 |
with st.spinner("Translating..."):
|
|
|
120 |
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
if __name__ == '__main__':
|
|
|
7 |
import streamlit as st
|
8 |
import torch
|
9 |
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
|
10 |
+
import logging
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
|
15 |
@st.cache_resource(show_spinner=False)
|
|
|
17 |
try:
|
18 |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
|
19 |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
20 |
+
except OSError as e:
|
21 |
st.error(f"Error loading model: {e}")
|
22 |
st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
|
23 |
model_path = DEFAULT_MODEL
|
24 |
model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
|
26 |
+
except Exception as e:
|
27 |
+
st.error(f"Error loading model: {e}")
|
28 |
+
raise RuntimeError("Error loading model")
|
29 |
return model, tokenizer
|
30 |
|
31 |
|
|
|
81 |
|
82 |
def load_selected_model(config, dataset, language_pair, epsilon):
|
83 |
ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
|
84 |
+
logger.info(f"Loading model from {ckpt}")
|
85 |
if "privalingo" in ckpt:
|
86 |
model_path = ckpt # load model from huggingface hub
|
87 |
else:
|
|
|
93 |
return model_path
|
94 |
|
95 |
|
96 |
+
def init_session_state():
|
97 |
+
if 'model_state' not in st.session_state:
|
98 |
+
st.session_state.model_state = {
|
99 |
+
'loaded': False,
|
100 |
+
'current_config': None
|
101 |
+
}
|
102 |
+
|
103 |
+
if 'first_run' not in st.session_state:
|
104 |
+
st.session_state.first_run = True
|
105 |
+
|
106 |
+
if 'translate_in_progress' not in st.session_state:
|
107 |
+
st.session_state.translate_in_progress = False
|
108 |
+
|
109 |
+
if 'translate_button' in st.session_state and st.session_state.translate_button == True:
|
110 |
+
st.session_state.translate_in_progress = True
|
111 |
+
|
112 |
+
if 'translation_result' not in st.session_state:
|
113 |
+
st.session_state.translation_result = {
|
114 |
+
'input': None,
|
115 |
+
'output': None
|
116 |
+
}
|
117 |
+
|
118 |
def main():
|
119 |
hold_deterministic(SEED)
|
120 |
config = load_json(DATASETS_MODEL_INFO_PATH)
|
121 |
|
122 |
left, right = display_ui()
|
123 |
|
124 |
+
init_session_state()
|
125 |
+
|
126 |
with left:
|
127 |
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
|
128 |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
|
|
|
130 |
src_lang, tgt_lang = language_pair.split("-")
|
131 |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
|
132 |
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
|
133 |
+
btn_select_model = st.button("Select Model", disabled=st.session_state.translate_in_progress,
|
134 |
+
use_container_width=True, key="select_model_button")
|
135 |
model_status_box = st.empty()
|
136 |
|
137 |
+
# Load model to cache, if the user has selected a model for the first time
|
138 |
+
if btn_select_model:
|
139 |
+
current_config = f"{dataset}_{language_pair}_{epsilon}"
|
|
|
|
|
|
|
140 |
|
141 |
+
st.session_state.model_state['loaded'] = False
|
142 |
model_status_box.write("")
|
143 |
with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
|
144 |
+
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
145 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
146 |
model_status_box.success('Model loaded!')
|
147 |
+
st.session_state.model_state['loaded'] = True
|
148 |
+
st.session_state.model_state['current_config'] = current_config
|
149 |
+
st.session_state.first_run = False
|
150 |
|
151 |
+
with right:
|
152 |
+
if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None:
|
153 |
+
input_text_content = st.session_state.translation_result['input']
|
154 |
+
else:
|
155 |
+
input_text_content = "Enter Text Here"
|
156 |
+
|
157 |
+
if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None:
|
158 |
+
output_text_content = st.session_state.translation_result['output']
|
159 |
+
else:
|
160 |
+
output_text_content = None
|
161 |
+
|
162 |
+
input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN)
|
163 |
+
|
164 |
+
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
|
165 |
+
if st.session_state.model_state['current_config'] is None \
|
166 |
+
else f"Current Model:{st.session_state.model_state['current_config']}"
|
167 |
+
|
168 |
+
st.write(msg_model)
|
169 |
+
|
170 |
+
btn_translate = st.button("Translate",
|
171 |
+
disabled=not st.session_state.model_state['loaded'],
|
172 |
+
use_container_width=True,
|
173 |
+
key="translate_button")
|
174 |
+
result_container = st.empty()
|
175 |
+
|
176 |
+
if output_text_content is not None and not st.session_state.translate_in_progress:
|
177 |
+
result_container.write("**Translation:**")
|
178 |
+
output_container = result_container.container(border=True)
|
179 |
+
output_container.write("".join([postprocess(output_text_content)]))
|
180 |
+
|
181 |
+
# Load model from cache when click translate button, if the user has selected a model previously
|
182 |
+
if not st.session_state.select_model_button and st.session_state.translate_button:
|
183 |
+
model_config = st.session_state.model_state['current_config']
|
184 |
+
dataset, language_pair, epsilon = model_config.split("_")
|
185 |
+
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
186 |
+
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
187 |
+
st.session_state.model_state['loaded'] = True
|
188 |
+
|
189 |
+
if btn_translate and st.session_state.model_state['loaded']:
|
190 |
+
st.session_state.translate_in_progress = True
|
191 |
with right:
|
192 |
with st.spinner("Translating..."):
|
193 |
+
input_text = st.session_state.model_state['current_config'] + input_text
|
194 |
prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
|
195 |
+
|
196 |
+
st.session_state.translation_result['input'] = input_text
|
197 |
+
st.session_state.translation_result['output'] = prediction
|
198 |
+
|
199 |
+
st.session_state.translate_in_progress = False
|
200 |
+
st.rerun()
|
201 |
|
202 |
|
203 |
if __name__ == '__main__':
|