|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import json |
|
import os |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from diffusers.models import UNet2DConditionModel |
|
from diffusers.models.attention_processor import Attention |
|
from diffusers.models.transformers.transformer_2d import BasicTransformerBlock |
|
from einops import rearrange |
|
|
|
|
|
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): |
|
|
|
if hidden_states.shape[chunk_dim] % chunk_size != 0: |
|
raise ValueError( |
|
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." |
|
) |
|
|
|
num_chunks = hidden_states.shape[chunk_dim] // chunk_size |
|
ff_output = torch.cat( |
|
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], |
|
dim=chunk_dim, |
|
) |
|
return ff_output |
|
|
|
class PoseRoPEAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
def get_1d_rotary_pos_embed( |
|
self, |
|
dim: int, |
|
pos: torch.Tensor, |
|
theta: float = 10000.0, |
|
linear_factor=1.0, |
|
ntk_factor=1.0, |
|
): |
|
assert dim % 2 == 0 |
|
|
|
theta = theta * ntk_factor |
|
freqs = ( |
|
1.0 |
|
/ (theta ** (torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device)[: (dim // 2)] / dim)) |
|
/ linear_factor |
|
) |
|
freqs = torch.outer(pos, freqs) |
|
|
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() |
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() |
|
return freqs_cos, freqs_sin |
|
|
|
|
|
def get_3d_rotary_pos_embed( |
|
self, |
|
position, |
|
embed_dim, |
|
voxel_resolution, |
|
theta: int = 10000, |
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
""" |
|
RoPE for video tokens with 3D structure. |
|
|
|
Args: |
|
voxel_resolution (`int`): |
|
The grid size of the spatial positional embedding (height, width). |
|
theta (`float`): |
|
Scaling factor for frequency computation. |
|
|
|
Returns: |
|
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. |
|
""" |
|
assert position.shape[-1]==3 |
|
|
|
|
|
dim_xy = embed_dim // 8 * 3 |
|
dim_z = embed_dim // 8 * 2 |
|
|
|
|
|
grid = torch.arange(voxel_resolution, dtype=torch.float32, device=position.device) |
|
freqs_xy = self.get_1d_rotary_pos_embed(dim_xy, grid, theta=theta) |
|
freqs_z = self.get_1d_rotary_pos_embed(dim_z, grid, theta=theta) |
|
|
|
xy_cos, xy_sin = freqs_xy |
|
z_cos, z_sin = freqs_z |
|
|
|
embed_flattn = position.view(-1, position.shape[-1]) |
|
x_cos = xy_cos[embed_flattn[:,0], :] |
|
x_sin = xy_sin[embed_flattn[:,0], :] |
|
y_cos = xy_cos[embed_flattn[:,1], :] |
|
y_sin = xy_sin[embed_flattn[:,1], :] |
|
z_cos = z_cos[embed_flattn[:,2], :] |
|
z_sin = z_sin[embed_flattn[:,2], :] |
|
|
|
cos = torch.cat((x_cos, y_cos, z_cos), dim=-1) |
|
sin = torch.cat((x_sin, y_sin, z_sin), dim=-1) |
|
|
|
cos = cos.view(*position.shape[:-1], embed_dim) |
|
sin = sin.view(*position.shape[:-1], embed_dim) |
|
return cos, sin |
|
|
|
def apply_rotary_emb( |
|
self, |
|
x: torch.Tensor, |
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] |
|
): |
|
cos, sin = freqs_cis |
|
cos, sin = cos.to(x.device), sin.to(x.device) |
|
cos = cos.unsqueeze(1) |
|
sin = sin.unsqueeze(1) |
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
|
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
|
return out |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_indices: Dict = None, |
|
temb: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
if position_indices is not None: |
|
if head_dim in position_indices: |
|
image_rotary_emb = position_indices[head_dim] |
|
else: |
|
image_rotary_emb = self.get_3d_rotary_pos_embed(position_indices['voxel_indices'], head_dim, voxel_resolution=position_indices['voxel_resolution']) |
|
position_indices[head_dim] = image_rotary_emb |
|
query = self.apply_rotary_emb(query, image_rotary_emb) |
|
key = self.apply_rotary_emb(key, image_rotary_emb) |
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
class IPAttnProcessor2_0: |
|
r""" |
|
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). |
|
""" |
|
|
|
def __init__(self, scale=0.0): |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
|
|
|
self.scale = scale |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
ip_hidden_states: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
temb: Optional[torch.Tensor] = None, |
|
*args, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
residual = hidden_states |
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
|
batch_size, sequence_length, _ = ( |
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
|
query = attn.to_q(hidden_states) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
inner_dim = key.shape[-1] |
|
head_dim = inner_dim // attn.heads |
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
|
|
|
|
hidden_states = F.scaled_dot_product_attention( |
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
hidden_states = hidden_states.to(query.dtype) |
|
|
|
|
|
if ip_hidden_states is not None: |
|
|
|
ip_key = attn.to_k_ip(ip_hidden_states) |
|
ip_value = attn.to_v_ip(ip_hidden_states) |
|
|
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
|
|
|
ip_hidden_states = F.scaled_dot_product_attention( |
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False |
|
) |
|
|
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
|
ip_hidden_states = ip_hidden_states.to(query.dtype) |
|
|
|
hidden_states = hidden_states + self.scale * ip_hidden_states |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
|
|
return hidden_states |
|
|
|
|
|
class Basic2p5DTransformerBlock(torch.nn.Module): |
|
def __init__(self, transformer: BasicTransformerBlock, layer_name, use_ipa=True, use_ma=True, use_ra=True) -> None: |
|
super().__init__() |
|
self.transformer = transformer |
|
self.layer_name = layer_name |
|
self.use_ipa = use_ipa |
|
self.use_ma = use_ma |
|
self.use_ra = use_ra |
|
|
|
if use_ipa: |
|
self.attn2.set_processor(IPAttnProcessor2_0()) |
|
cross_attention_dim = 1024 |
|
self.attn2.to_k_ip = nn.Linear(cross_attention_dim, self.dim, bias=False) |
|
self.attn2.to_v_ip = nn.Linear(cross_attention_dim, self.dim, bias=False) |
|
|
|
|
|
if self.use_ma: |
|
self.attn_multiview = Attention( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
processor=PoseRoPEAttnProcessor2_0(), |
|
) |
|
|
|
|
|
if self.use_ra: |
|
self.attn_refview = Attention( |
|
query_dim=self.dim, |
|
heads=self.num_attention_heads, |
|
dim_head=self.attention_head_dim, |
|
dropout=self.dropout, |
|
bias=self.attention_bias, |
|
cross_attention_dim=None, |
|
upcast_attention=self.attn1.upcast_attention, |
|
out_bias=True, |
|
) |
|
|
|
self._initialize_attn_weights() |
|
|
|
def _initialize_attn_weights(self): |
|
|
|
if self.use_ma: |
|
self.attn_multiview.load_state_dict(self.attn1.state_dict()) |
|
with torch.no_grad(): |
|
for layer in self.attn_multiview.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
if self.use_ra: |
|
self.attn_refview.load_state_dict(self.attn1.state_dict()) |
|
with torch.no_grad(): |
|
for layer in self.attn_refview.to_out: |
|
for param in layer.parameters(): |
|
param.zero_() |
|
|
|
if self.use_ipa: |
|
self.attn2.to_k_ip.load_state_dict(self.attn2.to_k.state_dict()) |
|
self.attn2.to_v_ip.load_state_dict(self.attn2.to_v.state_dict()) |
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.transformer, name) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
timestep: Optional[torch.LongTensor] = None, |
|
cross_attention_kwargs: Dict[str, Any] = None, |
|
class_labels: Optional[torch.LongTensor] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
|
num_in_batch = cross_attention_kwargs.pop('num_in_batch', 1) |
|
mode = cross_attention_kwargs.pop('mode', None) |
|
condition_embed_dict = cross_attention_kwargs.pop("condition_embed_dict", None) |
|
ip_hidden_states = cross_attention_kwargs.pop("ip_hidden_states", None) |
|
position_attn_mask = cross_attention_kwargs.pop("position_attn_mask", None) |
|
position_voxel_indices = cross_attention_kwargs.pop("position_voxel_indices", None) |
|
|
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm1(hidden_states, timestep) |
|
elif self.norm_type == "ada_norm_zero": |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( |
|
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype |
|
) |
|
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm1(hidden_states) |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif self.norm_type == "ada_norm_single": |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) |
|
).chunk(6, dim=1) |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
else: |
|
raise ValueError("Incorrect norm used") |
|
|
|
if self.pos_embed is not None: |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
|
|
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} |
|
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) |
|
|
|
attn_output = self.attn1( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, |
|
attention_mask=attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
if self.norm_type == "ada_norm_zero": |
|
attn_output = gate_msa.unsqueeze(1) * attn_output |
|
elif self.norm_type == "ada_norm_single": |
|
attn_output = gate_msa * attn_output |
|
|
|
hidden_states = attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if 'w' in mode: |
|
condition_embed_dict[self.layer_name] = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch) |
|
|
|
if 'r' in mode: |
|
condition_embed = condition_embed_dict[self.layer_name].unsqueeze(1).repeat(1,num_in_batch,1,1) |
|
condition_embed = rearrange(condition_embed, 'b n l c -> (b n) l c') |
|
|
|
attn_output = self.attn_refview( |
|
norm_hidden_states, |
|
encoder_hidden_states=condition_embed, |
|
attention_mask=None, |
|
**cross_attention_kwargs |
|
) |
|
|
|
hidden_states = attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
|
|
if num_in_batch > 1 and self.use_ma: |
|
multivew_hidden_states = rearrange(norm_hidden_states, '(b n) l c -> b (n l) c', n=num_in_batch) |
|
position_mask = None |
|
if position_attn_mask is not None: |
|
if multivew_hidden_states.shape[1] in position_attn_mask: |
|
position_mask = position_attn_mask[multivew_hidden_states.shape[1]] |
|
position_indices = None |
|
if position_voxel_indices is not None: |
|
if multivew_hidden_states.shape[1] in position_voxel_indices: |
|
position_indices = position_voxel_indices[multivew_hidden_states.shape[1]] |
|
|
|
attn_output = self.attn_multiview( |
|
multivew_hidden_states, |
|
encoder_hidden_states=multivew_hidden_states, |
|
attention_mask=position_mask, |
|
position_indices=position_indices, |
|
**cross_attention_kwargs |
|
) |
|
|
|
attn_output = rearrange(attn_output, 'b (n l) c -> (b n) l c', n=num_in_batch) |
|
|
|
hidden_states = attn_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
|
|
if gligen_kwargs is not None: |
|
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) |
|
|
|
|
|
if self.attn2 is not None: |
|
if self.norm_type == "ada_norm": |
|
norm_hidden_states = self.norm2(hidden_states, timestep) |
|
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]: |
|
norm_hidden_states = self.norm2(hidden_states) |
|
elif self.norm_type == "ada_norm_single": |
|
|
|
|
|
norm_hidden_states = hidden_states |
|
elif self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
else: |
|
raise ValueError("Incorrect norm") |
|
|
|
if self.pos_embed is not None and self.norm_type != "ada_norm_single": |
|
norm_hidden_states = self.pos_embed(norm_hidden_states) |
|
|
|
if ip_hidden_states is not None: |
|
ip_hidden_states = ip_hidden_states.unsqueeze(1).repeat(1,num_in_batch,1,1) |
|
ip_hidden_states = rearrange(ip_hidden_states, 'b n l c -> (b n) l c') |
|
|
|
if self.use_ipa: |
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
ip_hidden_states=ip_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
else: |
|
attn_output = self.attn2( |
|
norm_hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=encoder_attention_mask, |
|
**cross_attention_kwargs, |
|
) |
|
|
|
hidden_states = attn_output + hidden_states |
|
|
|
|
|
|
|
if self.norm_type == "ada_norm_continuous": |
|
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) |
|
elif not self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] |
|
|
|
if self.norm_type == "ada_norm_single": |
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
|
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
if self.norm_type == "ada_norm_zero": |
|
ff_output = gate_mlp.unsqueeze(1) * ff_output |
|
elif self.norm_type == "ada_norm_single": |
|
ff_output = gate_mlp * ff_output |
|
|
|
hidden_states = ff_output + hidden_states |
|
if hidden_states.ndim == 4: |
|
hidden_states = hidden_states.squeeze(1) |
|
|
|
return hidden_states |
|
|
|
@torch.no_grad() |
|
def compute_voxel_grid_mask(position, grid_resolution=8): |
|
|
|
position = position.half() |
|
B,N,_,H,W = position.shape |
|
assert H%grid_resolution==0 and W%grid_resolution==0 |
|
|
|
valid_mask = (position != 1).all(dim=2, keepdim=True) |
|
valid_mask = valid_mask.expand_as(position) |
|
position[valid_mask==False] = 0 |
|
|
|
|
|
position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution) |
|
valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution) |
|
|
|
grid_position = position.sum(dim=(-2, -1)) |
|
count_masked = valid_mask.sum(dim=(-2, -1)) |
|
|
|
grid_position = grid_position / count_masked.clamp(min=1) |
|
grid_position[count_masked<5] = 0 |
|
|
|
grid_position = grid_position.permute(0,1,4,2,3) |
|
grid_position = rearrange(grid_position, 'b n c h w -> b n (h w) c') |
|
|
|
grid_position_expanded_1 = grid_position.unsqueeze(2).unsqueeze(4) |
|
grid_position_expanded_2 = grid_position.unsqueeze(1).unsqueeze(3) |
|
|
|
|
|
distances = torch.norm(grid_position_expanded_1 - grid_position_expanded_2, dim=-1) |
|
|
|
weights = distances |
|
grid_distance = 1.73/grid_resolution |
|
|
|
|
|
|
|
|
|
weights = weights< grid_distance |
|
|
|
return weights |
|
|
|
def compute_multi_resolution_mask(position_maps, grid_resolutions=[32, 16, 8]): |
|
position_attn_mask = {} |
|
with torch.no_grad(): |
|
for grid_resolution in grid_resolutions: |
|
position_mask = compute_voxel_grid_mask(position_maps, grid_resolution) |
|
position_mask = rearrange(position_mask, 'b ni nj li lj -> b (ni li) (nj lj)') |
|
position_attn_mask[position_mask.shape[1]] = position_mask |
|
return position_attn_mask |
|
|
|
@torch.no_grad() |
|
def compute_discrete_voxel_indice(position, grid_resolution=8, voxel_resolution=128): |
|
|
|
position = position.half() |
|
B,N,_,H,W = position.shape |
|
assert H%grid_resolution==0 and W%grid_resolution==0 |
|
|
|
valid_mask = (position != 1).all(dim=2, keepdim=True) |
|
valid_mask = valid_mask.expand_as(position) |
|
position[valid_mask==False] = 0 |
|
|
|
position = rearrange(position, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution) |
|
valid_mask = rearrange(valid_mask, 'b n c (num_h grid_h) (num_w grid_w) -> b n num_h num_w c grid_h grid_w', num_h=grid_resolution, num_w=grid_resolution) |
|
|
|
grid_position = position.sum(dim=(-2, -1)) |
|
count_masked = valid_mask.sum(dim=(-2, -1)) |
|
|
|
grid_position = grid_position / count_masked.clamp(min=1) |
|
grid_position[count_masked<5] = 0 |
|
|
|
grid_position = grid_position.permute(0,1,4,2,3).clamp(0, 1) |
|
voxel_indices = grid_position * (voxel_resolution - 1) |
|
voxel_indices = torch.round(voxel_indices).long() |
|
return voxel_indices |
|
|
|
def compute_multi_resolution_discrete_voxel_indice(position_maps, grid_resolutions=[64, 32, 16, 8], voxel_resolutions=[512, 256, 128, 64]): |
|
voxel_indices = {} |
|
with torch.no_grad(): |
|
for grid_resolution, voxel_resolution in zip(grid_resolutions, voxel_resolutions): |
|
voxel_indice = compute_discrete_voxel_indice(position_maps, grid_resolution, voxel_resolution) |
|
voxel_indice = rearrange(voxel_indice, 'b n c h w -> b (n h w) c') |
|
voxel_indices[voxel_indice.shape[1]] = {'voxel_indices':voxel_indice, 'voxel_resolution':voxel_resolution} |
|
return voxel_indices |
|
|
|
class ImageProjModel(torch.nn.Module): |
|
"""Projection Model""" |
|
|
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): |
|
super().__init__() |
|
|
|
self.generator = None |
|
self.cross_attention_dim = cross_attention_dim |
|
self.clip_extra_context_tokens = clip_extra_context_tokens |
|
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) |
|
self.norm = torch.nn.LayerNorm(cross_attention_dim) |
|
|
|
def forward(self, image_embeds): |
|
embeds = image_embeds |
|
clip_extra_context_tokens = self.proj(embeds).reshape( |
|
-1, self.clip_extra_context_tokens, self.cross_attention_dim |
|
) |
|
clip_extra_context_tokens = self.norm(clip_extra_context_tokens) |
|
return clip_extra_context_tokens |
|
|
|
class UNet2p5DConditionModel(torch.nn.Module): |
|
def __init__(self, unet: UNet2DConditionModel) -> None: |
|
super().__init__() |
|
self.unet = unet |
|
self.unet_dual = copy.deepcopy(unet) |
|
|
|
self.init_camera_embedding() |
|
self.init_attention(self.unet, use_ipa=True, use_ma=True, use_ra=True) |
|
self.init_attention(self.unet_dual, use_ipa=False, use_ma=False, use_ra=False) |
|
self.init_condition() |
|
|
|
@staticmethod |
|
def from_pretrained(pretrained_model_name_or_path, **kwargs): |
|
torch_dtype = kwargs.pop('torch_dtype', torch.float32) |
|
config_path = os.path.join(pretrained_model_name_or_path, 'config.json') |
|
unet_ckpt_path = os.path.join(pretrained_model_name_or_path, 'diffusion_pytorch_model.bin') |
|
with open(config_path, 'r', encoding='utf-8') as file: |
|
config = json.load(file) |
|
unet = UNet2DConditionModel(**config) |
|
unet = UNet2p5DConditionModel(unet) |
|
|
|
unet.unet.conv_in = torch.nn.Conv2d( |
|
12, |
|
unet.unet.conv_in.out_channels, |
|
kernel_size=unet.unet.conv_in.kernel_size, |
|
stride=unet.unet.conv_in.stride, |
|
padding=unet.unet.conv_in.padding, |
|
dilation=unet.unet.conv_in.dilation, |
|
groups=unet.unet.conv_in.groups, |
|
bias=unet.unet.conv_in.bias is not None) |
|
|
|
unet_ckpt = torch.load(unet_ckpt_path, map_location='cpu', weights_only=True) |
|
unet.load_state_dict(unet_ckpt, strict=True) |
|
unet = unet.to(torch_dtype) |
|
return unet |
|
|
|
def init_condition(self): |
|
self.unet.learned_text_clip_gen = nn.Parameter(torch.randn(1,77,1024)) |
|
self.unet.learned_text_clip_ref = nn.Parameter(torch.randn(1,77,1024)) |
|
|
|
self.unet.image_proj_model = ImageProjModel( |
|
cross_attention_dim=self.unet.config.cross_attention_dim, |
|
clip_embeddings_dim=1024, |
|
) |
|
|
|
|
|
def init_camera_embedding(self): |
|
self.max_num_ref_image = 5 |
|
self.max_num_gen_image = 12*3+4*2 |
|
|
|
time_embed_dim = 1280 |
|
self.unet.class_embedding = nn.Embedding(self.max_num_ref_image+self.max_num_gen_image, time_embed_dim) |
|
|
|
nn.init.zeros_(self.unet.class_embedding.weight) |
|
|
|
def init_attention(self, unet, use_ipa=True, use_ma=True, use_ra=True): |
|
|
|
for down_block_i, down_block in enumerate(unet.down_blocks): |
|
if hasattr(down_block, "has_cross_attention") and down_block.has_cross_attention: |
|
for attn_i, attn in enumerate(down_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'down_{down_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra) |
|
|
|
if hasattr(unet.mid_block, "has_cross_attention") and unet.mid_block.has_cross_attention: |
|
for attn_i, attn in enumerate(unet.mid_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'mid_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra) |
|
|
|
for up_block_i, up_block in enumerate(unet.up_blocks): |
|
if hasattr(up_block, "has_cross_attention") and up_block.has_cross_attention: |
|
for attn_i, attn in enumerate(up_block.attentions): |
|
for transformer_i, transformer in enumerate(attn.transformer_blocks): |
|
if isinstance(transformer, BasicTransformerBlock): |
|
attn.transformer_blocks[transformer_i] = Basic2p5DTransformerBlock(transformer, f'up_{up_block_i}_{attn_i}_{transformer_i}',use_ipa,use_ma,use_ra) |
|
|
|
|
|
def __getattr__(self, name: str): |
|
try: |
|
return super().__getattr__(name) |
|
except AttributeError: |
|
return getattr(self.unet, name) |
|
|
|
def forward( |
|
self, sample, timestep, encoder_hidden_states, class_labels=None, |
|
*args, cross_attention_kwargs=None, down_intrablock_additional_residuals=None, |
|
down_block_res_samples=None, mid_block_res_sample=None, |
|
**cached_condition, |
|
): |
|
B, N_gen, _, H, W = sample.shape |
|
camera_info_gen = cached_condition['camera_info_gen'] + self.max_num_ref_image |
|
camera_info_gen = rearrange(camera_info_gen, 'b n -> (b n)') |
|
sample = [sample] |
|
|
|
if 'normal_imgs' in cached_condition: |
|
sample.append(cached_condition["normal_imgs"]) |
|
if 'position_imgs' in cached_condition: |
|
sample.append(cached_condition["position_imgs"]) |
|
|
|
sample = torch.cat(sample, dim=2) |
|
sample = rearrange(sample, 'b n c h w -> (b n) c h w') |
|
|
|
encoder_hidden_states_gen = encoder_hidden_states.unsqueeze(1).repeat(1, N_gen, 1, 1) |
|
encoder_hidden_states_gen = rearrange(encoder_hidden_states_gen, 'b n l c -> (b n) l c') |
|
|
|
|
|
use_position_mask = False |
|
use_position_rope = True |
|
|
|
position_attn_mask = None |
|
if use_position_mask: |
|
if 'position_attn_mask' in cached_condition: |
|
position_attn_mask = cached_condition['position_attn_mask'] |
|
else: |
|
if 'position_maps' in cached_condition: |
|
position_attn_mask = compute_multi_resolution_mask(cached_condition['position_maps']) |
|
|
|
position_voxel_indices = None |
|
if use_position_rope: |
|
if 'position_voxel_indices' in cached_condition: |
|
position_voxel_indices = cached_condition['position_voxel_indices'] |
|
else: |
|
if 'position_maps' in cached_condition: |
|
position_voxel_indices = compute_multi_resolution_discrete_voxel_indice(cached_condition['position_maps']) |
|
|
|
if 'ip_hidden_states' in cached_condition: |
|
ip_hidden_states = cached_condition['ip_hidden_states'] |
|
else: |
|
if 'clip_embeds' in cached_condition: |
|
ip_hidden_states = self.image_proj_model(cached_condition['clip_embeds']) |
|
else: |
|
ip_hidden_states = None |
|
cached_condition['ip_hidden_states'] = ip_hidden_states |
|
|
|
if 'condition_embed_dict' in cached_condition: |
|
condition_embed_dict = cached_condition['condition_embed_dict'] |
|
else: |
|
condition_embed_dict = {} |
|
ref_latents = cached_condition['ref_latents'] |
|
N_ref = ref_latents.shape[1] |
|
camera_info_ref = cached_condition['camera_info_ref'] |
|
camera_info_ref = rearrange(camera_info_ref, 'b n -> (b n)') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ref_latents = rearrange(ref_latents, 'b n c h w -> (b n) c h w') |
|
|
|
encoder_hidden_states_ref = self.learned_text_clip_ref.unsqueeze(1).repeat(B, N_ref, 1, 1) |
|
encoder_hidden_states_ref = rearrange(encoder_hidden_states_ref, 'b n l c -> (b n) l c') |
|
|
|
noisy_ref_latents = ref_latents |
|
timestep_ref = 0 |
|
''' |
|
if timestep.dim()>0: |
|
timestep_ref = rearrange(timestep, '(b n) -> b n', b=B)[:,:1].repeat(1, N_ref) |
|
timestep_ref = rearrange(timestep_ref, 'b n -> (b n)') |
|
else: |
|
timestep_ref = timestep |
|
noise = torch.randn_like(noisy_ref_latents[:,:4,...]) |
|
if self.training: |
|
noisy_ref_latents[:,:4,...] = self.train_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref) |
|
noisy_ref_latents[:,:4,...] = self.train_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref) |
|
else: |
|
noisy_ref_latents[:,:4,...] = self.val_sched.add_noise(noisy_ref_latents[:,:4,...], noise, timestep_ref.reshape(-1)) |
|
noisy_ref_latents[:,:4,...] = self.val_sched.scale_model_input(noisy_ref_latents[:,:4,...], timestep_ref.reshape(-1)) |
|
''' |
|
self.unet_dual( |
|
noisy_ref_latents, timestep_ref, |
|
encoder_hidden_states=encoder_hidden_states_ref, |
|
|
|
|
|
return_dict=False, |
|
cross_attention_kwargs={ |
|
'mode':'w', 'num_in_batch':N_ref, |
|
'condition_embed_dict':condition_embed_dict}, |
|
) |
|
cached_condition['condition_embed_dict'] = condition_embed_dict |
|
|
|
return self.unet( |
|
sample, timestep, |
|
encoder_hidden_states_gen, *args, |
|
class_labels=camera_info_gen, |
|
down_intrablock_additional_residuals=[ |
|
sample.to(dtype=self.unet.dtype) for sample in down_intrablock_additional_residuals |
|
] if down_intrablock_additional_residuals is not None else None, |
|
down_block_additional_residuals=[ |
|
sample.to(dtype=self.unet.dtype) for sample in down_block_res_samples |
|
] if down_block_res_samples is not None else None, |
|
mid_block_additional_residual=( |
|
mid_block_res_sample.to(dtype=self.unet.dtype) |
|
if mid_block_res_sample is not None else None |
|
), |
|
return_dict=False, |
|
cross_attention_kwargs={ |
|
'mode':'r', 'num_in_batch':N_gen, |
|
'ip_hidden_states':ip_hidden_states, |
|
'condition_embed_dict':condition_embed_dict, |
|
'position_attn_mask':position_attn_mask, |
|
'position_voxel_indices':position_voxel_indices |
|
}, |
|
) |
|
|