Aatricks's picture
Upload folder using huggingface_hub
8a58a6e verified
import glob
import gradio as gr
import sys
import os
from PIL import Image
import numpy as np
import spaces
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from modules.user.pipeline import pipeline
import torch
def load_generated_images():
"""Load generated images with given prefix from disk"""
image_files = glob.glob("./_internal/output/**/*.png")
# If there are no image files, return
if not image_files:
return []
# Sort files by modification time in descending order
image_files.sort(key=os.path.getmtime, reverse=True)
# Get most recent timestamp
latest_time = os.path.getmtime(image_files[0])
# Get all images from same batch (within 1 second of most recent)
batch_images = []
for file in image_files:
if abs(os.path.getmtime(file) - latest_time) < 1.0:
try:
img = Image.open(file)
batch_images.append(img)
except:
continue
if not batch_images:
return []
return batch_images
@spaces.GPU
def generate_images(
prompt: str,
width: int = 512,
height: int = 512,
num_images: int = 1,
batch_size: int = 1,
hires_fix: bool = False,
adetailer: bool = False,
enhance_prompt: bool = False,
img2img_enabled: bool = False,
img2img_image: str = None,
stable_fast: bool = False,
reuse_seed: bool = False,
flux_enabled: bool = False,
prio_speed: bool = False,
realistic_model: bool = False,
multiscale_enabled: bool = True,
multiscale_intermittent: bool = False,
multiscale_factor: float = 0.5,
multiscale_fullres_start: int = 3,
multiscale_fullres_end: int = 8,
keep_models_loaded: bool = True,
progress=gr.Progress(),
):
"""Generate images using the LightDiffusion pipeline"""
try:
# Set model persistence preference
from modules.Device.ModelCache import set_keep_models_loaded
set_keep_models_loaded(keep_models_loaded)
if img2img_enabled and img2img_image is not None:
# Convert numpy array to PIL Image
if isinstance(img2img_image, np.ndarray):
img_pil = Image.fromarray(img2img_image)
img_pil.save("temp_img2img.png")
prompt = "temp_img2img.png"
# Run pipeline and capture saved images
with torch.inference_mode():
pipeline(
prompt=prompt,
w=width,
h=height,
number=num_images,
batch=batch_size,
hires_fix=hires_fix,
adetailer=adetailer,
enhance_prompt=enhance_prompt,
img2img=img2img_enabled,
stable_fast=stable_fast,
reuse_seed=reuse_seed,
flux_enabled=flux_enabled,
prio_speed=prio_speed,
autohdr=True,
realistic_model=realistic_model,
enable_multiscale=multiscale_enabled,
multiscale_intermittent_fullres=multiscale_intermittent,
multiscale_factor=multiscale_factor,
multiscale_fullres_start=multiscale_fullres_start,
multiscale_fullres_end=multiscale_fullres_end,
)
# Clean up temporary file if it exists
if os.path.exists("temp_img2img.png"):
os.remove("temp_img2img.png")
return load_generated_images()
except Exception:
import traceback
print(traceback.format_exc())
# Clean up temporary file if it exists
if os.path.exists("temp_img2img.png"):
os.remove("temp_img2img.png")
return [Image.new("RGB", (512, 512), color="black")]
def get_vram_info():
"""Get VRAM usage information"""
try:
from modules.Device.ModelCache import get_memory_info
info = get_memory_info()
return f"""
**VRAM Usage:**
- Total: {info["total_vram"]:.1f} GB
- Used: {info["used_vram"]:.1f} GB
- Free: {info["free_vram"]:.1f} GB
- Keep Models Loaded: {info["keep_loaded"]}
- Has Cached Checkpoint: {info["has_cached_checkpoint"]}
"""
except Exception as e:
return f"Error getting VRAM info: {e}"
def clear_model_cache_ui():
"""Clear model cache from UI"""
try:
from modules.Device.ModelCache import clear_model_cache
clear_model_cache()
return "โœ… Model cache cleared successfully!"
except Exception as e:
return f"โŒ Error clearing cache: {e}"
def apply_multiscale_preset(preset_name):
"""Apply multiscale preset values to the UI components"""
if preset_name == "None":
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
try:
from modules.sample.multiscale_presets import get_preset_parameters
params = get_preset_parameters(preset_name)
return (
gr.update(value=params["enable_multiscale"]),
gr.update(value=params["multiscale_factor"]),
gr.update(value=params["multiscale_fullres_start"]),
gr.update(value=params["multiscale_fullres_end"]),
gr.update(value=params["multiscale_intermittent_fullres"]),
)
except Exception as e:
print(f"Error applying preset {preset_name}: {e}")
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
# Create Gradio interface
with gr.Blocks(title="LightDiffusion Web UI") as demo:
gr.Markdown("# LightDiffusion Web UI")
gr.Markdown("Generate AI images using LightDiffusion")
gr.Markdown(
"This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next"
)
with gr.Row():
with gr.Column():
# Input components
prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
with gr.Row():
width = gr.Slider(
minimum=64, maximum=2048, value=512, step=64, label="Width"
)
height = gr.Slider(
minimum=64, maximum=2048, value=512, step=64, label="Height"
)
with gr.Row():
num_images = gr.Slider(
minimum=1, maximum=10, value=1, step=1, label="Number of Images"
)
batch_size = gr.Slider(
minimum=1, maximum=4, value=1, step=1, label="Batch Size"
)
with gr.Row():
hires_fix = gr.Checkbox(label="HiRes Fix")
adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
enhance_prompt = gr.Checkbox(label="Enhance Prompt")
stable_fast = gr.Checkbox(label="Stable Fast Mode")
with gr.Row():
reuse_seed = gr.Checkbox(label="Reuse Seed")
flux_enabled = gr.Checkbox(label="Flux Mode")
prio_speed = gr.Checkbox(label="Prioritize Speed")
realistic_model = gr.Checkbox(label="Realistic Model")
with gr.Row():
multiscale_enabled = gr.Checkbox(
label="Multi-Scale Diffusion", value=True
)
img2img_enabled = gr.Checkbox(label="Image to Image Mode")
keep_models_loaded = gr.Checkbox(
label="Keep Models in VRAM",
value=True,
info="Keep models loaded for instant reuse (faster but uses more VRAM)",
)
img2img_image = gr.Image(label="Input Image for img2img", visible=False)
# Multi-scale preset selection
with gr.Row():
multiscale_preset = gr.Dropdown(
label="Multi-Scale Preset",
choices=["None", "quality", "performance", "balanced", "disabled"],
value="None",
info="Select a preset to automatically configure multi-scale settings",
)
multiscale_intermittent = gr.Checkbox(
label="Intermittent Full-Res",
value=False,
info="Enable intermittent full-resolution rendering in low-res region",
)
with gr.Row():
multiscale_factor = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.5,
step=0.1,
label="Multi-Scale Factor",
)
multiscale_fullres_start = gr.Slider(
minimum=0, maximum=10, value=3, step=1, label="Full-Res Start Steps"
)
multiscale_fullres_end = gr.Slider(
minimum=0, maximum=20, value=8, step=1, label="Full-Res End Steps"
)
# Make input image visible only when img2img is enabled
img2img_enabled.change(
fn=lambda x: gr.update(visible=x),
inputs=[img2img_enabled],
outputs=[img2img_image],
)
# Handle preset changes
multiscale_preset.change(
fn=apply_multiscale_preset,
inputs=[multiscale_preset],
outputs=[
multiscale_enabled,
multiscale_factor,
multiscale_fullres_start,
multiscale_fullres_end,
multiscale_intermittent,
],
)
generate_btn = gr.Button("Generate")
# Model Cache Management
with gr.Accordion("Model Cache Management", open=False):
with gr.Row():
vram_info_btn = gr.Button("๐Ÿ” Check VRAM Usage")
clear_cache_btn = gr.Button("๐Ÿ—‘๏ธ Clear Model Cache")
vram_info_display = gr.Markdown("")
cache_status_display = gr.Markdown("")
# Output gallery
gallery = gr.Gallery(
label="Generated Images",
show_label=True,
elem_id="gallery",
columns=[2],
rows=[2],
object_fit="contain",
height="auto",
)
# Connect generate button to pipeline
generate_btn.click(
fn=generate_images,
inputs=[
prompt,
width,
height,
num_images,
batch_size,
hires_fix,
adetailer,
enhance_prompt,
img2img_enabled,
img2img_image,
stable_fast,
reuse_seed,
flux_enabled,
prio_speed,
realistic_model,
multiscale_enabled,
multiscale_intermittent,
multiscale_factor,
multiscale_fullres_start,
multiscale_fullres_end,
keep_models_loaded,
],
outputs=gallery,
)
# Connect VRAM info and cache management buttons
vram_info_btn.click(
fn=get_vram_info,
outputs=vram_info_display,
)
clear_cache_btn.click(
fn=clear_model_cache_ui,
outputs=cache_status_display,
)
def is_huggingface_space():
return "SPACE_ID" in os.environ
def is_docker_environment():
return "GRADIO_SERVER_PORT" in os.environ and "GRADIO_SERVER_NAME" in os.environ
# For local testing
if __name__ == "__main__":
if is_huggingface_space():
demo.launch(
debug=False,
server_name="0.0.0.0",
server_port=7860, # Standard HF Spaces port
)
elif is_docker_environment():
# Docker environment - use environment variables
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
server_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
demo.launch(
debug=False,
server_name=server_name,
server_port=server_port,
)
else:
demo.launch(
server_name="0.0.0.0",
server_port=8000,
auth=None,
share=True, # Only enable sharing locally
debug=True,
)