zyu commited on
Commit
ce4c2ae
·
1 Parent(s): 77d7c10
Files changed (1) hide show
  1. app.py +9 -15
app.py CHANGED
@@ -81,20 +81,14 @@ def main():
81
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
82
  st.success("".join([prediction]))
83
 
 
 
 
 
 
84
 
85
- st.title("Neural Machine Translation with DP-SGD")
 
 
 
86
 
87
- st.write(
88
- "This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")
89
-
90
- dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
91
-
92
- # DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
93
- # print(DATASETS_MODEL_INFO_PATH)
94
- # DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
95
- # DEFAULT_MODEL = 'google/mt5-small'
96
- #
97
- # MAX_SEQ_LEN = 512
98
- # NUM_BEAMS = 3
99
- # SEED = 2023
100
- # main()
 
81
  prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
82
  st.success("".join([prediction]))
83
 
84
+ if __name__ == '__main__':
85
+ DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
86
+ print(DATASETS_MODEL_INFO_PATH)
87
+ DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
88
+ DEFAULT_MODEL = 'google/mt5-small'
89
 
90
+ MAX_SEQ_LEN = 512
91
+ NUM_BEAMS = 3
92
+ SEED = 2023
93
+ main()
94