baxin commited on
Commit
f3bc8db
·
1 Parent(s): 492dda1

change app title

Browse files
Files changed (1) hide show
  1. app.py +86 -66
app.py CHANGED
@@ -3,19 +3,20 @@ from cerebras.cloud.sdk import Cerebras
3
  import openai
4
  import os
5
  from dotenv import load_dotenv
6
- import base64 # 画像デコード用に追加
7
- from io import BytesIO # 画像ダウンロード用に追加
8
- from together import Together # Together AI SDKを追加
9
 
10
  # config
11
  import config
12
- import utils
13
 
14
  # --- RECIPE_BASE_PROMPT のインポート ---
15
  try:
16
  from prompt import RECIPE_BASE_PROMPT
17
  except ImportError:
18
- st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
 
19
  st.stop()
20
 
21
 
@@ -23,12 +24,14 @@ except ImportError:
23
  load_dotenv()
24
 
25
  # --- Streamlit ページ設定 ---
26
- st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infographic Prompt Generator")
 
27
 
28
  # --- UI 表示 ---
29
- utils.display_icon("🧠 x 🧑‍🍳")
30
- st.title("Recipe Infographic Prompt Generator")
31
- st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)
 
32
 
33
  # --- APIキーの処理 ---
34
  # Cerebras API Key
@@ -46,7 +49,8 @@ with st.sidebar:
46
  # Cerebras Key Input
47
  if show_api_key_input:
48
  st.markdown("### :red[Enter your Cerebras API Key below]")
49
- api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
 
50
  if api_key_input:
51
  cerebras_api_key = api_key_input
52
  else:
@@ -55,13 +59,14 @@ with st.sidebar:
55
 
56
  # Together Key Status
57
  if not together_api_key:
58
- st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
 
59
  else:
60
- st.success("✓ Together API Key loaded from environment") # キー自体は表示しない
61
 
62
  # Model selection
63
  model_option = st.selectbox(
64
- "Choose a LLM model:", # ラベルを明確化
65
  options=list(config.MODELS.keys()),
66
  format_func=lambda x: config.MODELS[x]["name"],
67
  key="model_select"
@@ -71,7 +76,7 @@ with st.sidebar:
71
  max_tokens_range = config.MODELS[model_option]["tokens"]
72
  default_tokens = min(2048, max_tokens_range)
73
  max_tokens = st.slider(
74
- "Max Tokens (LLM):", # ラベルを明確化
75
  min_value=512,
76
  max_value=max_tokens_range,
77
  value=default_tokens,
@@ -79,28 +84,30 @@ with st.sidebar:
79
  help="Select the maximum number of tokens for the language model's response."
80
  )
81
 
82
- use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化
 
83
 
84
  # --- メインアプリケーションロジック ---
85
 
86
  # APIキー(Cerebras)が最終的に利用可能かチェック
87
  if not cerebras_api_key:
88
  # (以前のエラー表示ロジックと同じ)
89
- st.markdown("...") # 省略: APIキーがない場合の説明
90
  st.stop()
91
 
92
  # APIクライアント初期化 (Cerebras & Together)
93
  try:
94
  # Cerebras Client
95
  if use_optillm:
96
- llm_client = openai.OpenAI(base_url=config.BASE_URL, api_key=cerebras_api_key)
 
97
  else:
98
  llm_client = Cerebras(api_key=cerebras_api_key)
99
 
100
  # Together Client (APIキーがあれば初期化)
101
  image_client = None
102
  if together_api_key:
103
- image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能
104
 
105
  except Exception as e:
106
  st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
@@ -110,7 +117,8 @@ except Exception as e:
110
  if "messages" not in st.session_state:
111
  st.session_state.messages = []
112
  if "generated_images" not in st.session_state:
113
- st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}
 
114
 
115
  if "selected_model" not in st.session_state:
116
  st.session_state.selected_model = None
@@ -118,7 +126,7 @@ if "selected_model" not in st.session_state:
118
  # モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
119
  if st.session_state.selected_model != model_option:
120
  st.session_state.messages = []
121
- st.session_state.generated_images = {} # 画像履歴もクリア
122
  st.session_state.selected_model = model_option
123
 
124
  # --- チャットメッセージの表示ループ ---
