ginipick's picture
Update app.py
f5693f2 verified
raw
history blame
7.87 kB
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
import spaces
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
device = "cuda"
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# ----------------------------------------------
# ์ˆ˜์ •๋œ ๋ถ€๋ถ„: infer ํ•จ์ˆ˜ ๋‚ด์—์„œ hidden_prompt๋ฅผ ์•ž์— ์ถ”๊ฐ€
# ----------------------------------------------
@spaces.GPU(duration=80)
def infer(
user_prompt,
ip_adapter_image,
ip_adapter_scale=0.5,
negative_prompt="",
seed=100,
randomize_seed=False,
width=1024,
height=1024,
guidance_scale=5.0,
num_inference_steps=50,
progress=gr.Progress(track_tqdm=True)
):
# ์ˆจ๊ฒจ์ง„(๊ธฐ๋ณธ/ํ•„์ˆ˜) ํ”„๋กฌํ”„ํŠธ
hidden_prompt = (
"Studio Ghibli animation style, featuring whimsical characters with expressive eyes "
"and fluid movements. Lush, detailed natural environments with ethereal lighting "
"and soft color palettes of blues, greens, and warm earth tones."
)
# ์‹ค์ œ๋กœ ํŒŒ์ดํ”„๋ผ์ธ์— ์ „๋‹ฌํ•  ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ
prompt = f"{hidden_prompt}, {user_prompt}"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.to("cuda")
image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
[
"dancing",
"gh1.jpg",
0.5
],
[
"studio ghibli style",
"gh2.jpg",
0.5
],
[
"studio ghibli style",
"gh3.webp",
0.5
],
[
"studio ghibli style",
"gh4.jpg",
0.5
],
]
css = """
#col-container {
margin: 0 auto;
max-width: 720px;
}
#result img{
object-position: top;
}
#result .image-container{
height: 100%
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Beyond Ghibli Reimagined
""")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
with gr.Row():
with gr.Column():
ip_adapter_image = gr.Image(label="IP-Adapter Image", type="pil")
ip_adapter_scale = gr.Slider(
label="Image influence scale",
info="Use 1 for creating variations",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
result = gr.Image(label="Result", elem_id="result")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
placeholder=(
"Copy(worst quality, low quality:1.4), bad anatomy, bad hands, text, error, "
"missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, "
"normal quality, jpeg artifacts, signature, watermark, username, blurry, "
"artist name, (deformed iris, deformed pupils:1.2), (semi-realistic, cgi, "
"3d, render:1.1), amateur, (poorly drawn hands, poorly drawn face:1.2)"
),
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[
prompt,
ip_adapter_image,
ip_adapter_scale,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps
],
outputs=[result, seed]
)
demo.queue().launch()