zyu commited on
Commit
3266b95
·
1 Parent(s): 3eacb4c

fix: bug fix

Browse files

improve the UX.
- disable all components while a process is in progress, such as loading a model or generating output.

Files changed (1) hide show
  1. app.py +79 -28
app.py CHANGED
@@ -100,12 +100,15 @@ def init_session_state():
100
  'current_config': None
101
  }
102
 
103
- if 'first_run' not in st.session_state:
104
- st.session_state.first_run = True
105
-
106
  if 'translate_in_progress' not in st.session_state:
107
  st.session_state.translate_in_progress = False
108
 
 
 
 
 
 
 
109
  if 'translate_button' in st.session_state and st.session_state.translate_button == True:
110
  st.session_state.translate_in_progress = True
111
 
@@ -115,6 +118,20 @@ def init_session_state():
115
  'output': None
116
  }
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def main():
119
  hold_deterministic(SEED)
120
  config = load_json(DATASETS_MODEL_INFO_PATH)
@@ -123,19 +140,51 @@ def main():
123
 
124
  init_session_state()
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  with left:
127
- dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
 
128
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
129
- language_pair = st.selectbox("Language pair for translation", language_pairs_list)
130
  src_lang, tgt_lang = language_pair.split("-")
131
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
132
- epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
133
- btn_select_model = st.button("Select Model", disabled=st.session_state.translate_in_progress,
134
- use_container_width=True, key="select_model_button")
 
 
 
135
  model_status_box = st.empty()
136
 
137
  # Load model to cache, if the user has selected a model for the first time
138
  if btn_select_model:
 
139
  current_config = f"{dataset}_{language_pair}_{epsilon}"
140
 
141
  st.session_state.model_state['loaded'] = False
@@ -144,49 +193,51 @@ def main():
144
  model_path = load_selected_model(config, dataset, language_pair, epsilon)
145
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
146
  model_status_box.success('Model loaded!')
147
- st.session_state.model_state['loaded'] = True
148
- st.session_state.model_state['current_config'] = current_config
149
- st.session_state.first_run = False
150
 
151
- with right:
152
- if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None:
153
- input_text_content = st.session_state.translation_result['input']
154
- else:
155
- input_text_content = "Enter Text Here"
156
 
157
- if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None:
158
- output_text_content = st.session_state.translation_result['output']
159
- else:
160
- output_text_content = None
161
 
162
- input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN)
163
 
164
  msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
165
- if st.session_state.model_state['current_config'] is None \
166
- else f"Current Model: {st.session_state.model_state['current_config']}"
167
 
168
  st.write(msg_model)
169
 
170
  btn_translate = st.button("Translate",
171
- disabled=not st.session_state.model_state['loaded'],
172
  use_container_width=True,
173
  key="translate_button")
174
  result_container = st.empty()
175
 
176
  if output_text_content is not None and not st.session_state.translate_in_progress:
177
- result_container.write("**Translation:**")
178
- output_container = result_container.container(border=True)
179
- output_container.write("".join([postprocess(output_text_content)]))
 
180
 
181
  # Load model from cache when click translate button, if the user has selected a model previously
182
  if not st.session_state.select_model_button and st.session_state.translate_button:
183
  model_config = st.session_state.model_state['current_config']
 
 
 
 
 
 
 
184
  dataset, language_pair, epsilon = model_config.split("_")
185
  model_path = load_selected_model(config, dataset, language_pair, epsilon)
186
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
187
  st.session_state.model_state['loaded'] = True
188
 
189
- if btn_translate and st.session_state.model_state['loaded']:
190
  st.session_state.translate_in_progress = True
191
  with right:
192
  with st.spinner("Translating..."):
 
100
  'current_config': None
101
  }
102
 
 
 
 
103
  if 'translate_in_progress' not in st.session_state:
104
  st.session_state.translate_in_progress = False
105
 
106
+ if "load_model_in_progress" not in st.session_state:
107
+ st.session_state.load_model_in_progress = False
108
+
109
+ if "select_model_button" in st.session_state and st.session_state.select_model_button == True:
110
+ st.session_state.load_model_in_progress = True
111
+
112
  if 'translate_button' in st.session_state and st.session_state.translate_button == True:
113
  st.session_state.translate_in_progress = True
114
 
 
118
  'output': None
119
  }
120
 