@@ -130,38 +138,42 @@ for idx, message in enumerate(st.session_state.messages):
130
 
131
  # アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
132
  if message["role"] == "assistant" and image_client:
133
- # 簡単なチェック: 拒否メッセージではないことを確認
134
- lower_content = message["content"].lower()
135
- is_likely_prompt = "please provide a valid food dish name" not in lower_content
136
-
137
- if is_likely_prompt:
138
- button_key = f"gen_img_{idx}"
139
- if st.button("Generate Image ✨", key=button_key):
140
- # 画像生成関数を呼び出し、結果をセッション状態に保存
141
- image_bytes = utils.generate_image_from_prompt(image_client, message["content"])
142
- if image_bytes:
143
- st.session_state.generated_images[idx] = image_bytes
144
- # ボタンが押されたら再実行されるので、画像表示は下のブロックで行う
145
-
146
- # 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
147
- if idx in st.session_state.generated_images:
148
- img_bytes = st.session_state.generated_images[idx]
149
- st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
150
- st.download_button(
151
- label="Download Image 💾",
152
- data=img_bytes,
153
- file_name=f"recipe_infographic_{idx+1}.png",
154
- mime="image/png",
155
- key=f"dl_img_{idx}"
156
- )
 
 
157
 
158
  # --- チャット入力と新しいメッセージの処理 ---
159
  if prompt := st.chat_input("Enter food name/food recipe here..."):
160
  # 入力検証
161
  if utils.contains_injection_keywords(prompt):
162
- st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
 
163
  elif len(prompt) > 4000:
164
- st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
 
165
  else:
166
  # --- 検証をパスした場合の処理 ---
167
  st.session_state.messages.append({"role": "user", "content": prompt})
@@ -176,16 +188,17 @@ if prompt := st.chat_input("Enter food name/food recipe here..."):
176
  response_placeholder = st.empty()
177
  full_response = ""
178
 
179
- messages_for_api=[
180
  {"role": "system", "content": RECIPE_BASE_PROMPT},
181
  {"role": "user", "content": prompt}
182
  ]
183
  stream_kwargs = {
184
- "model": model_option, "messages": messages_for_api,
185
- "max_tokens": max_tokens, "stream": True,
186
  }
187
  # LLM Client を使用
188
- response_stream = llm_client.chat.completions.create(**stream_kwargs)
 
189
 
190
  for chunk in response_stream:
191
  chunk_content = ""
@@ -200,41 +213,48 @@ if prompt := st.chat_input("Enter food name/food recipe here..."):
200
 
201
  # --- ここで新しいアシスタントメッセージに対する処理 ---
202
  # 応答を履歴に追加 *してから* インデックスを取得
203
- st.session_state.messages.append({"role": "assistant", "content": full_response})
204
- new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス
 
 
205
 
206
  # 出力検証
207
- expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
 
208
  lower_response = full_response.lower()
209
- is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords)
 
210
  is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
211
 
212
  if not is_valid_format_check and not is_refusal_check:
213
- st.warning("The generated response might not contain expected keywords...", icon="⚠️")
 
214
  elif is_refusal_check:
215
- st.info("Input was determined to be invalid or unrelated...")
216
 
217
  # 画像生成ボタンと表示エリア (新しいメッセージに対して)
218
  # 条件: 画像クライアントがあり、拒否応答でない場合
219
  if image_client and not is_refusal_check:
220
  button_key = f"gen_img_{new_message_idx}"
221
  if st.button("Generate Image ✨", key=button_key):
222
- image_bytes = utils.generate_image_from_prompt(image_client, full_response)
 
223
  if image_bytes:
224
  st.session_state.generated_images[new_message_idx] = image_bytes
225
  # 再実行ループで画像表示
226
 
227
  # 対応する画像データがあれば表示
228
  if new_message_idx in st.session_state.generated_images:
229
- img_bytes = st.session_state.generated_images[new_message_idx]
230
- st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
231
- st.download_button(
232
- label="Download Image 💾",
233
- data=img_bytes,
234
- file_name=f"recipe_infographic_{new_message_idx+1}.png",
235
- mime="image/png",
236
- key=f"dl_img_{new_message_idx}"
237
- )
 
238
 
239
  except Exception as e:
