deneme / relighting /pipeline_utils.py
sakinlesh's picture
Upload 25 files
dd06d6b verified
raw
history blame
9.13 kB
import torch
import numpy as np
import itertools
from diffusers.utils.torch_utils import randn_tensor
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def expand_noise(noise, shape, seed, device, dtype):
new_generator = torch.Generator().manual_seed(seed)
corner_shape = (shape[0], shape[1], shape[2] // 2, shape[3] // 2)
vert_border_shape = (shape[0], shape[1], shape[2], shape[3] // 2)
hori_border_shape = (shape[0], shape[1], shape[2] // 2, shape[3])
corners = [randn_tensor(corner_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(4)]
vert_borders = [randn_tensor(vert_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
hori_borders = [randn_tensor(hori_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
# combine
big_shape = (shape[0], shape[1], shape[2] * 2, shape[3] * 2)
noise_template = randn_tensor(big_shape, generator=new_generator, device=device, dtype=dtype)
ticks = [(0, 0.25), (0.25, 0.75), (0.75, 1.0)]
grid = list(itertools.product(ticks, ticks))
noise_list = [
corners[0], hori_borders[0], corners[1],
vert_borders[0], noise, vert_borders[1],
corners[2], hori_borders[1], corners[3],
]
for current_noise, ((x1, x2), (y1, y2)) in zip(noise_list, grid):
top_left = (int(x1 * big_shape[2]), int(y1 * big_shape[3]))
bottom_right = (int(x2 * big_shape[2]), int(y2 * big_shape[3]))
noise_template[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = current_noise
return noise_template
def custom_prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
use_noise_moving=True,
return_noise=False,
return_image_latents=False,
newx=0,
newy=0,
newr=256,
current_seed=None,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if latents is None and use_noise_moving:
# random big noise map
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = expand_noise(noise, shape, seed=current_seed, device=device, dtype=dtype)
# ensure noise is the same regardless of inpainting location (top-left corner notation)
newys = [newy] if not isinstance(newy, list) else newy
newxs = [newx] if not isinstance(newx, list) else newx
big_noise = noise.clone()
prev_noise = None
for newy, newx in zip(newys, newxs):
# find patch location within big noise map
sy = big_noise.shape[2] // 4 + ((512 - 128) - newy) // self.vae_scale_factor
sx = big_noise.shape[3] // 4 + ((512 - 128) - newx) // self.vae_scale_factor
if prev_noise is not None:
new_noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
ball_mask = torch.zeros(shape, device=device, dtype=bool)
top_left = (newy // self.vae_scale_factor, newx // self.vae_scale_factor)
bottom_right = (top_left[0] + newr // self.vae_scale_factor, top_left[1] + newr // self.vae_scale_factor) # fixed ball size r = 256
ball_mask[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = True
noise = prev_noise.clone()
noise[ball_mask] = new_noise[ball_mask]
else:
noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
prev_noise = noise.clone()
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
elif latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
def custom_prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
mode="bilinear", align_corners=False #PURE: We add this to avoid sharp border of the ball
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = None
if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents