Spaces:
Running
Running
import streamlit as st | |
from cerebras.cloud.sdk import Cerebras | |
import openai | |
import os | |
from dotenv import load_dotenv | |
import base64 # 画像デコード用に追加 | |
from io import BytesIO # 画像ダウンロード用に追加 | |
from together import Together # Together AI SDKを追加 | |
# config | |
import config | |
import utils | |
# --- RECIPE_BASE_PROMPT のインポート --- | |
try: | |
from prompt import RECIPE_BASE_PROMPT | |
except ImportError: | |
st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.") | |
st.stop() | |
# --- 環境変数読み込み --- | |
load_dotenv() | |
# --- Streamlit ページ設定 --- | |
st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infographic Prompt Generator") | |
# --- UI 表示 --- | |
utils.display_icon("🧠 x 🧑🍳") | |
st.title("Recipe Infographic Prompt Generator") | |
st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False) | |
# --- APIキーの処理 --- | |
# Cerebras API Key | |
api_key_from_env = os.getenv("CEREBRAS_API_KEY") | |
show_api_key_input = not bool(api_key_from_env) | |
cerebras_api_key = None | |
# Together AI API Key Check | |
together_api_key = os.getenv("TOGETHER_API_KEY") | |
# --- サイドバーの設定 --- | |
with st.sidebar: | |
st.title("Settings") | |
# Cerebras Key Input | |
if show_api_key_input: | |
st.markdown("### :red[Enter your Cerebras API Key below]") | |
api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field") | |
if api_key_input: | |
cerebras_api_key = api_key_input | |
else: | |
cerebras_api_key = api_key_from_env | |
st.success("✓ Cerebras API Key loaded from environment") | |
# Together Key Status | |
if not together_api_key: | |
st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️") | |
else: | |
st.success("✓ Together API Key loaded from environment") # キー自体は表示しない | |
# Model selection | |
model_option = st.selectbox( | |
"Choose a LLM model:", # ラベルを明確化 | |
options=list(config.MODELS.keys()), | |
format_func=lambda x: config.MODELS[x]["name"], | |
key="model_select" | |
) | |
# Max tokens slider | |
max_tokens_range = config.MODELS[model_option]["tokens"] | |
default_tokens = min(2048, max_tokens_range) | |
max_tokens = st.slider( | |
"Max Tokens (LLM):", # ラベルを明確化 | |
min_value=512, | |
max_value=max_tokens_range, | |
value=default_tokens, | |
step=512, | |
help="Select the maximum number of tokens for the language model's response." | |
) | |
use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化 | |
# --- メインアプリケーションロジック --- | |
# APIキー(Cerebras)が最終的に利用可能かチェック | |
if not cerebras_api_key: | |
# (以前のエラー表示ロジックと同じ) | |
st.markdown("...") # 省略: APIキーがない場合の説明 | |
st.stop() | |
# APIクライアント初期化 (Cerebras & Together) | |
try: | |
# Cerebras Client | |
if use_optillm: | |
llm_client = openai.OpenAI(base_url=config.BASE_URL, api_key=cerebras_api_key) | |
else: | |
llm_client = Cerebras(api_key=cerebras_api_key) | |
# Together Client (APIキーがあれば初期化) | |
image_client = None | |
if together_api_key: | |
image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能 | |
except Exception as e: | |
st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨") | |
st.stop() | |
# --- チャット履歴管理 --- | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "generated_images" not in st.session_state: | |
st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes} | |
if "selected_model" not in st.session_state: | |
st.session_state.selected_model = None | |
# モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討) | |
if st.session_state.selected_model != model_option: | |
st.session_state.messages = [] | |
st.session_state.generated_images = {} # 画像履歴もクリア | |
st.session_state.selected_model = model_option | |
# --- チャットメッセージの表示ループ --- | |
# このループでは過去のメッセージを表示し、それぞれに画像生成ボタンをつける | |
for idx, message in enumerate(st.session_state.messages): | |
avatar = '🤖' if message["role"] == "assistant" else '🦔' | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
# アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合 | |
if message["role"] == "assistant" and image_client: | |
# 簡単なチェック: 拒否メッセージではないことを確認 | |
lower_content = message["content"].lower() | |
is_likely_prompt = "please provide a valid food dish name" not in lower_content | |
if is_likely_prompt: | |
button_key = f"gen_img_{idx}" | |
if st.button("Generate Image ✨", key=button_key): | |
# 画像生成関数を呼び出し、結果をセッション状態に保存 | |
image_bytes = utils.generate_image_from_prompt(image_client, message["content"]) | |
if image_bytes: | |
st.session_state.generated_images[idx] = image_bytes | |
# ボタンが押されたら再実行されるので、画像表示は下のブロックで行う | |
# 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示 | |
if idx in st.session_state.generated_images: | |
img_bytes = st.session_state.generated_images[idx] | |
st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}") | |
st.download_button( | |
label="Download Image 💾", | |
data=img_bytes, | |
file_name=f"recipe_infographic_{idx+1}.png", | |
mime="image/png", | |
key=f"dl_img_{idx}" | |
) | |
# --- チャット入力と新しいメッセージの処理 --- | |
if prompt := st.chat_input("Enter food name/food recipe here..."): | |
# 入力検証 | |
if utils.contains_injection_keywords(prompt): | |
st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨") | |
elif len(prompt) > 4000: | |
st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨") | |
else: | |
# --- 検証をパスした場合の処理 --- | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# ユーザーメッセージを表示 | |
with st.chat_message("user", avatar='🦔'): | |
st.markdown(prompt) | |
# アシスタントの応答を生成・表示 | |
try: | |
with st.chat_message("assistant", avatar="🤖"): | |
response_placeholder = st.empty() | |
full_response = "" | |
messages_for_api=[ | |
{"role": "system", "content": RECIPE_BASE_PROMPT}, | |
{"role": "user", "content": prompt} | |
] | |
stream_kwargs = { | |
"model": model_option, "messages": messages_for_api, | |
"max_tokens": max_tokens, "stream": True, | |
} | |
# LLM Client を使用 | |
response_stream = llm_client.chat.completions.create(**stream_kwargs) | |
for chunk in response_stream: | |
chunk_content = "" | |
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'): | |
chunk_content = chunk.choices[0].delta.content or "" | |
if chunk_content: | |
full_response += chunk_content | |
response_placeholder.markdown(full_response + "▌") | |
# 最終応答表示 | |
response_placeholder.markdown(full_response) | |
# --- ここで新しいアシスタントメッセージに対する処理 --- | |
# 応答を履歴に追加 *してから* インデックスを取得 | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス | |
# 出力検証 | |
expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"] | |
lower_response = full_response.lower() | |
is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords) | |
is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response | |
if not is_valid_format_check and not is_refusal_check: | |
st.warning("The generated response might not contain expected keywords...", icon="⚠️") | |
elif is_refusal_check: | |
st.info("Input was determined to be invalid or unrelated...") | |
# 画像生成ボタンと表示エリア (新しいメッセージに対して) | |
# 条件: 画像クライアントがあり、拒否応答でない場合 | |
if image_client and not is_refusal_check: | |
button_key = f"gen_img_{new_message_idx}" | |
if st.button("Generate Image ✨", key=button_key): | |
image_bytes = utils.generate_image_from_prompt(image_client, full_response) | |
if image_bytes: | |
st.session_state.generated_images[new_message_idx] = image_bytes | |
# 再実行ループで画像表示 | |
# 対応する画像データがあれば表示 | |
if new_message_idx in st.session_state.generated_images: | |
img_bytes = st.session_state.generated_images[new_message_idx] | |
st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}") | |
st.download_button( | |
label="Download Image 💾", | |
data=img_bytes, | |
file_name=f"recipe_infographic_{new_message_idx+1}.png", | |
mime="image/png", | |
key=f"dl_img_{new_message_idx}" | |
) | |
except Exception as e: | |
st.error(f"Error generating response: {str(e)}", icon="🚨") | |