import os import logging from pathlib import Path import torch from typing import Optional, Dict, Any, Union, List from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig logger = logging.getLogger(__name__) # Constants for TinyLlama model MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Default system prompt for creative expansion DEFAULT_CREATIVE_SYSTEM_PROMPT = """You are an expert assistant that helps expand creative prompts for image generation. Your goal is to enrich the original prompt with vivid visual details, artistic style suggestions, and composition elements. Focus on visual enhancement only. Keep your response concise (under 40 words) and focus entirely on the expanded prompt. Do not include explanations or comments - only return the enhanced prompt text.""" # Default user template for creative expansion DEFAULT_CREATIVE_USER_TEMPLATE = """Original prompt: {original_prompt} Please expand this into a detailed, vivid prompt for image generation with rich visual elements, mood, and style.""" class LocalLLM: """ A local LLM implementation using TinyLlama-1.1B-Chat. Provides methods to generate text and expand creative prompts. """ def __init__(self, model_id: str = MODEL_ID): """ Initialize the local LLM. Args: model_id: The model ID to load """ self.model_id = model_id self.device = "cuda" if torch.cuda.is_available() else "cpu" # Configure quantization for efficient memory usage quantization_config = None if self.device == "cuda": quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, ) # Load model and tokenizer logger.info(f"Loading TinyLlama model on {self.device}...") self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.model = AutoModelForCausalLM.from_pretrained( self.model_id, quantization_config=quantization_config, device_map="auto" if self.device == "cuda" else None, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, ) logger.info(f"TinyLlama model loaded successfully") def generate( self, prompt: str, system_prompt: Optional[str] = None, max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.9, ) -> str: """ Generate text based on a prompt. Args: prompt: The user prompt to generate from system_prompt: Optional system prompt to guide the generation max_tokens: Maximum number of tokens to generate temperature: Sampling temperature (higher = more creative) top_p: Top-p sampling parameter Returns: The generated text """ messages = [] # Add system message if provided if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Add user message messages.append({"role": "user", "content": prompt}) # Format messages for the model prompt_text = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Generate response inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device) outputs = self.model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, ) # Decode the response and extract the assistant's message full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the assistant's response assistant_response = full_output[len(prompt_text) :].strip() return assistant_response def expand_creative_prompt(self, original_prompt: str) -> str: """ Expand a creative prompt with rich details for better image generation. Args: original_prompt: The user's original prompt Returns: An expanded, detailed creative prompt """ # Format the prompt for creative expansion prompt = DEFAULT_CREATIVE_USER_TEMPLATE.format(original_prompt=original_prompt) # Generate expanded prompt expanded = self.generate( prompt=prompt, system_prompt=DEFAULT_CREATIVE_SYSTEM_PROMPT, max_tokens=150, # Limit to ensure concise responses temperature=0.7, # Balanced creativity ) return expanded def get_llm_instance() -> LocalLLM: """ Get an instance of the local LLM. Returns: An initialized LocalLLM instance """ return LocalLLM(model_id=MODEL_ID)