tang-x commited on
Commit
0c752de
·
verified ·
1 Parent(s): e54ab22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -72
app.py CHANGED
@@ -103,35 +103,35 @@ def display_optimization_results(result_data):
103
  success = result["succeed"]
104
  prompt = result["prompt"]
105
 
106
- with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
107
- st.markdown("**Prompt:**")
108
  st.code(prompt, language="text")
109
  st.markdown("<br>", unsafe_allow_html=True)
110
 
111
  col1, col2 = st.columns(2)
112
  with col1:
113
- st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
114
  with col2:
115
- st.markdown(f"**Tokens:** {result['tokens']}")
116
 
117
- st.markdown("**Answers:**")
118
  for idx, answer in enumerate(result["answers"]):
119
- st.markdown(f"**Question {idx + 1}:**")
120
  st.text(answer["question"])
121
- st.markdown("**Answer:**")
122
  st.text(answer["answer"])
123
  st.markdown("---")
124
 
125
- # Summary
126
  success_count = sum(1 for r in result_data if r["succeed"])
127
  total_rounds = len(result_data)
128
 
129
- st.markdown("### Summary")
130
  col1, col2 = st.columns(2)
131
  with col1:
132
- st.metric("Total Rounds", total_rounds)
133
  with col2:
134
- st.metric("Successful Rounds", success_count)
135
 
136
 
137
  def main():
@@ -144,69 +144,68 @@ def main():
144
  """
145
  <div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
146
  <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
147
- <h1 style="margin: 0;">SPO | Self-Supervised Prompt Optimization 🤖</h1>
148
  </div>
149
  <div style="display: flex; gap: 20px; align-items: center">
150
  <a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
151
- <img src="https://img.shields.io/badge/Paper-PDF-red.svg" alt="Paper">
152
  </a>
153
  <a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
154
- <img src="https://img.shields.io/badge/GitHub-Repository-blue.svg" alt="GitHub">
155
  </a>
156
- <span style="color: #666;">A framework for self-supervised prompt optimization</span>
157
  </div>
158
  </div>
159
  """,
160
  unsafe_allow_html=True
161
  )
162
 
163
- # Sidebar for configurations
164
  with st.sidebar:
165
- st.header("Configuration")
166
 
167
- # Template Selection/Creation
168
  settings_path = Path("metagpt/ext/spo/settings")
169
  existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
170
-
171
- template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
172
 
173
  existing_templates = get_all_templates()
174
 
175
- if template_mode == "Use Existing":
176
- template_name = st.selectbox("Select Template", existing_templates)
177
  is_new_template = False
178
  else:
179
- template_name = st.text_input("New Template Name")
180
  is_new_template = True
181
 
182
- # LLM Settings
183
- st.subheader("LLM Settings")
184
 
185
- base_url = st.text_input("Base URL", value="https://api.example.com")
186
- api_key = st.text_input("API Key", type="password")
187
 
188
  opt_model = st.selectbox(
189
- "Optimization Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
190
  )
191
- opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7)
192
 
193
  eval_model = st.selectbox(
194
- "Evaluation Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
195
  )
196
- eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3)
197
 
198
  exec_model = st.selectbox(
199
- "Execution Model", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
200
  )
201
- exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0)
202
 
203
- # Optimizer Settings
204
- st.subheader("Optimizer Settings")
205
- initial_round = st.number_input("Initial Round", 1, 100, 1)
206
- max_rounds = st.number_input("Maximum Rounds", 1, 100, 10)
207
 
208
- # Main content area
209
- st.header("Template Configuration")
210
 
211
  if template_name:
212
  template_real_name = get_template_path(template_name, is_new_template)
@@ -220,30 +219,30 @@ def main():
220
  st.session_state.current_template = template_name
221
  st.session_state.qas = template_data.get("qa", [])
222
 
223
- # Edit template sections
224
- prompt = st.text_area("Prompt", template_data.get("prompt", ""), height=100)
225
- requirements = st.text_area("Requirements", template_data.get("requirements", ""), height=100)
226
 
227
- # qa section
228
- st.subheader("Q&A Examples")
229
 
230
- # Add new qa button
231
- if st.button("Add New Q&A"):
232
  st.session_state.qas.append({"question": "", "answer": ""})
233
 
234
- # Edit qas
235
  new_qas = []
236
  for i in range(len(st.session_state.qas)):
237
- st.markdown(f"**QA #{i + 1}**")
238
  col1, col2, col3 = st.columns([45, 45, 10])
239
 
240
  with col1:
241
  question = st.text_area(
242
- f"Question {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
243
  )
244
  with col2:
245
  answer = st.text_area(
246
- f"Answer {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
247
  )
248
  with col3:
