ruslanmv commited on
Commit
a65dc3e
·
verified ·
1 Parent(s): aec88ad

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +189 -150
src/app.py CHANGED
@@ -1,29 +1,51 @@
1
  """Developed by Ruslan Magana Vsevolodovna"""
 
2
  from collections.abc import Iterator
3
  from datetime import datetime
4
  from pathlib import Path
5
  from threading import Thread
 
 
 
 
6
  import gradio as gr
7
  import spaces
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
11
- import random
12
  from themes.research_monochrome import theme
13
 
14
  # =============================================================================
15
  # Constants & Prompts
16
  # =============================================================================
17
- today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002
18
- SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.Today's Date: {today_date}.You are Granite, developed by IBM. You are a helpful AI assistant. Respond in the following format:<reasoning>Step-by-step reasoning to arrive at the answer.</reasoning><answer>The final answer to the user's query.</answer> If reasoning is not applicable, you can directly provide the <answer>."""
 
 
 
 
 
 
 
 
 
19
  TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview"
20
- DESCRIPTION = """<p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample promptsor enter your own. Keep in mind that AI can occasionally make mistakes.<span class="gr_docs_link"><a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a></span></p>"""
 
 
 
 
 
 
 
21
  MAX_INPUT_TOKEN_LENGTH = 128_000
22
  MAX_NEW_TOKENS = 1024
23
- TEMPERATURE = 0.7
24
  TOP_P = 0.85
25
  TOP_K = 50
26
  REPETITION_PENALTY = 1.05
 
27
  # Vision defaults (advanced settings)
28
  VISION_TEMPERATURE = 0.2
29
  VISION_TOP_P = 0.95
@@ -32,13 +54,13 @@ VISION_MAX_TOKENS = 128
32
 
33
  if not torch.cuda.is_available():
34
  print("This demo may not work on CPU.")
 
35
  # =============================================================================
36
  # Text Model Loading
37
  # =============================================================================
38
- #Standard Model
39
- #granite_text_model="ibm-granite/granite-3.1-8b-instruct"
40
- #With Reasoning
41
- granite_text_model="ruslanmv/granite-3.1-8b-Reasoning"
42
  text_model = AutoModelForCausalLM.from_pretrained(
43
  granite_text_model,
44
  torch_dtype=torch.float16,
@@ -46,6 +68,7 @@ text_model = AutoModelForCausalLM.from_pretrained(
46
  )
47
  tokenizer = AutoTokenizer.from_pretrained(granite_text_model)
48
  tokenizer.use_default_system_prompt = False
 
49
  # =============================================================================
50
  # Vision Model Loading
51
  # =============================================================================
@@ -55,8 +78,63 @@ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
55
  vision_model_path,
56
  torch_dtype=torch.float16,
57
  device_map="auto",
58
- trust_remote_code=True # Ensure the custom code is used so that weight shapes match.)
59
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # =============================================================================
61
  # Text Generation Function (for text-only chat)
62
  # =============================================================================
@@ -70,7 +148,10 @@ def generate(
70
  top_k: float = TOP_K,
71
  max_new_tokens: int = MAX_NEW_TOKENS,
72
  ) -> Iterator[str]:
73
- """Generate function for text chat demo with chain of thought display."""
 
 
 
74
  conversation = []
75
  conversation.append({"role": "system", "content": SYS_PROMPT})
76
  conversation.extend(chat_history)
@@ -84,17 +165,17 @@ def generate(
84
  )
85
  input_ids = input_ids.to(text_model.device)
86
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
87
- generate_kwargs = dict(
88
- {"input_ids": input_ids},
89
- streamer=streamer,
90
- max_new_tokens=max_new_tokens,
91
- do_sample=True,
92
- top_p=top_p,
93
- top_k=top_k,
94
- temperature=temperature,
95
- num_beams=1,
96
- repetition_penalty=repetition_penalty,
97
- )
98
  t = Thread(target=text_model.generate, kwargs=generate_kwargs)
99
  t.start()
100
 
@@ -112,43 +193,36 @@ def generate(
112
  reasoning_started = True
113
  reasoning_start_index = current_output.find("<reasoning>") + len("<reasoning>")
114
  collected_reasoning = current_output[reasoning_start_index:]
115
- yield "[Reasoning]: " # Indicate start of reasoning in chatbot
116
- outputs = [collected_reasoning] # Reset outputs to only include reasoning part
117
 
118
  elif reasoning_started and "<answer>" in current_output and not answer_started:
119
  answer_started = True
120
  reasoning_end_index = current_output.find("<answer>")
121
- collected_reasoning = current_output[len("<reasoning>"):reasoning_end_index] # Correctly extract reasoning part
122
 
123
  answer_start_index = current_output.find("<answer>") + len("<answer>")
124
  collected_answer = current_output[answer_start_index:]
125
- yield "\n[Answer]: " # Indicate start of answer in chatbot
126
- outputs = [collected_answer] # Reset outputs to only include answer part
127
- yield collected_answer # Yield initial part of answer
128
 
129
  elif reasoning_started and not answer_started:
130
- collected_reasoning = text # Accumulate reasoning tokens
131
- yield text # Stream reasoning tokens
132
 
133
  elif answer_started:
134
- collected_answer += text # Accumulate answer tokens
135
- yield text # Stream answer tokens
136
- else:
137
- yield text # In case no tags are found, stream as before
138
 
 
 
139
 
140
  # =============================================================================
141
  # Vision Chat Inference Function (for image+text chat)
142
  # =============================================================================
143
- def get_text_from_content(content):
144
- texts = []
145
- for item in content:
146
- if item["type"] == "text":
147
- texts.append(item["text"])
148
- elif item["type"] == "image":
149
- texts.append("<Image>")
150
- return " ".join(texts)
151
-
152
  @spaces.GPU
153
  def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS):
154
  if conversation is None:
@@ -159,7 +233,7 @@ def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, to
159
  if text and text.strip():
160
  user_content.append({"type": "text", "text": text.strip()})
161
  if not user_content:
162
- return display_vision_conversation(conversation), conversation
163
  conversation.append({"role": "user", "content": user_content})
164
  inputs = vision_processor.apply_chat_template(
165
  conversation,
@@ -179,131 +253,89 @@ def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, to
179
  output = vision_model.generate(**inputs, **generation_kwargs)
180
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
181
 
182
-
183
- ### For future versions of Vision with Reasoning
184
- vision_reasoning=False
185
- if vision_reasoning:
186
- reasoning = ""
187
- answer = ""
188
- if "<reasoning>" in assistant_response and "<answer>" in assistant_response:
189
- reasoning_start = assistant_response.find("<reasoning>") + len("<reasoning>")
190
- reasoning_end = assistant_response.find("</reasoning>")
191
- reasoning = assistant_response[reasoning_start:reasoning_end].strip()
192
-
193
- answer_start = assistant_response.find("<answer>") + len("<answer>")
194
- answer_end = assistant_response.find("</answer>")
195
-
196
- if answer_end != -1: # Handle cases where answer end tag is present
197
- answer = assistant_response[answer_start:answer_end].strip()
198
- else: # Fallback if answer end tag is missing (less robust)
199
- answer = assistant_response[answer_start:].strip()
200
- formatted_response_content = []
201
- if reasoning:
202
- formatted_response_content.append({"type": "text", "text": f"[Reasoning]: {reasoning}"})
203
- formatted_response_content.append({"type": "text", "text": f"[Answer]: {answer}"})
204
- conversation.append({"role": "assistant", "content": formatted_response_content})
205
  else:
206
- conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response.strip()}]})
207
-
208
- return display_vision_conversation(conversation), conversation
209
 
210
- # =============================================================================
211
- # Helper Functions to Format Conversation for Display
212
- # =============================================================================
213
- def display_text_conversation(conversation):
214
- """Convert a text conversation (list of dicts) into a list of (user, assistant) tuples."""
215
- chat_history = []
216
- i = 0
217
- while i < len(conversation):
218
- if conversation[i]["role"] == "user":
219
- user_msg = conversation[i]["content"]
220
- assistant_msg = ""
221
- if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
222
- assistant_msg = conversation[i+1]["content"]
223
- i += 2
224
- else:
225
- i += 1
226
- chat_history.append((user_msg, assistant_msg))
227
- else:
228
- i += 1
229
- return chat_history
230
 
231
- def display_vision_conversation(conversation):
232
- """Convert a vision conversation (with mixed content types) into a list of (user, assistant) tuples."""
233
- chat_history = []
234
- i = 0
235
- while i < len(conversation):
236
- if conversation[i]["role"] == "user":
237
- user_msg = get_text_from_content(conversation[i]["content"])
238
- assistant_msg = ""
239
- if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
240
- # Extract assistant text; remove any special tokens if present.
241
- assistant_content = conversation[i+1]["content"]
242
- assistant_text_parts = []
243
- for item in assistant_content:
244
- if item["type"] == "text":
245
- assistant_text_parts.append(item["text"])
246
- assistant_msg = "\n".join(assistant_text_parts).strip()
247
- i += 2
248
- else:
249
- i += 1
250
- chat_history.append((user_msg, assistant_msg))
251
- else:
252
- i += 1
253
- return chat_history
254
  # =============================================================================
255
  # Unified Send-Message Function
 
 
 
 
 
256
  # =============================================================================
257
  def send_message(image, text,
258
  text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens,
259
  vision_temperature, vision_top_p, vision_top_k, vision_max_tokens,
260
- text_state, vision_state):
261
- """
262
- If an image is uploaded, use the vision model; otherwise, use the text model.
263
- Returns updated conversation (as a list of tuples) and state for each branch.
264
- """
 
 
265
  if image is not None:
266
- # Vision branch
267
- conv = vision_state if vision_state is not None else []
268
- chat_history, updated_conv = chat_inference(
269
- image, text, conv,
270
- temperature=vision_temperature,
271
- top_p=vision_top_p,
272
- top_k=vision_top_k,
273
- max_tokens=vision_max_tokens
274
- )
275
- vision_state = updated_conv
276
- # In vision mode, the conversation display is produced from the vision branch.
277
- return chat_history, text_state, vision_state
 
 
 
 
278
  else:
279
- # Text branch
280
- conv = text_state if text_state is not None else []
281
- output_text = ""
 
 
 
 
 
 
282
  for chunk in generate(
283
- text, conv,
284
  temperature=text_temperature,
285
  repetition_penalty=text_repetition_penalty,
286
  top_p=text_top_p,
287
  top_k=text_top_k,
288
  max_new_tokens=text_max_new_tokens
289
  ):
290
- output_text += chunk # Accumulate for display function to process correctly.
 
 
 
291
 
292
- conv.append({"role": "user", "content": text})
293
- conv.append({"role": "assistant", "content": output_text}) # Store full output with tags
294
- text_state = conv
295
- chat_history = display_text_conversation(text_state) # Display function handles tag parsing now.
296
- return chat_history, text_state, vision_state
297
 
 
 
 
298
  def clear_chat():
299
- # Clear the conversation and input fields.
300
- return [], [], [], None # (chat_history, text_state, vision_state, cleared text and image inputs)
 
301
  # =============================================================================
302
  # UI Layout with Gradio
303
  # =============================================================================
304
  css_file_path = Path(Path(__file__).parent / "app.css")
305
  head_file_path = Path(Path(__file__).parent / "app_head.html")
306
- with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
 
307
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
308
  gr.HTML(DESCRIPTION)
309
 
@@ -325,12 +357,17 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
325
  vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"])
326
  vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"])
327
  vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"])
328
- send_button = gr.Button("Send Message")
 
329
  clear_button = gr.Button("Clear Chat")
330
 
331
- # Conversation state variables for each branch.
332
- text_state = gr.State([])
 
 
 
333
  vision_state = gr.State([])
 
334
 
335
  send_button.click(
336
  send_message,
@@ -338,20 +375,21 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
338
  image_input, text_input,
339
  text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider,
340
  vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider,
341
- text_state, vision_state
342
  ],
343
- outputs=[chatbot, text_state, vision_state]
344
  )
345
 
346
  clear_button.click(
347
  clear_chat,
348
  inputs=None,
349
- outputs=[chatbot, text_state, vision_state, text_input, image_input]
350
  )
351
 
352
  gr.Examples(
353
  examples=[
354
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"],
 
355
  [None, "Explain quantum computing to a beginner."],
356
  [None, "What is OpenShift?"],
357
  [None, "Importance of low latency inference"],
@@ -362,6 +400,7 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
362
  inputs=[image_input, text_input],
363
  example_labels=[
364
  "Vision Example: What is in this image?",
 
365
  "Explain quantum computing",
366
  "What is OpenShift?",
367
  "Importance of low latency inference",
@@ -373,4 +412,4 @@ with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_p
373
  )
374
 
375
  if __name__ == "__main__":
376
- demo.queue().launch()
 
1
  """Developed by Ruslan Magana Vsevolodovna"""
2
+
3
  from collections.abc import Iterator
4
  from datetime import datetime
5
  from pathlib import Path
6
  from threading import Thread
7
+ import io
8
+ import base64
9
+ import random
10
+
11
  import gradio as gr
12
  import spaces
13
  import torch
14
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
15
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
16
+
17
  from themes.research_monochrome import theme
18
 
19
  # =============================================================================
20
  # Constants & Prompts
21
  # =============================================================================
22
+ today_date = datetime.today().strftime("%B %-d, %Y")
23
+ SYS_PROMPT = """
24
+ Respond in the following format:
25
+ <reasoning>
26
+ ...
27
+ </reasoning>
28
+ <answer>
29
+ ...
30
+ </answer>
31
+ """
32
+
33
  TITLE = "IBM Granite 3.1 8b Reasoning & Vision Preview"
34
+ DESCRIPTION = """
35
+ <p>Granite 3.1 8b Reasoning is an open‐source LLM supporting a 128k context window and Granite Vision 3.1 2B Preview for vision‐language capabilities. Start with one of the sample prompts
36
+ or enter your own. Keep in mind that AI can occasionally make mistakes.
37
+ <span class="gr_docs_link">
38
+ <a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
39
+ </span>
40
+ </p>
41
+ """
42
  MAX_INPUT_TOKEN_LENGTH = 128_000
43
  MAX_NEW_TOKENS = 1024
44
+ TEMPERATURE = 0.5
45
  TOP_P = 0.85
46
  TOP_K = 50
47
  REPETITION_PENALTY = 1.05
48
+
49
  # Vision defaults (advanced settings)
50
  VISION_TEMPERATURE = 0.2
51
  VISION_TOP_P = 0.95
 
54
 
55
  if not torch.cuda.is_available():
56
  print("This demo may not work on CPU.")
57
+
58
  # =============================================================================
59
  # Text Model Loading
60
  # =============================================================================
61
+
62
+ granite_text_model = "ruslanmv/granite-3.1-8b-Reasoning"
63
+
 
64
  text_model = AutoModelForCausalLM.from_pretrained(
65
  granite_text_model,
66
  torch_dtype=torch.float16,
 
68
  )
69
  tokenizer = AutoTokenizer.from_pretrained(granite_text_model)
70
  tokenizer.use_default_system_prompt = False
71
+
72
  # =============================================================================
73
  # Vision Model Loading
74
  # =============================================================================
 
78
  vision_model_path,
79
  torch_dtype=torch.float16,
80
  device_map="auto",
81
+ trust_remote_code=True # Ensure the custom code is used so that weight shapes match.
82
  )
83
+
84
+ # =============================================================================
85
+ # Unified Display Function
86
+ # =============================================================================
87
+ def get_text_from_content(content):
88
+ """Helper to extract text from a list of content items."""
89
+ texts = []
90
+ for item in content:
91
+ if isinstance(item, dict):
92
+ if item.get("type") == "text":
93
+ texts.append(item.get("text", ""))
94
+ elif item.get("type") == "image":
95
+ image = item.get("image")
96
+ if image is not None:
97
+ buffered = io.BytesIO()
98
+ image.save(buffered, format="JPEG")
99
+ img_str = base64.b64encode(buffered.getvalue()).decode()
100
+ texts.append(f'<img src="data:image/jpeg;base64,{img_str}" style="max-width: 200px; max-height: 200px;">')
101
+ else:
102
+ texts.append("<image>")
103
+ else:
104
+ texts.append(str(item))
105
+ return " ".join(texts)
106
+
107
+ def display_unified_conversation(conversation):
108
+ """
109
+ Combine both text-only and vision messages.
110
+ Each conversation entry is expected to be a dict with keys:
111
+ - role: "user" or "assistant"
112
+ - content: either a string (for text) or a list of content items (for vision)
113
+ """
114
+ chat_history = []
115
+ i = 0
116
+ while i < len(conversation):
117
+ if conversation[i]["role"] == "user":
118
+ user_content = conversation[i]["content"]
119
+ if isinstance(user_content, list):
120
+ user_msg = get_text_from_content(user_content)
121
+ else:
122
+ user_msg = user_content
123
+ assistant_msg = ""
124
+ if i + 1 < len(conversation) and conversation[i+1]["role"] == "assistant":
125
+ asst_content = conversation[i+1]["content"]
126
+ if isinstance(asst_content, list):
127
+ assistant_msg = get_text_from_content(asst_content)
128
+ else:
129
+ assistant_msg = asst_content
130
+ i += 2
131
+ else:
132
+ i += 1
133
+ chat_history.append((user_msg, assistant_msg))
134
+ else:
135
+ i += 1
136
+ return chat_history
137
+
138
  # =============================================================================
139
  # Text Generation Function (for text-only chat)
140
  # =============================================================================
 
148
  top_k: float = TOP_K,
149
  max_new_tokens: int = MAX_NEW_TOKENS,
150
  ) -> Iterator[str]:
151
+ """
152
+ Generate function for text chat. It streams tokens and stops once the generated answer
153
+ contains the closing </answer> tag.
154
+ """
155
  conversation = []
156
  conversation.append({"role": "system", "content": SYS_PROMPT})
157
  conversation.extend(chat_history)
 
165
  )
166
  input_ids = input_ids.to(text_model.device)
167
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
168
+ generate_kwargs = {
169
+ "input_ids": input_ids,
170
+ "streamer": streamer,
171
+ "max_new_tokens": max_new_tokens,
172
+ "do_sample": True,
173
+ "top_p": top_p,
174
+ "top_k": top_k,
175
+ "temperature": temperature,
176
+ "num_beams": 1,
177
+ "repetition_penalty": repetition_penalty,
178
+ }
179
  t = Thread(target=text_model.generate, kwargs=generate_kwargs)
180
  t.start()
181
 
 
193
  reasoning_started = True
194
  reasoning_start_index = current_output.find("<reasoning>") + len("<reasoning>")
195
  collected_reasoning = current_output[reasoning_start_index:]
196
+ yield "[Reasoning]: "
197
+ outputs = [collected_reasoning]
198
 
199
  elif reasoning_started and "<answer>" in current_output and not answer_started:
200
  answer_started = True
201
  reasoning_end_index = current_output.find("<answer>")
202
+ collected_reasoning = current_output[len("<reasoning>"):reasoning_end_index]
203
 
204
  answer_start_index = current_output.find("<answer>") + len("<answer>")
205
  collected_answer = current_output[answer_start_index:]
206
+ yield "\n[Answer]: "
207
+ outputs = [collected_answer]
208
+ yield collected_answer
209
 
210
  elif reasoning_started and not answer_started:
211
+ collected_reasoning += text
212
+ yield text
213
 
214
  elif answer_started:
215
+ collected_answer += text
216
+ yield text
217
+ if "</answer>" in collected_answer:
218
+ break
219
 
220
+ else:
221
+ yield text
222
 
223
  # =============================================================================
224
  # Vision Chat Inference Function (for image+text chat)
225
  # =============================================================================
 
 
 
 
 
 
 
 
 
226
  @spaces.GPU
227
  def chat_inference(image, text, conversation, temperature=VISION_TEMPERATURE, top_p=VISION_TOP_P, top_k=VISION_TOP_K, max_tokens=VISION_MAX_TOKENS):
228
  if conversation is None:
 
233
  if text and text.strip():
234
  user_content.append({"type": "text", "text": text.strip()})
235
  if not user_content:
236
+ return display_unified_conversation(conversation), conversation
237
  conversation.append({"role": "user", "content": user_content})
238
  inputs = vision_processor.apply_chat_template(
239
  conversation,
 
253
  output = vision_model.generate(**inputs, **generation_kwargs)
254
  assistant_response = vision_processor.decode(output[0], skip_special_tokens=True)
255
 
256
+ if "<|assistant|>" in assistant_response:
257
+ assistant_response_parts = assistant_response.split("<|assistant|>")
258
+ assistant_response_text = assistant_response_parts[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  else:
260
+ assistant_response_text = assistant_response.strip()
 
 
261
 
262
+ conversation.append({"role": "assistant", "content": [{"type": "text", "text": assistant_response_text.strip()}]})
263
+ return display_unified_conversation(conversation), conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # =============================================================================
266
  # Unified Send-Message Function
267
+ #
268
+ # We now maintain two histories:
269
+ # - unified_state: complete conversation (for display)
270
+ # - internal_text_state: only text turns (for text generation)
271
+ # Vision turns update only unified_state.
272
  # =============================================================================
273
  def send_message(image, text,
274
  text_temperature, text_repetition_penalty, text_top_p, text_top_k, text_max_new_tokens,
275
  vision_temperature, vision_top_p, vision_top_k, vision_max_tokens,
276
+ unified_state, vision_state, internal_text_state):
277
+ # Initialize states if empty
278
+ if unified_state is None:
279
+ unified_state = []
280
+ if internal_text_state is None:
281
+ internal_text_state = []
282
+
283
  if image is not None:
284
+ # Use vision inference.
285
+ user_msg = []
286
+ user_msg.append({"type": "image", "image": image})
287
+ if text and text.strip():
288
+ user_msg.append({"type": "text", "text": text.strip()})
289
+ unified_state.append({"role": "user", "content": user_msg})
290
+ chat_history, updated_vision_conv = chat_inference(image, text, vision_state,
291
+ temperature=vision_temperature,
292
+ top_p=vision_top_p,
293
+ top_k=vision_top_k,
294
+ max_tokens=vision_max_tokens)
295
+ vision_state = updated_vision_conv
296
+ if updated_vision_conv and updated_vision_conv[-1]["role"] == "assistant":
297
+ unified_state.append(updated_vision_conv[-1])
298
+ yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state
299
+
300
  else:
301
+ # Text-only mode: update both unified and internal text states.
302
+ unified_state.append({"role": "user", "content": text})
303
+ internal_text_state.append({"role": "user", "content": text})
304
+ unified_state.append({"role": "assistant", "content": ""})
305
+ internal_text_state.append({"role": "assistant", "content": ""})
306
+ yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state
307
+
308
+ base_conv = internal_text_state[:-1]
309
+ assistant_text = ""
310
  for chunk in generate(
311
+ text, base_conv,
312
  temperature=text_temperature,
313
  repetition_penalty=text_repetition_penalty,
314
  top_p=text_top_p,
315
  top_k=text_top_k,
316
  max_new_tokens=text_max_new_tokens
317
  ):
318
+ assistant_text += chunk
319
+ unified_state[-1]["content"] = assistant_text
320
+ internal_text_state[-1]["content"] = assistant_text
321
+ yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state
322
 
323
+ yield display_unified_conversation(unified_state), unified_state, vision_state, internal_text_state
 
 
 
 
324
 
325
+ # =============================================================================
326
+ # Clear Chat Function
327
+ # =============================================================================
328
  def clear_chat():
329
+ # Clear unified conversation, vision state, and internal text state.
330
+ return [], [], [], "", None
331
+
332
  # =============================================================================
333
  # UI Layout with Gradio
334
  # =============================================================================
335
  css_file_path = Path(Path(__file__).parent / "app.css")
336
  head_file_path = Path(Path(__file__).parent / "app_head.html")
337
+
338
+ with gr.Blocks(fill_height=True, css_paths=[str(css_file_path)], head_paths=[str(head_file_path)], theme=theme, title=TITLE) as demo:
339
  gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
340
  gr.HTML(DESCRIPTION)
341
 
 
357
  vision_top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, value=VISION_TOP_P, step=0.01, label="Vision Top p", elem_classes=["gr_accordion_element"])
358
  vision_top_k_slider = gr.Slider(minimum=0, maximum=100, value=VISION_TOP_K, step=1, label="Vision Top k", elem_classes=["gr_accordion_element"])
359
  vision_max_tokens_slider = gr.Slider(minimum=10, maximum=300, value=VISION_MAX_TOKENS, step=1, label="Vision Max Tokens", elem_classes=["gr_accordion_element"])
360
+
361
+ send_button = gr.Button("Send Message")
362
  clear_button = gr.Button("Clear Chat")
363
 
364
+ # Conversation state variables:
365
+ # - unified_state: complete conversation for display (text and vision)
366
+ # - vision_state: state for vision turns
367
+ # - internal_text_state: only text turns (for text-generation)
368
+ unified_state = gr.State([])
369
  vision_state = gr.State([])
370
+ internal_text_state = gr.State([])
371
 
372
  send_button.click(
373
  send_message,
 
375
  image_input, text_input,
376
  text_temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider,
377
  vision_temperature_slider, vision_top_p_slider, vision_top_k_slider, vision_max_tokens_slider,
378
+ unified_state, vision_state, internal_text_state
379
  ],
380
+ outputs=[chatbot, unified_state, vision_state, internal_text_state],
381
  )
382
 
383
  clear_button.click(
384
  clear_chat,
385
  inputs=None,
386
+ outputs=[chatbot, unified_state, vision_state, internal_text_state, text_input, image_input]
387
  )
388
 
389
  gr.Examples(
390
  examples=[
391
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/cheetah1.jpg", "What is in this image?"],
392
+ [None, "Compute Pi."],
393
  [None, "Explain quantum computing to a beginner."],
394
  [None, "What is OpenShift?"],
395
  [None, "Importance of low latency inference"],
 
400
  inputs=[image_input, text_input],
401
  example_labels=[
402
  "Vision Example: What is in this image?",
403
+ "Compute Pi.",
404
  "Explain quantum computing",
405
  "What is OpenShift?",
406
  "Importance of low latency inference",
 
412
  )
413
 
414
  if __name__ == "__main__":
415
+ demo.queue().launch(debug=True, share=False)