240
  st.error(f"Error generating response: {str(e)}", icon="🚨")
 
3
  import openai
4
  import os
5
  from dotenv import load_dotenv
6
+ import base64 # 画像デコード用に追加
7
+ from io import BytesIO # 画像ダウンロード用に追加
8
+ from together import Together # Together AI SDKを追加
9
 
10
  # config
11
  import config
12
+ import utils
13
 
14
  # --- RECIPE_BASE_PROMPT のインポート ---
15
  try:
16
  from prompt import RECIPE_BASE_PROMPT
17
  except ImportError:
18
+ st.error(
19
+ "Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
20
  st.stop()
21
 
22
 
 
24
  load_dotenv()
25
 
26
  # --- Streamlit ページ設定 ---
27
+ st.set_page_config(page_icon="🤖", layout="wide",
28
+ page_title="Recipe Infographic Prompt Generator")
29
 
30
  # --- UI 表示 ---
31
+ utils.display_icon("🤖 x 🧑‍🍳")
32
+ st.title("Recipe Infographic Generator")
33
+ st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics",
34
+ divider="orange", anchor=False)
35
 
36
  # --- APIキーの処理 ---
37
  # Cerebras API Key
 
49
  # Cerebras Key Input
50
  if show_api_key_input:
51
  st.markdown("### :red[Enter your Cerebras API Key below]")
52
+ api_key_input = st.text_input(
53
+ "Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
54
  if api_key_input:
55
  cerebras_api_key = api_key_input
56
  else:
 
59
 
60
  # Together Key Status
61
  if not together_api_key:
62
+ st.warning(
63
+ "TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
64
  else:
65
+ st.success("✓ Together API Key loaded from environment") # キー自体は表示しない
66
 
67
  # Model selection
68
  model_option = st.selectbox(
69
+ "Choose a LLM model:", # ラベルを明確化
70
  options=list(config.MODELS.keys()),
71
  format_func=lambda x: config.MODELS[x]["name"],
72
  key="model_select"
 
76
  max_tokens_range = config.MODELS[model_option]["tokens"]
77
  default_tokens = min(2048, max_tokens_range)
78
  max_tokens = st.slider(
79
+ "Max Tokens (LLM):", # ラベルを明確化
80
  min_value=512,
81
  max_value=max_tokens_range,
82
  value=default_tokens,
 
84
  help="Select the maximum number of tokens for the language model's response."
85
  )
86
 
87
+ use_optillm = st.toggle(
88
+ "Use Optillm (for Cerebras)", value=False) # ラベルを明確化
89
 
90
  # --- メインアプリケーションロジック ---
91
 
92
  # APIキー(Cerebras)が最終的に利用可能かチェック
93
  if not cerebras_api_key:
94
  # (以前のエラー表示ロジックと同じ)
95
+ st.markdown("...") # 省略: APIキーがない場合の説明
96
  st.stop()
97
 
98
  # APIクライアント初期化 (Cerebras & Together)
99
  try:
100
  # Cerebras Client
101
  if use_optillm:
102
+ llm_client = openai.OpenAI(
103
+ base_url=config.BASE_URL, api_key=cerebras_api_key)
104
  else:
105
  llm_client = Cerebras(api_key=cerebras_api_key)
106
 
107
  # Together Client (APIキーがあれば初期化)
108
  image_client = None
109
  if together_api_key:
110
+ image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能
111
 
112
  except Exception as e:
113
  st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
 
117
  if "messages" not in st.session_state:
118
  st.session_state.messages = []
119
  if "generated_images" not in st.session_state:
120
+ # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}
121
+ st.session_state.generated_images = {}
122
 
123
  if "selected_model" not in st.session_state:
124
  st.session_state.selected_model = None
 
126
  # モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
127
  if st.session_state.selected_model != model_option:
128
  st.session_state.messages = []
129
+ st.session_state.generated_images = {} # 画像履歴もクリア
130
  st.session_state.selected_model = model_option
131
 
132
  # --- チャットメッセージの表示ループ ---
 
138
 
139
  # アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
140
  if message["role"] == "assistant" and image_client:
141
+ # 簡単なチェック: 拒否メッセージではないことを確認
142
+ lower_content = message["content"].lower()
143
+ is_likely_prompt = "please provide a valid food dish name" not in lower_content
144
+
145
+ if is_likely_prompt:
146
+ button_key = f"gen_img_{idx}"
147
+ if st.button("Generate Image ✨", key=button_key):
148
+ # 画像生成関数を呼び出し、結果をセッション状態に保存
149
+ image_bytes = utils.generate_image_from_prompt(
150
+ image_client, message["content"])
151
+ if image_bytes:
152
+ st.session_state.generated_images[idx] = image_bytes
153
+ # ボタンが押されたら再実行されるので、画像表示は下のブロックで行う
154
+
155
+ # 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
156
+ if idx in st.session_state.generated_images:
157
+ img_bytes = st.session_state.generated_images[idx]
158
+ st.image(
159
+ img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
160
+ st.download_button(
161
+ label="Download Image 💾",
162
+ data=img_bytes,
163
+ file_name=f"recipe_infographic_{idx+1}.png",
164
+ mime="image/png",
165
+ key=f"dl_img_{idx}"
166
+ )
167
 
168
  # --- チャット入力と新しいメッセージの処理 ---
169
  if prompt := st.chat_input("Enter food name/food recipe here..."):
170
  # 入力検証
171
  if utils.contains_injection_keywords(prompt):
172
+ st.error(
173
+ "Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
174
  elif len(prompt) > 4000:
175
+ st.error(
176
+ "Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
177
  else:
178
  # --- 検証をパスした場合の処理 ---
179
  st.session_state.messages.append({"role": "user", "content": prompt})
 
188
  response_placeholder = st.empty()
189
  full_response = ""
190
 
191
+ messages_for_api = [
192
  {"role": "system", "content": RECIPE_BASE_PROMPT},
193
  {"role": "user", "content": prompt}
194
  ]
195
  stream_kwargs = {
196
+ "model": model_option, "messages": messages_for_api,
197
+ "max_tokens": max_tokens, "stream": True,
198
  }
199
  # LLM Client を使用
200
+ response_stream = llm_client.chat.completions.create(
201
+ **stream_kwargs)
202
 
203
  for chunk in response_stream:
204
  chunk_content = ""
 
213
 
214
  # --- ここで新しいアシスタントメッセージに対する処理 ---
215
  # 応答を履歴に追加 *してから* インデックスを取得
216
+ st.session_state.messages.append(
217
+ {"role": "assistant", "content": full_response})
218
+ new_message_idx = len(
219
+ st.session_state.messages) - 1 # 新しいメッセージのインデックス
220
 
221
  # 出力検証
222
+ expected_keywords = [
223
+ "infographic", "step-by-step", "ingredient", "layout", "minimal style"]
224
  lower_response = full_response.lower()
225
+ is_valid_format_check = any(
226
+ keyword in lower_response for keyword in expected_keywords)
227
  is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
228
 
229
  if not is_valid_format_check and not is_refusal_check:
230
+ st.warning(
231
+ "The generated response might not contain expected keywords...", icon="⚠️")
232
  elif is_refusal_check:
233
+ st.info("Input was determined to be invalid or unrelated...")
234
 
235
  # 画像生成ボタンと表示エリア (新しいメッセージに対して)
236
  # 条件: 画像クライアントがあり、拒否応答でない場合
237
  if image_client and not is_refusal_check:
238
  button_key = f"gen_img_{new_message_idx}"
239
  if st.button("Generate Image ✨", key=button_key):
240
+ image_bytes = utils.generate_image_from_prompt(
241
+ image_client, full_response)
242
  if image_bytes:
243
  st.session_state.generated_images[new_message_idx] = image_bytes
244
  # 再実行ループで画像表示
245
 
246
  # 対応する画像データがあれば表示
247
  if new_message_idx in st.session_state.generated_images:
248
+ img_bytes = st.session_state.generated_images[new_message_idx]
249
+ st.image(
250
+ img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
251
+ st.download_button(
252
+ label="Download Image 💾",
253
+ data=img_bytes,
254
+ file_name=f"recipe_infographic_{new_message_idx+1}.png",
255
+ mime="image/png",
256
+ key=f"dl_img_{new_message_idx}"
257
+ )
258
 
259
  except Exception as e:
260
  st.error(f"Error generating response: {str(e)}", icon="🚨")