Spaces:
Running
Running
import os | |
import logging | |
from pathlib import Path | |
import torch | |
from typing import Optional, Dict, Any, Union, List | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
logger = logging.getLogger(__name__) | |
# Constants for TinyLlama model | |
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
# Default system prompt for creative expansion | |
DEFAULT_CREATIVE_SYSTEM_PROMPT = """You are an expert assistant that helps expand creative prompts for image generation. | |
Your goal is to enrich the original prompt with vivid visual details, artistic style suggestions, and composition elements. | |
Focus on visual enhancement only. Keep your response concise (under 40 words) and focus entirely on the expanded prompt. | |
Do not include explanations or comments - only return the enhanced prompt text.""" | |
# Default user template for creative expansion | |
DEFAULT_CREATIVE_USER_TEMPLATE = """Original prompt: {original_prompt} | |
Please expand this into a detailed, vivid prompt for image generation with rich visual elements, mood, and style.""" | |
class LocalLLM: | |
""" | |
A local LLM implementation using TinyLlama-1.1B-Chat. | |
Provides methods to generate text and expand creative prompts. | |
""" | |
def __init__(self, model_id: str = MODEL_ID): | |
""" | |
Initialize the local LLM. | |
Args: | |
model_id: The model ID to load | |
""" | |
self.model_id = model_id | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Configure quantization for efficient memory usage | |
quantization_config = None | |
if self.device == "cuda": | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
) | |
# Load model and tokenizer | |
logger.info(f"Loading TinyLlama model on {self.device}...") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_id, | |
quantization_config=quantization_config, | |
device_map="auto" if self.device == "cuda" else None, | |
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
) | |
logger.info(f"TinyLlama model loaded successfully") | |
def generate( | |
self, | |
prompt: str, | |
system_prompt: Optional[str] = None, | |
max_tokens: int = 512, | |
temperature: float = 0.7, | |
top_p: float = 0.9, | |
) -> str: | |
""" | |
Generate text based on a prompt. | |
Args: | |
prompt: The user prompt to generate from | |
system_prompt: Optional system prompt to guide the generation | |
max_tokens: Maximum number of tokens to generate | |
temperature: Sampling temperature (higher = more creative) | |
top_p: Top-p sampling parameter | |
Returns: | |
The generated text | |
""" | |
messages = [] | |
# Add system message if provided | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
# Add user message | |
messages.append({"role": "user", "content": prompt}) | |
# Format messages for the model | |
prompt_text = self.tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
# Generate response | |
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device) | |
outputs = self.model.generate( | |
**inputs, | |
max_new_tokens=max_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
# Decode the response and extract the assistant's message | |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the assistant's response | |
assistant_response = full_output[len(prompt_text) :].strip() | |
return assistant_response | |
def expand_creative_prompt(self, original_prompt: str) -> str: | |
""" | |
Expand a creative prompt with rich details for better image generation. | |
Args: | |
original_prompt: The user's original prompt | |
Returns: | |
An expanded, detailed creative prompt | |
""" | |
# Format the prompt for creative expansion | |
prompt = DEFAULT_CREATIVE_USER_TEMPLATE.format(original_prompt=original_prompt) | |
# Generate expanded prompt | |
expanded = self.generate( | |
prompt=prompt, | |
system_prompt=DEFAULT_CREATIVE_SYSTEM_PROMPT, | |
max_tokens=150, # Limit to ensure concise responses | |
temperature=0.7, # Balanced creativity | |
) | |
return expanded | |
def get_llm_instance() -> LocalLLM: | |
""" | |
Get an instance of the local LLM. | |
Returns: | |
An initialized LocalLLM instance | |
""" | |
return LocalLLM(model_id=MODEL_ID) | |