Spaces:
Runtime error
Runtime error
import time | |
import os | |
import gradio as gr | |
import torch | |
from einops import rearrange | |
from PIL import Image | |
from transformers import pipeline | |
from flux.cli import SamplingOptions | |
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack | |
from flux.util import load_ae, load_clip, load_flow_model, load_t5 | |
from pulid.pipeline_flux import PuLIDPipeline | |
from pulid.utils import resize_numpy_image_long | |
NSFW_THRESHOLD = 0.85 | |
def get_models(name: str, device: torch.device, offload: bool): | |
t5 = load_t5(device, max_length=128) | |
clip = load_clip(device) | |
model = load_flow_model(name, device=device) | |
model.eval() | |
ae = load_ae(name, device=device) | |
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) | |
return model, ae, t5, clip, nsfw_classifier | |
class FluxGenerator: | |
def __init__(self): | |
self.device = torch.device('cpu') | |
self.offload = False | |
self.model_name = 'flux-schnell' | |
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models( | |
self.model_name, | |
device=self.device, | |
offload=self.offload, | |
) | |
self.pulid_model = PuLIDPipeline(self.model, 'cpu', weight_dtype=torch.float32) | |
self.pulid_model.load_pretrain() | |
flux_generator = FluxGenerator() | |
def generate_image( | |
prompt, | |
id_image, | |
start_step, | |
guidance, | |
seed, | |
true_cfg, | |
width=896, | |
height=1152, | |
num_steps=4, | |
id_weight=1.0, | |
neg_prompt="bad quality, worst quality, text, signature, watermark, extra limbs", | |
timestep_to_start_cfg=1, | |
max_sequence_length=128, | |
): | |
flux_generator.t5.max_length = max_sequence_length | |
seed = int(seed) | |
if seed == -1: | |
seed = None | |
opts = SamplingOptions( | |
prompt=prompt, | |
width=width, | |
height=height, | |
num_steps=num_steps, | |
guidance=guidance, | |
seed=seed, | |
) | |
if opts.seed is None: | |
opts.seed = torch.Generator(device="cpu").seed() | |
print(f"Generating '{opts.prompt}' with seed {opts.seed}") | |
t0 = time.perf_counter() | |
use_true_cfg = abs(true_cfg - 1.0) > 1e-2 | |
if id_image is not None: | |
id_image = resize_numpy_image_long(id_image, 1024) | |
id_embeddings, uncond_id_embeddings = flux_generator.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) | |
else: | |
id_embeddings = None | |
uncond_id_embeddings = None | |
x = get_noise( | |
1, | |
opts.height, | |
opts.width, | |
device=flux_generator.device, | |
dtype=torch.float32, | |
seed=opts.seed, | |
) | |
timesteps = get_schedule( | |
opts.num_steps, | |
x.shape[-1] * x.shape[-2] // 4, | |
shift=True, | |
) | |
inp = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=opts.prompt) | |
inp_neg = prepare(t5=flux_generator.t5, clip=flux_generator.clip, img=x, prompt=neg_prompt) if use_true_cfg else None | |
x = denoise( | |
flux_generator.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, | |
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, | |
timestep_to_start_cfg=timestep_to_start_cfg, | |
neg_txt=inp_neg["txt"] if use_true_cfg else None, | |
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, | |
neg_vec=inp_neg["vec"] if use_true_cfg else None, | |
) | |
x = unpack(x.float(), opts.height, opts.width) | |
x = flux_generator.ae.decode(x) | |
t1 = time.perf_counter() | |
print(f"Done in {t1 - t0:.1f}s.") | |
x = x.clamp(-1, 1) | |
x = rearrange(x[0], "c h w -> h w c") | |
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | |
nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0] | |
if nsfw_score < NSFW_THRESHOLD: | |
return img, str(opts.seed), flux_generator.pulid_model.debug_img_list | |
else: | |
return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content", | |
flux_generator.pulid_model.debug_img_list) | |
_HEADER_ = ''' | |
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> | |
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1> | |
<p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p> | |
</div> | |
π© Updates: | |
- 2024.11.01: update PuLID-FLUX-v0.9.1, please refer to <a href='https://github.com/ToTheBeginning/PuLID?tab=readme-ov-file#triangular_flag_on_post-updates'>our github repo</a> for more details. | |
βοΈβοΈβοΈ**Tips:** | |
- `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. | |
- `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. | |
- `Learn more about the model:` please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details. | |
''' | |
_CITE_ = r""" | |
If PuLID is helpful, please help to β the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks! | |
--- | |
π§ **Contact** | |
If you have any questions, feel free to contact <b>wuyanze123@gmail.com</b>. | |
""" | |
_DEV_DES = ''' | |
* Please refer to our repo for instructions on running gradio demo [locally](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo) | |
''' | |
def create_demo(args, model_name: str = 'flux-schnell', device: str = "cpu", offload: bool = False): | |
with gr.Blocks() as demo: | |
with gr.Accordion("For Developers", open=False): | |
gr.Markdown(_DEV_DES) | |
gr.Markdown(_HEADER_) | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic") | |
id_image = gr.Image(label="ID Image") | |
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight") | |
width = gr.Slider(256, 1536, 896, step=16, label="Width") | |
height = gr.Slider(256, 1536, 1152, step=16, label="Height") | |
num_steps = gr.Slider(1, 20, 4, step=1, label="Number of steps") # Default to 4 for schnell | |
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID") | |
guidance = gr.Slider(0.0, 10.0, 0.0, step=0.1, label="Guidance") # Default to 0.0 for schnell | |
seed = gr.Textbox(-1, label="Seed (-1 for random)") | |
max_sequence_length = gr.Slider(128, 512, 128, step=128, | |
label="max_sequence_length for prompt (T5), small will be faster") | |
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG)", open=False): | |
neg_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value="bad quality, worst quality, text, signature, watermark, extra limbs") | |
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale") | |
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev) | |
generate_btn = gr.Button("Generate") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image", format='png') | |
seed_output = gr.Textbox(label="Used Seed") | |
intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev) | |
gr.Markdown(_CITE_) | |
generate_btn.click( | |
fn=generate_image, | |
inputs=[prompt, id_image, start_step, guidance, seed, true_cfg, width, height, num_steps, id_weight, | |
neg_prompt, timestep_to_start_cfg, max_sequence_length], | |
outputs=[output_image, seed_output, intermediate_output], | |
) | |
return demo | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-schnell") | |
parser.add_argument("--name", type=str, default="flux-schnell", choices=['flux-schnell'], | |
help="currently only support flux-schnell") | |
parser.add_argument("--device", type=str, default="cpu", | |
help="Device to use") | |
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") | |
parser.add_argument("--port", type=int, default=8080, help="Port to use") | |
parser.add_argument("--dev", action='store_true', help="Development mode") | |
parser.add_argument("--pretrained_model", type=str, help='for development') | |
args = parser.parse_args() | |
demo = create_demo(args, args.name, args.device, args.offload) | |
demo.launch() |