baxin commited on
Commit
496ad85
·
unverified ·
2 Parent(s): c3acbd1 6d7ccb1

Merge pull request #6 from koji/feat_add-flux-schnell

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +156 -88
  3. requirements.txt +2 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__
3
+ __pycache__/
app.py CHANGED
@@ -3,28 +3,24 @@ from cerebras.cloud.sdk import Cerebras
3
  import openai
4
  import os
5
  from dotenv import load_dotenv
 
 
 
6
 
7
  # --- RECIPE_BASE_PROMPT のインポート ---
8
- # prompt.py が存在し、RECIPE_BASE_PROMPTが定義されていると仮定
9
  try:
10
  from prompt import RECIPE_BASE_PROMPT
11
  except ImportError:
12
- # エラー処理: prompt.pyが見つからないか、変数が定義されていない場合
13
  st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
14
- st.stop() # 致命的なエラーなのでアプリを停止
15
- # RECIPE_BASE_PROMPT = "You are a helpful recipe assistant." # フォールバックが必要な場合
16
- # print("Warning: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' not defined. Using a default system prompt.")
17
 
18
  # --- 定数と設定 ---
19
-
20
- # モデル定義
21
  models = {
22
  "llama3.1-8b": {"name": "Llama3.1-8b", "tokens": 8192, "developer": "Meta"},
23
  "llama-3.3-70b": {"name": "Llama-3.3-70b", "tokens": 8192, "developer": "Meta"}
24
  }
25
-
26
- # Optillm用ベースURL (必要に応じて変更)
27
  BASE_URL = "http://localhost:8000/v1"
 
28
 
29
  # --- 環境変数読み込み ---
30
  load_dotenv()
