zyu commited on
Commit
499a604
·
1 Parent(s): ce4c2ae
Files changed (2) hide show
  1. app.py +9 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -10,8 +10,15 @@ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
10
 
11
  @st.cache_resource
12
  def load_model(model_name, tokenizer_name):
13
- model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
14
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
 
 
 
 
 
 
 
15
  return model, tokenizer
16
 
17
 
@@ -66,10 +73,6 @@ def main():
66
 
67
  model_path = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
68
 
69
- if not os.path.exists(model_path):
70
- st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
71
- model_path = DEFAULT_MODEL
72
-
73
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
74
  st.success('Model loaded!')
75
  st_model_load.text("")
 
10
 
11
  @st.cache_resource
12
  def load_model(model_name, tokenizer_name):
13
+ try:
14
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
15
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
16
+ except Exception as e:
17
+ st.error(f"Error loading model: {e}")
18
+ st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
19
+ model_path = DEFAULT_MODEL
20
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_path)
21
+ tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)
22
  return model, tokenizer
23
 
24
 
 
73
 
74
  model_path = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
75
 
 
 
 
 
76
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
77
  st.success('Model loaded!')
78
  st_model_load.text("")
requirements.txt CHANGED
@@ -8,4 +8,5 @@ torch==1.13.1
8
  transformers==4.26.0
9
  streamlit==1.38.0
10
  optax==0.1.4
11
- orbax==0.1.1
 
 
8
  transformers==4.26.0
9
  streamlit==1.38.0
10
  optax==0.1.4
11
+ orbax==0.1.1
12
+ sentencepiece==0.1.97