import subprocess import os import re import logging from typing import List from threading import Thread import base64 # --- Install Dependencies --- print("Attempting to install Flash Attention 2...") _flash_attn_2_available = False try: # Note: On A100, building flash-attn usually works without skipping CUDA build subprocess.run( 'pip install flash-attn==2.5.8 --no-build-isolation', shell=True, check=True ) print("Flash Attention installed successfully.") _flash_attn_2_available = True except Exception as e: print(f"Flash Attention installation failed: {e}. Will attempt fallback.") try: print("Retrying Flash Attention install without --no-build-isolation...") subprocess.run('pip install flash-attn==2.5.8', shell=True, check=True) print("Flash Attention installed successfully on retry.") _flash_attn_2_available = True except Exception as e2: print(f"Flash Attention installation failed on retry: {e2}") print("Proceeding without Flash Attention 2. Performance may be impacted.") _flash_attn_2_available = False print("Installing transformers, accelerate, and bitsandbytes...") subprocess.run( # Add bitsandbytes needed for 8-bit quantization 'pip install --upgrade transformers>=4.40.0 accelerate>=0.25.0 bitsandbytes>=0.41.1', shell=True, check=True ) print("transformers, accelerate, and bitsandbytes installed successfully.") print("Installing Gradio...") subprocess.run( 'pip install --upgrade gradio', shell=True, check=True ) print("Gradio installed successfully.") # --- Import AFTER potential installs --- import spaces import torch import gradio as gr # Import BitsAndBytesConfig for 8-bit from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig # ---------------------------------------------------------------------- # 1. Setup Model & Tokenizer (Using 8-bit Quantization from Script 2) # ---------------------------------------------------------------------- model_name = 'Tesslate/UIGEN-T3-4B-Preview' # Or use Tesslate/UIGEN-T2-7B-3600 if preferred use_thread = True # For streaming UI updates # Determine optimal attention implementation attn_implementation = "flash_attention_2" if _flash_attn_2_available else "sdpa" print(f"Attempting to use attention implementation: {attn_implementation}") # --- 8-bit Quantization Config (from Script 2) --- print(f"Loading Model: {model_name} with 8-bit quantization") quantization_config = BitsAndBytesConfig(load_in_8bit=True) # --- Model Loading (from Script 2) --- try: model = AutoModelForCausalLM.from_pretrained( model_name, # token=hf_token, # Usually not needed for public models on Spaces device_map="auto", # Use accelerate quantization_config=quantization_config, # Apply 8-bit config attn_implementation=attn_implementation, trust_remote_code=True ) print(f"Model loaded successfully with 8-bit quantization and {attn_implementation} attention.") except Exception as e: print(f"Error loading model with {attn_implementation}: {e}") if attn_implementation == "flash_attention_2": print("Falling back to 'sdpa' attention implementation...") attn_implementation = "sdpa" try: model = AutoModelForCausalLM.from_pretrained( model_name, # token=hf_token, device_map="auto", quantization_config=quantization_config, attn_implementation=attn_implementation, trust_remote_code=True ) print(f"Model loaded successfully with 8-bit quantization and SDPA attention.") except Exception as e2: print(f"Fallback to SDPA attention also failed: {e2}") # Try without explicit attn implementation as last resort try: model = AutoModelForCausalLM.from_pretrained( model_name, # token=hf_token, device_map="auto", quantization_config=quantization_config, trust_remote_code=True ) print(f"Model loaded successfully with 8-bit quantization (default attention).") except Exception as e3: raise RuntimeError(f"Failed to load model: {e3}") from e3 else: raise RuntimeError(f"Failed to load model: {e}") from e print(f"Loading tokenizer: {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) # token=hf_token removed print("Tokenizer loaded.") # Set pad token if it's not set (important for generation) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = tokenizer.eos_token_id # System Prompt SYSTEM_PROMPT = "you are tesslate, a UI engine" logging.getLogger("transformers").setLevel(logging.WARNING) # Reduce transformers logging verbosity logging.getLogger("httpx").setLevel(logging.WARNING) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ---------------------------------------------------------------------- # 2. Generation Parameter Setup (Values can be adjusted) # ---------------------------------------------------------------------- # Using similar defaults as Script 2, but can be tuned MAX_NEW_TOKENS = 20000 # Adjusted back, but can be increased if needed TEMPERATURE = 0.6 TOP_P = 0.95 # Slightly lower than Script 2 default TOP_K = 20 # Added based on Script 2 defaults REPETITION_PENALTY = 1.0 # Based on Script 2 defaults # Note: initialize_gen_kwargs function is removed as params are used directly now # ---------------------------------------------------------------------- # 3. Helper to submit chat (Unchanged) # ---------------------------------------------------------------------- def submit_chat(chatbot, text_input): if not text_input.strip(): gr.Warning("Please enter a prompt.") return chatbot, "" response = "" chatbot.append([text_input, response]) return chatbot, "" # ---------------------------------------------------------------------- # 4. Artifacts Handling (Unchanged) # ---------------------------------------------------------------------- def extract_html_code_block(text: str) -> str: pattern = r'```html\s*(.*?)\s*```' match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) if match: return match.group(1).strip() else: trimmed_text = text.strip() if trimmed_text.startswith((' 5: return trimmed_text return "" def send_to_sandbox(html_code: str) -> str: if not html_code: return "

No HTML content generated or extracted.

" try: encoded_html = base64.b64encode(html_code.encode('utf-8')).decode('utf-8') data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}" return f'' except Exception as e: logger.error(f"Error encoding HTML for sandbox: {e}") escaped_code = html_code.replace('&', '&').replace('<', '<').replace('>', '>') return f"

Error displaying HTML preview:

{e}

{escaped_code[:500]}...
" # ---------------------------------------------------------------------- # 5. Single-Phase Streaming Inference (Using Script 2's generation logic) # ---------------------------------------------------------------------- @spaces.GPU() def single_pass_chat(chatbot: List[List[str]]): if not chatbot or not chatbot[-1][0]: yield chatbot, "

Enter a prompt to start.

" return last_query = chatbot[-1][0] # Use the SYSTEM_PROMPT defined globally messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": last_query} ] try: # --- Prompt Formatting (from Script 2) --- try: # Get the full prompt string formatted by the tokenizer full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception as e: logger.warning(f"Using fallback prompt format due to error: {e}") prompt_parts = [f"System: {SYSTEM_PROMPT}"] if SYSTEM_PROMPT else [] prompt_parts.append(f"\nUser: {last_query}\nAssistant:") full_prompt = "\n".join(prompt_parts) # --- Tokenization (from Script 2) --- # Tokenize the formatted string and move to device inputs = tokenizer( full_prompt, return_tensors="pt", truncation=True, # Add truncation max_length=4096 # Set max length based on model limits / desired context ).to(model.device) # Use model.device set by device_map="auto" logger.info(f"Input tensors moved to device: {model.device}") # --- Streamer Setup (Similar in both scripts) --- streamer = TextIteratorStreamer( tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True ) # --- Generation Kwargs (from Script 2, using constants defined above) --- # Note: 'inputs' are passed separately in model.generate call generation_kwargs = dict( # Removed 'inputs' dict here, will pass separately streamer=streamer, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE if TEMPERATURE > 0 else None, top_p=TOP_P, top_k=TOP_K, repetition_penalty=REPETITION_PENALTY, do_sample=True if TEMPERATURE > 0 else False, pad_token_id=tokenizer.pad_token_id, # Ensure pad_token_id is set eos_token_id=tokenizer.eos_token_id ) # Adjust for deterministic generation (temp=0) if TEMPERATURE == 0: generation_kwargs.pop('top_p', None) generation_kwargs.pop('top_k', None) generation_kwargs['do_sample'] = False # --- Generation Thread --- full_response = "" # Pass inputs explicitly along with generation_kwargs thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs}) thread.start() # --- Streaming Output --- logger.info("Starting generation stream...") artifact_html_content = "

Generating UI...

" for new_text in streamer: full_response += new_text chatbot[-1][1] = full_response.strip() yield chatbot, artifact_html_content thread.join() logger.info("Generation stream finished.") chatbot[-1][1] = full_response.strip() log_conversation(chatbot) html_code = extract_html_code_block(full_response) sandbox_iframe = send_to_sandbox(html_code) yield chatbot, sandbox_iframe except torch.cuda.OutOfMemoryError as e: logger.error(f"CUDA OutOfMemoryError during generation: {e}", exc_info=True) torch.cuda.empty_cache() chatbot[-1][1] = f"Error: Ran out of GPU memory. Try a shorter prompt or clear the chat." yield chatbot, "

