Spaces:
Running
Running
Switch to google/flan-t5-base and improve prompt expansion for better quality
Browse files- app.py +4 -4
- app/llm/model.py +46 -39
app.py
CHANGED
@@ -28,15 +28,15 @@ os.environ["HF_SPACES"] = "1" # Flag to indicate we're running in Spaces
|
|
28 |
|
29 |
# Set model environment variables explicitly for Hugging Face Spaces
|
30 |
# These will override any variables loaded from .env.spaces
|
31 |
-
os.environ["MODEL_ID"] =
|
32 |
-
"distilgpt2" # Use DistilGPT2 model which is publicly available
|
33 |
-
)
|
34 |
os.environ["USE_LOCAL_MODEL"] = "true"
|
35 |
os.environ["MODEL_TYPE"] = "transformers"
|
36 |
os.environ["MODEL_QUANTIZED"] = (
|
37 |
"false" # Disable quantization to avoid bitsandbytes dependency
|
38 |
)
|
39 |
-
os.environ["MODEL_ARCHITECTURE"] =
|
|
|
|
|
40 |
|
41 |
# Import UI module directly
|
42 |
try:
|
|
|
28 |
|
29 |
# Set model environment variables explicitly for Hugging Face Spaces
|
30 |
# These will override any variables loaded from .env.spaces
|
31 |
+
os.environ["MODEL_ID"] = "google/flan-t5-base" # Use flan-t5-base model
|
|
|
|
|
32 |
os.environ["USE_LOCAL_MODEL"] = "true"
|
33 |
os.environ["MODEL_TYPE"] = "transformers"
|
34 |
os.environ["MODEL_QUANTIZED"] = (
|
35 |
"false" # Disable quantization to avoid bitsandbytes dependency
|
36 |
)
|
37 |
+
os.environ["MODEL_ARCHITECTURE"] = (
|
38 |
+
"seq2seq" # T5 models are sequence-to-sequence models
|
39 |
+
)
|
40 |
|
41 |
# Import UI module directly
|
42 |
try:
|
app/llm/model.py
CHANGED
@@ -455,54 +455,61 @@ class LocalLLM:
|
|
455 |
"""
|
456 |
# For seq2seq models like T5, use a format that works better with their training
|
457 |
if self.model_architecture == "seq2seq":
|
458 |
-
# Special
|
459 |
-
if "flan" in self.model_path.lower():
|
460 |
-
task_prompt = (
|
461 |
-
f"Enhance this image prompt with artistic details: {prompt}"
|
462 |
-
)
|
463 |
-
|
464 |
-
# Generate with higher max tokens for T5 models
|
465 |
try:
|
466 |
logger.info(
|
467 |
-
f"Using
|
468 |
-
)
|
469 |
-
expanded = self.generate(
|
470 |
-
prompt=task_prompt,
|
471 |
-
system_prompt=None, # T5 doesn't use system prompts the same way
|
472 |
-
max_tokens=512,
|
473 |
-
temperature=0.9, # Higher temperature for more creative outputs
|
474 |
)
|
475 |
|
476 |
-
#
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
|
|
|
|
|
|
|
|
481 |
expanded = self.pipe(
|
482 |
-
|
483 |
-
max_length=
|
484 |
do_sample=True,
|
485 |
-
temperature=0.
|
486 |
top_p=0.92,
|
487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
-
logger.info(f"Expanded prompt: {expanded[:100]}...")
|
490 |
-
return expanded
|
491 |
except Exception as e:
|
492 |
-
logger.error(f"Error expanding prompt with T5: {str(e)}")
|
493 |
-
# Fall back to original prompt with
|
494 |
-
|
495 |
-
"
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
"professional",
|
500 |
-
]
|
501 |
-
import random
|
502 |
-
|
503 |
-
enhanced = f"{random.choice(adjectives)} {prompt}, {random.choice(adjectives)} artwork, highly detailed"
|
504 |
-
logger.info(f"Fallback prompt expansion: {enhanced}")
|
505 |
-
return enhanced
|
506 |
|
507 |
# Standard approach for causal LMs like GPT-2 or Llama
|
508 |
system_prompt = """You are a creative assistant specializing in enhancing text prompts for image and 3D model generation.
|
|
|
455 |
"""
|
456 |
# For seq2seq models like T5, use a format that works better with their training
|
457 |
if self.model_architecture == "seq2seq":
|
458 |
+
# Special handling for FLAN-T5 models
|
459 |
+
if "flan-t5" in self.model_path.lower():
|
|
|
|
|
|
|
|
|
|
|
460 |
try:
|
461 |
logger.info(
|
462 |
+
f"Using optimized T5 format for prompt expansion with {self.model_path}"
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
)
|
464 |
|
465 |
+
# Try different instruction formats that work well with FLAN-T5
|
466 |
+
prompts_to_try = [
|
467 |
+
f"Create a detailed, professional-quality image description for: {prompt}",
|
468 |
+
f"Turn this simple prompt into a detailed, vivid scene description: {prompt}",
|
469 |
+
f"Enhance this image prompt with artistic details, lighting, and style: {prompt}",
|
470 |
+
]
|
471 |
+
|
472 |
+
# Try each prompt format until we get a good result
|
473 |
+
for task_prompt in prompts_to_try:
|
474 |
expanded = self.pipe(
|
475 |
+
task_prompt,
|
476 |
+
max_length=150, # Allow longer expansions
|
477 |
do_sample=True,
|
478 |
+
temperature=0.8, # Slightly more focused than previous attempts
|
479 |
top_p=0.92,
|
480 |
+
repetition_penalty=1.2, # Discourage repetition
|
481 |
+
)[0]["generated_text"].strip()
|
482 |
+
|
483 |
+
# Check if the result is good
|
484 |
+
if expanded and len(expanded) > len(prompt) + 10:
|
485 |
+
logger.info(f"Expanded prompt: {expanded[:100]}...")
|
486 |
+
|
487 |
+
# For longer generations, check if we need to clean it up
|
488 |
+
if len(expanded) > 200:
|
489 |
+
sentences = expanded.split(".")
|
490 |
+
# Keep first 3-4 meaningful sentences
|
491 |
+
meaningful_sentences = [
|
492 |
+
s for s in sentences if len(s.strip()) > 5
|
493 |
+
][:4]
|
494 |
+
expanded = ". ".join(meaningful_sentences)
|
495 |
+
if not expanded.endswith("."):
|
496 |
+
expanded += "."
|
497 |
+
|
498 |
+
return expanded
|
499 |
+
|
500 |
+
# If all attempts failed, use a template-based expansion
|
501 |
+
fallback = f"{prompt}, high resolution, professional photography, detailed, vivid colors, dramatic lighting"
|
502 |
+
logger.info(f"Using template fallback: {fallback}")
|
503 |
+
return fallback
|
504 |
|
|
|
|
|
505 |
except Exception as e:
|
506 |
+
logger.error(f"Error expanding prompt with FLAN-T5: {str(e)}")
|
507 |
+
# Fall back to original prompt with enhancements
|
508 |
+
fallback = (
|
509 |
+
f"{prompt}, high quality, detailed, 4k, professional, artistic"
|
510 |
+
)
|
511 |
+
logger.info(f"Using error fallback: {fallback}")
|
512 |
+
return fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
513 |
|
514 |
# Standard approach for causal LMs like GPT-2 or Llama
|
515 |
system_prompt = """You are a creative assistant specializing in enhancing text prompts for image and 3D model generation.
|