File size: 11,080 Bytes
9832632
7d8708c
 
f8d71fd
 
6d7ccb1
 
 
f8d71fd
78c577d
 
c391eb7
78c577d
f8d71fd
87671ab
 
 
f8d71fd
6d7ccb1
87671ab
 
f8d71fd
87671ab
9832632
f8d71fd
9523069
7d8708c
f8d71fd
80218be
afb7956
 
7d8708c
87671ab
6d7ccb1
87671ab
 
6d7ccb1
 
 
 
7d8708c
87671ab
7d8708c
 
87671ab
6d7ccb1
87671ab
 
6d7ccb1
87671ab
6d7ccb1
87671ab
6d7ccb1
 
7d8708c
6d7ccb1
 
 
 
 
 
 
7d8708c
6d7ccb1
78c577d
 
7d8708c
 
 
6d7ccb1
78c577d
87671ab
7d8708c
6d7ccb1
7d8708c
 
f8d71fd
7d8708c
6d7ccb1
7d8708c
 
6d7ccb1
7d8708c
87671ab
 
6d7ccb1
 
 
 
 
7d8708c
6d7ccb1
87671ab
6d7ccb1
87671ab
78c577d
87671ab
6d7ccb1
 
 
 
 
 
 
87671ab
6d7ccb1
 
87671ab
 
7d8708c
 
6d7ccb1
 
7d8708c
 
 
 
6d7ccb1
7d8708c
 
6d7ccb1
7d8708c
 
6d7ccb1
 
 
 
7d8708c
 
 
6d7ccb1
 
 
 
 
 
 
 
 
 
c391eb7
6d7ccb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afb7956
6d7ccb1
c391eb7
9ebf937
6d7ccb1
9ebf937
f8d71fd
6d7ccb1
9ebf937
7d8708c
6d7ccb1
 
f8d71fd
 
6d7ccb1
f8d71fd
6d7ccb1
f8d71fd
 
 
 
 
6d7ccb1
f8d71fd
 
6d7ccb1
 
f8d71fd
6d7ccb1
 
f8d71fd
 
 
 
 
 
 
6d7ccb1
f8d71fd
6d7ccb1
f8d71fd
 
6d7ccb1
 
 
 
 
 
f8d71fd
 
6d7ccb1
 
 
 
 
 
 
 
 
 
 
 
 
c391eb7
6d7ccb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8d71fd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import streamlit as st
from cerebras.cloud.sdk import Cerebras
import openai
import os
from dotenv import load_dotenv
import base64 # 画像デコード用に追加
from io import BytesIO # 画像ダウンロード用に追加
from together import Together # Together AI SDKを追加

# config
import config
import utils 

# --- RECIPE_BASE_PROMPT のインポート ---
try:
    from prompt import RECIPE_BASE_PROMPT
except ImportError:
    st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
    st.stop()


# --- 環境変数読み込み ---
load_dotenv()

# --- Streamlit ページ設定 ---
st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infographic Prompt Generator")

# --- UI 表示 ---
utils.display_icon("🧠 x 🧑‍🍳")
st.title("Recipe Infographic Prompt Generator")
st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)

# --- APIキーの処理 ---
# Cerebras API Key
api_key_from_env = os.getenv("CEREBRAS_API_KEY")
show_api_key_input = not bool(api_key_from_env)
cerebras_api_key = None

# Together AI API Key Check
together_api_key = os.getenv("TOGETHER_API_KEY")

# --- サイドバーの設定 ---
with st.sidebar:
    st.title("Settings")

    # Cerebras Key Input
    if show_api_key_input:
        st.markdown("### :red[Enter your Cerebras API Key below]")
        api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
        if api_key_input:
            cerebras_api_key = api_key_input
    else:
        cerebras_api_key = api_key_from_env
        st.success("✓ Cerebras API Key loaded from environment")

    # Together Key Status
    if not together_api_key:
         st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
    else:
         st.success("✓ Together API Key loaded from environment") # キー自体は表示しない

    # Model selection
    model_option = st.selectbox(
        "Choose a LLM model:", # ラベルを明確化
        options=list(config.MODELS.keys()),
        format_func=lambda x: config.MODELS[x]["name"],
        key="model_select"
    )

    # Max tokens slider
    max_tokens_range = config.MODELS[model_option]["tokens"]
    default_tokens = min(2048, max_tokens_range)
    max_tokens = st.slider(
        "Max Tokens (LLM):", # ラベルを明確化
        min_value=512,
        max_value=max_tokens_range,
        value=default_tokens,
        step=512,
        help="Select the maximum number of tokens for the language model's response."
    )

    use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化

# --- メインアプリケーションロジック ---

# APIキー(Cerebras)が最終的に利用可能かチェック
if not cerebras_api_key:
    # (以前のエラー表示ロジックと同じ)
    st.markdown("...") # 省略: APIキーがない場合の説明
    st.stop()