Generation failed due to OOM Error.

" except Exception as e: logger.error(f"Error during generation: {e}", exc_info=True) if 'cuda' in str(e).lower(): torch.cuda.empty_cache() chatbot[-1][1] = f"An error occurred during generation: {str(e)}" yield chatbot, f"

Generation failed: {str(e)}

" # ---------------------------------------------------------------------- # 6. Logging and Clearing (Unchanged) # ---------------------------------------------------------------------- def log_conversation(chatbot: List[List[str]]): logger.info("[CONVERSATION TURN]") if not chatbot: logger.info("Conversation is empty."); return i = len(chatbot) query, response = chatbot[-1] logger.info(f"Q{i}: {query}") logger.info(f"A{i} (Preview): {response[:500]}...") def clear_chat(): logger.info("Chat cleared.") torch.cuda.empty_cache() # Clear GPU cache return [], "", "

Chat cleared. Enter a new prompt.

" # ---------------------------------------------------------------------- # 7. Gradio UI Setup (Unchanged) # ---------------------------------------------------------------------- css_code = """ body, .gradio-container { font-family: sans-serif; } .left_header { display: flex; flex-direction: column; justify-content: center; align-items: center; padding: 10px 0; } .right_panel { margin-top: 16px; border: 1px solid #E0E0E0; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 4px rgba(0,0,0,0.05); } .render_header { height: 30px; width: 100%; padding: 5px 16px; background-color: #f0f0f0; border-bottom: 1px solid #E0E0E0; display: flex; align-items: center; } .header_btn { display: inline-block; height: 12px; width: 12px; border-radius: 50%; margin-right: 6px; } .render_header > .header_btn:nth-child(1) { background-color: #ff5f57; } .render_header > .header_btn:nth-child(2) { background-color: #ffbd2e; } .render_header > .header_btn:nth-child(3) { background-color: #28c940; } .right_content { height: 600px; display: flex; flex-direction: column; justify-content: center; align-items: center; background-color: #ffffff; } .html_content { width: 100%; height: 100%; padding: 5px; } .html_content iframe { border: none; } .gr-chatbot { border-radius: 8px; box-shadow: 0 1px 3px rgba(0,0,0,0.05); } footer {display: none !important;} """ svg_content = """ """ model_display_name = model_name.split('/')[-1] + " (8-bit)" # Indicate quantization with gr.Blocks(title=f"{model_display_name} UI Engine", css=css_code, theme=gr.themes.Soft()) as demo: with gr.Row(): with gr.Column(scale=1): pass with gr.Column(scale=3): gr.HTML(f"""
{svg_content}

{model_display_name}

UI Generation Engine (8-bit Quantization)

""") with gr.Column(scale=1): pass with gr.Row(): with gr.Column(scale=4): chatbot = gr.Chatbot( label="Chat History", height=550, show_copy_button=True, bubble_full_width=False ) with gr.Row(): text_input = gr.Textbox( label="Prompt", placeholder="Enter your UI request...", lines=2, scale=5 ) submit_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("Clear Chat", variant="secondary") with gr.Column(scale=6): with gr.Group(elem_classes="right_panel"): gr.HTML('
') with gr.Column(elem_classes="right_content"): artifact_html = gr.HTML( value="

Generated UI will appear here...

", elem_classes="html_content" ) # --- Event Listeners --- submit_btn.click( submit_chat, inputs=[chatbot, text_input], outputs=[chatbot, text_input] ).then( single_pass_chat, inputs=[chatbot], outputs=[chatbot, artifact_html], api_name="generate_ui_8bit" # Renamed API ) text_input.submit( submit_chat, inputs=[chatbot, text_input], outputs=[chatbot, text_input] ).then( single_pass_chat, inputs=[chatbot], outputs=[chatbot, artifact_html], api_name="generate_ui_8bit_enter" # Renamed API ) clear_btn.click( clear_chat, outputs=[chatbot, text_input, artifact_html], api_name="clear" ) # --- Launch the app --- print("Launching Gradio Interface (8-bit Quantization)...") # Using queue settings from Script 2 example if __name__ == "__main__": demo.queue().launch(debug=True, share=False) # Use share=True for public link if needed