zyu commited on
Commit
93d0613
·
1 Parent(s): 45d73d8

update model paths. improve readability of code.

Browse files
Files changed (2) hide show
  1. app.py +30 -20
  2. dataset_and_model_info.json +4 -4
app.py CHANGED
@@ -57,11 +57,8 @@ def postprocess(output_text):
57
  return output
58
 
59
 
60
- def main():
61
- hold_deterministic(SEED)
62
-
63
  st.set_page_config(page_title="DP-NMT DEMO", layout="wide")
64
-
65
  st.title("Neural Machine Translation with DP-SGD")
66
 
67
  st.write(
@@ -73,6 +70,27 @@ def main():
73
  st.write("This is a demo for private neural machine translation with DP-SGD.")
74
 
75
  left, right = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  with left:
77
  dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
78
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
@@ -80,36 +98,28 @@ def main():
80
  src_lang, tgt_lang = language_pair.split("-")
81
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
82
  epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
 
83
 
84
  with right:
85
  input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN)
86
  btn_translate = st.button("Translate")
 
87
 
88
- ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
89
-
90
- if "privalingo" in ckpt:
91
- # that means the model is loaded from huggingface hub rather checkpoints locally
92
- model_path = ckpt
93
- else:
94
- model_name = DEFAULT_MODEL.split('/')[-1]
95
- model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
96
- if not os.path.exists(model_path):
97
- with left:
98
- st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
99
- model_path = DEFAULT_MODEL
100
 
101
  with left:
 
102
  with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
103
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
104
- sc_box_model_loaded = st.success('Model loaded!')
105
 
106
  if btn_translate:
107
  with right:
108
  with st.spinner("Translating..."):
109
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
110
- st.write("**Translation:**")
111
- result_container = st.container(border=True)
112
- result_container.write("".join([postprocess(prediction)]))
113
 
114
 
115
  if __name__ == '__main__':
 
57
  return output
58
 
59
 
60
+ def display_ui():
 
 
61
  st.set_page_config(page_title="DP-NMT DEMO", layout="wide")
 
62
  st.title("Neural Machine Translation with DP-SGD")
63
 
64
  st.write(
 
70
  st.write("This is a demo for private neural machine translation with DP-SGD.")
71
 
72
  left, right = st.columns(2)
73
+ return left, right
74
+
75
+
76
+ def load_selected_model(config, dataset, language_pair, epsilon):
77
+ ckpt = config[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
78
+ if "privalingo" in ckpt:
79
+ model_path = ckpt # load model from huggingface hub
80
+ else:
81
+ model_name = DEFAULT_MODEL.split('/')[-1]
82
+ model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
83
+ if not os.path.exists(model_path):
84
+ st.error(f"Model not found. Using default model: {DEFAULT_MODEL}")
85
+ model_path = DEFAULT_MODEL
86
+ return model_path
87
+
88
+
89
+ def main():
90
+ hold_deterministic(SEED)
91
+
92
+ left, right = display_ui()
93
+
94
  with left:
95
  dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
96
  language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
 
98
  src_lang, tgt_lang = language_pair.split("-")
99
  epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
100
  epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
101
+ model_status_box = st.empty()
102
 
103
  with right:
104
  input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN)
105
  btn_translate = st.button("Translate")
106
+ result_container = st.empty()
107
 
108
+ model_path = load_selected_model(config, dataset, language_pair, epsilon)
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  with left:
111
+ model_status_box.write("")
112
  with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
113
  model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
114
+ model_status_box.success('Model loaded!')
115
 
116
  if btn_translate:
117
  with right:
118
  with st.spinner("Translating..."):
119
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
120
+ result_container.write("**Translation:**")
121
+ output_container = result_container.container(border=True)
122
+ output_container.write("".join([postprocess(prediction)]))
123
 
124
 
125
  if __name__ == '__main__':
dataset_and_model_info.json CHANGED
@@ -1,11 +1,11 @@
1
  {
2
- "WMT": {
3
  "languages pairs": {
4
  "German-English": {
5
  "epsilon": {
6
  "1": "TrustHLT/privalingo-WMT-de-en-eps1",
7
  "5": "TrustHLT/privalingo-WMT-de-en-eps5",
8
- "infinite": "TrustHLT/privalingo-WMT-de-en-Infinity"
9
  }
10
  }
11
  }
@@ -18,7 +18,7 @@
18
  "2": "TrustHLT/privalingo-BSD-jp-en-eps2",
19
  "5": "TrustHLT/privalingo-BSD-jp-en-eps5",
20
  "10": "TrustHLT/privalingo-BSD-jp-en-eps10",
21
- "infinite": "TrustHLT/privalingo-BSD-jp-en-Infinity"
22
  }
23
  }
24
  }
@@ -29,7 +29,7 @@
29
  "epsilon": {
30
  "1": "TrustHLT/privalingo-ClinSpEn-es-en-eps1",
31
  "5": "TrustHLT/privalingo-ClinSpEn-es-en-eps5",
32
- "infinite": "TrustHLT/privalingo-ClinSpEn-es-en-Infinity"
33
  }
34
  }
35
  }
 
1
  {
2
+ "WMT-16": {
3
  "languages pairs": {
4
  "German-English": {
5
  "epsilon": {
6
  "1": "TrustHLT/privalingo-WMT-de-en-eps1",
7
  "5": "TrustHLT/privalingo-WMT-de-en-eps5",
8
+ "Infinity": "TrustHLT/privalingo-WMT-de-en-Infinity"
9
  }
10
  }
11
  }
 
18
  "2": "TrustHLT/privalingo-BSD-jp-en-eps2",
19
  "5": "TrustHLT/privalingo-BSD-jp-en-eps5",
20
  "10": "TrustHLT/privalingo-BSD-jp-en-eps10",
21
+ "Infinity": "TrustHLT/privalingo-BSD-jp-en-Infinity"
22
  }
23
  }
24
  }
 
29
  "epsilon": {
30
  "1": "TrustHLT/privalingo-ClinSpEn-es-en-eps1",
31
  "5": "TrustHLT/privalingo-ClinSpEn-es-en-eps5",
32
+ "Infinity": "TrustHLT/privalingo-ClinSpEn-es-en-Infinity"
33
  }
34
  }
35
  }