import os import random import sys import subprocess from typing import Sequence, Mapping, Any, Union import torch import gradio as gr from PIL import Image, ImageChops from huggingface_hub import hf_hub_download # Setup ComfyUI if not already set up if not os.path.exists("ComfyUI"): print("Setting up ComfyUI...") subprocess.run(["bash", "setup_comfyui.sh"], check=True) # Ensure the output directory exists os.makedirs("output", exist_ok=True) # Download models if not already present print("Checking and downloading models...") hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models") hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models") hf_hub_download(repo_id="black-forest-labs/FLUX.1-Canny-dev", filename="flux1-canny-dev.safetensors", local_dir="models/controlnet") hf_hub_download(repo_id="XLabs-AI/flux-controlnet-collections", filename="flux-canny-controlnet-v3.safetensors", local_dir="models/controlnet") hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision") hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything") hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1") hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders") t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5") # Import required functions and setup ComfyUI path import folder_paths def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except KeyError: return obj["result"][index] def find_path(name: str, path: str = None) -> str: if path is None: path = os.getcwd() if name in os.listdir(path): path_name = os.path.join(path, name) print(f"{name} found: {path_name}") return path_name parent_directory = os.path.dirname(path) if parent_directory == path: return None return find_path(name, parent_directory) def add_comfyui_directory_to_sys_path() -> None: comfyui_path = find_path("ComfyUI") if comfyui_path is not None and os.path.isdir(comfyui_path): sys.path.append(comfyui_path) print(f"'{comfyui_path}' added to sys.path") def add_extra_model_paths() -> None: try: from main import load_extra_path_config except ImportError: from utils.extra_config import load_extra_path_config extra_model_paths = find_path("extra_model_paths.yaml") if extra_model_paths is not None: load_extra_path_config(extra_model_paths) else: print("Could not find the extra_model_paths config file.") # Initialize paths add_comfyui_directory_to_sys_path() add_extra_model_paths() def import_custom_nodes() -> None: import asyncio import execution from nodes import init_extra_nodes import server # Create a new event loop if running in a new thread try: loop = asyncio.get_event_loop() except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) server_instance = server.PromptServer(loop) execution.PromptQueue(server_instance) init_extra_nodes() # Import all necessary nodes print("Importing ComfyUI nodes...") try: from nodes import ( StyleModelLoader, VAEEncode, NODE_CLASS_MAPPINGS, LoadImage, CLIPVisionLoader, SaveImage, VAELoader, CLIPVisionEncode, DualCLIPLoader, EmptyLatentImage, VAEDecode, UNETLoader, CLIPTextEncode, ) # Initialize all constant nodes and models in global context import_custom_nodes() except Exception as e: print(f"Error importing ComfyUI nodes: {e}") raise print("Setting up models...") # Global variables for preloaded models and constants intconstant = NODE_CLASS_MAPPINGS["INTConstant"]() CONST_1024 = intconstant.get_value(value=1024) # Load CLIP dualcliploader = DualCLIPLoader() CLIP_MODEL = dualcliploader.load_clip( clip_name1="t5/t5xxl_fp16.safetensors", clip_name2="clip_l.safetensors", type="flux", ) # Load VAE vaeloader = VAELoader() VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors") # Load UNET unetloader = UNETLoader() UNET_MODEL = unetloader.load_unet( unet_name="flux1-depth-dev.safetensors", weight_dtype="default" ) # Load CLIP Vision clipvisionloader = CLIPVisionLoader() CLIP_VISION_MODEL = clipvisionloader.load_clip( clip_name="sigclip_vision_patch14_384.safetensors" ) # Load Style Model stylemodelloader = StyleModelLoader() STYLE_MODEL = stylemodelloader.load_style_model( style_model_name="flux1-redux-dev.safetensors" ) # Initialize samplers ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]() SAMPLER = ksamplerselect.get_sampler(sampler_name="euler") # Initialize depth model cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]() downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]() DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel( model="depth_anything_v2_vitl_fp32.safetensors" ) controlnetloader = NODE_CLASS_MAPPINGS["ControlNetLoader"]() CANNY_XLABS_MODEL = controlnetloader.load_controlnet( control_net_name="flux-canny-controlnet-v3.safetensors" ) # Initialize nodes cliptextencode = CLIPTextEncode() loadimage = LoadImage() vaeencode = VAEEncode() fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]() controlNetApplyAdvanced = NODE_CLASS_MAPPINGS["ControlNetApplyAdvanced"]() instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]() clipvisionencode = CLIPVisionEncode() stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]() emptylatentimage = EmptyLatentImage() basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]() basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]() randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]() samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]() vaedecode = VAEDecode() cr_text = NODE_CLASS_MAPPINGS["CR Text"]() saveimage = SaveImage() getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]() depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]() canny_prossessor = NODE_CLASS_MAPPINGS["Canny"]() imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]() from comfy import model_management model_loaders = [CLIP_MODEL, VAE_MODEL, UNET_MODEL, CLIP_VISION_MODEL] print("Loading models to GPU...") model_management.load_models_gpu([ loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders ]) print("Setup complete!") def generate_image(prompt, structure_image, style_image, depth_strength=15, canny_strength=30, style_strength=0.5, steps=28, progress=gr.Progress(track_tqdm=True)): """Main generation function that processes inputs and returns the path to the generated image.""" timestamp = random.randint(10000, 99999) output_filename = f"flux_zen_{timestamp}.png" with torch.inference_mode(): # Set up CLIP clip_switch = cr_clip_input_switch.switch( Input=1, clip1=get_value_at_index(CLIP_MODEL, 0), clip2=get_value_at_index(CLIP_MODEL, 0), ) # Encode text text_encoded = cliptextencode.encode( text=prompt, clip=get_value_at_index(clip_switch, 0), ) empty_text = cliptextencode.encode( text="", clip=get_value_at_index(clip_switch, 0), ) # Process structure image structure_img = loadimage.load_image(image=structure_image) # Resize image resized_img = imageresize.execute( width=get_value_at_index(CONST_1024, 0), height=get_value_at_index(CONST_1024, 0), interpolation="bicubic", method="keep proportion", condition="always", multiple_of=16, image=get_value_at_index(structure_img, 0), ) # Get image size size_info = getimagesizeandcount.getsize( image=get_value_at_index(resized_img, 0) ) # Encode VAE vae_encoded = vaeencode.encode( pixels=get_value_at_index(size_info, 0), vae=get_value_at_index(VAE_MODEL, 0), ) # Process canny canny_processed = canny_prossessor.detect_edge( image=get_value_at_index(size_info, 0), low_threshold=0.4, high_threshold=0.8 ) #Apply canny Advanced canny_conditions = controlNetApplyAdvanced.apply_controlnet( positive=get_value_at_index(text_encoded, 0), negative=get_value_at_index(empty_text, 0), control_net=get_value_at_index(CANNY_XLABS_MODEL, 0), image=get_value_at_index(canny_processed, 0), strength=canny_strength, start_percent=0.0, end_percent=0.5, vae=get_value_at_index(VAE_MODEL, 0) ) # Process depth depth_processed = depthanything_v2.process( da_model=get_value_at_index(DEPTH_MODEL, 0), images=get_value_at_index(size_info, 0), ) # Apply Flux guidance flux_guided = fluxguidance.append( guidance=depth_strength, conditioning=get_value_at_index(canny_conditions, 0), ) # Process style image style_img = loadimage.load_image(image=style_image) # Encode style with CLIP Vision style_encoded = clipvisionencode.encode( crop="center", clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0), image=get_value_at_index(style_img, 0), ) # Set up conditioning conditioning = instructpixtopixconditioning.encode( positive=get_value_at_index(flux_guided, 0), negative=get_value_at_index(canny_conditions, 1), vae=get_value_at_index(VAE_MODEL, 0), pixels=get_value_at_index(depth_processed, 0), ) # Apply style style_applied = stylemodelapplyadvanced.apply_stylemodel( strength=style_strength, conditioning=get_value_at_index(conditioning, 0), style_model=get_value_at_index(STYLE_MODEL, 0), clip_vision_output=get_value_at_index(style_encoded, 0), ) # Set up empty latent empty_latent = emptylatentimage.generate( width=get_value_at_index(resized_img, 1), height=get_value_at_index(resized_img, 2), batch_size=1, ) # Set up guidance guided = basicguider.get_guider( model=get_value_at_index(UNET_MODEL, 0), conditioning=get_value_at_index(style_applied, 0), ) # Set up scheduler schedule = basicscheduler.get_sigmas( scheduler="simple", steps=steps, denoise=1, model=get_value_at_index(UNET_MODEL, 0), ) # Generate random noise noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64)) # Sample sampled = samplercustomadvanced.sample( noise=get_value_at_index(noise, 0), guider=get_value_at_index(guided, 0), sampler=get_value_at_index(SAMPLER, 0), sigmas=get_value_at_index(schedule, 0), latent_image=get_value_at_index(empty_latent, 0), ) # Decode VAE decoded = vaedecode.decode( samples=get_value_at_index(sampled, 0), vae=get_value_at_index(VAE_MODEL, 0), ) # Create text node for prefix prefix = cr_text.text_multiline(text=f"flux_zen_{timestamp}") # Use SaveImage node to save the image saved_data = saveimage.save_images( filename_prefix=get_value_at_index(prefix, 0), images=get_value_at_index(decoded, 0), ) try: # Get the saved file path saved_filename = saved_data["ui"]["images"][0]["filename"] saved_subfolder = saved_data["ui"]["images"][0]["subfolder"] output_dir = folder_paths.get_output_directory() # Construct the full path if saved_subfolder: full_path = os.path.join(output_dir, saved_subfolder, saved_filename) else: full_path = os.path.join(output_dir, saved_filename) return full_path except Exception as e: print(f"Error getting saved image path: {e}") # Fall back to the expected path return os.path.join("output", output_filename) with gr.Blocks(css="footer {visibility: hidden}") as app: gr.Markdown("# 🎨 FLUX Zen Style Depth+Canny") gr.Markdown("Flux[dev] Redux + Flux[dev] Depth and XLabs Canny based on the space FLUX Style Shaping") with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your prompt here...", info="Describe the image you want to generate" ) with gr.Row(): with gr.Column(scale=1): structure_image = gr.Image( image_mode='RGB', label="Structure Image", type="filepath", info="Upload an image to provide structure" ) depth_strength = gr.Slider( minimum=0, maximum=50, value=15, label="Depth Strength", info="Controls how much the depth map influences the result" ) canny_strength = gr.Slider( minimum=0, maximum=1.0, value=0.30, label="Canny Strength", info="Controls how much the edge detection influences the result" ) steps = gr.Slider( minimum=10, maximum=50, value=28, label="Steps", info="More steps = better quality but slower generation" ) with gr.Column(scale=1): style_image = gr.Image( label="Style Image", type="filepath", info="Upload an image to influence the style" ) style_strength = gr.Slider( minimum=0, maximum=1, value=0.5, label="Style Strength", info="Controls how much the style image influences the result" ) with gr.Row(): generate_btn = gr.Button("Generate", value=True, variant="primary") with gr.Column(scale=1): output_image = gr.Image(label="Generated Image") gr.Examples( examples=[ ["A beautiful landscape with mountains and a lake", "examples/structure1.jpg", "examples/style1.jpg", 20, 0.4, 0.6, 30], ["A cyberpunk cityscape at night", "examples/structure2.jpg", "examples/style2.jpg", 15, 0.35, 0.7, 28], ], inputs=[prompt_input, structure_image, style_image, depth_strength, canny_strength, style_strength, steps], outputs=output_image, fn=generate_image, cache_examples=True ) generate_btn.click( fn=generate_image, inputs=[prompt_input, structure_image, style_image, depth_strength, canny_strength, style_strength, steps], outputs=output_image ) gr.Markdown(""" ## How to use 1. Enter a prompt describing the image you want to generate 2. Upload a structure image to provide the basic shape/composition 3. Upload a style image to influence the visual style 4. Adjust the sliders to control the effect strength 5. Click "Generate" to create your image ## About This demo uses FLUX.1-Redux-dev for style transfer, FLUX.1-Depth-dev for depth-guided generation, and XLabs Canny for edge detection and structure preservation. """) if __name__ == "__main__": # Create an examples directory if it doesn't exist os.makedirs("examples", exist_ok=True) # Launch the app app.launch(share=True)