File size: 9,133 Bytes
dd06d6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
import numpy as np
import itertools
from diffusers.utils.torch_utils import randn_tensor
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def expand_noise(noise, shape, seed, device, dtype):
new_generator = torch.Generator().manual_seed(seed)
corner_shape = (shape[0], shape[1], shape[2] // 2, shape[3] // 2)
vert_border_shape = (shape[0], shape[1], shape[2], shape[3] // 2)
hori_border_shape = (shape[0], shape[1], shape[2] // 2, shape[3])
corners = [randn_tensor(corner_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(4)]
vert_borders = [randn_tensor(vert_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
hori_borders = [randn_tensor(hori_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
# combine
big_shape = (shape[0], shape[1], shape[2] * 2, shape[3] * 2)
noise_template = randn_tensor(big_shape, generator=new_generator, device=device, dtype=dtype)
ticks = [(0, 0.25), (0.25, 0.75), (0.75, 1.0)]
grid = list(itertools.product(ticks, ticks))
noise_list = [
corners[0], hori_borders[0], corners[1],
vert_borders[0], noise, vert_borders[1],
corners[2], hori_borders[1], corners[3],
]
for current_noise, ((x1, x2), (y1, y2)) in zip(noise_list, grid):
top_left = (int(x1 * big_shape[2]), int(y1 * big_shape[3]))
bottom_right = (int(x2 * big_shape[2]), int(y2 * big_shape[3]))
noise_template[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = current_noise
return noise_template
def custom_prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
image=None,
timestep=None,
is_strength_max=True,
use_noise_moving=True,
return_noise=False,
return_image_latents=False,
newx=0,
newy=0,
newr=256,
current_seed=None,
):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if (image is None or timestep is None) and not is_strength_max:
raise ValueError(
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
"However, either the image or the noise timestep has not been provided."
)
if image.shape[1] == 4:
image_latents = image.to(device=device, dtype=dtype)
elif return_image_latents or (latents is None and not is_strength_max):
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if latents is None and use_noise_moving:
# random big noise map
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = expand_noise(noise, shape, seed=current_seed, device=device, dtype=dtype)
# ensure noise is the same regardless of inpainting location (top-left corner notation)
newys = [newy] if not isinstance(newy, list) else newy
newxs = [newx] if not isinstance(newx, list) else newx
big_noise = noise.clone()
prev_noise = None
for newy, newx in zip(newys, newxs):
# find patch location within big noise map
sy = big_noise.shape[2] // 4 + ((512 - 128) - newy) // self.vae_scale_factor
sx = big_noise.shape[3] // 4 + ((512 - 128) - newx) // self.vae_scale_factor
if prev_noise is not None:
new_noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
ball_mask = torch.zeros(shape, device=device, dtype=bool)
top_left = (newy // self.vae_scale_factor, newx // self.vae_scale_factor)
bottom_right = (top_left[0] + newr // self.vae_scale_factor, top_left[1] + newr // self.vae_scale_factor) # fixed ball size r = 256
ball_mask[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = True
noise = prev_noise.clone()
noise[ball_mask] = new_noise[ball_mask]
else:
noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
prev_noise = noise.clone()
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
elif latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = image_latents.to(device)
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_image_latents:
outputs += (image_latents,)
return outputs
def custom_prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
mode="bilinear", align_corners=False #PURE: We add this to avoid sharp border of the ball
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = None
if masked_image is not None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents |