zyu commited on
Commit
03fbdf1
·
1 Parent(s): c7c367c

bugs fixed:

Browse files

- added more instructions and slightly modified the layout
- disable translate while loading model and vice versa.
- input text would be disappeared while clicking either translate button or model select button.

Files changed (1) hide show
  1. app.py +89 -12
app.py CHANGED
@@ -7,6 +7,9 @@ import numpy as np
7
  import streamlit as st
8
  import torch
9
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
10
 
11
 
12
  @st.cache_resource(show_spinner=False)
@@ -14,12 +17,15 @@ def load_model(model_name, tokenizer_name):
14
  try:
15
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
16
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
17
- except Exception as e:
18
  st.error(f"Error loading model: {e}")
19
  st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
20
  model_path = DEFAULT_MODEL
21
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
22
  tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
 
 
 
23
  return model, tokenizer
24
 
25
 
@@ -75,6 +81,7 @@ def display_ui():
75
 
76
  def load_selected_model(config, dataset, language_pair, epsilon):
77
  ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
 
78
  if "privalingo" in ckpt:
79
  model_path = ckpt # load model from huggingface hub
80
  else:
@@ -86,12 +93,36 @@ def load_selected_model(config, dataset, language_pair, epsilon):
86
  return model_path
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def main():
90
  hold_deterministic(SEED)
91
  config = load_json(DATASETS_MODEL_INFO_PATH)
92
 
93
  left, right = display_ui()
94
 
 
 
95
  with left:
96
  dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
97
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
@@ -99,28 +130,74 @@ def main():
99
  src_lang, tgt_lang = language_pair.split("-")
100
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
101
  epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
 
 
102
  model_status_box = st.empty()
103
 
104
- with right:
105
- input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN)
106
- btn_translate = st.button("Translate")
107
- result_container = st.empty()
108
-
109
- model_path = load_selected_model(config, dataset, language_pair, epsilon)
110
 
111
- with left:
112
  model_status_box.write("")
113
  with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
 
114
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
115
  model_status_box.success('Model loaded!')
 
 
 
116
 
117
- if btn_translate:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with right:
119
  with st.spinner("Translating..."):
 
120
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
121
- result_container.write("**Translation:**")
122
- output_container = result_container.container(border=True)
123
- output_container.write("".join([postprocess(prediction)]))
 
 
 
124
 
125
 
126
  if __name__ == '__main__':
 
7
  import streamlit as st
8
  import torch
9
  from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
10
+ import logging
11
+
12
+ logger = logging.getLogger(__name__)
13
 
14
 
15
  @st.cache_resource(show_spinner=False)
 
17
  try:
18
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
19
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
20
+ except OSError as e:
21
  st.error(f"Error loading model: {e}")
22
  st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
23
  model_path = DEFAULT_MODEL
24
  model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
25
  tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
26
+ except Exception as e:
27
+ st.error(f"Error loading model: {e}")
28
+ raise RuntimeError("Error loading model")
29
  return model, tokenizer
30
 
31
 
 
81
 
82
  def load_selected_model(config, dataset, language_pair, epsilon):
83
  ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
84
+ logger.info(f"Loading model from {ckpt}")
85
  if "privalingo" in ckpt:
86
  model_path = ckpt # load model from huggingface hub
87
  else:
 
93
  return model_path
94
 
95
 
96
+ def init_session_state():
97
+ if 'model_state' not in st.session_state:
98
+ st.session_state.model_state = {
99
+ 'loaded': False,
100
+ 'current_config': None
101
+ }
102
+
103
+ if 'first_run' not in st.session_state:
104
+ st.session_state.first_run = True
105
+
106
+ if 'translate_in_progress' not in st.session_state:
107
+ st.session_state.translate_in_progress = False
108
+
109
+ if 'translate_button' in st.session_state and st.session_state.translate_button == True:
110
+ st.session_state.translate_in_progress = True
111
+
112
+ if 'translation_result' not in st.session_state:
113
+ st.session_state.translation_result = {
114
+ 'input': None,
115
+ 'output': None
116
+ }
117
+
118
  def main():
119
  hold_deterministic(SEED)
120
  config = load_json(DATASETS_MODEL_INFO_PATH)
121
 
122
  left, right = display_ui()
123
 
124
+ init_session_state()
125
+
126
  with left:
127
  dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
128
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
 
130
  src_lang, tgt_lang = language_pair.split("-")
131
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
132
  epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
133
+ btn_select_model = st.button("Select Model", disabled=st.session_state.translate_in_progress,
134
+ use_container_width=True, key="select_model_button")
135
  model_status_box = st.empty()
136
 
137
+ # Load model to cache, if the user has selected a model for the first time
138
+ if btn_select_model:
139
+ current_config = f"{dataset}_{language_pair}_{epsilon}"
 
 
 
140
 
141
+ st.session_state.model_state['loaded'] = False
142
  model_status_box.write("")
143
  with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
144
+ model_path = load_selected_model(config, dataset, language_pair, epsilon)
145
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
146
  model_status_box.success('Model loaded!')
147
+ st.session_state.model_state['loaded'] = True
148
+ st.session_state.model_state['current_config'] = current_config
149
+ st.session_state.first_run = False
150
 
151
+ with right:
152
+ if "translation_result" in st.session_state and st.session_state.translation_result['input'] is not None:
153
+ input_text_content = st.session_state.translation_result['input']
154
+ else:
155
+ input_text_content = "Enter Text Here"
156
+
157
+ if "translation_result" in st.session_state and st.session_state.translation_result['output'] is not None:
158
+ output_text_content = st.session_state.translation_result['output']
159
+ else:
160
+ output_text_content = None
161
+
162
+ input_text = st.text_area("Enter Text", input_text_content, max_chars=MAX_INPUT_LEN)
163
+
164
+ msg_model = "Please confirm model selection via the \'Select Model\' Button first!" \
165
+ if st.session_state.model_state['current_config'] is None \
166
+ else f"Current Model:{st.session_state.model_state['current_config']}"
167
+
168
+ st.write(msg_model)
169
+
170
+ btn_translate = st.button("Translate",
171
+ disabled=not st.session_state.model_state['loaded'],
172
+ use_container_width=True,
173
+ key="translate_button")
174
+ result_container = st.empty()
175
+
176
+ if output_text_content is not None and not st.session_state.translate_in_progress:
177
+ result_container.write("**Translation:**")
178
+ output_container = result_container.container(border=True)
179
+ output_container.write("".join([postprocess(output_text_content)]))
180
+
181
+ # Load model from cache when click translate button, if the user has selected a model previously
182
+ if not st.session_state.select_model_button and st.session_state.translate_button:
183
+ model_config = st.session_state.model_state['current_config']
184
+ dataset, language_pair, epsilon = model_config.split("_")
185
+ model_path = load_selected_model(config, dataset, language_pair, epsilon)
186
+ model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
187
+ st.session_state.model_state['loaded'] = True
188
+
189
+ if btn_translate and st.session_state.model_state['loaded']:
190
+ st.session_state.translate_in_progress = True
191
  with right:
192
  with st.spinner("Translating..."):
193
+ input_text = st.session_state.model_state['current_config'] + input_text
194
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
195
+
196
+ st.session_state.translation_result['input'] = input_text
197
+ st.session_state.translation_result['output'] = prediction
198
+
199
+ st.session_state.translate_in_progress = False
200
+ st.rerun()
201
 
202
 
203
  if __name__ == '__main__':