File size: 10,005 Bytes
f3b1002
 
 
 
 
 
 
 
 
 
 
 
e86a765
d48854d
f3b1002
 
 
d48854d
f3b1002
 
 
 
 
 
 
 
 
ff334f0
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d48854d
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8de4827
f3b1002
 
 
8de4827
f3b1002
 
 
 
8de4827
f3b1002
8de4827
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
e86a765
f3b1002
 
 
 
 
 
 
 
 
 
3e463de
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e463de
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e463de
f3b1002
 
 
 
 
 
3e463de
f3b1002
 
 
 
 
 
 
3e463de
f3b1002
 
 
 
3e463de
d48854d
 
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e463de
f3b1002
 
 
 
 
3e463de
f3b1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import subprocess
subprocess.run(
    'pip install flash-attn==2.7.0.post2 --no-build-isolation',
    env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
    shell=True
)
subprocess.run(
    'pip install transformers',
    shell=True
)


import spaces
import os
import re
import logging
from typing import List
from threading import Thread
import base64

import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# ----------------------------------------------------------------------
# 1. Setup Model & Tokenizer
# ----------------------------------------------------------------------
model_name = 'prithivMLmods/Raptor-X6'  # Change as needed
use_thread = True  # Generation happens in a background thread

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

logging.getLogger("httpx").setLevel(logging.WARNING)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ----------------------------------------------------------------------
# 2. Two-Phase Prompt Templates
# ----------------------------------------------------------------------
s1_inference_prompt_think_only = """<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
<|im_start|>think
"""

# ----------------------------------------------------------------------
# 3. Generation Parameter Setup
# ----------------------------------------------------------------------
THINK_MAX_NEW_TOKENS = 12000
ANSWER_MAX_NEW_TOKENS = 12000

def initialize_gen_kwargs():
    return {
        "max_new_tokens": 1024,  # default; will be overwritten per phase
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.05,
        # "eos_token_id": model.generation_config.eos_token_id,  # Removed to avoid premature stopping
        "pad_token_id": tokenizer.pad_token_id,
        "use_cache": True,
        "streamer": None  # dynamically added
    }

# ----------------------------------------------------------------------
# 4. Helper to submit chat
# ----------------------------------------------------------------------
def submit_chat(chatbot, text_input):
    if not text_input.strip():
        return chatbot, ""
    response = ""
    chatbot.append((text_input, response))
    return chatbot, ""

# ----------------------------------------------------------------------
# 5. Artifacts Handling
#    We parse code from the final answer and display it in an iframe
# ----------------------------------------------------------------------
def extract_html_code_block(text: str) -> str:
    """
    Look for a ```html ... ``` block in the text. 
    If found, return only that block content. 
    Otherwise, return the entire text.
    """
    pattern = r'```html\s*(.*?)\s*```'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return text.strip()

def send_to_sandbox(html_code: str) -> str:
    """
    Convert the code to a data URI iframe so it can be rendered
    inside Gradio HTML component.
    """
    encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8')
    data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
    return f'<iframe src="{data_uri}" width="100%" height="920px"></iframe>'

# ----------------------------------------------------------------------
# 6. The Two-Phase Streaming Inference
#    - Phase 1: "think" (chain-of-thought)
#    - Phase 2: "answer"
# ----------------------------------------------------------------------
@spaces.GPU
def ovis_chat(chatbot: List[List[str]]):
    # Phase 1: chain-of-thought
    last_query = chatbot[-1][0]
    formatted_think_prompt = s1_inference_prompt_think_only.format(question=last_query)
    input_ids_think = tokenizer.encode(formatted_think_prompt, return_tensors="pt").to(model.device)
    attention_mask_think = torch.ne(input_ids_think, tokenizer.pad_token_id).to(model.device)

    think_inputs = {
        "input_ids": input_ids_think,
        "attention_mask": attention_mask_think
    }
    gen_kwargs_think = initialize_gen_kwargs()
    gen_kwargs_think["max_new_tokens"] = THINK_MAX_NEW_TOKENS
    think_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs_think["streamer"] = think_streamer

    full_think = ""
    with torch.inference_mode():
        thread_think = Thread(target=lambda: model.generate(**think_inputs, **gen_kwargs_think))
        thread_think.start()
        for new_text in think_streamer:
            full_think += new_text
            display_text = f"<|im_start|>think\n{full_think.strip()}"
            chatbot[-1][1] = display_text
            yield chatbot, ""  # second return is artifact placeholder
        thread_think.join()

    # Phase 2: answer
    new_prompt = formatted_think_prompt + full_think.strip() + "\n<|im_start|>answer\n"
    input_ids_answer = tokenizer.encode(new_prompt, return_tensors="pt").to(model.device)
    attention_mask_answer = torch.ne(input_ids_answer, tokenizer.pad_token_id).to(model.device)

    answer_inputs = {
        "input_ids": input_ids_answer,
        "attention_mask": attention_mask_answer
    }
    gen_kwargs_answer = initialize_gen_kwargs()
    gen_kwargs_answer["max_new_tokens"] = ANSWER_MAX_NEW_TOKENS
    answer_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs_answer["streamer"] = answer_streamer

    full_answer = ""
    with torch.inference_mode():
        thread_answer = Thread(target=lambda: model.generate(**answer_inputs, **gen_kwargs_answer))
        thread_answer.start()
        for new_text in answer_streamer:
            full_answer += new_text
            display_text = (
                f"<|im_start|>think\n{full_think.strip()}\n\n"
                f"<|im_start|>answer\n{full_answer.strip()}"
            )
            chatbot[-1][1] = display_text
            yield chatbot, ""
        thread_answer.join()

    log_conversation(chatbot)

    # Once final answer is complete, parse out HTML code block and
    # return it as an artifact (iframe).
    html_code = extract_html_code_block(full_answer)
    sandbox_iframe = send_to_sandbox(html_code)
    yield chatbot, sandbox_iframe

