Spaces:
Running
Running
File size: 11,080 Bytes
9832632 7d8708c f8d71fd 6d7ccb1 f8d71fd 78c577d c391eb7 78c577d f8d71fd 87671ab f8d71fd 6d7ccb1 87671ab f8d71fd 87671ab 9832632 f8d71fd 9523069 7d8708c f8d71fd 80218be afb7956 7d8708c 87671ab 6d7ccb1 87671ab 6d7ccb1 7d8708c 87671ab 7d8708c 87671ab 6d7ccb1 87671ab 6d7ccb1 87671ab 6d7ccb1 87671ab 6d7ccb1 7d8708c 6d7ccb1 7d8708c 6d7ccb1 78c577d 7d8708c 6d7ccb1 78c577d 87671ab 7d8708c 6d7ccb1 7d8708c f8d71fd 7d8708c 6d7ccb1 7d8708c 6d7ccb1 7d8708c 87671ab 6d7ccb1 7d8708c 6d7ccb1 87671ab 6d7ccb1 87671ab 78c577d 87671ab 6d7ccb1 87671ab 6d7ccb1 87671ab 7d8708c 6d7ccb1 7d8708c 6d7ccb1 7d8708c 6d7ccb1 7d8708c 6d7ccb1 7d8708c 6d7ccb1 c391eb7 6d7ccb1 afb7956 6d7ccb1 c391eb7 9ebf937 6d7ccb1 9ebf937 f8d71fd 6d7ccb1 9ebf937 7d8708c 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 f8d71fd 6d7ccb1 c391eb7 6d7ccb1 f8d71fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import openai
import os
from dotenv import load_dotenv
import base64 # 画像デコード用に追加
from io import BytesIO # 画像ダウンロード用に追加
from together import Together # Together AI SDKを追加
# config
import config
import utils
# --- RECIPE_BASE_PROMPT のインポート ---
try:
from prompt import RECIPE_BASE_PROMPT
except ImportError:
st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
st.stop()
# --- 環境変数読み込み ---
load_dotenv()
# --- Streamlit ページ設定 ---
st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infographic Prompt Generator")
# --- UI 表示 ---
utils.display_icon("🧠 x 🧑🍳")
st.title("Recipe Infographic Prompt Generator")
st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)
# --- APIキーの処理 ---
# Cerebras API Key
api_key_from_env = os.getenv("CEREBRAS_API_KEY")
show_api_key_input = not bool(api_key_from_env)
cerebras_api_key = None
# Together AI API Key Check
together_api_key = os.getenv("TOGETHER_API_KEY")
# --- サイドバーの設定 ---
with st.sidebar:
st.title("Settings")
# Cerebras Key Input
if show_api_key_input:
st.markdown("### :red[Enter your Cerebras API Key below]")
api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
if api_key_input:
cerebras_api_key = api_key_input
else:
cerebras_api_key = api_key_from_env
st.success("✓ Cerebras API Key loaded from environment")
# Together Key Status
if not together_api_key:
st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
else:
st.success("✓ Together API Key loaded from environment") # キー自体は表示しない
# Model selection
model_option = st.selectbox(
"Choose a LLM model:", # ラベルを明確化
options=list(config.MODELS.keys()),
format_func=lambda x: config.MODELS[x]["name"],
key="model_select"
)
# Max tokens slider
max_tokens_range = config.MODELS[model_option]["tokens"]
default_tokens = min(2048, max_tokens_range)
max_tokens = st.slider(
"Max Tokens (LLM):", # ラベルを明確化
min_value=512,
max_value=max_tokens_range,
value=default_tokens,
step=512,
help="Select the maximum number of tokens for the language model's response."
)
use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化
# --- メインアプリケーションロジック ---
# APIキー(Cerebras)が最終的に利用可能かチェック
if not cerebras_api_key:
# (以前のエラー表示ロジックと同じ)
st.markdown("...") # 省略: APIキーがない場合の説明
st.stop()
# APIクライアント初期化 (Cerebras & Together)
try:
# Cerebras Client
if use_optillm:
llm_client = openai.OpenAI(base_url=config.BASE_URL, api_key=cerebras_api_key)
else:
llm_client = Cerebras(api_key=cerebras_api_key)
# Together Client (APIキーがあれば初期化)
image_client = None
if together_api_key:
image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能
except Exception as e:
st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
st.stop()
# --- チャット履歴管理 ---
if "messages" not in st.session_state:
st.session_state.messages = []
if "generated_images" not in st.session_state:
st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}
if "selected_model" not in st.session_state:
st.session_state.selected_model = None
# モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
if st.session_state.selected_model != model_option:
st.session_state.messages = []
st.session_state.generated_images = {} # 画像履歴もクリア
st.session_state.selected_model = model_option
# --- チャットメッセージの表示ループ ---
# このループでは過去のメッセージを表示し、それぞれに画像生成ボタンをつける
for idx, message in enumerate(st.session_state.messages):
avatar = '🤖' if message["role"] == "assistant" else '🦔'
with st.chat_message(message["role"], avatar=avatar):
st.markdown(message["content"])
# アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
if message["role"] == "assistant" and image_client:
# 簡単なチェック: 拒否メッセージではないことを確認
lower_content = message["content"].lower()
is_likely_prompt = "please provide a valid food dish name" not in lower_content
if is_likely_prompt:
button_key = f"gen_img_{idx}"
if st.button("Generate Image ✨", key=button_key):
# 画像生成関数を呼び出し、結果をセッション状態に保存
image_bytes = utils.generate_image_from_prompt(image_client, message["content"])
if image_bytes:
st.session_state.generated_images[idx] = image_bytes
# ボタンが押されたら再実行されるので、画像表示は下のブロックで行う
# 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
if idx in st.session_state.generated_images:
img_bytes = st.session_state.generated_images[idx]
st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
st.download_button(
label="Download Image 💾",
data=img_bytes,
file_name=f"recipe_infographic_{idx+1}.png",
mime="image/png",
key=f"dl_img_{idx}"
)
# --- チャット入力と新しいメッセージの処理 ---
if prompt := st.chat_input("Enter food name/food recipe here..."):
# 入力検証
if utils.contains_injection_keywords(prompt):
st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
elif len(prompt) > 4000:
st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
else:
# --- 検証をパスした場合の処理 ---
st.session_state.messages.append({"role": "user", "content": prompt})
# ユーザーメッセージを表示
with st.chat_message("user", avatar='🦔'):
st.markdown(prompt)
# アシスタントの応答を生成・表示
try:
with st.chat_message("assistant", avatar="🤖"):
response_placeholder = st.empty()
full_response = ""
messages_for_api=[
{"role": "system", "content": RECIPE_BASE_PROMPT},
{"role": "user", "content": prompt}
]
stream_kwargs = {
"model": model_option, "messages": messages_for_api,
"max_tokens": max_tokens, "stream": True,
}
# LLM Client を使用
response_stream = llm_client.chat.completions.create(**stream_kwargs)
for chunk in response_stream:
chunk_content = ""
if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
chunk_content = chunk.choices[0].delta.content or ""
if chunk_content:
full_response += chunk_content
response_placeholder.markdown(full_response + "▌")
# 最終応答表示
response_placeholder.markdown(full_response)
# --- ここで新しいアシスタントメッセージに対する処理 ---
# 応答を履歴に追加 *してから* インデックスを取得
st.session_state.messages.append({"role": "assistant", "content": full_response})
new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス
# 出力検証
expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
lower_response = full_response.lower()
is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords)
is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
if not is_valid_format_check and not is_refusal_check:
st.warning("The generated response might not contain expected keywords...", icon="⚠️")
elif is_refusal_check:
st.info("Input was determined to be invalid or unrelated...")
# 画像生成ボタンと表示エリア (新しいメッセージに対して)
# 条件: 画像クライアントがあり、拒否応答でない場合
if image_client and not is_refusal_check:
button_key = f"gen_img_{new_message_idx}"
if st.button("Generate Image ✨", key=button_key):
image_bytes = utils.generate_image_from_prompt(image_client, full_response)
if image_bytes:
st.session_state.generated_images[new_message_idx] = image_bytes
# 再実行ループで画像表示
# 対応する画像データがあれば表示
if new_message_idx in st.session_state.generated_images:
img_bytes = st.session_state.generated_images[new_message_idx]
st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
st.download_button(
label="Download Image 💾",
data=img_bytes,
file_name=f"recipe_infographic_{new_message_idx+1}.png",
mime="image/png",
key=f"dl_img_{new_message_idx}"
)
except Exception as e:
st.error(f"Error generating response: {str(e)}", icon="🚨")
|