Update app.py
Browse files
app.py
CHANGED
@@ -103,35 +103,35 @@ def display_optimization_results(result_data):
|
|
103 |
success = result["succeed"]
|
104 |
prompt = result["prompt"]
|
105 |
|
106 |
-
with st.expander(f"
|
107 |
-
st.markdown("
|
108 |
st.code(prompt, language="text")
|
109 |
st.markdown("<br>", unsafe_allow_html=True)
|
110 |
|
111 |
col1, col2 = st.columns(2)
|
112 |
with col1:
|
113 |
-
st.markdown(f"
|
114 |
with col2:
|
115 |
-
st.markdown(f"
|
116 |
|
117 |
-
st.markdown("
|
118 |
for idx, answer in enumerate(result["answers"]):
|
119 |
-
st.markdown(f"
|
120 |
st.text(answer["question"])
|
121 |
-
st.markdown("
|
122 |
st.text(answer["answer"])
|
123 |
st.markdown("---")
|
124 |
|
125 |
-
#
|
126 |
success_count = sum(1 for r in result_data if r["succeed"])
|
127 |
total_rounds = len(result_data)
|
128 |
|
129 |
-
st.markdown("###
|
130 |
col1, col2 = st.columns(2)
|
131 |
with col1:
|
132 |
-
st.metric("
|
133 |
with col2:
|
134 |
-
st.metric("
|
135 |
|
136 |
|
137 |
def main():
|
@@ -144,69 +144,68 @@ def main():
|
|
144 |
"""
|
145 |
<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
|
146 |
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
|
147 |
-
<h1 style="margin: 0;">SPO |
|
148 |
</div>
|
149 |
<div style="display: flex; gap: 20px; align-items: center">
|
150 |
<a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
|
151 |
-
<img src="https://img.shields.io/badge
|
152 |
</a>
|
153 |
<a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
|
154 |
-
<img src="https://img.shields.io/badge/GitHub
|
155 |
</a>
|
156 |
-
<span style="color: #666;"
|
157 |
</div>
|
158 |
</div>
|
159 |
""",
|
160 |
unsafe_allow_html=True
|
161 |
)
|
162 |
|
163 |
-
#
|
164 |
with st.sidebar:
|
165 |
-
st.header("
|
166 |
|
167 |
-
#
|
168 |
settings_path = Path("metagpt/ext/spo/settings")
|
169 |
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
170 |
-
|
171 |
-
template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
|
172 |
|
173 |
existing_templates = get_all_templates()
|
174 |
|
175 |
-
if template_mode == "
|
176 |
-
template_name = st.selectbox("
|
177 |
is_new_template = False
|
178 |
else:
|
179 |
-
template_name = st.text_input("
|
180 |
is_new_template = True
|
181 |
|
182 |
-
# LLM
|
183 |
-
st.subheader("LLM
|
184 |
|
185 |
-
base_url = st.text_input("
|
186 |
-
api_key = st.text_input("API
|
187 |
|
188 |
opt_model = st.selectbox(
|
189 |
-
"
|
190 |
)
|
191 |
-
opt_temp = st.slider("
|
192 |
|
193 |
eval_model = st.selectbox(
|
194 |
-
"
|
195 |
)
|
196 |
-
eval_temp = st.slider("
|
197 |
|
198 |
exec_model = st.selectbox(
|
199 |
-
"
|
200 |
)
|
201 |
-
exec_temp = st.slider("
|
202 |
|
203 |
-
#
|
204 |
-
st.subheader("
|
205 |
-
initial_round = st.number_input("
|
206 |
-
max_rounds = st.number_input("
|
207 |
|
208 |
-
#
|
209 |
-
st.header("
|
210 |
|
211 |
if template_name:
|
212 |
template_real_name = get_template_path(template_name, is_new_template)
|
@@ -220,30 +219,30 @@ def main():
|
|
220 |
st.session_state.current_template = template_name
|
221 |
st.session_state.qas = template_data.get("qa", [])
|
222 |
|
223 |
-
#
|
224 |
-
prompt = st.text_area("
|
225 |
-
requirements = st.text_area("
|
226 |
|
227 |
-
#
|
228 |
-
st.subheader("
|
229 |
|
230 |
-
#
|
231 |
-
if st.button("
|
232 |
st.session_state.qas.append({"question": "", "answer": ""})
|
233 |
|
234 |
-
#
|
235 |
new_qas = []
|
236 |
for i in range(len(st.session_state.qas)):
|
237 |
-
st.markdown(f"
|
238 |
col1, col2, col3 = st.columns([45, 45, 10])
|
239 |
|
240 |
with col1:
|
241 |
question = st.text_area(
|
242 |
-
f"
|
243 |
)
|
244 |
with col2:
|
245 |
answer = st.text_area(
|
246 |
-
f"
|
247 |
)
|
248 |
with col3:
|
249 |
if st.button("🗑️", key=f"delete_{i}"):
|
@@ -252,20 +251,20 @@ def main():
|
|
252 |
|
253 |
new_qas.append({"question": question, "answer": answer})
|
254 |
|
255 |
-
#
|
256 |
-
if st.button("
|
257 |
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
|
258 |
|
259 |
save_yaml_template(template_path, new_template_data, is_new_template)
|
260 |
|
261 |
st.session_state.qas = new_qas
|
262 |
-
st.success(f"
|
263 |
|
264 |
-
st.subheader("
|
265 |
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
|
266 |
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
267 |
|
268 |
-
st.subheader("
|
269 |
log_container = st.empty()
|
270 |
|
271 |
class StreamlitSink:
|
@@ -289,8 +288,8 @@ def main():
|
|
289 |
)
|
290 |
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
|
291 |
|
292 |
-
#
|
293 |
-
if st.button("
|
294 |
try:
|
295 |
# Initialize LLM
|
296 |
SPO_LLM.initialize(
|
@@ -315,37 +314,35 @@ def main():
|
|
315 |
with st.spinner("Optimizing prompts..."):
|
316 |
optimizer.optimize()
|
317 |
|
318 |
-
st.success("
|
319 |
-
|
320 |
-
st.header("Optimization Results")
|
321 |
-
|
322 |
prompt_path = optimizer.root_path / "prompts"
|
323 |
result_data = optimizer.data_utils.load_results(prompt_path)
|
324 |
|
325 |
st.session_state.optimization_results = result_data
|
326 |
|
327 |
except Exception as e:
|
328 |
-
st.error(f"
|
329 |
-
_logger.error(f"
|
330 |
|
331 |
if st.session_state.optimization_results:
|
332 |
-
st.header("
|
333 |
display_optimization_results(st.session_state.optimization_results)
|
334 |
|
335 |
st.markdown("---")
|
336 |
-
st.subheader("
|
337 |
col1, col2 = st.columns(2)
|
338 |
|
339 |
with col1:
|
340 |
-
test_prompt = st.text_area("
|
341 |
|
342 |
with col2:
|
343 |
-
test_question = st.text_area("
|
344 |
|
345 |
-
if st.button("
|
346 |
if test_prompt and test_question:
|
347 |
try:
|
348 |
-
with st.spinner("
|
349 |
SPO_LLM.initialize(
|
350 |
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
|
351 |
"api_key": api_key},
|
@@ -368,13 +365,13 @@ def main():
|
|
368 |
finally:
|
369 |
loop.close()
|
370 |
|
371 |
-
st.subheader("
|
372 |
st.markdown(response)
|
373 |
|
374 |
except Exception as e:
|
375 |
-
st.error(f"
|
376 |
else:
|
377 |
-
st.warning("
|
378 |
|
379 |
|
380 |
if __name__ == "__main__":
|
|
|
103 |
success = result["succeed"]
|
104 |
prompt = result["prompt"]
|
105 |
|
106 |
+
with st.expander(f"轮次 {round_num} {':white_check_mark:' if success else ':x:'}"):
|
107 |
+
st.markdown("**提示词:**")
|
108 |
st.code(prompt, language="text")
|
109 |
st.markdown("<br>", unsafe_allow_html=True)
|
110 |
|
111 |
col1, col2 = st.columns(2)
|
112 |
with col1:
|
113 |
+
st.markdown(f"**状态:** {'成功 ✅ ' if success else '失败 ❌ '}")
|
114 |
with col2:
|
115 |
+
st.markdown(f"**令牌数:** {result['tokens']}")
|
116 |
|
117 |
+
st.markdown("**回答:**")
|
118 |
for idx, answer in enumerate(result["answers"]):
|
119 |
+
st.markdown(f"**问题 {idx + 1}:**")
|
120 |
st.text(answer["question"])
|
121 |
+
st.markdown("**答案:**")
|
122 |
st.text(answer["answer"])
|
123 |
st.markdown("---")
|
124 |
|
125 |
+
# 总结
|
126 |
success_count = sum(1 for r in result_data if r["succeed"])
|
127 |
total_rounds = len(result_data)
|
128 |
|
129 |
+
st.markdown("### 总结")
|
130 |
col1, col2 = st.columns(2)
|
131 |
with col1:
|
132 |
+
st.metric("总轮次", total_rounds)
|
133 |
with col2:
|
134 |
+
st.metric("成功轮次", success_count)
|
135 |
|
136 |
|
137 |
def main():
|
|
|
144 |
"""
|
145 |
<div style="background-color: #f0f2f6; padding: 20px; border-radius: 10px; margin-bottom: 25px">
|
146 |
<div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px">
|
147 |
+
<h1 style="margin: 0;">SPO | 自监督提示词优化 🤖</h1>
|
148 |
</div>
|
149 |
<div style="display: flex; gap: 20px; align-items: center">
|
150 |
<a href="https://arxiv.org/pdf/2502.06855" target="_blank" style="text-decoration: none;">
|
151 |
+
<img src="https://img.shields.io/badge/论文-PDF-red.svg" alt="论文">
|
152 |
</a>
|
153 |
<a href="https://github.com/geekan/MetaGPT/blob/main/examples/spo/README.md" target="_blank" style="text-decoration: none;">
|
154 |
+
<img src="https://img.shields.io/badge/GitHub-仓库-blue.svg" alt="GitHub">
|
155 |
</a>
|
156 |
+
<span style="color: #666;">一个自监督提示词优化框架</span>
|
157 |
</div>
|
158 |
</div>
|
159 |
""",
|
160 |
unsafe_allow_html=True
|
161 |
)
|
162 |
|
163 |
+
# 侧边栏配置
|
164 |
with st.sidebar:
|
165 |
+
st.header("配置")
|
166 |
|
167 |
+
# 模板选择/创建
|
168 |
settings_path = Path("metagpt/ext/spo/settings")
|
169 |
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
170 |
+
template_mode = st.radio("模板模式", ["使用现有", "创建新模板"])
|
|
|
171 |
|
172 |
existing_templates = get_all_templates()
|
173 |
|
174 |
+
if template_mode == "使用现有":
|
175 |
+
template_name = st.selectbox("选择模板", existing_templates)
|
176 |
is_new_template = False
|
177 |
else:
|
178 |
+
template_name = st.text_input("新模板名称")
|
179 |
is_new_template = True
|
180 |
|
181 |
+
# LLM 设置
|
182 |
+
st.subheader("LLM 设置")
|
183 |
|
184 |
+
base_url = st.text_input("基础 URL", value="https://api.example.com")
|
185 |
+
api_key = st.text_input("API 密钥", type="password")
|
186 |
|
187 |
opt_model = st.selectbox(
|
188 |
+
"优化模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
189 |
)
|
190 |
+
opt_temp = st.slider("优化温度", 0.0, 1.0, 0.7)
|
191 |
|
192 |
eval_model = st.selectbox(
|
193 |
+
"评估模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
194 |
)
|
195 |
+
eval_temp = st.slider("评估温度", 0.0, 1.0, 0.3)
|
196 |
|
197 |
exec_model = st.selectbox(
|
198 |
+
"执行模型", ["gpt-4o-mini", "gpt-4o", "deepseek-chat", "claude-3-5-sonnet-20240620"], index=0
|
199 |
)
|
200 |
+
exec_temp = st.slider("执行温度", 0.0, 1.0, 0.0)
|
201 |
|
202 |
+
# 优化器设置
|
203 |
+
st.subheader("优化器设置")
|
204 |
+
initial_round = st.number_input("初始轮次", 1, 100, 1)
|
205 |
+
max_rounds = st.number_input("最大轮次", 1, 100, 10)
|
206 |
|
207 |
+
# 主要内容区域
|
208 |
+
st.header("模板配置")
|
209 |
|
210 |
if template_name:
|
211 |
template_real_name = get_template_path(template_name, is_new_template)
|
|
|
219 |
st.session_state.current_template = template_name
|
220 |
st.session_state.qas = template_data.get("qa", [])
|
221 |
|
222 |
+
# 编辑模板部分
|
223 |
+
prompt = st.text_area("提示词", template_data.get("prompt", ""), height=100)
|
224 |
+
requirements = st.text_area("要求", template_data.get("requirements", ""), height=100)
|
225 |
|
226 |
+
# 问答部分
|
227 |
+
st.subheader("问答示例")
|
228 |
|
229 |
+
# 添加新问答按钮
|
230 |
+
if st.button("添加新问答"):
|
231 |
st.session_state.qas.append({"question": "", "answer": ""})
|
232 |
|
233 |
+
# 编辑问答
|
234 |
new_qas = []
|
235 |
for i in range(len(st.session_state.qas)):
|
236 |
+
st.markdown(f"**问答 #{i + 1}**")
|
237 |
col1, col2, col3 = st.columns([45, 45, 10])
|
238 |
|
239 |
with col1:
|
240 |
question = st.text_area(
|
241 |
+
f"问题 {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
|
242 |
)
|
243 |
with col2:
|
244 |
answer = st.text_area(
|
245 |
+
f"答案 {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
|
246 |
)
|
247 |
with col3:
|
248 |
if st.button("🗑️", key=f"delete_{i}"):
|
|
|
251 |
|
252 |
new_qas.append({"question": question, "answer": answer})
|
253 |
|
254 |
+
# 保存模板按钮
|
255 |
+
if st.button("保存模板"):
|
256 |
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
|
257 |
|
258 |
save_yaml_template(template_path, new_template_data, is_new_template)
|
259 |
|
260 |
st.session_state.qas = new_qas
|
261 |
+
st.success(f"模板已保存到 {template_path}")
|
262 |
|
263 |
+
st.subheader("当前模板预览")
|
264 |
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
|
265 |
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
266 |
|
267 |
+
st.subheader("优化日志")
|
268 |
log_container = st.empty()
|
269 |
|
270 |
class StreamlitSink:
|
|
|
288 |
)
|
289 |
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
|
290 |
|
291 |
+
# 开始优化按钮
|
292 |
+
if st.button("开始优化"):
|
293 |
try:
|
294 |
# Initialize LLM
|
295 |
SPO_LLM.initialize(
|
|
|
314 |
with st.spinner("Optimizing prompts..."):
|
315 |
optimizer.optimize()
|
316 |
|
317 |
+
st.success("优化完成!")
|
318 |
+
st.header("优化结果")
|
|
|
|
|
319 |
prompt_path = optimizer.root_path / "prompts"
|
320 |
result_data = optimizer.data_utils.load_results(prompt_path)
|
321 |
|
322 |
st.session_state.optimization_results = result_data
|
323 |
|
324 |
except Exception as e:
|
325 |
+
st.error(f"发生错误:{str(e)}")
|
326 |
+
_logger.error(f"优化过程中出错:{str(e)}")
|
327 |
|
328 |
if st.session_state.optimization_results:
|
329 |
+
st.header("优化结果")
|
330 |
display_optimization_results(st.session_state.optimization_results)
|
331 |
|
332 |
st.markdown("---")
|
333 |
+
st.subheader("测试优化后的提示词")
|
334 |
col1, col2 = st.columns(2)
|
335 |
|
336 |
with col1:
|
337 |
+
test_prompt = st.text_area("优化后的提示词", value="", height=200, key="test_prompt")
|
338 |
|
339 |
with col2:
|
340 |
+
test_question = st.text_area("你的问题", value="", height=200, key="test_question")
|
341 |
|
342 |
+
if st.button("测试提示词"):
|
343 |
if test_prompt and test_question:
|
344 |
try:
|
345 |
+
with st.spinner("正在生成回答..."):
|
346 |
SPO_LLM.initialize(
|
347 |
optimize_kwargs={"model": opt_model, "temperature": opt_temp, "base_url": base_url,
|
348 |
"api_key": api_key},
|
|
|
365 |
finally:
|
366 |
loop.close()
|
367 |
|
368 |
+
st.subheader("回答:")
|
369 |
st.markdown(response)
|
370 |
|
371 |
except Exception as e:
|
372 |
+
st.error(f"生成回答时出错:{str(e)}")
|
373 |
else:
|
374 |
+
st.warning("请输入提示词和问题。")
|
375 |
|
376 |
|
377 |
if __name__ == "__main__":
|