Zhuo commited on
Commit
10a832a
·
1 Parent(s): 31b8a7b
Files changed (2) hide show
  1. app.py +103 -0
  2. dataset_and_model_info.json +37 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import streamlit as st
7
+ import torch
8
+ from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer
9
+
10
+
11
+ @st.cache_resource
12
+ def load_model(model_name, tokenizer_name):
13
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
14
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
15
+ return model, tokenizer
16
+
17
+
18
+ def load_json(file_path):
19
+ with open(file_path, 'r', encoding='utf-8') as f:
20
+ data = json.load(f)
21
+ return data
22
+
23
+
24
+ def preprocess(input_text, tokenizer, src_lang, tgt_lang):
25
+ # task_prefix = f"translate {src_lang} to {tgt_lang}: "
26
+ # input_text = task_prefix + input_text
27
+ model_inputs = tokenizer(
28
+ input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np"
29
+ )
30
+ return model_inputs
31
+
32
+
33
+ def translate(input_text, model, tokenizer, src_lang, tgt_lang):
34
+ model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang)
35
+ model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS)
36
+ prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)
37
+ return prediction[0]
38
+
39
+
40
+ def hold_deterministic(seed):
41
+ np.random.seed(seed)
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ random.seed(seed)
45
+
46
+
47
+ def main():
48
+ # st.title("Privalingo Playground Demo")
49
+ # html_temp = """
50
+ # <div style="background:#025246 ;padding:10px">
51
+ # <h2 style="color:white;text-align:center;">Playground Demo </h2>
52
+ # </div>
53
+ # """
54
+ # st.markdown(html_temp, unsafe_allow_html=True)
55
+ hold_deterministic(SEED)
56
+
57
+ st.title("Neural Machine Translation with DP-SGD")
58
+
59
+ st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")
60
+
61
+ dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))
62
+
63
+ language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())
64
+
65
+ language_pair = st.selectbox("Language pair for translation", language_pairs_list)
66
+
67
+ src_lang, tgt_lang = language_pair.split("-")
68
+
69
+ epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
70
+ epsilon = st.radio("Select a privacy budget epsilon", epsilon_options)
71
+
72
+ st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...')
73
+
74
+ ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
75
+ model_name = MODEL.split('/')[-1]
76
+ model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)
77
+
78
+ if not os.path.exists(model_path):
79
+ st.error(f"Model not found. Use {MODEL} instead")
80
+ model_path = MODEL
81
+
82
+ model, tokenizer = load_model(model_path, tokenizer_name=MODEL)
83
+ st.success('Model loaded!')
84
+ st_model_load.text("")
85
+
86
+ input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200)
87
+
88
+ if st.button("Translate"):
89
+ st.write("Translation")
90
+ prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
91
+ st.success("".join([prediction]))
92
+
93
+
94
+ if __name__ == '__main__':
95
+ DATASETS_MODEL_INFO_PATH = os.getcwd() + "/app/dataset_and_model_info.json"
96
+ CHECKPOINTS_DIR = os.getcwd() + "/checkpoints"
97
+ DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
98
+ MODEL = 'google/mt5-small'
99
+
100
+ MAX_SEQ_LEN = 512
101
+ NUM_BEAMS = 3
102
+ SEED = 2023
103
+ main()
dataset_and_model_info.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "WMT": {
3
+ "languages pairs": {
4
+ "German-English": {
5
+ "epsilon": {
6
+ "1": "2023_10_07-05_50_17",
7
+ "5": "2023_10_07-16_40_49",
8
+ "non": "2023_11_24-20_58_27"
9
+ }
10
+ }
11
+ }
12
+ },
13
+ "BSD": {
14
+ "languages pairs": {
15
+ "Japanese-English": {
16
+ "epsilon": {
17
+ "1": "2023_09_04-15_54_22",
18
+ "2": "2023_09_04-16_23_41",
19
+ "5": "2023_09_04-16_51_06",
20
+ "10": "2023_09_04-17_17_44",
21
+ "non": "2023_10_22-19_08_23"
22
+ }
23
+ }
24
+ }
25
+ },
26
+ "ClinSpEn-CC": {
27
+ "languages pairs": {
28
+ "Spanish-English": {
29
+ "epsilon": {
30
+ "1": "2023_10_01-00_38_27",
31
+ "5": "2023_10_01-01_11_49",
32
+ "non": "2023_10_23-15_46_22"
33
+ }
34
+ }
35
+ }
36
+ }
37
+ }