@@ -34,180 +30,252 @@ st.set_page_config(page_icon="🤖", layout="wide", page_title="Recipe Infograph
34
 
35
  # --- ヘルパー関数 ---
36
  def contains_injection_keywords(text):
37
- """Checks for basic prompt injection keywords."""
38
  keywords = ["ignore previous", "ignore instructions", "disregard", "forget your instructions", "act as", "you must", "system prompt:"]
39
  lower_text = text.lower()
40
  return any(keyword in lower_text for keyword in keywords)
41
 
42
  def icon(emoji: str):
43
- """Shows an emoji as a Notion-style page icon."""
44
  st.write(
45
  f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
46
  unsafe_allow_html=True,
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # --- UI 表示 ---
50
- icon("🧠 x 🧑‍🍳") # アイコンを修正
51
  st.title("Recipe Infographic Prompt Generator")
52
  st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)
53
 
54
  # --- APIキーの処理 ---
 
55
  api_key_from_env = os.getenv("CEREBRAS_API_KEY")
56
  show_api_key_input = not bool(api_key_from_env)
57
- api_key = None
 
 
 
58
 
59
  # --- サイドバーの設定 ---
60
  with st.sidebar:
61
  st.title("Settings")
62
 
 
63
  if show_api_key_input:
64
  st.markdown("### :red[Enter your Cerebras API Key below]")
65
- api_key_input = st.text_input("Cerebras API Key:", type="password", key="api_key_input_field")
66
  if api_key_input:
67
- api_key = api_key_input
68
  else:
69
- api_key = api_key_from_env
70
- st.success("✓ API Key loaded from environment")
71
 
 
 
 
 
 
 
 
72
  model_option = st.selectbox(
73
- "Choose a model:",
74
  options=list(models.keys()),
75
  format_func=lambda x: models[x]["name"],
76
  key="model_select"
77
  )
78
 
 
79
  max_tokens_range = models[model_option]["tokens"]
80
  default_tokens = min(2048, max_tokens_range)
81
  max_tokens = st.slider(
82
- "Max Tokens:",
83
  min_value=512,
84
  max_value=max_tokens_range,
85
  value=default_tokens,
86
  step=512,
87
- help="Select the maximum number of tokens for the model's response." # helpテキストを修正
88
  )
89
 
90
- use_optillm = st.toggle("Use Optillm", value=False)
91
 
92
  # --- メインアプリケーションロジック ---
93
 
94
- # APIキーが最終的に利用可能かチェック (サイドバーの処理後)
95
- if not api_key:
96
- st.markdown("""
97
- ## Cerebras API x Streamlit Demo!
98
-
99
- This simple chatbot app demonstrates how to use Cerebras with Streamlit.
100
 
101
- To get started:
102
- """)
103
- if show_api_key_input:
104
- # サイドバー入力が表示されている場合
105
- st.warning("1. :red[Enter your Cerebras API Key in the sidebar.]")
106
- else:
107
- # 環境変数から読み込むべきだったが、見つからなかった/空だった場合
108
- st.error("1. :red[CEREBRAS_API_KEY environment variable not found or empty.] Please set it in your environment (e.g., in a .env file).")
109
- st.markdown("2. Configure your settings and start chatting.") # メッセージを少し変更
110
- st.stop() # APIキーがない場合はここで停止
111
-
112
- # APIキーが利用可能な場合のみクライアントを初期化
113
  try:
 
114
  if use_optillm:
115
- client = openai.OpenAI(
116
- base_url=BASE_URL, # Optillmがlocalhostを使用する場合
117
- api_key=api_key
118
- )
119
  else:
120
- # Cerebras SDKがapi_keyだけで初期化可能か確認
121
- # SDKのバージョンや使い方によってはendpoint等の追加設定が必要な場合あり
122
- client = Cerebras(api_key=api_key)
123
- # st.success("API Client Initialized.") # 任意:初期化成功メッセージ
 
 
 
124
  except Exception as e:
125
- st.error(f"Failed to initialize API client: {str(e)}", icon="🚨")
126
- st.stop() # クライアント初期化失敗時も停止
127
 
128
  # --- チャット履歴管理 ---
129
  if "messages" not in st.session_state:
130
  st.session_state.messages = []
 
 
131
 
132
  if "selected_model" not in st.session_state:
133
  st.session_state.selected_model = None
134
 
135
- # モデルが変更されたら履歴をクリア
136
  if st.session_state.selected_model != model_option:
137
  st.session_state.messages = []
 
138
  st.session_state.selected_model = model_option
139
 
140
- # チャットメッセージを表示
141
- for message in st.session_state.messages:
142
- avatar = '🤖' if message["role"] == "assistant" else '🦔' # アバターを調整 (ユーザーはハリネズミ?)
 
143
  with st.chat_message(message["role"], avatar=avatar):
144
  st.markdown(message["content"])
145
 
146
- # --- チャット入力と処理 (インデント修正済み) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  if prompt := st.chat_input("Enter food name/food recipe here..."):
148
- # ☆★☆ 入力検証 ☆★☆
149
  if contains_injection_keywords(prompt):
150
  st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
151
- elif len(prompt) > 4000: # 文字数制限は適切に調整してください
152
  st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
153
  else:
154
- # ↓↓↓ --- 検証をパスした場合の処理 (ここからインデント) --- ↓↓↓
155
  st.session_state.messages.append({"role": "user", "content": prompt})
156
 
157
- with st.chat_message("user", avatar='🦔'): # ユーザーアバター
 
158
  st.markdown(prompt)
159
 
 
160
  try:
161
- with st.chat_message("assistant", avatar="🤖"): # アシスタントアバター
162
  response_placeholder = st.empty()
163
  full_response = ""
164
 
165
- # APIに送信するメッセージリストを作成
166
  messages_for_api=[
167
  {"role": "system", "content": RECIPE_BASE_PROMPT},
168
- {"role": "user", "content": prompt} # 最新のユーザープロンプトのみ
169
  ]
170
-
171
- # ストリーミングで応答を取得
172
  stream_kwargs = {
173
- "model": model_option,
174
- "messages": messages_for_api,
175
- "max_tokens": max_tokens,
176
- "stream": True,
177
  }
178
- response_stream = client.chat.completions.create(**stream_kwargs)
 
179
 
180
  for chunk in response_stream:
181
  chunk_content = ""
182
- # API応答の構造に合わせて調整が必要な場合あり
183
  if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
184
  chunk_content = chunk.choices[0].delta.content or ""
185
-
186
  if chunk_content:
187
  full_response += chunk_content
188
- response_placeholder.markdown(full_response + "▌") # カーソル表示
189
 
190
- # 最終的な応答を表示(カーソルなし)
191
  response_placeholder.markdown(full_response)
192
 
193
- # ☆★☆ 出力検証 ☆★☆
 
 
 
 
 
194
  expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
195
  lower_response = full_response.lower()
196
- is_valid_format = any(keyword in lower_response for keyword in expected_keywords)
197
- # システムプロンプトで定義した拒否応答の文字列と一致させる
198
- is_refusal = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
199
-
200
- if not is_valid_format and not is_refusal:
201
- # 期待される形式でもなく、意図した拒否応答でもない場合
202
- st.warning("The generated response might not contain expected keywords or could indicate an issue.", icon="⚠️")
203
- elif is_refusal:
204
- # 意図した拒否応答の場合 (infoレベルで表示)
205
- st.info("Input was determined to be invalid or unrelated. Please provide a valid food dish/recipe.") # メッセージを少し調整
206
-
207
- # アシスタントの応答を履歴に追加
208
- st.session_state.messages.append(
209
- {"role": "assistant", "content": full_response})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
 
211
  except Exception as e:
212
  st.error(f"Error generating response: {str(e)}", icon="🚨")
213
- # ↑↑↑ --- ここまでが else 節のインデント内 --- ↑↑↑
 
3
  import openai
4
  import os
5
  from dotenv import load_dotenv
6
+ import base64 # 画像デコード用に追加
7
+ from io import BytesIO # 画像ダウンロード用に追加
8
+ from together import Together # Together AI SDKを追加
9
 
10
  # --- RECIPE_BASE_PROMPT のインポート ---
 
11
  try:
12
  from prompt import RECIPE_BASE_PROMPT
13
  except ImportError:
 
14
  st.error("Error: 'prompt.py' not found or 'RECIPE_BASE_PROMPT' is not defined within it.")
15
+ st.stop()
 
 
16
 
17
  # --- 定数と設定 ---
 
 
18
  models = {
19
  "llama3.1-8b": {"name": "Llama3.1-8b", "tokens": 8192, "developer": "Meta"},
20
  "llama-3.3-70b": {"name": "Llama-3.3-70b", "tokens": 8192, "developer": "Meta"}
21
  }
 
 
22
  BASE_URL = "http://localhost:8000/v1"
23
+ IMAGE_MODEL = "black-forest-labs/FLUX.1-schnell-Free" # 使用する画像生成モデル
24
 
25
  # --- 環境変数読み込み ---
26
  load_dotenv()
 
30
 
31
  # --- ヘルパー関数 ---
32
  def contains_injection_keywords(text):
 
33
  keywords = ["ignore previous", "ignore instructions", "disregard", "forget your instructions", "act as", "you must", "system prompt:"]
34
  lower_text = text.lower()
35
  return any(keyword in lower_text for keyword in keywords)
36
 
37
  def icon(emoji: str):
 
38
  st.write(
39
  f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
40
  unsafe_allow_html=True,
41
  )
42
 
43
+ # --- 画像生成関数 ---
44
+ @st.cache_data(show_spinner="Generating image...") # 結果をキャッシュ & スピナー表示
45
+ def generate_image_from_prompt(_together_client, prompt_text):
46
+ """Generates an image using Together AI and returns image bytes."""
47
+ try:
48
+ response = _together_client.images.generate(
49
+ prompt=prompt_text,
50
+ model=IMAGE_MODEL,
51
+ width=1024,
52
+ height=768, # モデルに合わせて調整が必要な場合あり
53
+ steps=4, # モデルに合わせて調整が必要な場合あり
54
+ n=1,
55
+ response_format="b64_json",
56
+ # stop=[] # stopは通常不要
57
+ )
58
+ if response.data and response.data[0].b64_json:
59
+ b64_data = response.data[0].b64_json
60
+ image_bytes = base64.b64decode(b64_data)
61
+ return image_bytes
62
+ else:
63
+ st.error("Image generation failed: No image data received.")
64
+ return None
65
+ except Exception as e:
66
+ st.error(f"Image generation error: {e}", icon="🚨")
67
+ return None
68
+
69
  # --- UI 表示 ---
70
+ icon("🧠 x 🧑‍🍳")
71
  st.title("Recipe Infographic Prompt Generator")
72
  st.subheader("Simply enter a dish name or recipe to easily generate image prompts for stunning recipe infographics", divider="orange", anchor=False)
73
 
74
  # --- APIキーの処理 ---
75
+ # Cerebras API Key
76
  api_key_from_env = os.getenv("CEREBRAS_API_KEY")
77
  show_api_key_input = not bool(api_key_from_env)
78
+ cerebras_api_key = None
79
+
80
+ # Together AI API Key Check
81
+ together_api_key = os.getenv("TOGETHER_API_KEY")
82
 
83
  # --- サイドバーの設定 ---
84
  with st.sidebar:
85
  st.title("Settings")
86
 
87
+ # Cerebras Key Input
88
  if show_api_key_input:
89
  st.markdown("### :red[Enter your Cerebras API Key below]")
90
+ api_key_input = st.text_input("Cerebras API Key:", type="password", key="cerebras_api_key_input_field")
91
  if api_key_input:
92
+ cerebras_api_key = api_key_input
93
  else:
94
+ cerebras_api_key = api_key_from_env
95
+ st.success("✓ Cerebras API Key loaded from environment")
96
 
97
+ # Together Key Status
98
+ if not together_api_key:
99
+ st.warning("TOGETHER_API_KEY environment variable not set. Image generation will not work.", icon="⚠️")
100
+ else:
101
+ st.success("✓ Together API Key loaded from environment") # キー自体は表示しない
102
+
103
+ # Model selection
104
  model_option = st.selectbox(
105
+ "Choose a LLM model:", # ラベルを明確化
106
  options=list(models.keys()),
107
  format_func=lambda x: models[x]["name"],
108
  key="model_select"
109
  )
110
 
111
+ # Max tokens slider
112
  max_tokens_range = models[model_option]["tokens"]
113
  default_tokens = min(2048, max_tokens_range)
114
  max_tokens = st.slider(
115
+ "Max Tokens (LLM):", # ラベルを明確化
116
  min_value=512,
117
  max_value=max_tokens_range,
118
  value=default_tokens,
119
  step=512,
120
+ help="Select the maximum number of tokens for the language model's response."
121
  )
122
 
123
+ use_optillm = st.toggle("Use Optillm (for Cerebras)", value=False) # ラベルを明確化
124
 
125
  # --- メインアプリケーションロジック ---
126
 
127
+ # APIキー(Cerebras)が最終的に利用可能かチェック
128
+ if not cerebras_api_key:
129
+ # (以前のエラー表示ロジックと同じ)
130
+ st.markdown("...") # 省略: APIキーがない場合の説明
131
+ st.stop()
 
132
 
133
+ # APIクライアント初期化 (Cerebras & Together)
 
 
 
 
 
 
 
 
 
 
 
134
  try:
135
+ # Cerebras Client
136
  if use_optillm:
137
+ llm_client = openai.OpenAI(base_url=BASE_URL, api_key=cerebras_api_key)
 
 
 
138
  else:
139
+ llm_client = Cerebras(api_key=cerebras_api_key)
140
+
141
+ # Together Client (APIキーがあれば初期化)
142
+ image_client = None
143
+ if together_api_key:
144
+ image_client = Together(api_key=together_api_key) # 明示的にキーを渡すことも可能
145
+
146
  except Exception as e:
147
+ st.error(f"Failed to initialize API client(s): {str(e)}", icon="🚨")
148
+ st.stop()
149
 
150
  # --- チャット履歴管理 ---
151
  if "messages" not in st.session_state:
152
  st.session_state.messages = []
153
+ if "generated_images" not in st.session_state:
154
+ st.session_state.generated_images = {} # 画像データをメッセージIDごとに保存 {msg_idx: image_bytes}
155
 
156
  if "selected_model" not in st.session_state:
157
  st.session_state.selected_model = None
158
 
159
+ # モデルが変更されたら履歴をクリア (画像履歴もクリアするかは要検討)
160
  if st.session_state.selected_model != model_option:
161
  st.session_state.messages = []
162
+ st.session_state.generated_images = {} # 画像履歴もクリア
163
  st.session_state.selected_model = model_option
164
 
165
+ # --- チャットメッセージの表示ループ ---
166
+ # このループでは過去のメッセージを表示し、それぞれに画像生成ボタンをつける
167
+ for idx, message in enumerate(st.session_state.messages):
168
+ avatar = '🤖' if message["role"] == "assistant" else '🦔'
169
  with st.chat_message(message["role"], avatar=avatar):
170
  st.markdown(message["content"])
171
 
172
+ # アシスタントのメッセージで、かつ有効な形式の可能性があり、画像クライアントが利用可能な場合
173
+ if message["role"] == "assistant" and image_client:
174
+ # 簡単なチェック: 拒否メッセージではないことを確認
175
+ lower_content = message["content"].lower()
176
+ is_likely_prompt = "please provide a valid food dish name" not in lower_content
177
+
178
+ if is_likely_prompt:
179
+ button_key = f"gen_img_{idx}"
180
+ if st.button("Generate Image ✨", key=button_key):
181
+ # 画像生成関数を呼び出し、結果をセッション状態に保存
182
+ image_bytes = generate_image_from_prompt(image_client, message["content"])
183
+ if image_bytes:
184
+ st.session_state.generated_images[idx] = image_bytes
185
+ # ボタンが押されたら再実行されるので、画像表示は下のブロックで行う
186
+
187
+ # 対応する画像データがセッション状態にあれば表示・ダウンロードボタンを表示
188
+ if idx in st.session_state.generated_images:
189
+ img_bytes = st.session_state.generated_images[idx]
190
+ st.image(img_bytes, caption=f"Generated Image for Prompt #{idx+1}")
191
+ st.download_button(
192
+ label="Download Image 💾",
193
+ data=img_bytes,
194
+ file_name=f"recipe_infographic_{idx+1}.png",
195
+ mime="image/png",
196
+ key=f"dl_img_{idx}"
197
+ )
198
+
199
+ # --- チャット入力と新しいメッセージの処理 ---
200
  if prompt := st.chat_input("Enter food name/food recipe here..."):
201
+ # 入力検証
202
  if contains_injection_keywords(prompt):
203
  st.error("Your input seems to contain instructions. Please provide only the dish name or recipe.", icon="🚨")
204
+ elif len(prompt) > 4000:
205
  st.error("Input is too long. Please provide a shorter recipe or dish name.", icon="🚨")
206
  else:
207
+ # --- 検証をパスした場合の処理 ---
208
  st.session_state.messages.append({"role": "user", "content": prompt})
209
 
210
+ # ユーザーメッセージを表示
211
+ with st.chat_message("user", avatar='🦔'):
212
  st.markdown(prompt)
213
 
214
+ # アシスタントの応答を生成・表示
215
  try:
216
+ with st.chat_message("assistant", avatar="🤖"):
217
  response_placeholder = st.empty()
218
  full_response = ""
219
 
 
220
  messages_for_api=[
221
  {"role": "system", "content": RECIPE_BASE_PROMPT},
222
+ {"role": "user", "content": prompt}
223
  ]
 
 
224
  stream_kwargs = {
225
+ "model": model_option, "messages": messages_for_api,
226
+ "max_tokens": max_tokens, "stream": True,
 
 
227
  }
228
+ # LLM Client を使用
229
+ response_stream = llm_client.chat.completions.create(**stream_kwargs)
230
 
231
  for chunk in response_stream:
232
  chunk_content = ""
 
233
  if hasattr(chunk, 'choices') and chunk.choices and hasattr(chunk.choices[0], 'delta') and chunk.choices[0].delta and hasattr(chunk.choices[0].delta, 'content'):
234
  chunk_content = chunk.choices[0].delta.content or ""
 
235
  if chunk_content:
236
  full_response += chunk_content
237
+ response_placeholder.markdown(full_response + "▌")
238
 
239
+ # 最終応答表示
240
  response_placeholder.markdown(full_response)
241
 
242
+ # --- ここで新しいアシスタントメッセージに対する処理 ---
243
+ # 応答を履歴に追加 *してから* インデックスを取得
244
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
245
+ new_message_idx = len(st.session_state.messages) - 1 # 新しいメッセージのインデックス
246
+
247
+ # 出力検証
248
  expected_keywords = ["infographic", "step-by-step", "ingredient", "layout", "minimal style"]
249
  lower_response = full_response.lower()
250
+ is_valid_format_check = any(keyword in lower_response for keyword in expected_keywords)
251
+ is_refusal_check = "please provide a valid food dish name or recipe for infographic prompt generation" in lower_response
252
+
253
+ if not is_valid_format_check and not is_refusal_check:
254
+ st.warning("The generated response might not contain expected keywords...", icon="⚠️")
255
+ elif is_refusal_check:
256
+ st.info("Input was determined to be invalid or unrelated...")
257
+
258
+ # 画像生成ボタンと表示エリア (新しいメッセージに対して)
259
+ # 条件: 画像クライアントがあり、拒否応答でない場合
260
+ if image_client and not is_refusal_check:
261
+ button_key = f"gen_img_{new_message_idx}"
262
+ if st.button("Generate Image ✨", key=button_key):
263
+ image_bytes = generate_image_from_prompt(image_client, full_response)
264
+ if image_bytes:
265
+ st.session_state.generated_images[new_message_idx] = image_bytes
266
+ # 再実行ループで画像表示
267
+
268
+ # 対応する画像データがあれば表示
269
+ if new_message_idx in st.session_state.generated_images:
270
+ img_bytes = st.session_state.generated_images[new_message_idx]
271
+ st.image(img_bytes, caption=f"Generated Image for Prompt #{new_message_idx+1}")
272
+ st.download_button(
273
+ label="Download Image 💾",
274
+ data=img_bytes,
275
+ file_name=f"recipe_infographic_{new_message_idx+1}.png",
276
+ mime="image/png",
277
+ key=f"dl_img_{new_message_idx}"
278
+ )
279
 
280
  except Exception as e:
281
  st.error(f"Error generating response: {str(e)}", icon="🚨")
 
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  cerebras_cloud_sdk
2
  openai
3
  python-dotenv
 
 
 
1
  cerebras_cloud_sdk
2
  openai
3
  python-dotenv
4
+ together
5
+ Pillow