Nyanfa commited on
Commit
37d6487
·
verified ·
1 Parent(s): f04938e

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +338 -0
  2. requirements.txt +2 -0
  3. style.css +8 -0
app.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cohere
2
+ import streamlit as st
3
+ from streamlit.components.v1 import html
4
+ from streamlit_extras.stylable_container import stylable_container
5
+ import re
6
+ import urllib.parse
7
+
8
+ st.title("Cohere Chat UI")
9
+
10
+ if "api_key" not in st.session_state:
11
+ api_key = st.text_input("Enter your API Key", type="password")
12
+ if api_key:
13
+ if api_key.isascii():
14
+ st.session_state.api_key = api_key
15
+ client = cohere.Client(api_key=api_key)
16
+ st.rerun()
17
+ else:
18
+ st.warning("Please enter your API key correctly.")
19
+ st.stop()
20
+ else:
21
+ st.warning("Please enter your API key to use the app. You can obtain your API key from here: https://dashboard.cohere.com/api-keys")
22
+ st.stop()
23
+ else:
24
+ client = cohere.Client(api_key=st.session_state.api_key)
25
+
26
+ if "messages" not in st.session_state:
27
+ st.session_state.messages = []
28
+
29
+ def get_ai_response(prompt, chat_history):
30
+ st.session_state.is_streaming = True
31
+ st.session_state.response = ""
32
+
33
+ with st.chat_message("ai", avatar=st.session_state.assistant_avatar):
34
+ penalty_kwargs = {
35
+ "frequency_penalty" if penalty_type == "Frequency Penalty" else "presence_penalty": penalty_value
36
+ }
37
+
38
+ stream = client.chat_stream(
39
+ message=prompt,
40
+ model=model,
41
+ preamble=preamble,
42
+ chat_history=chat_history,
43
+ temperature=temperature,
44
+ k=k,
45
+ p=p,
46
+ **penalty_kwargs
47
+ )
48
+
49
+ placeholder = st.empty()
50
+
51
+ with stylable_container(
52
+ key="stop_generating",
53
+ css_styles="""
54
+ button {
55
+ position: fixed;
56
+ bottom: 100px;
57
+ left: 50%;
58
+ transform: translateX(-50%);
59
+ z-index: 1;
60
+ }
61
+ """,
62
+ ):
63
+ st.button("Stop generating")
64
+
65
+ shown_message = ""
66
+
67
+ for event in stream:
68
+ if event.event_type == "text-generation":
69
+ content = event.text
70
+ st.session_state.response += content
71
+ shown_message += content.replace("\n", " \n")\
72
+ .replace("<", "\\<")\
73
+ .replace(">", "\\>")
74
+ placeholder.markdown(shown_message)
75
+
76
+ st.session_state.is_streaming = False
77
+ return st.session_state.response
78
+
79
+ def normalize_code_block(match):
80
+ return match.group(0).replace(" \n", "\n")\
81
+ .replace("\\<", "<")\
82
+ .replace("\\>", ">")
83
+
84
+ def normalize_inline(match):
85
+ return match.group(0).replace("\\<", "<")\
86
+ .replace("\\>", ">")
87
+
88
+ code_block_pattern = r"(```.*?```)"
89
+ inline_pattern = r"`([^`\n]+?)`"
90
+
91
+ def display_messages():
92
+ for i, message in enumerate(st.session_state.messages):
93
+ name = "user" if message["role"] == "USER" else "ai"
94
+ avatar = st.session_state.user_avatar if message["role"] == "USER" else st.session_state.assistant_avatar
95
+ with st.chat_message(name, avatar=avatar):
96
+ shown_message = message["text"].replace("\n", " \n")\
97
+ .replace("<", "\\<")\
98
+ .replace(">", "\\>")
99
+ if "```" in shown_message:
100
+ # Replace " \n" with "\n" within code blocks
101
+ shown_message = re.sub(code_block_pattern, normalize_code_block, shown_message, flags=re.DOTALL)
102
+ if "`" in shown_message:
103
+ shown_message = re.sub(inline_pattern, normalize_inline, shown_message)
104
+ st.markdown(shown_message)
105
+
106
+ col1, col2, col3, col4 = st.columns([1, 1, 1, 1])
107
+ with col1:
108
+ if st.button("Edit", key=f"edit_{i}_{len(st.session_state.messages)}"):
109
+ st.session_state.edit_index = i
110
+ st.rerun()
111
+ with col2:
112
+ if st.session_state.is_delete_mode and st.button("Delete", key=f"delete_{i}_{len(st.session_state.messages)}"):
113
+ del st.session_state.messages[i]
114
+ st.rerun()
115
+ with col3:
116
+ text_to_copy = message["text"]
117
+ # Encode the string to escape
118
+ text_to_copy_escaped = urllib.parse.quote(text_to_copy)
119
+
120
+ copy_button_html = f"""
121
+ <button id="copy-msg-btn-{i}" style='font-size: 1em; padding: 0.5em;' onclick='copyMessage("{i}")'>Copy</button>
122
+
123
+ <script>
124
+ function copyMessage(index) {{
125
+ navigator.clipboard.writeText(decodeURIComponent("{text_to_copy_escaped}"));
126
+ let copyBtn = document.getElementById("copy-msg-btn-" + index);
127
+ copyBtn.innerHTML = "Copied!";
128
+ setTimeout(function(){{ copyBtn.innerHTML = "Copy"; }}, 2000);
129
+ }}
130
+ </script>
131
+ """
132
+ html(copy_button_html, height=50)
133
+
134
+ if i == len(st.session_state.messages) - 1 and message["role"] == "CHATBOT":
135
+ with col4:
136
+ if st.button("Retry", key=f"retry_{i}_{len(st.session_state.messages)}"):
137
+ if len(st.session_state.messages) >= 2:
138
+ del st.session_state.messages[-1]
139
+ st.session_state.retry_flag = True
140
+ st.rerun()
141
+
142
+ if "edit_index" in st.session_state and st.session_state.edit_index == i:
143
+ with st.form(key=f"edit_form_{i}_{len(st.session_state.messages)}"):
144
+ new_content = st.text_area("Edit message", height=200, value=st.session_state.messages[i]["text"])
145
+ col1, col2 = st.columns([1, 1])
146
+ with col1:
147
+ if st.form_submit_button("Save"):
148
+ st.session_state.messages[i]["text"] = new_content
149
+ del st.session_state.edit_index
150
+ st.rerun()
151
+ with col2:
152
+ if st.form_submit_button("Cancel"):
153
+ del st.session_state.edit_index
154
+ st.rerun()
155
+
156
+ # Add sidebar for advanced settings
157
+ with st.sidebar:
158
+ settings_tab, appearance_tab = st.tabs(["Settings", "Appearance"])
159
+
160
+ with settings_tab:
161
+ st.markdown("Help (Japanese): https://rentry.org/9hgneofz")
162
+
163
+ # Copy Conversation History button
164
+ log_text = ""
165
+ for message in st.session_state.messages:
166
+ if message["role"] == "USER":
167
+ log_text += "<USER>\n"
168
+ log_text += message["text"] + "\n\n"
169
+ else:
170
+ log_text += "<ASSISTANT>\n"
171
+ log_text += message["text"] + "\n\n"
172
+ log_text = log_text.rstrip("\n")
173
+
174
+ # Encode the string to escape
175
+ log_text_escaped = urllib.parse.quote(log_text)
176
+
177
+ copy_log_button_html = f"""
178
+ <button id="copy-log-btn" style='font-size: 1em; padding: 0.5em;' onclick='copyLog()'>Copy Conversation History</button>
179
+
180
+ <script>
181
+ function copyLog() {{
182
+ navigator.clipboard.writeText(decodeURIComponent("{log_text_escaped}"));
183
+ let copyBtn = document.getElementById("copy-log-btn");
184
+ copyBtn.innerHTML = "Copied!";
185
+ setTimeout(function(){{ copyBtn.innerHTML = "Copy Conversation History"; }}, 2000);
186
+ }}
187
+ </script>
188
+ """
189
+ html(copy_log_button_html, height=50)
190
+
191
+ if st.session_state.get("is_history_shown") != True:
192
+ if st.button("Display History as Code Block"):
193
+ st.session_state.is_history_shown = True
194
+ st.rerun()
195
+ else:
196
+ if st.button("Hide History"):
197
+ st.session_state.is_history_shown = False
198
+ st.rerun()
199
+ st.code(log_text)
200
+
201
+ st.session_state.is_delete_mode = st.toggle("Enable Delete button")
202
+
203
+ st.header("Advanced Settings")
204
+ model = st.selectbox("Model", options=["command-r-plus", "command-r"], index=0)
205
+ preamble = st.text_area("Preamble", height=200)
206
+ temperature = st.slider("Temperature", min_value=0.0, max_value=1.0, value=0.3, step=0.1)
207
+ k = st.slider("Top-K", min_value=0, max_value=500, value=0, step=1)
208
+ p = st.slider("Top-P", min_value=0.01, max_value=0.99, value=0.75, step=0.01)
209
+ penalty_type = st.selectbox("Penalty Type", options=["Frequency Penalty", "Presence Penalty"])
210
+ penalty_value = st.slider("Penalty Value", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
211
+
212
+ st.header("Restore History")
213
+ history_input = st.text_area("Paste conversation history:", height=200)
214
+ if st.button("Restore History"):
215
+ st.session_state.messages = []
216
+ messages = re.split(r"^(<USER>|<ASSISTANT>)\n", history_input, flags=re.MULTILINE)
217
+ role = None
218
+ text = ""
219
+ for message in messages:
220
+ if message.strip() in ["<USER>", "<ASSISTANT>"]:
221
+ if role and text:
222
+ st.session_state.messages.append({"role": role, "text": text.strip()})
223
+ text = ""
224
+ role = "USER" if message.strip() == "<USER>" else "CHATBOT"
225
+ else:
226
+ text += message
227
+ if role and text:
228
+ st.session_state.messages.append({"role": role, "text": text.strip()})
229
+ st.rerun()
230
+
231
+ st.header("Clear History")
232
+ if st.button("Clear Chat History"):
233
+ st.session_state.messages = []
234
+ st.rerun()
235
+
236
+ st.header("Change API Key")
237
+ new_api_key = st.text_input("Enter new API Key", type="password")
238
+ if st.button("Update API Key"):
239
+ if new_api_key and new_api_key.isascii():
240
+ st.session_state.api_key = new_api_key
241
+ client = cohere.Client(api_key=new_api_key)
242
+ st.success("API Key updated successfully!")
243
+ else:
244
+ st.warning("Please enter a valid API Key.")
245
+
246
+ with appearance_tab:
247
+ st.header("Font Selection")
248
+ font_options = {
249
+ "Zen Maru Gothic": "Zen Maru Gothic",
250
+ "Noto Sans JP": "Noto Sans JP",
251
+ "Sawarabi Mincho": "Sawarabi Mincho"
252
+ }
253
+ selected_font = st.selectbox("Choose a font", ["Default"] + list(font_options.keys()))
254
+
255
+ st.header("Change the font size")
256
+ st.session_state.font_size = st.slider("Font size", min_value=16.0, max_value=50.0, value=16.0, step=1.0)
257
+
258
+ st.header("Change the user's icon")
259
+ st.session_state.user_avatar = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg", "webp", "gif", "bmp", "svg",], key="user_avatar_uploader")
260
+
261
+ st.header("Change the assistant's icon")
262
+ st.session_state.assistant_avatar = st.file_uploader("Choose an image", type=["png", "jpg", "jpeg", "webp", "gif", "bmp", "svg",], key="assistant_avatar_uploader")
263
+
264
+ st.header("Change the icon size")
265
+ st.session_state.avatar_size = st.slider("Icon size", min_value=2.0, max_value=20.0, value=2.0, step=0.2)
266
+
267
+
268
+ # After Stop generating
269
+ if st.session_state.get("is_streaming"):
270
+ st.session_state.messages.append({"role": "CHATBOT", "text": st.session_state.response})
271
+ st.session_state.is_streaming = False
272
+ if "retry_flag" in st.session_state and st.session_state.retry_flag:
273
+ st.session_state.retry_flag = False
274
+ st.rerun()
275
+
276
+ if selected_font != "Default":
277
+ with open("style.css") as css:
278
+ st.markdown(f'<style>{css.read()}</style>', unsafe_allow_html=True)
279
+ st.markdown(f'<style>body * {{ font-family: "{font_options[selected_font]}", serif !important; }}</style>', unsafe_allow_html=True)
280
+
281
+ # Change font size
282
+ st.markdown(f'<style>[data-testid="stChatMessageContent"] .st-emotion-cache-cnbvxy p{{font-size: {st.session_state.font_size}px;}}</style>', unsafe_allow_html=True)
283
+
284
+ # Change icon size
285
+ # (CSS element names may be subject to change.)
286
+ # (Contributor: ★31 >>538)
287
+ AVATAR_SIZE_STYLE = f"""
288
+ <style>
289
+ [data-testid="chatAvatarIcon-user"] {{
290
+ width: {st.session_state.avatar_size}rem;
291
+ height: {st.session_state.avatar_size}rem;
292
+ }}
293
+ [data-testid="chatAvatarIcon-assistant"] {{
294
+ width: {st.session_state.avatar_size}rem;
295
+ height: {st.session_state.avatar_size}rem;
296
+ }}
297
+ [data-testid="stChatMessage"] .st-emotion-cache-1pbsqtx {{
298
+ width: {st.session_state.avatar_size / 1.6}rem;
299
+ height: {st.session_state.avatar_size / 1.6}rem;
300
+ }}
301
+ [data-testid="stChatMessage"] .st-emotion-cache-p4micv {{
302
+ width: {st.session_state.avatar_size}rem;
303
+ height: {st.session_state.avatar_size}rem;
304
+ }}
305
+ </style>
306
+ """
307
+ st.markdown(AVATAR_SIZE_STYLE, unsafe_allow_html=True)
308
+
309
+ display_messages()
310
+
311
+ # After Retry
312
+ if st.session_state.get("retry_flag"):
313
+ if len(st.session_state.messages) > 0:
314
+ prompt = st.session_state.messages[-1]["text"]
315
+ messages = st.session_state.messages[:-1].copy()
316
+ response = get_ai_response(prompt, messages)
317
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
318
+ st.session_state.retry_flag = False
319
+ st.rerun()
320
+ else:
321
+ st.session_state.retry_flag = False
322
+
323
+ if prompt := st.chat_input("Enter your message here..."):
324
+ chat_history = st.session_state.messages.copy()
325
+
326
+ shown_message = prompt.replace("\n", " \n")\
327
+ .replace("<", "\\<")\
328
+ .replace(">", "\\>")
329
+
330
+ with st.chat_message("user", avatar=st.session_state.user_avatar):
331
+ st.write(shown_message)
332
+
333
+ st.session_state.messages.append({"role": "USER", "text": prompt})
334
+
335
+ response = get_ai_response(prompt, chat_history)
336
+
337
+ st.session_state.messages.append({"role": "CHATBOT", "text": response})
338
+ st.rerun()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cohere
2
+ streamlit-extras
style.css ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ @import url('https://fonts.googleapis.com/css2?family=Zen+Maru+Gothic&display=swap');
2
+ @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+JP&display=swap');
3
+ @import url('https://fonts.googleapis.com/css2?family=Sawarabi+Mincho&display=swap');
4
+
5
+ body * {
6
+ font-weight: 400;
7
+ font-style: normal;
8
+ }