Zoro-chi's picture
gallery fix
2f33e0c
import os
import logging
from pathlib import Path
import torch
from typing import Optional, Dict, Any, Union, List
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
logger = logging.getLogger(__name__)
# Constants for TinyLlama model
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# Default system prompt for creative expansion
DEFAULT_CREATIVE_SYSTEM_PROMPT = """You are an expert assistant that helps expand creative prompts for image generation.
Your goal is to enrich the original prompt with vivid visual details, artistic style suggestions, and composition elements.
Focus on visual enhancement only. Keep your response concise (under 40 words) and focus entirely on the expanded prompt.
Do not include explanations or comments - only return the enhanced prompt text."""
# Default user template for creative expansion
DEFAULT_CREATIVE_USER_TEMPLATE = """Original prompt: {original_prompt}
Please expand this into a detailed, vivid prompt for image generation with rich visual elements, mood, and style."""
class LocalLLM:
"""
A local LLM implementation using TinyLlama-1.1B-Chat.
Provides methods to generate text and expand creative prompts.
"""
def __init__(self, model_id: str = MODEL_ID):
"""
Initialize the local LLM.
Args:
model_id: The model ID to load
"""
self.model_id = model_id
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Configure quantization for efficient memory usage
quantization_config = None
if self.device == "cuda":
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
)
# Load model and tokenizer
logger.info(f"Loading TinyLlama model on {self.device}...")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_id,
quantization_config=quantization_config,
device_map="auto" if self.device == "cuda" else None,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
)
logger.info(f"TinyLlama model loaded successfully")
def generate(
self,
prompt: str,
system_prompt: Optional[str] = None,
max_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
) -> str:
"""
Generate text based on a prompt.
Args:
prompt: The user prompt to generate from
system_prompt: Optional system prompt to guide the generation
max_tokens: Maximum number of tokens to generate
temperature: Sampling temperature (higher = more creative)
top_p: Top-p sampling parameter
Returns:
The generated text
"""
messages = []
# Add system message if provided
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Add user message
messages.append({"role": "user", "content": prompt})
# Format messages for the model
prompt_text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Generate response
inputs = self.tokenizer(prompt_text, return_tensors="pt").to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
)
# Decode the response and extract the assistant's message
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract just the assistant's response
assistant_response = full_output[len(prompt_text) :].strip()
return assistant_response
def expand_creative_prompt(self, original_prompt: str) -> str:
"""
Expand a creative prompt with rich details for better image generation.
Args:
original_prompt: The user's original prompt
Returns:
An expanded, detailed creative prompt
"""
# Format the prompt for creative expansion
prompt = DEFAULT_CREATIVE_USER_TEMPLATE.format(original_prompt=original_prompt)
# Generate expanded prompt
expanded = self.generate(
prompt=prompt,
system_prompt=DEFAULT_CREATIVE_SYSTEM_PROMPT,
max_tokens=150, # Limit to ensure concise responses
temperature=0.7, # Balanced creativity
)
return expanded
def get_llm_instance() -> LocalLLM:
"""
Get an instance of the local LLM.
Returns:
An initialized LocalLLM instance
"""
return LocalLLM(model_id=MODEL_ID)