|
import torch |
|
import numpy as np |
|
import itertools |
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
def expand_noise(noise, shape, seed, device, dtype): |
|
new_generator = torch.Generator().manual_seed(seed) |
|
corner_shape = (shape[0], shape[1], shape[2] // 2, shape[3] // 2) |
|
vert_border_shape = (shape[0], shape[1], shape[2], shape[3] // 2) |
|
hori_border_shape = (shape[0], shape[1], shape[2] // 2, shape[3]) |
|
|
|
corners = [randn_tensor(corner_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(4)] |
|
vert_borders = [randn_tensor(vert_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)] |
|
hori_borders = [randn_tensor(hori_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)] |
|
|
|
|
|
big_shape = (shape[0], shape[1], shape[2] * 2, shape[3] * 2) |
|
noise_template = randn_tensor(big_shape, generator=new_generator, device=device, dtype=dtype) |
|
|
|
ticks = [(0, 0.25), (0.25, 0.75), (0.75, 1.0)] |
|
grid = list(itertools.product(ticks, ticks)) |
|
noise_list = [ |
|
corners[0], hori_borders[0], corners[1], |
|
vert_borders[0], noise, vert_borders[1], |
|
corners[2], hori_borders[1], corners[3], |
|
] |
|
for current_noise, ((x1, x2), (y1, y2)) in zip(noise_list, grid): |
|
top_left = (int(x1 * big_shape[2]), int(y1 * big_shape[3])) |
|
bottom_right = (int(x2 * big_shape[2]), int(y2 * big_shape[3])) |
|
noise_template[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = current_noise |
|
|
|
return noise_template |
|
|
|
def custom_prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
latents=None, |
|
image=None, |
|
timestep=None, |
|
is_strength_max=True, |
|
use_noise_moving=True, |
|
return_noise=False, |
|
return_image_latents=False, |
|
newx=0, |
|
newy=0, |
|
newr=256, |
|
current_seed=None, |
|
): |
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if (image is None or timestep is None) and not is_strength_max: |
|
raise ValueError( |
|
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." |
|
"However, either the image or the noise timestep has not been provided." |
|
) |
|
|
|
if image.shape[1] == 4: |
|
image_latents = image.to(device=device, dtype=dtype) |
|
elif return_image_latents or (latents is None and not is_strength_max): |
|
image = image.to(device=device, dtype=dtype) |
|
image_latents = self._encode_vae_image(image=image, generator=generator) |
|
|
|
if latents is None and use_noise_moving: |
|
|
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
noise = expand_noise(noise, shape, seed=current_seed, device=device, dtype=dtype) |
|
|
|
|
|
newys = [newy] if not isinstance(newy, list) else newy |
|
newxs = [newx] if not isinstance(newx, list) else newx |
|
big_noise = noise.clone() |
|
prev_noise = None |
|
for newy, newx in zip(newys, newxs): |
|
|
|
sy = big_noise.shape[2] // 4 + ((512 - 128) - newy) // self.vae_scale_factor |
|
sx = big_noise.shape[3] // 4 + ((512 - 128) - newx) // self.vae_scale_factor |
|
|
|
if prev_noise is not None: |
|
new_noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]] |
|
|
|
ball_mask = torch.zeros(shape, device=device, dtype=bool) |
|
top_left = (newy // self.vae_scale_factor, newx // self.vae_scale_factor) |
|
bottom_right = (top_left[0] + newr // self.vae_scale_factor, top_left[1] + newr // self.vae_scale_factor) |
|
ball_mask[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = True |
|
|
|
noise = prev_noise.clone() |
|
noise[ball_mask] = new_noise[ball_mask] |
|
else: |
|
noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]] |
|
|
|
prev_noise = noise.clone() |
|
|
|
|
|
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) |
|
|
|
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents |
|
elif latents is None: |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
latents = image_latents.to(device) |
|
else: |
|
noise = latents.to(device) |
|
latents = noise * self.scheduler.init_noise_sigma |
|
|
|
outputs = (latents,) |
|
|
|
if return_noise: |
|
outputs += (noise,) |
|
|
|
if return_image_latents: |
|
outputs += (image_latents,) |
|
|
|
return outputs |
|
|
|
def custom_prepare_mask_latents( |
|
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance |
|
): |
|
|
|
|
|
|
|
mask = torch.nn.functional.interpolate( |
|
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor), |
|
mode="bilinear", align_corners=False |
|
) |
|
mask = mask.to(device=device, dtype=dtype) |
|
|
|
|
|
if mask.shape[0] < batch_size: |
|
if not batch_size % mask.shape[0] == 0: |
|
raise ValueError( |
|
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" |
|
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" |
|
" of masks that you pass is divisible by the total requested batch size." |
|
) |
|
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) |
|
|
|
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask |
|
|
|
masked_image_latents = None |
|
if masked_image is not None: |
|
masked_image = masked_image.to(device=device, dtype=dtype) |
|
masked_image_latents = self._encode_vae_image(masked_image, generator=generator) |
|
if masked_image_latents.shape[0] < batch_size: |
|
if not batch_size % masked_image_latents.shape[0] == 0: |
|
raise ValueError( |
|
"The passed images and the required batch size don't match. Images are supposed to be duplicated" |
|
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." |
|
" Make sure the number of images that you pass is divisible by the total requested batch size." |
|
) |
|
masked_image_latents = masked_image_latents.repeat( |
|
batch_size // masked_image_latents.shape[0], 1, 1, 1 |
|
) |
|
|
|
masked_image_latents = ( |
|
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents |
|
) |
|
|
|
|
|
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) |
|
|
|
return mask, masked_image_latents |