|
|
|
|
|
|
|
|
|
import torch |
|
import argparse |
|
import numpy as np |
|
import torch.distributed as dist |
|
import os |
|
from PIL import Image |
|
from tqdm.auto import tqdm |
|
import json |
|
|
|
|
|
from relighting.inpainter import BallInpainter |
|
|
|
from relighting.mask_utils import MaskGenerator |
|
from relighting.ball_processor import ( |
|
get_ideal_normal_ball, |
|
crop_ball |
|
) |
|
from relighting.dataset import GeneralLoader |
|
from relighting.utils import name2hash |
|
import relighting.dist_utils as dist_util |
|
import time |
|
|
|
|
|
|
|
from relighting.argument import ( |
|
SD_MODELS, |
|
CONTROLNET_MODELS, |
|
VAE_MODELS |
|
) |
|
|
|
def create_argparser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--dataset", type=str, required=True ,help='directory that contain the image') |
|
parser.add_argument("--ball_size", type=int, default=256, help="size of the ball in pixel") |
|
parser.add_argument("--ball_dilate", type=int, default=20, help="How much pixel to dilate the ball to make a sharper edge") |
|
parser.add_argument("--prompt", type=str, default="a perfect mirrored reflective chrome ball sphere") |
|
parser.add_argument("--prompt_dark", type=str, default="a perfect black dark mirrored reflective chrome ball sphere") |
|
parser.add_argument("--negative_prompt", type=str, default="matte, diffuse, flat, dull") |
|
parser.add_argument("--model_option", default="sdxl", help='selecting fancy model option (sd15_old, sd15_new, sd21, sdxl, sdxl_turbo)') |
|
parser.add_argument("--output_dir", required=True, type=str, help="output directory") |
|
parser.add_argument("--img_height", type=int, default=1024, help="Dataset Image Height") |
|
parser.add_argument("--img_width", type=int, default=1024, help="Dataset Image Width") |
|
|
|
parser.add_argument("--seed", default="auto", type=str, help="Seed: right now we use single seed instead to reduce the time, (Auto will use hash file name to generate seed)") |
|
parser.add_argument("--denoising_step", default=30, type=int, help="number of denoising step of diffusion model") |
|
parser.add_argument("--control_scale", default=0.5, type=float, help="controlnet conditioning scale") |
|
parser.add_argument("--guidance_scale", default=5.0, type=float, help="guidance scale (also known as CFG)") |
|
|
|
parser.add_argument('--no_controlnet', dest='use_controlnet', action='store_false', help='by default we using controlnet, we have option to disable to see the different') |
|
parser.set_defaults(use_controlnet=True) |
|
|
|
parser.add_argument('--no_force_square', dest='force_square', action='store_false', help='SDXL is trained for square image, we prefered the square input. but you use this option to disable reshape') |
|
parser.set_defaults(force_square=True) |
|
|
|
parser.add_argument('--no_random_loader', dest='random_loader', action='store_false', help="by default, we random how dataset load. This make us able to peak into the trend of result without waiting entire dataset. but can disable if prefereed") |
|
parser.set_defaults(random_loader=True) |
|
|
|
parser.add_argument('--cpu', dest='is_cpu', action='store_true', help="using CPU inference instead of GPU inference") |
|
parser.set_defaults(is_cpu=False) |
|
|
|
parser.add_argument('--offload', dest='offload', action='store_false', help="to enable diffusers cpu offload") |
|
parser.set_defaults(offload=False) |
|
|
|
parser.add_argument("--limit_input", default=0, type=int, help="limit number of image to process to n image (0 = no limit), useful for run smallset") |
|
|
|
|
|
|
|
parser.add_argument('--no_lora', dest='use_lora', action='store_false', help='by default we using lora, we have option to disable to see the different') |
|
parser.set_defaults(use_lora=True) |
|
|
|
parser.add_argument("--lora_path", default="models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500", type=str, help="LoRA Checkpoint path") |
|
parser.add_argument("--lora_scale", default=0.75, type=float, help="LoRA scale factor") |
|
|
|
|
|
parser.add_argument('--no_torch_compile', dest='use_torch_compile', action='store_false', help='by default we using torch compile for faster processing speed. disable it if your environemnt is lower than pytorch2.0') |
|
parser.set_defaults(use_torch_compile=True) |
|
|
|
|
|
parser.add_argument("--algorithm", type=str, default="iterative", choices=["iterative", "normal"], help="Selecting between iterative or normal (single pass inpaint) algorithm") |
|
|
|
parser.add_argument("--agg_mode", default="median", type=str) |
|
parser.add_argument("--strength", default=0.8, type=float) |
|
parser.add_argument("--num_iteration", default=2, type=int) |
|
parser.add_argument("--ball_per_iteration", default=30, type=int) |
|
parser.add_argument('--no_save_intermediate', dest='save_intermediate', action='store_false') |
|
parser.set_defaults(save_intermediate=True) |
|
parser.add_argument("--cache_dir", default="./temp_inpaint_iterative", type=str, help="cache directory for iterative inpaint") |
|
|
|
|
|
parser.add_argument("--idx", default=0, type=int, help="index of the current process, useful for running on multiple node") |
|
parser.add_argument("--total", default=1, type=int, help="total number of process") |
|
|
|
|
|
parser.add_argument("--max_negative_ev", default=-5, type=int, help="maximum negative EV for lora") |
|
parser.add_argument("--ev", default="0,-2.5,-5", type=str, help="EV: list of EV to generate") |
|
|
|
return parser |
|
|
|
def get_ball_location(image_data, args): |
|
if 'boundary' in image_data: |
|
|
|
x = image_data["boundary"]["x"] |
|
y = image_data["boundary"]["y"] |
|
r = image_data["boundary"]["size"] |
|
|
|
|
|
half_dilate = args.ball_dilate // 2 |
|
|
|
|
|
if x - half_dilate < 0: x += half_dilate |
|
if y - half_dilate < 0: y += half_dilate |
|
|
|
|
|
if x + r + half_dilate > args.img_width: x -= half_dilate |
|
if y + r + half_dilate > args.img_height: y -= half_dilate |
|
|
|
else: |
|
|
|
x, y, r = ((args.img_width // 2) - (args.ball_size // 2), (args.img_height // 2) - (args.ball_size // 2), args.ball_size) |
|
return x, y, r |
|
|
|
def interpolate_embedding(pipe, args): |
|
print("interpolate embedding...") |
|
|
|
|
|
ev_list = [float(x) for x in args.ev.split(",")] |
|
interpolants = [ev / args.max_negative_ev for ev in ev_list] |
|
|
|
print("EV : ", ev_list) |
|
print("EV : ", interpolants) |
|
|
|
|
|
prompt_normal = args.prompt |
|
prompt_dark = args.prompt_dark |
|
prompt_embeds_normal, _, pooled_prompt_embeds_normal, _ = pipe.pipeline.encode_prompt(prompt_normal) |
|
prompt_embeds_dark, _, pooled_prompt_embeds_dark, _ = pipe.pipeline.encode_prompt(prompt_dark) |
|
|
|
|
|
interpolate_embeds = [] |
|
for t in interpolants: |
|
int_prompt_embeds = prompt_embeds_normal + t * (prompt_embeds_dark - prompt_embeds_normal) |
|
int_pooled_prompt_embeds = pooled_prompt_embeds_normal + t * (pooled_prompt_embeds_dark - pooled_prompt_embeds_normal) |
|
|
|
interpolate_embeds.append((int_prompt_embeds, int_pooled_prompt_embeds)) |
|
|
|
return dict(zip(ev_list, interpolate_embeds)) |
|
|
|
def main(): |
|
|
|
args = create_argparser().parse_args() |
|
|
|
|
|
if args.is_cpu: |
|
device = torch.device("cpu") |
|
torch_dtype = torch.float32 |
|
else: |
|
device = dist_util.dev() |
|
torch_dtype = torch.float16 |
|
|
|
|
|
assert args.ball_dilate % 2 == 0 |
|
|
|
|
|
if args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and args.use_controlnet: |
|
model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option] |
|
pipe = BallInpainter.from_sdxl( |
|
model=model, |
|
controlnet=controlnet, |
|
device=device, |
|
torch_dtype = torch_dtype, |
|
offload = args.offload |
|
) |
|
elif args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and not args.use_controlnet: |
|
model = SD_MODELS[args.model_option] |
|
pipe = BallInpainter.from_sdxl( |
|
model=model, |
|
controlnet=None, |
|
device=device, |
|
torch_dtype = torch_dtype, |
|
offload = args.offload |
|
) |
|
elif args.use_controlnet: |
|
model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option] |
|
pipe = BallInpainter.from_sd( |
|
model=model, |
|
controlnet=controlnet, |
|
device=device, |
|
torch_dtype = torch_dtype, |
|
offload = args.offload |
|
) |
|
else: |
|
model = SD_MODELS[args.model_option] |
|
pipe = BallInpainter.from_sd( |
|
model=model, |
|
controlnet=None, |
|
device=device, |
|
torch_dtype = torch_dtype, |
|
offload = args.offload |
|
) |
|
|
|
if args.model_option in ["sdxl_turbo"]: |
|
|
|
args.guidance_scale = 0.0 |
|
|
|
if args.lora_scale > 0 and args.lora_path is None: |
|
raise ValueError("lora scale is not 0 but lora path is not set") |
|
|
|
if (args.lora_path is not None) and (args.use_lora): |
|
print(f"using lora path {args.lora_path}") |
|
print(f"using lora scale {args.lora_scale}") |
|
pipe.pipeline.load_lora_weights(args.lora_path) |
|
pipe.pipeline.fuse_lora(lora_scale=args.lora_scale) |
|
enabled_lora = True |
|
else: |
|
enabled_lora = False |
|
|
|
if args.use_torch_compile: |
|
try: |
|
print("compiling unet model") |
|
start_time = time.time() |
|
pipe.pipeline.unet = torch.compile(pipe.pipeline.unet, mode="reduce-overhead", fullgraph=True) |
|
print("Model compilation time: ", time.time() - start_time) |
|
except: |
|
pass |
|
|
|
|
|
if args.model_option == "sdxl" and args.img_height == 0 and args.img_width == 0: |
|
args.img_height = 1024 |
|
args.img_width = 1024 |
|
|
|
|
|
dataset = GeneralLoader( |
|
root=args.dataset, |
|
resolution=(args.img_width, args.img_height), |
|
force_square=args.force_square, |
|
return_dict=True, |
|
random_shuffle=args.random_loader, |
|
process_id=args.idx, |
|
process_total=args.total, |
|
limit_input=args.limit_input, |
|
) |
|
|
|
|
|
embedding_dict = interpolate_embedding(pipe, args) |
|
|
|
|
|
mask_generator = MaskGenerator() |
|
normal_ball, mask_ball = get_ideal_normal_ball(size=args.ball_size+args.ball_dilate) |
|
_, mask_ball_for_crop = get_ideal_normal_ball(size=args.ball_size) |
|
|
|
|
|
|
|
raw_output_dir = os.path.join(args.output_dir, "raw") |
|
control_output_dir = os.path.join(args.output_dir, "control") |
|
square_output_dir = os.path.join(args.output_dir, "square") |
|
os.makedirs(args.output_dir, exist_ok=True) |
|
os.makedirs(raw_output_dir, exist_ok=True) |
|
os.makedirs(control_output_dir, exist_ok=True) |
|
os.makedirs(square_output_dir, exist_ok=True) |
|
|
|
|
|
|
|
seeds = args.seed.split(",") |
|
|
|
for image_data in tqdm(dataset): |
|
input_image = image_data["image"] |
|
image_path = image_data["path"] |
|
|
|
for ev, (prompt_embeds, pooled_prompt_embeds) in embedding_dict.items(): |
|
|
|
ev_str = str(ev).replace(".", "") if ev != 0 else "-00" |
|
outname = os.path.basename(image_path).split(".")[0] + f"_ev{ev_str}" |
|
|
|
|
|
x, y, r = get_ball_location(image_data, args) |
|
|
|
|
|
mask = mask_generator.generate_single( |
|
input_image, mask_ball, |
|
x - (args.ball_dilate // 2), |
|
y - (args.ball_dilate // 2), |
|
r + args.ball_dilate |
|
) |
|
|
|
seeds = tqdm(seeds, desc="seeds") if len(seeds) > 10 else seeds |
|
|
|
|
|
for seed in seeds: |
|
start_time = time.time() |
|
|
|
if seed == "auto": |
|
filename = os.path.basename(image_path).split(".")[0] |
|
seed = name2hash(filename) |
|
outpng = f"{outname}.png" |
|
cache_name = f"{outname}" |
|
else: |
|
seed = int(seed) |
|
outpng = f"{outname}_seed{seed}.png" |
|
cache_name = f"{outname}_seed{seed}" |
|
|
|
if os.path.exists(os.path.join(square_output_dir, outpng)): |
|
continue |
|
generator = torch.Generator().manual_seed(seed) |
|
kwargs = { |
|
"prompt_embeds": prompt_embeds, |
|
"pooled_prompt_embeds": pooled_prompt_embeds, |
|
'negative_prompt': args.negative_prompt, |
|
'num_inference_steps': args.denoising_step, |
|
'generator': generator, |
|
'image': input_image, |
|
'mask_image': mask, |
|
'strength': 1.0, |
|
'current_seed': seed, |
|
'controlnet_conditioning_scale': args.control_scale, |
|
'height': args.img_height, |
|
'width': args.img_width, |
|
'normal_ball': normal_ball, |
|
'mask_ball': mask_ball, |
|
'x': x, |
|
'y': y, |
|
'r': r, |
|
'guidance_scale': args.guidance_scale, |
|
} |
|
|
|
if enabled_lora: |
|
kwargs["cross_attention_kwargs"] = {"scale": args.lora_scale} |
|
|
|
if args.algorithm == "normal": |
|
output_image = pipe.inpaint(**kwargs).images[0] |
|
elif args.algorithm == "iterative": |
|
|
|
print("using inpainting iterative, this is going to take a while...") |
|
kwargs.update({ |
|
"strength": args.strength, |
|
"num_iteration": args.num_iteration, |
|
"ball_per_iteration": args.ball_per_iteration, |
|
"agg_mode": args.agg_mode, |
|
"save_intermediate": args.save_intermediate, |
|
"cache_dir": os.path.join(args.cache_dir, cache_name), |
|
}) |
|
output_image = pipe.inpaint_iterative(**kwargs) |
|
else: |
|
raise NotImplementedError(f"Unknown algorithm {args.algorithm}") |
|
|
|
|
|
square_image = output_image.crop((x, y, x+r, y+r)) |
|
|
|
|
|
control_image = pipe.get_cache_control_image() |
|
if control_image is not None: |
|
control_image.save(os.path.join(control_output_dir, outpng)) |
|
|
|
|
|
output_image.save(os.path.join(raw_output_dir, outpng)) |
|
square_image.save(os.path.join(square_output_dir, outpng)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |