Chroma / app.py
gokaygokay's picture
Upload 26 files
19b48f8 verified
raw
history blame
6.39 kB
import os
import random
import sys
from typing import Sequence, Mapping, Any, Union
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
# Download required models
t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp8_e4m3fn.safetensors", local_dir="models/text_encoders/")
vae_path = hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae")
unet_path = hf_hub_download(repo_id="lodestones/Chroma", filename="chroma-unlocked-v31.safetensors", local_dir="models/unet")
# Import the workflow functions
from my_workflow import (
get_value_at_index,
add_comfyui_directory_to_sys_path,
add_extra_model_paths,
import_custom_nodes,
NODE_CLASS_MAPPINGS,
CLIPTextEncode,
CLIPLoader,
VAEDecode,
UNETLoader,
VAELoader,
SaveImage,
)
# Initialize ComfyUI
add_comfyui_directory_to_sys_path()
add_extra_model_paths()
import_custom_nodes()
def generate_image(prompt, negative_prompt, width, height, steps, cfg, seed):
with torch.inference_mode():
# Set random seed if provided
if seed == -1:
seed = random.randint(1, 2**64)
random.seed(seed)
randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
randomnoise_68 = randomnoise.get_noise(noise_seed=seed)
emptysd3latentimage = NODE_CLASS_MAPPINGS["EmptySD3LatentImage"]()
emptysd3latentimage_69 = emptysd3latentimage.generate(
width=width, height=height, batch_size=1
)
ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
ksamplerselect_72 = ksamplerselect.get_sampler(sampler_name="euler")
cliploader = CLIPLoader()
cliploader_78 = cliploader.load_clip(
clip_name="t5xxl_fp8_e4m3fn.safetensors", type="chroma", device="default"
)
t5tokenizeroptions = NODE_CLASS_MAPPINGS["T5TokenizerOptions"]()
t5tokenizeroptions_82 = t5tokenizeroptions.set_options(
min_padding=1, min_length=0, clip=get_value_at_index(cliploader_78, 0)
)
cliptextencode = CLIPTextEncode()
cliptextencode_74 = cliptextencode.encode(
text=prompt,
clip=get_value_at_index(t5tokenizeroptions_82, 0),
)
cliptextencode_75 = cliptextencode.encode(
text=negative_prompt,
clip=get_value_at_index(t5tokenizeroptions_82, 0),
)
unetloader = UNETLoader()
unetloader_76 = unetloader.load_unet(
unet_name="chroma-unlocked-v31.safetensors", weight_dtype="fp8_e4m3fn"
)
vaeloader = VAELoader()
vaeloader_80 = vaeloader.load_vae(vae_name="ae.safetensors")
cfgguider = NODE_CLASS_MAPPINGS["CFGGuider"]()
basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
vaedecode = VAEDecode()
saveimage = SaveImage()
cfgguider_73 = cfgguider.get_guider(
cfg=cfg,
model=get_value_at_index(unetloader_76, 0),
positive=get_value_at_index(cliptextencode_74, 0),
negative=get_value_at_index(cliptextencode_75, 0),
)
basicscheduler_84 = basicscheduler.get_sigmas(
scheduler="beta",
steps=steps,
denoise=1,
model=get_value_at_index(unetloader_76, 0),
)
samplercustomadvanced_67 = samplercustomadvanced.sample(
noise=get_value_at_index(randomnoise_68, 0),
guider=get_value_at_index(cfgguider_73, 0),
sampler=get_value_at_index(ksamplerselect_72, 0),
sigmas=get_value_at_index(basicscheduler_84, 0),
latent_image=get_value_at_index(emptysd3latentimage_69, 0),
)
vaedecode_79 = vaedecode.decode(
samples=get_value_at_index(samplercustomadvanced_67, 0),
vae=get_value_at_index(vaeloader_80, 0),
)
# Instead of saving to file, return the image directly
return get_value_at_index(vaedecode_79, 0)
# Create Gradio interface
with gr.Blocks() as app:
gr.Markdown("# Chroma Image Generator")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
placeholder="Enter your prompt here...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
placeholder="Enter negative prompt here...",
value="low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors",
lines=2
)
with gr.Row():
width = gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=64,
label="Width"
)
height = gr.Slider(
minimum=512,
maximum=2048,
value=1024,
step=64,
label="Height"
)
with gr.Row():
steps = gr.Slider(
minimum=1,
maximum=50,
value=26,
step=1,
label="Steps"
)
cfg = gr.Slider(
minimum=1,
maximum=20,
value=4,
step=0.5,
label="CFG Scale"
)
seed = gr.Number(
value=-1,
label="Seed (-1 for random)"
)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
generate_btn.click(
fn=generate_image,
inputs=[prompt, negative_prompt, width, height, steps, cfg, seed],
outputs=[output_image]
)
if __name__ == "__main__":
app.launch(share=True)