249
  if st.button("🗑️", key=f"delete_{i}"):
@@ -252,20 +251,20 @@ def main():
252
 
253
  new_qas.append({"question": question, "answer": answer})
254
 
255
- # Save template button
256
- if st.button("Save Template"):
257
  new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
258
 
259
  save_yaml_template(template_path, new_template_data, is_new_template)
260
 
261
  st.session_state.qas = new_qas
262
- st.success(f"Template saved to {template_path}")
263
 
264
- st.subheader("Current Template Preview")
265
  preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
266
  st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
267
 
268
- st.subheader("Optimization Logs")
269
  log_container = st.empty()
270
 
271
  class StreamlitSink:
@@ -289,8 +288,8 @@ def main():
289
  )
290
  _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
291
 
292
- # Start optimization button
293
- if st.button("Start Optimization"):
294
  try:
295
  # Initialize LLM
296
  SPO_LLM.initialize(
@@ -315,37 +314,35 @@ def main():
315
  with st.spinner("Optimizing prompts..."):
316
  optimizer.optimize()
317
 
318
- st.success("Optimization completed!")
319
-
320
- st.header("Optimization Results")
321
-
322
  prompt_path = optimizer.root_path / "prompts"
323
  result_data = optimizer.data_utils.load_results(prompt_path)
324
 
325
  st.session_state.optimization_results = result_data
326
 
327
  except Exception as e:
328
- st.error(f"An error occurred: {str(e)}")
329
- _logger.error(f"Error during optimization: {str(e)}")
330
 
331
  if st.session_state.optimization_results:
332
- st.header("Optimization Results")
333
  display_optimization_results(st.session_state.optimization_results)
334
 
335
  st.markdown("---")
336
- st.subheader("Test Optimized Prompt")
337
  col1, col2 = st.columns(2)
338
 
339
  with col1:
340
- test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
341
 
342
  with col2:
343
- test_question = st.text_area("Your Question", value="", height=200, key="test_question")
344
 
345
- if st.button("Test Prompt"):
346
  if test_prompt and test_question:
347
  try:
348
- with st.spinner("Generating response..."):
349
  SPO_LLM.initialize(
350
  optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
351
  "api_key": api_key},
@@ -368,13 +365,13 @@ def main():
368
  finally:
369
  loop.close()
370
 
371
- st.subheader("Response:")
372
  st.markdown(response)
373
 
374
  except Exception as e:
375
- st.error(f"Error generating response: {str(e)}")
376
  else:
377
- st.warning("Please enter both prompt and question.")
378
 
379
 
380
  if __name__ == "__main__":
 
103
  success = result["succeed"]
104
  prompt = result["prompt"]
105
 
106
+ with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"):
107
+ st.markdown("**提示词:**")
108
  st.code(prompt, language="text")
109
  st.markdown("<br>", unsafe_allow_html=True)
110
 
111
  col1, col2 = st.columns(2)
112
  with col1:
113
+ st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}")
114
  with col2:
115
+ st.markdown(f"**令牌数:** {result['tokens']}")
116
 
117
+ st.markdown("**回答:**")
118
  for idx, answer in enumerate(result["answers"]):
119
+ st.markdown(f"**问题 {idx + 1}:**")
120
  st.text(answer["question"])
121
+ st.markdown("**答案:**")
122
  st.text(answer["answer"])
123
  st.markdown("---")
124
 
125
+ # 总结
126
  success_count = sum(1 for r in result_data if r["succeed"])
127
  total_rounds = len(result_data)
128
 
129
+ st.markdown("### 总结")
130
  col1, col2 = st.columns(2)
131
  with col1:
132
+ st.metric("总轮次", total_rounds)
133
  with col2:
134
+ st.metric("成功轮次", success_count)
135
 
136
 
137
  def main():
 
144
  """
145
  <div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
146
  <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
147
+ <h1 style="margin: 0;">SPO | 自监督提示词优化 🤖</h1>
148
  </div>
149
  <div style="display: flex; gap: 20px; align-items: center">
150
  <a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
151
+ <img src="https://img.shields.io/badge/论文-PDF-red.svg" alt="论文">
152
  </a>
153
  <a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
154
+ <img src="https://img.shields.io/badge/GitHub-仓库-blue.svg" alt="GitHub">
155
  </a>
156
+ <span style="color: #666;">一个自监督提示词优化框架</span>
157
  </div>
158
  </div>
159
  """,
160
  unsafe_allow_html=True
161
  )
162
 
163
+ # 侧边栏配置
164
  with st.sidebar:
165
+ st.header("配置")
166
 
167
+ # 模板选择/创建
168
  settings_path = Path("metagpt/ext/spo/settings")
169
  existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