# APIクライアント初期化 (Cerebras & Together)
try:
    # Cerebras Client
    if use_optillm:
        llm_client = openai.OpenAI(base_url=config.BASE_URL, api_key=cerebras_api_key)
    else:
        llm_client = Cerebras(api_key=cerebras_api_key)

    # Together Client (APIキーがあれば初期化)
    image_client = None
    if together_api_key:
        image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能

except Exception as e:
    st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
    st.stop()

# --- チャット履歴管理 ---
if "messages" not in st.session_state:
    st.session_state.messages = []
if "generated_images" not in st.session_state:
     st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}

if "selected_model" not in st.session_state:
    st.session_state.selected_model = None

# モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
if st.session_state.selected_model != model_option:
    st.session_state.messages = []
    st.session_state.generated_images = {} # 画像履歴もクリア
    st.session_state.selected_model = model_option

# --- チャットメッセージの表示ループ ---
# このループでは過去のメッセージを表示し、それぞれに画像生成ボタンをつける
for idx, message in enumerate(st.session_state.messages):
    avatar = '🤖' if message["role"] == "assistant" else '🦔'
    with st.chat_message(message["role"], avatar=avatar):
        st.markdown(message["content"])

        # アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
        if message["role"] == "assistant" and image_client:
             # 簡単なチェック: 拒否メッセージではないことを確認
             lower_content = message["content"].lower()
             is_likely_prompt = "please provide a valid food dish name" not in lower_content

             if is_likely_prompt:
                 button_key = f"gen_img_{idx}"
                 if st.button("Generate Image ✨", key=button_key):
                     # 画像生成関数を呼び出し、結果をセッション状態に保存
                     image_bytes = utils.generate_image_from_prompt(image_client, message["content"])
                     if image_bytes:
                         st.session_state.generated_images[idx] = image_bytes
                     # ボタンが押されたら再実行されるので、画像表示は下のブロックで行う

                 # 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
                 if idx in st.session_state.generated_images:
                     img_bytes = st.session_state.generated_images[idx]
                     st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
                     st.download_button(
                         label="Download Image 💾",
                         data=img_bytes,
                         file_name=f"recipe_infographic_{idx+1}.png",
                         mime="image/png",
                         key=f"dl_img_{idx}"
                     )

# --- チャット入力と新しいメッセージの処理 ---
if prompt := st.chat_input("Enter food name/food recipe here..."):
    # 入力検証
    if utils.contains_injection_keywords(prompt):
        st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
    elif len(prompt) > 4000:
        st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
    else:
        # --- 検証をパスした場合の処理 ---
        st.session_state.messages.append({"role": "user", "content": prompt})

        # ユーザーメッセージを表示
        with st.chat_message("user", avatar='🦔'):
            st.markdown(prompt)

        # アシスタントの応答を生成・表示
        try:
            with st.chat_message("assistant", avatar="🤖"):
                response_placeholder = st.empty()
                full_response = ""

                messages_for_api=[
                    {"role": "system", "content": RECIPE_BASE_PROMPT},
                    {"role": "user", "content": prompt}
                ]
                stream_kwargs = {
                   "model": model_option, "messages": messages_for_api,
                   "max_tokens": max_tokens, "stream": True,
                }
                # LLM Client を使用
                response_stream = llm_client.chat.completions.create(**stream_kwargs)

                for chunk in response_stream:
                    chunk_content = ""
                    if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
                        chunk_content = chunk.choices[0].delta.content or ""
                    if chunk_content:
                        full_response += chunk_content
                        response_placeholder.markdown(full_response + "▌")

                # 最終応答表示
                response_placeholder.markdown(full_response)

                # --- ここで新しいアシスタントメッセージに対する処理 ---
                # 応答を履歴に追加 *してから* インデックスを取得
                st.session_state.messages.append({"role": "assistant", "content": full_response})
                new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス

                # 出力検証
                expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
                lower_response = full_response.lower()
                is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords)
                is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response

                if not is_valid_format_check and not is_refusal_check:
                    st.warning("The generated response might not contain expected keywords...", icon="⚠️")
                elif is_refusal_check:
                     st.info("Input was determined to be invalid or unrelated...")

                # 画像生成ボタンと表示エリア (新しいメッセージに対して)
                # 条件: 画像クライアントがあり、拒否応答でない場合
                if image_client and not is_refusal_check:
                    button_key = f"gen_img_{new_message_idx}"
                    if st.button("Generate Image ✨", key=button_key):
                        image_bytes = utils.generate_image_from_prompt(image_client, full_response)
                        if image_bytes:
                            st.session_state.generated_images[new_message_idx] = image_bytes
                        # 再実行ループで画像表示

                    # 対応する画像データがあれば表示
                    if new_message_idx in st.session_state.generated_images:
                         img_bytes = st.session_state.generated_images[new_message_idx]
                         st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
                         st.download_button(
                             label="Download Image 💾",
                             data=img_bytes,
                             file_name=f"recipe_infographic_{new_message_idx+1}.png",
                             mime="image/png",
                             key=f"dl_img_{new_message_idx}"
                         )

        except Exception as e:
            st.error(f"Error generating response: {str(e)}", icon="🚨")