zyu commited on
Commit
cb7cdd4
·
1 Parent(s): 499a604

update app layout

Browse files
Files changed (2) hide show
  1. app.py +54 -26
  2. dataset_and_model_info.json +3 -3
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import os
3
  import random
 
4
 
5
  import numpy as np
6
  import streamlit as st
@@ -51,38 +52,65 @@ def hold_deterministic(seed):
51
  random.seed(seed)
52
 
53
 
54
- def main():
55
- hold_deterministic(SEED)
56
-
57
- st.title("Neural Machine Translation with DP-SGD")
58
-
59
- st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")
60
-
61
- dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
62
-
63
- language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
64
 
65
- language_pair = st.selectbox("Language pair for translation", language_pairs_list)
66
 
67
- src_lang, tgt_lang = language_pair.split("-")
68
-
69
- epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
70
- epsilon = st.radio("Select a privacy budget epsilon", epsilon_options)
71
 
72
- st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...')
73
 
74
- model_path = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
75
 
76
- model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
77
- st.success('Model loaded!')
78
- st_model_load.text("")
 
 
79
 
80
- input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- if st.button("Translate"):
83
- st.write("Translation")
84
- prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
85
- st.success("".join([prediction]))
86
 
87
  if __name__ == '__main__':
88
  DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
@@ -93,5 +121,5 @@ if __name__ == '__main__':
93
  MAX_SEQ_LEN = 512
94
  NUM_BEAMS = 3
95
  SEED = 2023
 
96
  main()
97
-
 
1
  import json
2
  import os
3
  import random
4
+ import re
5
 
6
  import numpy as np
7
  import streamlit as st
 
52
  random.seed(seed)
53
 
54
 
55
+ def postprocess(output_text):
56
+ output = re.sub(r"<extra_id[^>]*>", "", output_text)
57
+ return output
 
 
 
 
 
 
 
58
 
 
59
 
60
+ def main():
61
+ hold_deterministic(SEED)
 
 
62
 
63
+ st.set_page_config(page_title="DP-NMT DEMO", layout="wide")
64
 
65
+ st.title("Neural Machine Translation with DP-SGD")
66
 
67
+ st.write(
68
+ "[![Star](https://img.shields.io/github/stars/trusthlt/dp-nmt.svg?logo=github&style=social)](https://github.com/trusthlt/dp-nmt)"
69
+ "&nbsp;&nbsp;&nbsp;"
70
+ "[![ACL](https://img.shields.io/badge/ACL-Link-red.svg?logo=data:image/svg%2bxml;base64,PHN2ZyBoZWlnaHQ9IjI2MC4wOTA0ODIiIHZpZXdCb3g9IjAgMCA2OCA0NiIgd2lkdGg9IjM4NC40ODE1ODIiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+PHBhdGggZD0ibTQxLjk3NzU1MyAwdjMuMDE1OGgtMzQuNDkwNjQ3Ni03LjQ4NjkwNTR2Ny40ODQ5OSAyNy45Nzc4OCA3LjUyMTMzaDcuNDg2OTA1NCA0Mi4wMTM4OTY2IDcuNDg2OTA2IDExLjAxMjI5MnYtMTUuMDA2MzJoLTExLjAxMjI5MnYtMjAuNDkyODktNy40ODQ5OWMwLTEuNTczNjkgMC0xLjI1NDAyIDAtMy4wMTU4em0tMjYuOTY3Mzk4IDE3Ljk4NTc4aDI2Ljk2NzM5OHYxMy4wMDc5aC0yNi45NjczOTh6IiBmaWxsPSIjZWQxYzI0IiBmaWxsLXJ1bGU9ImV2ZW5vZGQiLz48L3N2Zz4=&link=https%3A%2F%2Faclanthology.org%2F2024.eacl-demo.11%2F)](https://aclanthology.org/2024.eacl-demo.11/)"
71
+ )
72
 
73
+ st.write("This is a demo for private neural machine translation with DP-SGD.")
74
+
75
+ left, right = st.columns(2)
76
+ with left:
77
+ dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
78
+ language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
79
+ language_pair = st.selectbox("Language pair for translation", language_pairs_list)
80
+ src_lang, tgt_lang = language_pair.split("-")
81
+ epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
82
+ epsilon = st.radio("Select a privacy budget epsilon", epsilon_options, horizontal=True)
83
+
84
+ with right:
85
+ input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=MAX_INPUT_LEN)
86
+ btn_translate = st.button("Translate")
87
+
88
+ ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
89
+
90
+ if "privalingo" in ckpt:
91
+ # that means the model is loaded from huggingface hub rather checkpoints locally
92
+ model_path = ckpt
93
+ else:
94
+ model_name = DEFAULT_MODEL.split('/')[-1]
95
+ model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
96
+ if not os.path.exists(model_path):
97
+ with left:
98
+ st.error(f"Model not found. Use {DEFAULT_MODEL} instead")
99
+ model_path = DEFAULT_MODEL
100
+
101
+ with left:
102
+ with st.spinner(f'Loading model trained on {dataset} with epsilon {epsilon}...'):
103
+ model, tokenizer = load_model(model_path, tokenizer_name=DEFAULT_MODEL)
104
+ sc_box_model_loaded = st.success('Model loaded!')
105
+
106
+ if btn_translate:
107
+ with right:
108
+ with st.spinner("Translating..."):
109
+ prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
110
+ st.write("**Translation:**")
111
+ result_container = st.container(border=True)
112
+ result_container.write("".join([postprocess(prediction)]))
113
 
 
 
 
 
114
 
115
  if __name__ == '__main__':
116
  DATASETS_MODEL_INFO_PATH = os.path.join(os.getcwd(), "dataset_and_model_info.json")
 
121
  MAX_SEQ_LEN = 512
122
  NUM_BEAMS = 3
123
  SEED = 2023
124
+ MAX_INPUT_LEN = 500
125
  main()
 
dataset_and_model_info.json CHANGED
@@ -5,7 +5,7 @@
5
  "epsilon": {
6
  "1": "MarcoYuTono/privalingo-WMT-de-en-eps1",
7
  "5": "MarcoYuTono/privalingo-WMT-de-en-eps5",
8
- "non": "MarcoYuTono/privalingo-WMT-de-en-infinite"
9
  }
10
  }
11
  }
@@ -18,7 +18,7 @@
18
  "2": "2023_09_04-16_23_41",
19
  "5": "2023_09_04-16_51_06",
20
  "10": "2023_09_04-17_17_44",
21
- "non": "2023_10_22-19_08_23"
22
  }
23
  }
24
  }
@@ -29,7 +29,7 @@
29
  "epsilon": {
30
  "1": "2023_10_01-00_38_27",
31
  "5": "2023_10_01-01_11_49",
32
- "non": "2023_10_23-15_46_22"
33
  }
34
  }
35
  }
 
5
  "epsilon": {
6
  "1": "MarcoYuTono/privalingo-WMT-de-en-eps1",
7
  "5": "MarcoYuTono/privalingo-WMT-de-en-eps5",
8
+ "infinite": "MarcoYuTono/privalingo-WMT-de-en-infinite"
9
  }
10
  }
11
  }
 
18
  "2": "2023_09_04-16_23_41",
19
  "5": "2023_09_04-16_51_06",
20
  "10": "2023_09_04-17_17_44",
21
+ "infinite": "2023_10_22-19_08_23"
22
  }
23
  }
24
  }
 
29
  "epsilon": {
30
  "1": "2023_10_01-00_38_27",
31
  "5": "2023_10_01-01_11_49",
32
+ "infinite": "2023_10_23-15_46_22"
33
  }
34
  }
35
  }