Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -34,13 +34,14 @@ The current release version of Breeze-7B is v1.0.
|
|
34 |
"""
|
35 |
|
36 |
LICENSE = """
|
37 |
-
|
38 |
"""
|
39 |
|
40 |
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
41 |
|
42 |
API_URL = os.environ.get("API_URL")
|
43 |
TOKEN = os.environ.get("TOKEN")
|
|
|
|
|
44 |
|
45 |
HEADERS = {
|
46 |
"Authorization": f"Bearer {TOKEN}",
|
@@ -48,34 +49,10 @@ HEADERS = {
|
|
48 |
"accept": "application/json"
|
49 |
}
|
50 |
|
51 |
-
|
52 |
MAX_SEC = 30
|
53 |
MAX_INPUT_LENGTH = 5000
|
54 |
|
55 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
56 |
-
|
57 |
-
def insert_to_db(prompt, response, temperature, top_p):
|
58 |
-
try:
|
59 |
-
#Establishing the connection
|
60 |
-
conn = psycopg2.connect(
|
61 |
-
database=os.environ.get("DB"), user=os.environ.get("USER"), password=os.environ.get("DB_PASS"), host=os.environ.get("DB_HOST"), port= '5432'
|
62 |
-
)
|
63 |
-
#Setting auto commit false
|
64 |
-
conn.autocommit = True
|
65 |
-
|
66 |
-
#Creating a cursor object using the cursor() method
|
67 |
-
cursor = conn.cursor()
|
68 |
-
|
69 |
-
# Preparing SQL queries to INSERT a record into the database.
|
70 |
-
cursor.execute(f"INSERT INTO breezedata(prompt, response, temperature, top_p) VALUES ('{prompt}', '{response}', {temperature}, {top_p})")
|
71 |
-
|
72 |
-
# Commit your changes in the database
|
73 |
-
conn.commit()
|
74 |
-
|
75 |
-
# Closing the connection
|
76 |
-
conn.close()
|
77 |
-
except:
|
78 |
-
pass
|
79 |
|
80 |
|
81 |
def refusal_condition(query):
|
@@ -105,19 +82,20 @@ with gr.Blocks() as demo:
|
|
105 |
system_prompt = gr.Textbox(label='System prompt',
|
106 |
value=DEFAULT_SYSTEM_PROMPT,
|
107 |
lines=1)
|
108 |
-
|
109 |
with gr.Accordion(label='Advanced options', open=False):
|
|
|
110 |
max_new_tokens = gr.Slider(
|
111 |
label='Max new tokens',
|
112 |
minimum=32,
|
113 |
-
maximum=
|
114 |
step=1,
|
115 |
-
value=
|
116 |
)
|
117 |
temperature = gr.Slider(
|
118 |
label='Temperature',
|
119 |
minimum=0.01,
|
120 |
-
maximum=
|
121 |
step=0.01,
|
122 |
value=0.01,
|
123 |
)
|
@@ -128,15 +106,8 @@ with gr.Blocks() as demo:
|
|
128 |
step=0.01,
|
129 |
value=0.01,
|
130 |
)
|
131 |
-
repetition_penalty = gr.Slider(
|
132 |
-
label='Repetition Penalty',
|
133 |
-
minimum=0.1,
|
134 |
-
maximum=2,
|
135 |
-
step=0.01,
|
136 |
-
value=1.1,
|
137 |
-
)
|
138 |
|
139 |
-
chatbot = gr.Chatbot()
|
140 |
with gr.Row():
|
141 |
msg = gr.Textbox(
|
142 |
container=False,
|
@@ -157,7 +128,6 @@ with gr.Blocks() as demo:
|
|
157 |
|
158 |
saved_input = gr.State()
|
159 |
|
160 |
-
|
161 |
def user(user_message, history):
|
162 |
return "", history + [[user_message, None]]
|
163 |
|
@@ -195,7 +165,7 @@ with gr.Blocks() as demo:
|
|
195 |
# start_time = time.time()
|
196 |
|
197 |
|
198 |
-
def bot(history, max_new_tokens, temperature, top_p, system_prompt
|
199 |
chat_data = []
|
200 |
system_prompt = system_prompt.strip()
|
201 |
if system_prompt:
|
@@ -217,19 +187,13 @@ with gr.Blocks() as demo:
|
|
217 |
yield history
|
218 |
else:
|
219 |
data = {
|
220 |
-
"model_type":
|
221 |
"prompt": str(message),
|
222 |
"parameters": {
|
223 |
"temperature": float(temperature),
|
224 |
"top_p": float(top_p),
|
225 |
"max_new_tokens": int(max_new_tokens),
|
226 |
-
"repetition_penalty":
|
227 |
-
|
228 |
-
"num_beams":1, # w/o beam search
|
229 |
-
"typical_p":0.99,
|
230 |
-
"top_k":61952, # w/o top_k
|
231 |
-
"do_sample": True,
|
232 |
-
"min_length":1,
|
233 |
}
|
234 |
}
|
235 |
|
@@ -248,14 +212,13 @@ with gr.Blocks() as demo:
|
|
248 |
response = history[-1][1]
|
249 |
|
250 |
if refusal_condition(history[-1][1]):
|
251 |
-
history[-1][1] = history[-1][1] + '\n\n**[免責聲明:
|
252 |
yield history
|
253 |
else:
|
254 |
del history[-1]
|
255 |
yield history
|
256 |
|
257 |
print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1])))
|
258 |
-
insert_to_db(message, response, float(temperature), float(top_p))
|
259 |
|
260 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
261 |
fn=bot,
|
@@ -265,7 +228,6 @@ with gr.Blocks() as demo:
|
|
265 |
temperature,
|
266 |
top_p,
|
267 |
system_prompt,
|
268 |
-
repetition_penalty,
|
269 |
],
|
270 |
outputs=chatbot
|
271 |
)
|
@@ -279,7 +241,6 @@ with gr.Blocks() as demo:
|
|
279 |
temperature,
|
280 |
top_p,
|
281 |
system_prompt,
|
282 |
-
repetition_penalty,
|
283 |
],
|
284 |
outputs=chatbot
|
285 |
)
|
@@ -319,7 +280,6 @@ with gr.Blocks() as demo:
|
|
319 |
temperature,
|
320 |
top_p,
|
321 |
system_prompt,
|
322 |
-
repetition_penalty,
|
323 |
],
|
324 |
outputs=chatbot,
|
325 |
)
|
@@ -342,5 +302,5 @@ with gr.Blocks() as demo:
|
|
342 |
|
343 |
gr.Markdown(LICENSE)
|
344 |
|
345 |
-
demo.queue(concurrency_count=
|
346 |
-
demo.launch()
|
|
|
34 |
"""
|
35 |
|
36 |
LICENSE = """
|
|
|
37 |
"""
|
38 |
|
39 |
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant built by MediaTek Research. The user you are helping speaks Traditional Chinese and comes from Taiwan."
|
40 |
|
41 |
API_URL = os.environ.get("API_URL")
|
42 |
TOKEN = os.environ.get("TOKEN")
|
43 |
+
TOKENIZER_REPO = "MediaTek-Research/Breeze-7B-Instruct-v1_0"
|
44 |
+
API_MODEL_TYPE = "breeze-7b-instruct-v10"
|
45 |
|
46 |
HEADERS = {
|
47 |
"Authorization": f"Bearer {TOKEN}",
|
|
|
49 |
"accept": "application/json"
|
50 |
}
|
51 |
|
|
|
52 |
MAX_SEC = 30
|
53 |
MAX_INPUT_LENGTH = 5000
|
54 |
|
55 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_REPO, use_auth_token=os.environ.get("HF_TOKEN"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
def refusal_condition(query):
|
|
|
82 |
system_prompt = gr.Textbox(label='System prompt',
|
83 |
value=DEFAULT_SYSTEM_PROMPT,
|
84 |
lines=1)
|
85 |
+
|
86 |
with gr.Accordion(label='Advanced options', open=False):
|
87 |
+
|
88 |
max_new_tokens = gr.Slider(
|
89 |
label='Max new tokens',
|
90 |
minimum=32,
|
91 |
+
maximum=2048,
|
92 |
step=1,
|
93 |
+
value=1024,
|
94 |
)
|
95 |
temperature = gr.Slider(
|
96 |
label='Temperature',
|
97 |
minimum=0.01,
|
98 |
+
maximum=0.5,
|
99 |
step=0.01,
|
100 |
value=0.01,
|
101 |
)
|
|
|
106 |
step=0.01,
|
107 |
value=0.01,
|
108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
+
chatbot = gr.Chatbot(show_copy_button=True, show_share_button=True, )
|
111 |
with gr.Row():
|
112 |
msg = gr.Textbox(
|
113 |
container=False,
|
|
|
128 |
|
129 |
saved_input = gr.State()
|
130 |
|
|
|
131 |
def user(user_message, history):
|
132 |
return "", history + [[user_message, None]]
|
133 |
|
|
|
165 |
# start_time = time.time()
|
166 |
|
167 |
|
168 |
+
def bot(history, max_new_tokens, temperature, top_p, system_prompt):
|
169 |
chat_data = []
|
170 |
system_prompt = system_prompt.strip()
|
171 |
if system_prompt:
|
|
|
187 |
yield history
|
188 |
else:
|
189 |
data = {
|
190 |
+
"model_type": API_MODEL_TYPE,
|
191 |
"prompt": str(message),
|
192 |
"parameters": {
|
193 |
"temperature": float(temperature),
|
194 |
"top_p": float(top_p),
|
195 |
"max_new_tokens": int(max_new_tokens),
|
196 |
+
"repetition_penalty": 1.1
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
}
|
198 |
}
|
199 |
|
|
|
212 |
response = history[-1][1]
|
213 |
|
214 |
if refusal_condition(history[-1][1]):
|
215 |
+
history[-1][1] = history[-1][1] + '\n\n**[免責聲明: 此模型並未針對問答進行安全保護,因此語言模型的任何回應不代表 MediaTek Research 立場。]**'
|
216 |
yield history
|
217 |
else:
|
218 |
del history[-1]
|
219 |
yield history
|
220 |
|
221 |
print('== Record ==\nQuery: {query}\nResponse: {response}'.format(query=repr(message), response=repr(history[-1][1])))
|
|
|
222 |
|
223 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
224 |
fn=bot,
|
|
|
228 |
temperature,
|
229 |
top_p,
|
230 |
system_prompt,
|
|
|
231 |
],
|
232 |
outputs=chatbot
|
233 |
)
|
|
|
241 |
temperature,
|
242 |
top_p,
|
243 |
system_prompt,
|
|
|
244 |
],
|
245 |
outputs=chatbot
|
246 |
)
|
|
|
280 |
temperature,
|
281 |
top_p,
|
282 |
system_prompt,
|
|
|
283 |
],
|
284 |
outputs=chatbot,
|
285 |
)
|
|
|
302 |
|
303 |
gr.Markdown(LICENSE)
|
304 |
|
305 |
+
demo.queue(concurrency_count=4, max_size=128)
|
306 |
+
demo.launch()
|