# ----------------------------------------------------------------------
# 7. Logging and Clearing
# ----------------------------------------------------------------------
def log_conversation(chatbot: List[List[str]]):
    logger.info("[CONVERSATION]")
    for i, (query, response) in enumerate(chatbot, 1):
        logger.info(f"Q{i}: {query}\nA{i}: {response}")

def clear_chat():
    return [], "", ""

# ----------------------------------------------------------------------
# 8. Gradio UI Setup
# ----------------------------------------------------------------------
css_code = """
.left_header {
  display: flex;
  flex-direction: column;
  justify-content: center;
  align-items: center;
}

.right_panel {
  margin-top: 16px;
  border: 1px solid #BFBFC4;
  border-radius: 8px;
  overflow: hidden;
}

.render_header {
  height: 30px;
  width: 100%;
  padding: 5px 16px;
  background-color: #f5f5f5;
}

.header_btn {
  display: inline-block;
  height: 10px;
  width: 10px;
  border-radius: 50%;
  margin-right: 4px;
}

.render_header > .header_btn:nth-child(1) {
  background-color: #f5222d;
}

.render_header > .header_btn:nth-child(2) {
  background-color: #faad14;
}
.render_header > .header_btn:nth-child(3) {
  background-color: #52c41a;
}

.right_content {
  height: 920px; 
  display: flex; 
  flex-direction: column;
  justify-content: center;
  align-items: center;
}

.html_content {
  width: 100%;
  height: 920px;
}
"""

svg_content = """
<svg width="40" height="40" viewBox="0 0 45 45" fill="none" xmlns="http://www.w3.org/2000/svg">
  <circle cx="22.5" cy="22.5" r="22.5" fill="#5572F9"/>
  <path d="M22.5 11.25L26.25 16.875H18.75L22.5 11.25Z" fill="white"/>
  <path d="M22.5 33.75L26.25 28.125H18.75L22.5 33.75Z" fill="white"/>
  <path d="M28.125 22.5L22.5 28.125L16.875 22.5L22.5 16.875L28.125 22.5Z" fill="white"/>
</svg>
"""

with gr.Blocks(title=model_name.split('/')[-1], css=css_code) as demo:
    gr.HTML(f"""
        <div class="left_header" style="margin-bottom: 20px;">
            {svg_content}
            <h1>{model_name.split('/')[-1]} - Chat + Artifacts</h1>
            <p>(Two-phase chain-of-thought with artifact extraction)</p>
        </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(
                label="Chat",
                height=520,
                show_copy_button=True
            )
            with gr.Row():
                text_input = gr.Textbox(
                    label="Prompt",
                    placeholder="Enter your query...",
                    lines=1
                )
            with gr.Row():
                submit_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear", variant="secondary")
        with gr.Column(scale=6):
            gr.HTML('<div class="render_header"><span class="header_btn"></span><span class="header_btn"></span><span class="header_btn"></span></div>')
            artifact_html = gr.HTML(
                value="",
                elem_classes="html_content"
            )

    submit_btn.click(
        submit_chat, [chatbot, text_input], [chatbot, text_input]
    ).then(
        ovis_chat, [chatbot], [chatbot, artifact_html]
    )

    text_input.submit(
        submit_chat, [chatbot, text_input], [chatbot, text_input]
    ).then(
        ovis_chat, [chatbot], [chatbot, artifact_html]
    )

    clear_btn.click(
        clear_chat,
        outputs=[chatbot, text_input, artifact_html]
    )

demo.queue(default_concurrency_limit=1).launch(server_name="0.0.0.0", share=True)