File size: 3,484 Bytes
10a832a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import os
import random

import numpy as np
import streamlit as st
import torch
from transformers import FlaxAutoModelForSeq2SeqLM, AutoTokenizer


@st.cache_resource
def load_model(model_name, tokenizer_name):
    model = FlaxAutoModelForSeq2SeqLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    return model, tokenizer


def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data


def preprocess(input_text, tokenizer, src_lang, tgt_lang):
    # task_prefix = f"translate {src_lang} to {tgt_lang}: "
    # input_text = task_prefix + input_text
    model_inputs = tokenizer(
        input_text, max_length=MAX_SEQ_LEN, padding="max_length", truncation=True, return_tensors="np"
    )
    return model_inputs


def translate(input_text, model, tokenizer, src_lang, tgt_lang):
    model_inputs = preprocess(input_text, tokenizer, src_lang, tgt_lang)
    model_outputs = model.generate(**model_inputs, num_beams=NUM_BEAMS)
    prediction = tokenizer.batch_decode(model_outputs.sequences, skip_special_tokens=True)
    return prediction[0]


def hold_deterministic(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)


def main():
    # st.title("Privalingo Playground Demo")
    # html_temp = """
    # <div style="background:#025246 ;padding:10px">
    # <h2 style="color:white;text-align:center;">Playground Demo </h2>
    # </div>
    # """
    # st.markdown(html_temp, unsafe_allow_html=True)
    hold_deterministic(SEED)

    st.title("Neural Machine Translation with DP-SGD")

    st.write("This is a demo for private neural machine translation with DP-SGD. More detail can be found in the [repository](https://github.com/trusthlt/dp-nmt)")

    dataset = st.selectbox("Choose a dataset used for fine-tuning", list(DATASETS_MODEL_INFO.keys()))

    language_pairs_list = list(DATASETS_MODEL_INFO[dataset]["languages pairs"].keys())

    language_pair = st.selectbox("Language pair for translation", language_pairs_list)

    src_lang, tgt_lang = language_pair.split("-")

    epsilon_options = list(DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'].keys())
    epsilon = st.radio("Select a privacy budget epsilon", epsilon_options)

    st_model_load = st.text(f'Loading model trained on {dataset} with epsilon {epsilon}...')

    ckpt = DATASETS_MODEL_INFO[dataset]['languages pairs'][language_pair]['epsilon'][str(epsilon)]
    model_name = MODEL.split('/')[-1]
    model_path = os.path.join(CHECKPOINTS_DIR, ckpt, model_name)

    if not os.path.exists(model_path):
        st.error(f"Model not found. Use {MODEL} instead")
        model_path = MODEL

    model, tokenizer = load_model(model_path, tokenizer_name=MODEL)
    st.success('Model loaded!')
    st_model_load.text("")

    input_text = st.text_area("Enter Text", "Enter Text Here", max_chars=200)

    if st.button("Translate"):
        st.write("Translation")
        prediction = translate(input_text, model, tokenizer, src_lang, tgt_lang)
        st.success("".join([prediction]))


if __name__ == '__main__':
    DATASETS_MODEL_INFO_PATH = os.getcwd() + "/app/dataset_and_model_info.json"
    CHECKPOINTS_DIR = os.getcwd() + "/checkpoints"
    DATASETS_MODEL_INFO = load_json(DATASETS_MODEL_INFO_PATH)
    MODEL = 'google/mt5-small'

    MAX_SEQ_LEN = 512
    NUM_BEAMS = 3
    SEED = 2023
    main()