170
+ template_mode = st.radio("模板模式", ["使用现有", "创建新模板"])
 
171
 
172
  existing_templates = get_all_templates()
173
 
174
+ if template_mode == "使用现有":
175
+ template_name = st.selectbox("选择模板", existing_templates)
176
  is_new_template = False
177
  else:
178
+ template_name = st.text_input("新模板名称")
179
  is_new_template = True
180
 
181
+ # LLM 设置
182
+ st.subheader("LLM 设置")
183
 
184
+ base_url = st.text_input("基础 URL", value="https://api.example.com")
185
+ api_key = st.text_input("API 密钥", type="password")
186
 
187
  opt_model = st.selectbox(
188
+ "优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
189
  )
190
+ opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7)
191
 
192
  eval_model = st.selectbox(
193
+ "评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
194
  )
195
+ eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3)
196
 
197
  exec_model = st.selectbox(
198
+ "执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
199
  )
200
+ exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0)
201
 
202
+ # 优化器设置
203
+ st.subheader("优化器设置")
204
+ initial_round = st.number_input("初始轮次", 1, 100, 1)
205
+ max_rounds = st.number_input("最大轮次", 1, 100, 10)
206
 
207
+ # 主要内容区域
208
+ st.header("模板配置")
209
 
210
  if template_name:
211
  template_real_name = get_template_path(template_name, is_new_template)
 
219
  st.session_state.current_template = template_name
220
  st.session_state.qas = template_data.get("qa", [])
221
 
222
+ # 编辑模板部分
223
+ prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100)
224
+ requirements = st.text_area("要求", template_data.get("requirements", ""), height=100)
225
 
226
+ # 问答部分
227
+ st.subheader("问答示例")
228
 
229
+ # 添加新问答按钮
230
+ if st.button("添加新问答"):
231
  st.session_state.qas.append({"question": "", "answer": ""})
232
 
233
+ # 编辑问答
234
  new_qas = []
235
  for i in range(len(st.session_state.qas)):
236
+ st.markdown(f"**问答 #{i + 1}**")
237
  col1, col2, col3 = st.columns([45, 45, 10])
238
 
239
  with col1:
240
  question = st.text_area(
241
+ f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
242
  )
243
  with col2:
244
  answer = st.text_area(
245
+ f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
246
  )
247
  with col3:
248
  if st.button("🗑️", key=f"delete_{i}"):
 
251
 
252
  new_qas.append({"question": question, "answer": answer})
253
 
254
+ # 保存模板按钮
255
+ if st.button("保存模板"):
256
  new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
257
 
258
  save_yaml_template(template_path, new_template_data, is_new_template)
259
 
260
  st.session_state.qas = new_qas
261
+ st.success(f"模板已保存到 {template_path}")
262
 
263
+ st.subheader("当前模板预览")
264
  preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
265
  st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
266
 
267
+ st.subheader("优化日志")
268
  log_container = st.empty()
269
 
270
  class StreamlitSink:
 
288
  )
289
  _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
290
 
291
+ # 开始优化按钮
292
+ if st.button("开始优化"):
293
  try:
294
  # Initialize LLM
295
  SPO_LLM.initialize(
 
314
  with st.spinner("Optimizing prompts..."):
315
  optimizer.optimize()
316
 
317
+ st.success("优化完成!")
318
+ st.header("优化结果")
 
 
319
  prompt_path = optimizer.root_path / "prompts"
320
  result_data = optimizer.data_utils.load_results(prompt_path)
321
 
322
  st.session_state.optimization_results = result_data
323
 
324
  except Exception as e:
325
+ st.error(f"发生错误:{str(e)}")
326
+ _logger.error(f"优化过程中出错:{str(e)}")
327
 
328
  if st.session_state.optimization_results:
329
+ st.header("优化结果")
330
  display_optimization_results(st.session_state.optimization_results)
331
 
332
  st.markdown("---")
333
+ st.subheader("测试优化后的提示词")
334
  col1, col2 = st.columns(2)
335
 
336
  with col1:
337
+ test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt")
338
 
339
  with col2:
340
+ test_question = st.text_area("你的问题", value="", height=200, key="test_question")
341
 
342
+ if st.button("测试提示词"):
343
  if test_prompt and test_question:
344
  try:
345
+ with st.spinner("正在生成回答..."):
346
  SPO_LLM.initialize(
347
  optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
348
  "api_key": api_key},
 
365
  finally:
366
  loop.close()
367
 
368
+ st.subheader("回答:")
369
  st.markdown(response)
370
 
371
  except Exception as e:
372
+ st.error(f"生成回答时出错:{str(e)}")
373
  else:
374
+ st.warning("请输入提示词和问题。")
375
 
376
 
377
  if __name__ == "__main__":