RageshAntony's picture
Update app.py
66268d9 verified
raw
history blame
12.6 kB
import spaces
import gradio as gr
import numpy as np
import random
import torch
from diffusers import (
DiffusionPipeline, StableDiffusion3Pipeline, FluxPipeline, PixArtSigmaPipeline,
AuraFlowPipeline, Kandinsky3Pipeline, HunyuanDiTPipeline,
LuminaText2ImgPipeline, SanaPipeline,AutoPipelineForText2Image
)
import gc
import os
import psutil
import threading
from pathlib import Path
import shutil
import time
import glob
from datetime import datetime
from PIL import Image
from onediffusion.diffusion.pipelines.onediffusion import OneDiffusionPipeline
from onediffusion.models.denoiser.nextdit import NextDiT
from onediffusion.dataset.utils import get_closest_ratio, ASPECT_RATIO_512
#import os
#cache_dir = '/workspace/hf_cache'
# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
OUTPUT_DIR = "generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Model configurations
MODEL_CONFIGS = {
"OneDiffusion": {
"repo_id": "lehduong/OneDiffusion",
"pipeline_class": OneDiffusionPipeline,
# "cache_dir" : cache_dir
}
}
# Dictionary to store model pipelines
pipes = {}
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}
def get_process_memory():
"""Get memory usage of current process in GB"""
process = psutil.Process(os.getpid())
return process.memory_info().rss / 1024 / 1024 / 1024
def clear_torch_cache():
"""Clear PyTorch's CUDA cache"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def remove_cache_dir(model_name):
"""Remove the model's cache directory"""
cache_dir = Path.home() / '.cache' / 'huggingface' / 'diffusers' / MODEL_CONFIGS[model_name]['repo_id'].replace('/',
'--')
if cache_dir.exists():
shutil.rmtree(cache_dir, ignore_errors=True)
def deep_cleanup(model_name, pipe):
"""Perform deep cleanup of model resources"""
try:
# 1. Move model to CPU first (helps prevent CUDA memory fragmentation)
if hasattr(pipe, 'to'):
pipe.to('cpu')
# 2. Delete all model components explicitly
for attr_name in list(pipe.__dict__.keys()):
if hasattr(pipe, attr_name):
delattr(pipe, attr_name)
# 3. Remove from pipes dictionary
if model_name in pipes:
del pipes[model_name]
# 4. Clear CUDA cache
clear_torch_cache()
# 5. Run garbage collection multiple times
for _ in range(3):
gc.collect()
# 6. Remove cached files
remove_cache_dir(model_name)
# 7. Additional CUDA cleanup if available
if torch.cuda.is_available():
torch.cuda.synchronize()
# 8. Wait a small amount of time to ensure cleanup
time.sleep(1)
except Exception as e:
print(f"Error during cleanup of {model_name}: {str(e)}")
finally:
# Final garbage collection
gc.collect()
clear_torch_cache()
def load_pipeline(model_name):
"""Load model pipeline with memory tracking"""
initial_memory = get_process_memory()
config = MODEL_CONFIGS[model_name]
pipe = None
if model_name == "Kandinsky":
print("Kandinsky Special")
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
else:
pipe = config["pipeline_class"].from_pretrained(
config["repo_id"],
torch_dtype=TORCH_DTYPE,
# cache_dir=cache_dir
)
pipe = pipe.to(DEVICE)
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
final_memory = get_process_memory()
print(f"Memory used by {model_name}: {final_memory - initial_memory:.2f} GB")
return pipe
def save_generated_image(image, model_name, prompt):
"""Save generated image with timestamp and model name"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create sanitized filename from prompt (first 30 chars)
prompt_part = "".join(c for c in prompt[:30] if c.isalnum() or c in (' ', '-', '_')).strip()
filename = f"{timestamp}_{model_name}_{prompt_part}.png"
filepath = os.path.join(OUTPUT_DIR, filename)
image.save(filepath)
return filepath
def get_generated_images():
"""Get list of generated images with their details"""
files = glob.glob(os.path.join(OUTPUT_DIR, "*.png"))
files.sort(key=os.path.getctime, reverse=True) # Sort by creation time
return [
{
"path": f,
"name": os.path.basename(f),
"date": datetime.fromtimestamp(os.path.getctime(f)).strftime("%Y-%m-%d %H:%M:%S"),
"size": f"{os.path.getsize(f) / 1024:.1f} KB"
}
for f in files
]
def generate_image(
model_name,
prompt,
negative_prompt="",
seed=42,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=4.5,
num_inference_steps=40,
progress=gr.Progress(track_tqdm=True)
):
with model_locks[model_name]:
try:
# progress(0, desc=f"Loading {model_name} model...")
if model_name not in pipes:
pipes[model_name] = load_pipeline(model_name)
pipe = pipes[model_name]
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(DEVICE).manual_seed(seed)
print(f"Generating image with {model_name}...")
# progress(0.3, desc=f"Generating image with {model_name}...")
if model_name == "OneDiffusion":
prompt = "[[text2image]] " + prompt
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
filepath = save_generated_image(image, model_name, prompt)
print(f"Saved image to: {filepath}")
# progress(0.9, desc=f"Cleaning up {model_name} resources...")
# deep_cleanup(model_name, pipe)
# progress(1.0, desc=f"Generation complete with {model_name}")
return image, seed
except Exception as e:
print(f"Error with {model_name}: {str(e)}")
if model_name in pipes:
deep_cleanup(model_name, pipes[model_name])
raise e
# Gradio Interface
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Multi-Model Image Generation")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0, variant="primary")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder="Enter a negative prompt",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=512,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=7.5,
step=0.1,
value=4.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=40,
)
memory_indicator = gr.Markdown("Current memory usage: 0 GB")
with gr.Row():
with gr.Column(scale=2):
with gr.Tabs() as tabs:
results = {}
seeds = {}
for model_name in MODEL_CONFIGS.keys():
with gr.Tab(model_name):
results[model_name] = gr.Image(label=f"{model_name} Result")
seeds[model_name] = gr.Number(label="Seed used", visible=True)
with gr.Column(scale=1):
gr.Markdown("### Generated Images")
file_gallery = gr.Gallery(
label="Generated Images",
show_label=False,
elem_id="file_gallery",
columns=3,
height=800,
visible=True
)
refresh_button = gr.Button("Refresh Gallery")
def update_gallery():
"""Update the file gallery"""
files = get_generated_images()
return [
(f["path"], f"{f['name']}\n{f['date']}")
for f in files
]
@spaces.GPU(duration=400)
def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps,
progress=gr.Progress()):
outputs = [None] * (len(MODEL_CONFIGS) * 2)
for idx, model_name in enumerate(MODEL_CONFIGS.keys()):
try:
# Display progress for the specific model
# progress(0, desc=f"Starting generation for {model_name}...")
print(f"IMAGE GENERATING {model_name} ")
image, used_seed = generate_image(
model_name, prompt, negative_prompt, seed,
randomize_seed, width, height, guidance_scale,
num_inference_steps, progress
)
print(f"IMAGE GENERATIED {model_name} ")
# Update the respective model's tab with the generated image
# results[model_name].update(image)
# seeds[model_name].update(used_seed)
outputs[idx * 2] = image # Image slot
outputs[idx * 2 + 1] = seed # Seed slot
# outputs.extend([image, used_seed])
# Add intermediate results to progress * (len(all_outputs) - len(all_outputs))
print("YELID")
yield outputs + [None]
except Exception as e:
print(f"Error generating with {model_name}: {str(e)}")
outputs[idx * 2] = None
outputs[idx * 2 + 1] = None
# Update the gallery after generation
gallery_images = update_gallery()
# file_gallery.update(value=gallery_images)
return outputs
output_components = []
for model_name in MODEL_CONFIGS.keys():
output_components.extend([results[model_name], seeds[model_name]])
run_button.click(
fn=generate_all,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
],
outputs=output_components,
)
refresh_button.click(
fn=update_gallery,
inputs=[],
outputs=[file_gallery],
)
demo.load(
fn=update_gallery,
inputs=[],
outputs=[file_gallery],
)
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0')