uu / app.py
AkashKumarave's picture
Update app.py
3e1dbb0 verified
import time
import os
import gradio as gr
import torch
from einops import rearrange
from PIL import Image
from transformers import pipeline
from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import load_ae, load_clip, load_flow_model, load_t5
from pulid.pipeline_flux import PuLIDPipeline
from pulid.utils import resize_numpy_image_long
NSFW_THRESHOLD = 0.85
def get_models(name: str, device: torch.device, offload: bool):
t5 = load_t5(device, max_length=128)
clip = load_clip(device)
model = load_flow_model(name, device=device)
model.eval()
ae = load_ae(name, device=device)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier
class FluxGenerator:
def __init__(self):
self.device = torch.device('cpu')
self.offload = False
self.model_name = 'flux-schnell'
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
self.model_name,
device=self.device,
offload=self.offload,
)
self.pulid_model = PuLIDPipeline(self.model, 'cpu', weight_dtype=torch.float32)
self.pulid_model.load_pretrain()
flux_generator = FluxGenerator()
@torch.inference_mode()
def generate_image(
prompt,
id_image,
start_step,
guidance,
seed,
true_cfg,
width=896,
height=1152,
num_steps=4,
id_weight=1.0,
neg_prompt="bad quality, worst quality, text, signature, watermark, extra limbs",
timestep_to_start_cfg=1,
max_sequence_length=128,
):
flux_generator.t5.max_length = max_sequence_length
seed = int(seed)
if seed == -1:
seed = None
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if opts.seed is None:
opts.seed = torch.Generator(device="cpu").seed()
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
t0 = time.perf_counter()
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
if id_image is not None:
id_image = resize_numpy_image_long(id_image, 1024)
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
else:
id_embeddings = None
uncond_id_embeddings = None
x = get_noise(
1,
opts.height,
opts.width,
device=flux_generator.device,
dtype=torch.float32,
seed=opts.seed,
)
timesteps = get_schedule(
opts.num_steps,
x.shape[-1] * x.shape[-2] // 4,
shift=True,
)
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt)
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
x = denoise(
flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
timestep_to_start_cfg=timestep_to_start_cfg,
neg_txt=inp_neg["txt"] if use_true_cfg else None,
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
neg_vec=inp_neg["vec"] if use_true_cfg else None,
)
x = unpack(x.float(), opts.height, opts.width)
x = flux_generator.ae.decode(x)
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
x = x.clamp(-1, 1)
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0]
if nsfw_score < NSFW_THRESHOLD:
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
else:
return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content",
flux_generator.pulid_model.debug_img_list)
_HEADER_ = '''
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
<p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
</div>
🚩 Updates:
- 2024.11.01: update PuLID-FLUX-v0.9.1, please refer to <a href='https://github.com/ToTheBeginning/PuLID?tab=readme-ov-file#triangular_flag_on_post-updates'>our github repo</a> for more details.
❗️❗️❗️**Tips:**
- `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1.
- `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale.
- `Learn more about the model:` please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details.
'''
_CITE_ = r"""
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
---
πŸ“§ **Contact**
If you have any questions, feel free to contact <b>wuyanze123@gmail.com</b>.
"""
_DEV_DES = '''
* Please refer to our repo for instructions on running gradio demo [locally](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
'''
def create_demo(args, model_name: str = 'flux-schnell', device: str = "cpu", offload: bool = False):
with gr.Blocks() as demo:
with gr.Accordion("For Developers", open=False):
gr.Markdown(_DEV_DES)
gr.Markdown(_HEADER_)
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
id_image = gr.Image(label="ID Image")
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
width = gr.Slider(256, 1536, 896, step=16, label="Width")
height = gr.Slider(256, 1536, 1152, step=16, label="Height")
num_steps = gr.Slider(1, 20, 4, step=1, label="Number of steps") # Default to 4 for schnell
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
guidance = gr.Slider(0.0, 10.0, 0.0, step=0.1, label="Guidance") # Default to 0.0 for schnell
seed = gr.Textbox(-1, label="Seed (-1 for random)")
max_sequence_length = gr.Slider(128, 512, 128, step=128,
label="max_sequence_length for prompt (T5), small will be faster")
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG)", open=False):
neg_prompt = gr.Textbox(
label="Negative Prompt",
value="bad quality, worst quality, text, signature, watermark, extra limbs")
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image", format='png')
seed_output = gr.Textbox(label="Used Seed")
intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
gr.Markdown(_CITE_)
generate_btn.click(
fn=generate_image,
inputs=[prompt, id_image, start_step, guidance, seed, true_cfg, width, height, num_steps, id_weight,
neg_prompt, timestep_to_start_cfg, max_sequence_length],
outputs=[output_image, seed_output, intermediate_output],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-schnell")
parser.add_argument("--name", type=str, default="flux-schnell", choices=['flux-schnell'],
help="currently only support flux-schnell")
parser.add_argument("--device", type=str, default="cpu",
help="Device to use")
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--port", type=int, default=8080, help="Port to use")
parser.add_argument("--dev", action='store_true', help="Development mode")
parser.add_argument("--pretrained_model", type=str, help='for development')
args = parser.parse_args()
demo = create_demo(args, args.name, args.device, args.offload)
demo.launch()