121
+
122
+ def get_translation_result():
123
+ if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None:
124
+ input_text_content = st.session_state.translation_result['input']
125
+ else:
126
+ input_text_content = "Enter Text Here"
127
+
128
+ if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None:
129
+ output_text_content = st.session_state.translation_result['output']
130
+ else:
131
+ output_text_content = None
132
+ return input_text_content, output_text_content
133
+
134
+
135
  def main():
136
  hold_deterministic(SEED)
137
  config = load_json(DATASETS_MODEL_INFO_PATH)
 
140
 
141
  init_session_state()
142
 
143
+ st.write(st.session_state)
144
+
145
+ with right:
146
+ right_placeholder = st.empty()
147
+
148
+ if st.session_state.load_model_in_progress:
149
+
150
+ # Placeholder for right column, to display the input text area and translation result. If do not overwrite the
151
+ # right column from previous run, the translate button and input text area will be available for user to interace
152
+ # during the loading of model.
153
+ disable = True
154
+ with right_placeholder.container():
155
+ input_text_content, output_text_content = get_translation_result()
156
+ input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable)
157
+
158
+ msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
159
+ if st.session_state.model_state['current_config'] is None \
160
+ else f"Current Model: {st.session_state.model_state['current_config']}"
161
+
162
+ st.write(msg_model)
163
+
164
+ btn_translate = st.button("Translate",
165
+ disabled=disable,
166
+ use_container_width=True,
167
+ key="translate_button")
168
+
169
+
170
  with left:
171
+ disable = st.session_state.translate_in_progress or st.session_state.load_model_in_progress
172
+ dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()), disabled=disable)
173
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
174
+ language_pair = st.selectbox("Language pair for translation", language_pairs_list, disabled=disable)
175
  src_lang, tgt_lang = language_pair.split("-")
176
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
177
+ epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True, disabled=disable)
178
+ btn_select_model = st.button(
179
+ "Select Model",
180
+ disabled=disable,
181
+ use_container_width=True,
182
+ key="select_model_button")
183
  model_status_box = st.empty()
184
 
185
  # Load model to cache, if the user has selected a model for the first time
186
  if btn_select_model:
187
+ st.session_state.load_model_in_progress = True
188
  current_config = f"{dataset}_{language_pair}_{epsilon}"
189
 
190
  st.session_state.model_state['loaded'] = False
 
193
  model_path = load_selected_model(config, dataset, language_pair, epsilon)
194
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
195
  model_status_box.success('Model loaded!')
 
 
 
196
 
197
+ st.session_state.model_state['current_config'] = current_config
198
+ st.session_state.load_model_in_progress = False
199
+ st.rerun()
 
 
200
 
201
+ with right_placeholder.container():
202
+ disable = st.session_state.load_model_in_progress or st.session_state.translate_in_progress
203
+ input_text_content, output_text_content = get_translation_result()
 
204
 
205
+ input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN, disabled=disable, key="input_text")
206
 
207
  msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
208
+ if st.session_state.model_state['current_config'] is None \
209
+ else f"Current Model: {st.session_state.model_state['current_config']}"
210
 
211
  st.write(msg_model)
212
 
213
  btn_translate = st.button("Translate",
214
+ disabled=(disable or st.session_state.translate_in_progress),
215
  use_container_width=True,
216
  key="translate_button")
217
  result_container = st.empty()
218
 
219
  if output_text_content is not None and not st.session_state.translate_in_progress:
220
+ with result_container.container():
221
+ st.write("**Translation:**")
222
+ output_container = result_container.container(border=True)
223
+ output_container.write("".join([postprocess(output_text_content)]))
224
 
225
  # Load model from cache when click translate button, if the user has selected a model previously
226
  if not st.session_state.select_model_button and st.session_state.translate_button:
227
  model_config = st.session_state.model_state['current_config']
228
+ if model_config is None:
229
+
230
+ # If the user click translate button without selecting a model, set st.session_state.translate_in_progress to False,
231
+ # to avoid death of program and then refresh the page
232
+ st.session_state.translate_in_progress = False
233
+ st.rerun()
234
+
235
  dataset, language_pair, epsilon = model_config.split("_")
236
  model_path = load_selected_model(config, dataset, language_pair, epsilon)
237
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
238
  st.session_state.model_state['loaded'] = True
239
 
240
+ if btn_translate:
241
  st.session_state.translate_in_progress = True
242
  with right:
243
  with st.spinner("Translating..."):