Spaces:
Sleeping
Sleeping
zyu
commited on
Commit
·
3266b95
1
Parent(s):
3eacb4c
fix: bug fix
Browse filesimprove the UX.
- disable all components while a process is in progress, such as loading a model or generating output.
app.py
CHANGED
@@ -100,12 +100,15 @@ def init_session_state():
|
|
100 |
'current_config': None
|
101 |
}
|
102 |
|
103 |
-
if 'first_run' not in st.session_state:
|
104 |
-
st.session_state.first_run = True
|
105 |
-
|
106 |
if 'translate_in_progress' not in st.session_state:
|
107 |
st.session_state.translate_in_progress = False
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
if 'translate_button' in st.session_state and st.session_state.translate_button == True:
|
110 |
st.session_state.translate_in_progress = True
|
111 |
|
@@ -115,6 +118,20 @@ def init_session_state():
|
|
115 |
'output': None
|
116 |
}
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
def main():
|
119 |
hold_deterministic(SEED)
|
120 |
config = load_json(DATASETS_MODEL_INFO_PATH)
|
@@ -123,19 +140,51 @@ def main():
|
|
123 |
|
124 |
init_session_state()
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
with left:
|
127 |
-
|
|
|
128 |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
|
129 |
-
language_pair = st.selectbox("Language pair for translation", language_pairs_list)
|
130 |
src_lang, tgt_lang = language_pair.split("-")
|
131 |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
|
132 |
-
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
|
133 |
-
btn_select_model = st.button(
|
134 |
-
|
|
|
|
|
|
|
135 |
model_status_box = st.empty()
|
136 |
|
137 |
# Load model to cache, if the user has selected a model for the first time
|
138 |
if btn_select_model:
|
|
|
139 |
current_config = f"{dataset}_{language_pair}_{epsilon}"
|
140 |
|
141 |
st.session_state.model_state['loaded'] = False
|
@@ -144,49 +193,51 @@ def main():
|
|
144 |
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
145 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
146 |
model_status_box.success('Model loaded!')
|
147 |
-
st.session_state.model_state['loaded'] = True
|
148 |
-
st.session_state.model_state['current_config'] = current_config
|
149 |
-
st.session_state.first_run = False
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
else:
|
155 |
-
input_text_content = "Enter Text Here"
|
156 |
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
output_text_content = None
|
161 |
|
162 |
-
input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN)
|
163 |
|
164 |
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
|
165 |
-
|
166 |
-
|
167 |
|
168 |
st.write(msg_model)
|
169 |
|
170 |
btn_translate = st.button("Translate",
|
171 |
-
disabled=
|
172 |
use_container_width=True,
|
173 |
key="translate_button")
|
174 |
result_container = st.empty()
|
175 |
|
176 |
if output_text_content is not None and not st.session_state.translate_in_progress:
|
177 |
-
result_container.
|
178 |
-
|
179 |
-
|
|
|
180 |
|
181 |
# Load model from cache when click translate button, if the user has selected a model previously
|
182 |
if not st.session_state.select_model_button and st.session_state.translate_button:
|
183 |
model_config = st.session_state.model_state['current_config']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
dataset, language_pair, epsilon = model_config.split("_")
|
185 |
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
186 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
187 |
st.session_state.model_state['loaded'] = True
|
188 |
|
189 |
-
if btn_translate
|
190 |
st.session_state.translate_in_progress = True
|
191 |
with right:
|
192 |
with st.spinner("Translating..."):
|
|
|
100 |
'current_config': None
|
101 |
}
|
102 |
|
|
|
|
|
|
|
103 |
if 'translate_in_progress' not in st.session_state:
|
104 |
st.session_state.translate_in_progress = False
|
105 |
|
106 |
+
if "load_model_in_progress" not in st.session_state:
|
107 |
+
st.session_state.load_model_in_progress = False
|
108 |
+
|
109 |
+
if "select_model_button" in st.session_state and st.session_state.select_model_button == True:
|
110 |
+
st.session_state.load_model_in_progress = True
|
111 |
+
|
112 |
if 'translate_button' in st.session_state and st.session_state.translate_button == True:
|
113 |
st.session_state.translate_in_progress = True
|
114 |
|
|
|
118 |
'output': None
|
119 |
}
|
120 |
|
121 |
+
|
122 |
+
def get_translation_result():
|
123 |
+
if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None:
|
124 |
+
input_text_content = st.session_state.translation_result['input']
|
125 |
+
else:
|
126 |
+
input_text_content = "Enter Text Here"
|
127 |
+
|
128 |
+
if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None:
|
129 |
+
output_text_content = st.session_state.translation_result['output']
|
130 |
+
else:
|
131 |
+
output_text_content = None
|
132 |
+
return input_text_content, output_text_content
|
133 |
+
|
134 |
+
|
135 |
def main():
|
136 |
hold_deterministic(SEED)
|
137 |
config = load_json(DATASETS_MODEL_INFO_PATH)
|
|
|
140 |
|
141 |
init_session_state()
|
142 |
|
143 |
+
st.write(st.session_state)
|
144 |
+
|
145 |
+
with right:
|
146 |
+
right_placeholder = st.empty()
|
147 |
+
|
148 |
+
if st.session_state.load_model_in_progress:
|
149 |
+
|
150 |
+
# Placeholder for right column, to display the input text area and translation result. If do not overwrite the
|
151 |
+
# right column from previous run, the translate button and input text area will be available for user to interace
|
152 |
+
# during the loading of model.
|
153 |
+
disable = True
|
154 |
+
with right_placeholder.container():
|
155 |
+
input_text_content, output_text_content = get_translation_result()
|
156 |
+
input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable)
|
157 |
+
|
158 |
+
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
|
159 |
+
if st.session_state.model_state['current_config'] is None \
|
160 |
+
else f"Current Model: {st.session_state.model_state['current_config']}"
|
161 |
+
|
162 |
+
st.write(msg_model)
|
163 |
+
|
164 |
+
btn_translate = st.button("Translate",
|
165 |
+
disabled=disable,
|
166 |
+
use_container_width=True,
|
167 |
+
key="translate_button")
|
168 |
+
|
169 |
+
|
170 |
with left:
|
171 |
+
disable = st.session_state.translate_in_progress or st.session_state.load_model_in_progress
|
172 |
+
dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()), disabled=disable)
|
173 |
language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
|
174 |
+
language_pair = st.selectbox("Language pair for translation", language_pairs_list, disabled=disable)
|
175 |
src_lang, tgt_lang = language_pair.split("-")
|
176 |
epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
|
177 |
+
epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True, disabled=disable)
|
178 |
+
btn_select_model = st.button(
|
179 |
+
"Select Model",
|
180 |
+
disabled=disable,
|
181 |
+
use_container_width=True,
|
182 |
+
key="select_model_button")
|
183 |
model_status_box = st.empty()
|
184 |
|
185 |
# Load model to cache, if the user has selected a model for the first time
|
186 |
if btn_select_model:
|
187 |
+
st.session_state.load_model_in_progress = True
|
188 |
current_config = f"{dataset}_{language_pair}_{epsilon}"
|
189 |
|
190 |
st.session_state.model_state['loaded'] = False
|
|
|
193 |
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
194 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
195 |
model_status_box.success('Model loaded!')
|
|
|
|
|
|
|
196 |
|
197 |
+
st.session_state.model_state['current_config'] = current_config
|
198 |
+
st.session_state.load_model_in_progress = False
|
199 |
+
st.rerun()
|
|
|
|
|
200 |
|
201 |
+
with right_placeholder.container():
|
202 |
+
disable = st.session_state.load_model_in_progress or st.session_state.translate_in_progress
|
203 |
+
input_text_content, output_text_content = get_translation_result()
|
|
|
204 |
|
205 |
+
input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable, key="input_text")
|
206 |
|
207 |
msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
|
208 |
+
if st.session_state.model_state['current_config'] is None \
|
209 |
+
else f"Current Model: {st.session_state.model_state['current_config']}"
|
210 |
|
211 |
st.write(msg_model)
|
212 |
|
213 |
btn_translate = st.button("Translate",
|
214 |
+
disabled=(disable or st.session_state.translate_in_progress),
|
215 |
use_container_width=True,
|
216 |
key="translate_button")
|
217 |
result_container = st.empty()
|
218 |
|
219 |
if output_text_content is not None and not st.session_state.translate_in_progress:
|
220 |
+
with result_container.container():
|
221 |
+
st.write("**Translation:**")
|
222 |
+
output_container = result_container.container(border=True)
|
223 |
+
output_container.write("".join([postprocess(output_text_content)]))
|
224 |
|
225 |
# Load model from cache when click translate button, if the user has selected a model previously
|
226 |
if not st.session_state.select_model_button and st.session_state.translate_button:
|
227 |
model_config = st.session_state.model_state['current_config']
|
228 |
+
if model_config is None:
|
229 |
+
|
230 |
+
# If the user click translate button without selecting a model, set st.session_state.translate_in_progress to False,
|
231 |
+
# to avoid death of program and then refresh the page
|
232 |
+
st.session_state.translate_in_progress = False
|
233 |
+
st.rerun()
|
234 |
+
|
235 |
dataset, language_pair, epsilon = model_config.split("_")
|
236 |
model_path = load_selected_model(config, dataset, language_pair, epsilon)
|
237 |
model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
|
238 |
st.session_state.model_state['loaded'] = True
|
239 |
|
240 |
+
if btn_translate:
|
241 |
st.session_state.translate_in_progress = True
|
242 |
with right:
|
243 |
with st.spinner("Translating..."):
|