baxin's picture
move icon as display_icon
80218be
raw
history blame
11.1 kB
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import openai
import os
from dotenv import load_dotenv
import base64 # 画像デコード用に追加
from io import BytesIO # 画像ダウンロード用に追加
from together import Together # Together AI SDKを追加
# config
import config
import utils
# --- RECIPE_BASE_PROMPT のインポート ---
try:
from prompt import RECIPE_BASE_PROMPT
except ImportError:
st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
st.stop()
# --- 環境変数読み込み ---
load_dotenv()
# --- Streamlit ページ設定 ---
st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infographic Prompt Generator")
# --- UI 表示 ---
utils.display_icon("🧠 x 🧑‍🍳")
st.title("Recipe Infographic Prompt Generator")
st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)
# --- APIキーの処理 ---
# Cerebras API Key
api_key_from_env = os.getenv("CEREBRAS_API_KEY")
show_api_key_input = not bool(api_key_from_env)
cerebras_api_key = None
# Together AI API Key Check
together_api_key = os.getenv("TOGETHER_API_KEY")
# --- サイドバーの設定 ---
with st.sidebar:
st.title("Settings")
# Cerebras Key Input
if show_api_key_input:
st.markdown("### :red[Enter your Cerebras API Key below]")
api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
if api_key_input:
cerebras_api_key = api_key_input
else:
cerebras_api_key = api_key_from_env
st.success("✓ Cerebras API Key loaded from environment")
# Together Key Status
if not together_api_key:
st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
else:
st.success("✓ Together API Key loaded from environment") # キー自体は表示しない
# Model selection
model_option = st.selectbox(
"Choose a LLM model:", # ラベルを明確化
options=list(config.MODELS.keys()),
format_func=lambda x: config.MODELS[x]["name"],
key="model_select"
)
# Max tokens slider
max_tokens_range = config.MODELS[model_option]["tokens"]
default_tokens = min(2048, max_tokens_range)
max_tokens = st.slider(
"Max Tokens (LLM):", # ラベルを明確化
min_value=512,
max_value=max_tokens_range,
value=default_tokens,
step=512,
help="Select the maximum number of tokens for the language model's response."
)
use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化
# --- メインアプリケーションロジック ---
# APIキー(Cerebras)が最終的に利用可能かチェック
if not cerebras_api_key:
# (以前のエラー表示ロジックと同じ)
st.markdown("...") # 省略: APIキーがない場合の説明
st.stop()
# APIクライアント初期化 (Cerebras & Together)
try:
# Cerebras Client
if use_optillm:
llm_client = openai.OpenAI(base_url=config.BASE_URL, api_key=cerebras_api_key)
else:
llm_client = Cerebras(api_key=cerebras_api_key)
# Together Client (APIキーがあれば初期化)
image_client = None
if together_api_key:
image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能
except Exception as e:
st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
st.stop()
# --- チャット履歴管理 ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "generated_images" not in st.session_state:
st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
# モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
if st.session_state.selected_model != model_option:
st.session_state.messages = []
st.session_state.generated_images = {} # 画像履歴もクリア
st.session_state.selected_model = model_option
# --- チャットメッセージの表示ループ ---
# このループでは過去のメッセージを表示し、それぞれに画像生成ボタンをつける
for idx, message in enumerate(st.session_state.messages):
avatar = '🤖' if message["role"] == "assistant" else '🦔'
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
if message["role"] == "assistant" and image_client:
# 簡単なチェック: 拒否メッセージではないことを確認
lower_content = message["content"].lower()
is_likely_prompt = "please provide a valid food dish name" not in lower_content
if is_likely_prompt:
button_key = f"gen_img_{idx}"
if st.button("Generate Image ✨", key=button_key):
# 画像生成関数を呼び出し、結果をセッション状態に保存
image_bytes = utils.generate_image_from_prompt(image_client, message["content"])
if image_bytes:
st.session_state.generated_images[idx] = image_bytes
# ボタンが押されたら再実行されるので、画像表示は下のブロックで行う
# 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
if idx in st.session_state.generated_images:
img_bytes = st.session_state.generated_images[idx]
st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
st.download_button(
label="Download Image 💾",
data=img_bytes,
file_name=f"recipe_infographic_{idx+1}.png",
mime="image/png",
key=f"dl_img_{idx}"
)
# --- チャット入力と新しいメッセージの処理 ---
if prompt := st.chat_input("Enter food name/food recipe here..."):
# 入力検証
if utils.contains_injection_keywords(prompt):
st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
elif len(prompt) > 4000:
st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
else:
# --- 検証をパスした場合の処理 ---
st.session_state.messages.append({"role": "user", "content": prompt})
# ユーザーメッセージを表示
with st.chat_message("user", avatar='🦔'):
st.markdown(prompt)
# アシスタントの応答を生成・表示
try:
with st.chat_message("assistant", avatar="🤖"):
response_placeholder = st.empty()
full_response = ""
messages_for_api=[
{"role": "system", "content": RECIPE_BASE_PROMPT},
{"role": "user", "content": prompt}
]
stream_kwargs = {
"model": model_option, "messages": messages_for_api,
"max_tokens": max_tokens, "stream": True,
}
# LLM Client を使用
response_stream = llm_client.chat.completions.create(**stream_kwargs)
for chunk in response_stream:
chunk_content = ""
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
chunk_content = chunk.choices[0].delta.content or ""
if chunk_content:
full_response += chunk_content
response_placeholder.markdown(full_response + "▌")
# 最終応答表示
response_placeholder.markdown(full_response)
# --- ここで新しいアシスタントメッセージに対する処理 ---
# 応答を履歴に追加 *してから* インデックスを取得
st.session_state.messages.append({"role": "assistant", "content": full_response})
new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス
# 出力検証
expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
lower_response = full_response.lower()
is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords)
is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
if not is_valid_format_check and not is_refusal_check:
st.warning("The generated response might not contain expected keywords...", icon="⚠️")
elif is_refusal_check:
st.info("Input was determined to be invalid or unrelated...")
# 画像生成ボタンと表示エリア (新しいメッセージに対して)
# 条件: 画像クライアントがあり、拒否応答でない場合
if image_client and not is_refusal_check:
button_key = f"gen_img_{new_message_idx}"
if st.button("Generate Image ✨", key=button_key):
image_bytes = utils.generate_image_from_prompt(image_client, full_response)
if image_bytes:
st.session_state.generated_images[new_message_idx] = image_bytes
# 再実行ループで画像表示
# 対応する画像データがあれば表示
if new_message_idx in st.session_state.generated_images:
img_bytes = st.session_state.generated_images[new_message_idx]
st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
st.download_button(
label="Download Image 💾",
data=img_bytes,
file_name=f"recipe_infographic_{new_message_idx+1}.png",
mime="image/png",
key=f"dl_img_{new_message_idx}"
)
except Exception as e:
st.error(f"Error generating response: {str(e)}", icon="🚨")