Zoro-chi commited on
Commit
b26638e
·
1 Parent(s): 6c28b5c

Switch to google/flan-t5-base and improve prompt expansion for better quality

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. app/llm/model.py +46 -39
app.py CHANGED
@@ -28,15 +28,15 @@ os.environ["HF_SPACES"] = "1" # Flag to indicate we're running in Spaces
28
 
29
  # Set model environment variables explicitly for Hugging Face Spaces
30
  # These will override any variables loaded from .env.spaces
31
- os.environ["MODEL_ID"] = (
32
- "distilgpt2" # Use DistilGPT2 model which is publicly available
33
- )
34
  os.environ["USE_LOCAL_MODEL"] = "true"
35
  os.environ["MODEL_TYPE"] = "transformers"
36
  os.environ["MODEL_QUANTIZED"] = (
37
  "false" # Disable quantization to avoid bitsandbytes dependency
38
  )
39
- os.environ["MODEL_ARCHITECTURE"] = "causal" # GPT2 is a causal language model
 
 
40
 
41
  # Import UI module directly
42
  try:
 
28
 
29
  # Set model environment variables explicitly for Hugging Face Spaces
30
  # These will override any variables loaded from .env.spaces
31
+ os.environ["MODEL_ID"] = "google/flan-t5-base" # Use flan-t5-base model
 
 
32
  os.environ["USE_LOCAL_MODEL"] = "true"
33
  os.environ["MODEL_TYPE"] = "transformers"
34
  os.environ["MODEL_QUANTIZED"] = (
35
  "false" # Disable quantization to avoid bitsandbytes dependency
36
  )
37
+ os.environ["MODEL_ARCHITECTURE"] = (
38
+ "seq2seq" # T5 models are sequence-to-sequence models
39
+ )
40
 
41
  # Import UI module directly
42
  try:
app/llm/model.py CHANGED
@@ -455,54 +455,61 @@ class LocalLLM:
455
  """
456
  # For seq2seq models like T5, use a format that works better with their training
457
  if self.model_architecture == "seq2seq":
458
- # Special formatting for T5 models which work better with task-specific prefixes
459
- if "flan" in self.model_path.lower():
460
- task_prompt = (
461
- f"Enhance this image prompt with artistic details: {prompt}"
462
- )
463
-
464
- # Generate with higher max tokens for T5 models
465
  try:
466
  logger.info(
467
- f"Using seq2seq format for prompt expansion with model: {self.model_path}"
468
- )
469
- expanded = self.generate(
470
- prompt=task_prompt,
471
- system_prompt=None, # T5 doesn't use system prompts the same way
472
- max_tokens=512,
473
- temperature=0.9, # Higher temperature for more creative outputs
474
  )
475
 
476
- # If the model returns the input, try a different approach
477
- if (
478
- expanded.strip() == task_prompt.strip()
479
- or expanded.strip() == prompt.strip()
480
- ):
 
 
 
 
481
  expanded = self.pipe(
482
- f"Generate a detailed visual description of: {prompt}",
483
- max_length=256,
484
  do_sample=True,
485
- temperature=0.9,
486
  top_p=0.92,
487
- )[0]["generated_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- logger.info(f"Expanded prompt: {expanded[:100]}...")
490
- return expanded
491
  except Exception as e:
492
- logger.error(f"Error expanding prompt with T5: {str(e)}")
493
- # Fall back to original prompt with some basic additions
494
- adjectives = [
495
- "vibrant",
496
- "detailed",
497
- "high-quality",
498
- "stunning",
499
- "professional",
500
- ]
501
- import random
502
-
503
- enhanced = f"{random.choice(adjectives)} {prompt}, {random.choice(adjectives)} artwork, highly detailed"
504
- logger.info(f"Fallback prompt expansion: {enhanced}")
505
- return enhanced
506
 
507
  # Standard approach for causal LMs like GPT-2 or Llama
508
  system_prompt = """You are a creative assistant specializing in enhancing text prompts for image and 3D model generation.
 
455
  """
456
  # For seq2seq models like T5, use a format that works better with their training
457
  if self.model_architecture == "seq2seq":
458
+ # Special handling for FLAN-T5 models
459
+ if "flan-t5" in self.model_path.lower():
 
 
 
 
 
460
  try:
461
  logger.info(
462
+ f"Using optimized T5 format for prompt expansion with {self.model_path}"
 
 
 
 
 
 
463
  )
464
 
465
+ # Try different instruction formats that work well with FLAN-T5
466
+ prompts_to_try = [
467
+ f"Create a detailed, professional-quality image description for: {prompt}",
468
+ f"Turn this simple prompt into a detailed, vivid scene description: {prompt}",
469
+ f"Enhance this image prompt with artistic details, lighting, and style: {prompt}",
470
+ ]
471
+
472
+ # Try each prompt format until we get a good result
473
+ for task_prompt in prompts_to_try:
474
  expanded = self.pipe(
475
+ task_prompt,
476
+ max_length=150, # Allow longer expansions
477
  do_sample=True,
478
+ temperature=0.8, # Slightly more focused than previous attempts
479
  top_p=0.92,
480
+ repetition_penalty=1.2, # Discourage repetition
481
+ )[0]["generated_text"].strip()
482
+
483
+ # Check if the result is good
484
+ if expanded and len(expanded) > len(prompt) + 10:
485
+ logger.info(f"Expanded prompt: {expanded[:100]}...")
486
+
487
+ # For longer generations, check if we need to clean it up
488
+ if len(expanded) > 200:
489
+ sentences = expanded.split(".")
490
+ # Keep first 3-4 meaningful sentences
491
+ meaningful_sentences = [
492
+ s for s in sentences if len(s.strip()) > 5
493
+ ][:4]
494
+ expanded = ". ".join(meaningful_sentences)
495
+ if not expanded.endswith("."):
496
+ expanded += "."
497
+
498
+ return expanded
499
+
500
+ # If all attempts failed, use a template-based expansion
501
+ fallback = f"{prompt}, high resolution, professional photography, detailed, vivid colors, dramatic lighting"
502
+ logger.info(f"Using template fallback: {fallback}")
503
+ return fallback
504
 
 
 
505
  except Exception as e:
506
+ logger.error(f"Error expanding prompt with FLAN-T5: {str(e)}")
507
+ # Fall back to original prompt with enhancements
508
+ fallback = (
509
+ f"{prompt}, high quality, detailed, 4k, professional, artistic"
510
+ )
511
+ logger.info(f"Using error fallback: {fallback}")
512
+ return fallback
 
 
 
 
 
 
 
513
 
514
  # Standard approach for causal LMs like GPT-2 or Llama
515
  system_prompt = """You are a creative assistant specializing in enhancing text prompts for image and 3D model generation.