D1-3105 commited on
Commit
729964f
·
verified ·
1 Parent(s): bd72f29

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ teasers/0.png filter=lfs diff=lfs merge=lfs -text
37
+ teasers/1.png filter=lfs diff=lfs merge=lfs -text
38
+ ip-adapter.safetensors filter=lfs diff=lfs merge=lfs -text
39
+ assets/1.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/2.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: stabilityai-ai-community
4
+ license_link: >-
5
+ https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md
6
+ language:
7
+ - en
8
+ library_name: diffusers
9
+ pipeline_tag: text-to-image
10
+ tags:
11
+ - Text-to-Image
12
+ - IP-Adapter
13
+ - StableDiffusion3Pipeline
14
+ - image-generation
15
+ - Stable Diffusion
16
+ base_model:
17
+ - stabilityai/stable-diffusion-3.5-large
18
+ ---
19
+
20
+ # SD3.5-Large-IP-Adapter
21
+
22
+ This repository contains a IP-Adapter for SD3.5-Large model released by researchers from [InstantX Team](https://huggingface.co/InstantX), where image work just like text, so it may not be responsive or interfere with other text, but we do hope you enjoy this model, have fun and share your creative works with us [on Twitter](https://x.com/instantx_ai).
23
+
24
+ # Model Card
25
+ This is a regular IP-Adapter, where the new layers are added into all 38 blocks. We use [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) to encode image for its superior performance, and adopt a TimeResampler to project. The image token number is set to 64.
26
+
27
+ # Showcases
28
+
29
+ <div class="container">
30
+ <img src="./teasers/0.png" width="1024"/>
31
+ <img src="./teasers/1.png" width="1024"/>
32
+ </div>
33
+
34
+ # Inference
35
+ The code has not been integrated into diffusers yet, please use our local files at this moment.
36
+ ```python
37
+ import torch
38
+ from PIL import Image
39
+
40
+ from models.transformer_sd3 import SD3Transformer2DModel
41
+ from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
42
+
43
+ model_path = 'stabilityai/stable-diffusion-3.5-large'
44
+ ip_adapter_path = './ip-adapter.bin'
45
+ image_encoder_path = "google/siglip-so400m-patch14-384"
46
+
47
+ transformer = SD3Transformer2DModel.from_pretrained(
48
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
49
+ )
50
+
51
+ pipe = StableDiffusion3Pipeline.from_pretrained(
52
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
53
+ ).to("cuda")
54
+
55
+ pipe.init_ipadapter(
56
+ ip_adapter_path=ip_adapter_path,
57
+ image_encoder_path=image_encoder_path,
58
+ nb_token=64,
59
+ )
60
+
61
+ ref_img = Image.open('./assets/1.jpg').convert('RGB')
62
+
63
+ # please note that SD3.5 Large is sensitive to highres generation like 1536x1536
64
+ image = pipe(
65
+ width=1024,
66
+ height=1024,
67
+ prompt='a cat',
68
+ negative_prompt="lowres, low quality, worst quality",
69
+ num_inference_steps=24,
70
+ guidance_scale=5.0,
71
+ generator=torch.Generator("cuda").manual_seed(42),
72
+ clip_image=ref_img,
73
+ ipadapter_scale=0.5,
74
+ ).images[0]
75
+ image.save('./result.jpg')
76
+ ```
77
+
78
+ # Community ComfyUI Support
79
+ Please refer to [Slickytail/ComfyUI-InstantX-IPAdapter-SD3](https://github.com/Slickytail/ComfyUI-InstantX-IPAdapter-SD3).
80
+
81
+
82
+ # License
83
+ The model is released under [stabilityai-ai-community](https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md). All copyright reserved.
84
+
85
+ # Acknowledgements
86
+ This project is sponsored by [HuggingFace](https://huggingface.co/) and [fal.ai](https://fal.ai/). Thanks to [Slickytail](https://github.com/Slickytail) for supporting ComfyUI node.
87
+
88
+ # Citation
89
+ If you find this project useful in your research, please cite us via
90
+ ```
91
+ @misc{sd35-large-ipa,
92
+ author = {InstantX Team},
93
+ title = {InstantX SD3.5-Large IP-Adapter Page},
94
+ year = {2024},
95
+ }
96
+ ```
assets/1.jpg ADDED

Git LFS Details

  • SHA256: e03aa4626b895b3182b2d1d635c16b06b56086a0fb59952977d78cd4b697f1f8
  • Pointer size: 131 Bytes
  • Size of remote file: 973 kB
assets/2.jpg ADDED

Git LFS Details

  • SHA256: 14d5615e2c52e2b24689eb0a95941592a885fcc5270561e23b01ccfd2d91290c
  • Pointer size: 131 Bytes
  • Size of remote file: 655 kB
infer_sd35_large_ipa.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+
4
+ from models.transformer_sd3 import SD3Transformer2DModel
5
+ from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
6
+
7
+
8
+ if __name__ == '__main__':
9
+
10
+ model_path = 'stabilityai/stable-diffusion-3.5-large'
11
+ ip_adapter_path = './ip-adapter.bin'
12
+ image_encoder_path = "google/siglip-so400m-patch14-384"
13
+
14
+ transformer = SD3Transformer2DModel.from_pretrained(
15
+ model_path, subfolder="transformer", torch_dtype=torch.bfloat16
16
+ )
17
+
18
+ pipe = StableDiffusion3Pipeline.from_pretrained(
19
+ model_path, transformer=transformer, torch_dtype=torch.bfloat16
20
+ ).to("cuda")
21
+
22
+ pipe.init_ipadapter(
23
+ ip_adapter_path=ip_adapter_path,
24
+ image_encoder_path=image_encoder_path,
25
+ nb_token=64,
26
+ )
27
+
28
+ ref_img = Image.open('./assets/1.jpg').convert('RGB')
29
+ image = pipe(
30
+ width=1024,
31
+ height=1024,
32
+ prompt='a cat',
33
+ negative_prompt="lowres, low quality, worst quality",
34
+ num_inference_steps=24,
35
+ guidance_scale=5.0,
36
+ generator=torch.Generator("cuda").manual_seed(42),
37
+ clip_image=ref_img,
38
+ ipadapter_scale=0.5,
39
+ ).images[0]
40
+ image.save('./result.jpg')
ip-adapter.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9fe54774aa528e712d9145ff6a59dd93b1fcf1d5935304feffd980ae6d42ae03
3
+ size 1595970439
ip-adapter.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c6d90e1e9efbdc9db81b28420a9a5e4d3a0d6f7e9ef9eed013825f54d3239ac
3
+ size 1372601256
models/__init__.py ADDED
File without changes
models/attention.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+ from diffusers.utils import deprecate, logging
21
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
22
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
+ from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
32
+ # "feed_forward_chunk_size" can be used to save memory
33
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
34
+ raise ValueError(
35
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
36
+ )
37
+
38
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ return ff_output
44
+
45
+
46
+ @maybe_allow_in_graph
47
+ class GatedSelfAttentionDense(nn.Module):
48
+ r"""
49
+ A gated self-attention dense layer that combines visual features and object features.
50
+
51
+ Parameters:
52
+ query_dim (`int`): The number of channels in the query.
53
+ context_dim (`int`): The number of channels in the context.
54
+ n_heads (`int`): The number of heads to use for attention.
55
+ d_head (`int`): The number of channels in each head.
56
+ """
57
+
58
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
59
+ super().__init__()
60
+
61
+ # we need a linear projection since we need cat visual feature and obj feature
62
+ self.linear = nn.Linear(context_dim, query_dim)
63
+
64
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
65
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
66
+
67
+ self.norm1 = nn.LayerNorm(query_dim)
68
+ self.norm2 = nn.LayerNorm(query_dim)
69
+
70
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
71
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
72
+
73
+ self.enabled = True
74
+
75
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
76
+ if not self.enabled:
77
+ return x
78
+
79
+ n_visual = x.shape[1]
80
+ objs = self.linear(objs)
81
+
82
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
83
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
84
+
85
+ return x
86
+
87
+
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ dim: int,
106
+ num_attention_heads: int,
107
+ attention_head_dim: int,
108
+ context_pre_only: bool = False,
109
+ qk_norm: Optional[str] = None,
110
+ use_dual_attention: bool = False,
111
+ ):
112
+ super().__init__()
113
+
114
+ self.use_dual_attention = use_dual_attention
115
+ self.context_pre_only = context_pre_only
116
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
117
+
118
+ if use_dual_attention:
119
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
120
+ else:
121
+ self.norm1 = AdaLayerNormZero(dim)
122
+
123
+ if context_norm_type == "ada_norm_continous":
124
+ self.norm1_context = AdaLayerNormContinuous(
125
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
126
+ )
127
+ elif context_norm_type == "ada_norm_zero":
128
+ self.norm1_context = AdaLayerNormZero(dim)
129
+ else:
130
+ raise ValueError(
131
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
132
+ )
133
+
134
+ if hasattr(F, "scaled_dot_product_attention"):
135
+ processor = JointAttnProcessor2_0()
136
+ else:
137
+ raise ValueError(
138
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
139
+ )
140
+
141
+ self.attn = Attention(
142
+ query_dim=dim,
143
+ cross_attention_dim=None,
144
+ added_kv_proj_dim=dim,
145
+ dim_head=attention_head_dim,
146
+ heads=num_attention_heads,
147
+ out_dim=dim,
148
+ context_pre_only=context_pre_only,
149
+ bias=True,
150
+ processor=processor,
151
+ qk_norm=qk_norm,
152
+ eps=1e-6,
153
+ )
154
+
155
+ if use_dual_attention:
156
+ self.attn2 = Attention(
157
+ query_dim=dim,
158
+ cross_attention_dim=None,
159
+ dim_head=attention_head_dim,
160
+ heads=num_attention_heads,
161
+ out_dim=dim,
162
+ bias=True,
163
+ processor=processor,
164
+ qk_norm=qk_norm,
165
+ eps=1e-6,
166
+ )
167
+ else:
168
+ self.attn2 = None
169
+
170
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
171
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
172
+
173
+ if not context_pre_only:
174
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
175
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
176
+ else:
177
+ self.norm2_context = None
178
+ self.ff_context = None
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
185
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
186
+ # Sets chunk feed-forward
187
+ self._chunk_size = chunk_size
188
+ self._chunk_dim = dim
189
+
190
+ def forward(
191
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
192
+ joint_attention_kwargs=None,
193
+ ):
194
+ if self.use_dual_attention:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
196
+ hidden_states, emb=temb
197
+ )
198
+ else:
199
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
200
+
201
+ if self.context_pre_only:
202
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
203
+ else:
204
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
205
+ encoder_hidden_states, emb=temb
206
+ )
207
+
208
+ # Attention.
209
+ attn_output, context_attn_output = self.attn(
210
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
211
+ **({} if joint_attention_kwargs is None else joint_attention_kwargs),
212
+ )
213
+
214
+ # Process attention outputs for the `hidden_states`.
215
+ attn_output = gate_msa.unsqueeze(1) * attn_output
216
+ hidden_states = hidden_states + attn_output
217
+
218
+ if self.use_dual_attention:
219
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
220
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
221
+ hidden_states = hidden_states + attn_output2
222
+
223
+ norm_hidden_states = self.norm2(hidden_states)
224
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
225
+ if self._chunk_size is not None:
226
+ # "feed_forward_chunk_size" can be used to save memory
227
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
228
+ else:
229
+ ff_output = self.ff(norm_hidden_states)
230
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
231
+
232
+ hidden_states = hidden_states + ff_output
233
+
234
+ # Process attention outputs for the `encoder_hidden_states`.
235
+ if self.context_pre_only:
236
+ encoder_hidden_states = None
237
+ else:
238
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
239
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
240
+
241
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
242
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
243
+ if self._chunk_size is not None:
244
+ # "feed_forward_chunk_size" can be used to save memory
245
+ context_ff_output = _chunked_feed_forward(
246
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
247
+ )
248
+ else:
249
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
250
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
251
+
252
+ return encoder_hidden_states, hidden_states
253
+
254
+
255
+ @maybe_allow_in_graph
256
+ class BasicTransformerBlock(nn.Module):
257
+ r"""
258
+ A basic Transformer block.
259
+
260
+ Parameters:
261
+ dim (`int`): The number of channels in the input and output.
262
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
263
+ attention_head_dim (`int`): The number of channels in each head.
264
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
265
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
266
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
267
+ num_embeds_ada_norm (:
268
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
269
+ attention_bias (:
270
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
271
+ only_cross_attention (`bool`, *optional*):
272
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
273
+ double_self_attention (`bool`, *optional*):
274
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
275
+ upcast_attention (`bool`, *optional*):
276
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
277
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
278
+ Whether to use learnable elementwise affine parameters for normalization.
279
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
280
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
281
+ final_dropout (`bool` *optional*, defaults to False):
282
+ Whether to apply a final dropout after the last feed-forward layer.
283
+ attention_type (`str`, *optional*, defaults to `"default"`):
284
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
285
+ positional_embeddings (`str`, *optional*, defaults to `None`):
286
+ The type of positional embeddings to apply to.
287
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
288
+ The maximum number of positional embeddings to apply.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ dim: int,
294
+ num_attention_heads: int,
295
+ attention_head_dim: int,
296
+ dropout=0.0,
297
+ cross_attention_dim: Optional[int] = None,
298
+ activation_fn: str = "geglu",
299
+ num_embeds_ada_norm: Optional[int] = None,
300
+ attention_bias: bool = False,
301
+ only_cross_attention: bool = False,
302
+ double_self_attention: bool = False,
303
+ upcast_attention: bool = False,
304
+ norm_elementwise_affine: bool = True,
305
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
306
+ norm_eps: float = 1e-5,
307
+ final_dropout: bool = False,
308
+ attention_type: str = "default",
309
+ positional_embeddings: Optional[str] = None,
310
+ num_positional_embeddings: Optional[int] = None,
311
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
312
+ ada_norm_bias: Optional[int] = None,
313
+ ff_inner_dim: Optional[int] = None,
314
+ ff_bias: bool = True,
315
+ attention_out_bias: bool = True,
316
+ ):
317
+ super().__init__()
318
+ self.dim = dim
319
+ self.num_attention_heads = num_attention_heads
320
+ self.attention_head_dim = attention_head_dim
321
+ self.dropout = dropout
322
+ self.cross_attention_dim = cross_attention_dim
323
+ self.activation_fn = activation_fn
324
+ self.attention_bias = attention_bias
325
+ self.double_self_attention = double_self_attention
326
+ self.norm_elementwise_affine = norm_elementwise_affine
327
+ self.positional_embeddings = positional_embeddings
328
+ self.num_positional_embeddings = num_positional_embeddings
329
+ self.only_cross_attention = only_cross_attention
330
+
331
+ # We keep these boolean flags for backward-compatibility.
332
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
333
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
334
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
335
+ self.use_layer_norm = norm_type == "layer_norm"
336
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
337
+
338
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
339
+ raise ValueError(
340
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
341
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
342
+ )
343
+
344
+ self.norm_type = norm_type
345
+ self.num_embeds_ada_norm = num_embeds_ada_norm
346
+
347
+ if positional_embeddings and (num_positional_embeddings is None):
348
+ raise ValueError(
349
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
350
+ )
351
+
352
+ if positional_embeddings == "sinusoidal":
353
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
354
+ else:
355
+ self.pos_embed = None
356
+
357
+ # Define 3 blocks. Each block has its own normalization layer.
358
+ # 1. Self-Attn
359
+ if norm_type == "ada_norm":
360
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
361
+ elif norm_type == "ada_norm_zero":
362
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
363
+ elif norm_type == "ada_norm_continuous":
364
+ self.norm1 = AdaLayerNormContinuous(
365
+ dim,
366
+ ada_norm_continous_conditioning_embedding_dim,
367
+ norm_elementwise_affine,
368
+ norm_eps,
369
+ ada_norm_bias,
370
+ "rms_norm",
371
+ )
372
+ else:
373
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
374
+
375
+ self.attn1 = Attention(
376
+ query_dim=dim,
377
+ heads=num_attention_heads,
378
+ dim_head=attention_head_dim,
379
+ dropout=dropout,
380
+ bias=attention_bias,
381
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
382
+ upcast_attention=upcast_attention,
383
+ out_bias=attention_out_bias,
384
+ )
385
+
386
+ # 2. Cross-Attn
387
+ if cross_attention_dim is not None or double_self_attention:
388
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
389
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
390
+ # the second cross attention block.
391
+ if norm_type == "ada_norm":
392
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
393
+ elif norm_type == "ada_norm_continuous":
394
+ self.norm2 = AdaLayerNormContinuous(
395
+ dim,
396
+ ada_norm_continous_conditioning_embedding_dim,
397
+ norm_elementwise_affine,
398
+ norm_eps,
399
+ ada_norm_bias,
400
+ "rms_norm",
401
+ )
402
+ else:
403
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
404
+
405
+ self.attn2 = Attention(
406
+ query_dim=dim,
407
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
408
+ heads=num_attention_heads,
409
+ dim_head=attention_head_dim,
410
+ dropout=dropout,
411
+ bias=attention_bias,
412
+ upcast_attention=upcast_attention,
413
+ out_bias=attention_out_bias,
414
+ ) # is self-attn if encoder_hidden_states is none
415
+ else:
416
+ if norm_type == "ada_norm_single": # For Latte
417
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
418
+ else:
419
+ self.norm2 = None
420
+ self.attn2 = None
421
+
422
+ # 3. Feed-forward
423
+ if norm_type == "ada_norm_continuous":
424
+ self.norm3 = AdaLayerNormContinuous(
425
+ dim,
426
+ ada_norm_continous_conditioning_embedding_dim,
427
+ norm_elementwise_affine,
428
+ norm_eps,
429
+ ada_norm_bias,
430
+ "layer_norm",
431
+ )
432
+
433
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
434
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
435
+ elif norm_type == "layer_norm_i2vgen":
436
+ self.norm3 = None
437
+
438
+ self.ff = FeedForward(
439
+ dim,
440
+ dropout=dropout,
441
+ activation_fn=activation_fn,
442
+ final_dropout=final_dropout,
443
+ inner_dim=ff_inner_dim,
444
+ bias=ff_bias,
445
+ )
446
+
447
+ # 4. Fuser
448
+ if attention_type == "gated" or attention_type == "gated-text-image":
449
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
450
+
451
+ # 5. Scale-shift for PixArt-Alpha.
452
+ if norm_type == "ada_norm_single":
453
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
454
+
455
+ # let chunk size default to None
456
+ self._chunk_size = None
457
+ self._chunk_dim = 0
458
+
459
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
460
+ # Sets chunk feed-forward
461
+ self._chunk_size = chunk_size
462
+ self._chunk_dim = dim
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states: torch.Tensor,
467
+ attention_mask: Optional[torch.Tensor] = None,
468
+ encoder_hidden_states: Optional[torch.Tensor] = None,
469
+ encoder_attention_mask: Optional[torch.Tensor] = None,
470
+ timestep: Optional[torch.LongTensor] = None,
471
+ cross_attention_kwargs: Dict[str, Any] = None,
472
+ class_labels: Optional[torch.LongTensor] = None,
473
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
474
+ ) -> torch.Tensor:
475
+ if cross_attention_kwargs is not None:
476
+ if cross_attention_kwargs.get("scale", None) is not None:
477
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
478
+
479
+ # Notice that normalization is always applied before the real computation in the following blocks.
480
+ # 0. Self-Attention
481
+ batch_size = hidden_states.shape[0]
482
+
483
+ if self.norm_type == "ada_norm":
484
+ norm_hidden_states = self.norm1(hidden_states, timestep)
485
+ elif self.norm_type == "ada_norm_zero":
486
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
487
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
488
+ )
489
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
490
+ norm_hidden_states = self.norm1(hidden_states)
491
+ elif self.norm_type == "ada_norm_continuous":
492
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
493
+ elif self.norm_type == "ada_norm_single":
494
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
495
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
496
+ ).chunk(6, dim=1)
497
+ norm_hidden_states = self.norm1(hidden_states)
498
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
499
+ else:
500
+ raise ValueError("Incorrect norm used")
501
+
502
+ if self.pos_embed is not None:
503
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
504
+
505
+ # 1. Prepare GLIGEN inputs
506
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
507
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
508
+
509
+ attn_output = self.attn1(
510
+ norm_hidden_states,
511
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
512
+ attention_mask=attention_mask,
513
+ **cross_attention_kwargs,
514
+ )
515
+
516
+ if self.norm_type == "ada_norm_zero":
517
+ attn_output = gate_msa.unsqueeze(1) * attn_output
518
+ elif self.norm_type == "ada_norm_single":
519
+ attn_output = gate_msa * attn_output
520
+
521
+ hidden_states = attn_output + hidden_states
522
+ if hidden_states.ndim == 4:
523
+ hidden_states = hidden_states.squeeze(1)
524
+
525
+ # 1.2 GLIGEN Control
526
+ if gligen_kwargs is not None:
527
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
528
+
529
+ # 3. Cross-Attention
530
+ if self.attn2 is not None:
531
+ if self.norm_type == "ada_norm":
532
+ norm_hidden_states = self.norm2(hidden_states, timestep)
533
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
534
+ norm_hidden_states = self.norm2(hidden_states)
535
+ elif self.norm_type == "ada_norm_single":
536
+ # For PixArt norm2 isn't applied here:
537
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
538
+ norm_hidden_states = hidden_states
539
+ elif self.norm_type == "ada_norm_continuous":
540
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
541
+ else:
542
+ raise ValueError("Incorrect norm")
543
+
544
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
545
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
546
+
547
+ attn_output = self.attn2(
548
+ norm_hidden_states,
549
+ encoder_hidden_states=encoder_hidden_states,
550
+ attention_mask=encoder_attention_mask,
551
+ **cross_attention_kwargs,
552
+ )
553
+ hidden_states = attn_output + hidden_states
554
+
555
+ # 4. Feed-forward
556
+ # i2vgen doesn't have this norm 🤷‍♂️
557
+ if self.norm_type == "ada_norm_continuous":
558
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
559
+ elif not self.norm_type == "ada_norm_single":
560
+ norm_hidden_states = self.norm3(hidden_states)
561
+
562
+ if self.norm_type == "ada_norm_zero":
563
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
564
+
565
+ if self.norm_type == "ada_norm_single":
566
+ norm_hidden_states = self.norm2(hidden_states)
567
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
568
+
569
+ if self._chunk_size is not None:
570
+ # "feed_forward_chunk_size" can be used to save memory
571
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
572
+ else:
573
+ ff_output = self.ff(norm_hidden_states)
574
+
575
+ if self.norm_type == "ada_norm_zero":
576
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
577
+ elif self.norm_type == "ada_norm_single":
578
+ ff_output = gate_mlp * ff_output
579
+
580
+ hidden_states = ff_output + hidden_states
581
+ if hidden_states.ndim == 4:
582
+ hidden_states = hidden_states.squeeze(1)
583
+
584
+ return hidden_states
585
+
586
+
587
+ class LuminaFeedForward(nn.Module):
588
+ r"""
589
+ A feed-forward layer.
590
+
591
+ Parameters:
592
+ hidden_size (`int`):
593
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
594
+ hidden representations.
595
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
596
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
597
+ of this value.
598
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
599
+ dimension. Defaults to None.
600
+ """
601
+
602
+ def __init__(
603
+ self,
604
+ dim: int,
605
+ inner_dim: int,
606
+ multiple_of: Optional[int] = 256,
607
+ ffn_dim_multiplier: Optional[float] = None,
608
+ ):
609
+ super().__init__()
610
+ inner_dim = int(2 * inner_dim / 3)
611
+ # custom hidden_size factor multiplier
612
+ if ffn_dim_multiplier is not None:
613
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
614
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
615
+
616
+ self.linear_1 = nn.Linear(
617
+ dim,
618
+ inner_dim,
619
+ bias=False,
620
+ )
621
+ self.linear_2 = nn.Linear(
622
+ inner_dim,
623
+ dim,
624
+ bias=False,
625
+ )
626
+ self.linear_3 = nn.Linear(
627
+ dim,
628
+ inner_dim,
629
+ bias=False,
630
+ )
631
+ self.silu = FP32SiLU()
632
+
633
+ def forward(self, x):
634
+ return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
635
+
636
+
637
+ @maybe_allow_in_graph
638
+ class TemporalBasicTransformerBlock(nn.Module):
639
+ r"""
640
+ A basic Transformer block for video like data.
641
+
642
+ Parameters:
643
+ dim (`int`): The number of channels in the input and output.
644
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
645
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
646
+ attention_head_dim (`int`): The number of channels in each head.
647
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
648
+ """
649
+
650
+ def __init__(
651
+ self,
652
+ dim: int,
653
+ time_mix_inner_dim: int,
654
+ num_attention_heads: int,
655
+ attention_head_dim: int,
656
+ cross_attention_dim: Optional[int] = None,
657
+ ):
658
+ super().__init__()
659
+ self.is_res = dim == time_mix_inner_dim
660
+
661
+ self.norm_in = nn.LayerNorm(dim)
662
+
663
+ # Define 3 blocks. Each block has its own normalization layer.
664
+ # 1. Self-Attn
665
+ self.ff_in = FeedForward(
666
+ dim,
667
+ dim_out=time_mix_inner_dim,
668
+ activation_fn="geglu",
669
+ )
670
+
671
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
672
+ self.attn1 = Attention(
673
+ query_dim=time_mix_inner_dim,
674
+ heads=num_attention_heads,
675
+ dim_head=attention_head_dim,
676
+ cross_attention_dim=None,
677
+ )
678
+
679
+ # 2. Cross-Attn
680
+ if cross_attention_dim is not None:
681
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
682
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
683
+ # the second cross attention block.
684
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
685
+ self.attn2 = Attention(
686
+ query_dim=time_mix_inner_dim,
687
+ cross_attention_dim=cross_attention_dim,
688
+ heads=num_attention_heads,
689
+ dim_head=attention_head_dim,
690
+ ) # is self-attn if encoder_hidden_states is none
691
+ else:
692
+ self.norm2 = None
693
+ self.attn2 = None
694
+
695
+ # 3. Feed-forward
696
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
697
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
698
+
699
+ # let chunk size default to None
700
+ self._chunk_size = None
701
+ self._chunk_dim = None
702
+
703
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
704
+ # Sets chunk feed-forward
705
+ self._chunk_size = chunk_size
706
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
707
+ self._chunk_dim = 1
708
+
709
+ def forward(
710
+ self,
711
+ hidden_states: torch.Tensor,
712
+ num_frames: int,
713
+ encoder_hidden_states: Optional[torch.Tensor] = None,
714
+ ) -> torch.Tensor:
715
+ # Notice that normalization is always applied before the real computation in the following blocks.
716
+ # 0. Self-Attention
717
+ batch_size = hidden_states.shape[0]
718
+
719
+ batch_frames, seq_length, channels = hidden_states.shape
720
+ batch_size = batch_frames // num_frames
721
+
722
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
723
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
724
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
725
+
726
+ residual = hidden_states
727
+ hidden_states = self.norm_in(hidden_states)
728
+
729
+ if self._chunk_size is not None:
730
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
731
+ else:
732
+ hidden_states = self.ff_in(hidden_states)
733
+
734
+ if self.is_res:
735
+ hidden_states = hidden_states + residual
736
+
737
+ norm_hidden_states = self.norm1(hidden_states)
738
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
739
+ hidden_states = attn_output + hidden_states
740
+
741
+ # 3. Cross-Attention
742
+ if self.attn2 is not None:
743
+ norm_hidden_states = self.norm2(hidden_states)
744
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
745
+ hidden_states = attn_output + hidden_states
746
+
747
+ # 4. Feed-forward
748
+ norm_hidden_states = self.norm3(hidden_states)
749
+
750
+ if self._chunk_size is not None:
751
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
752
+ else:
753
+ ff_output = self.ff(norm_hidden_states)
754
+
755
+ if self.is_res:
756
+ hidden_states = ff_output + hidden_states
757
+ else:
758
+ hidden_states = ff_output
759
+
760
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
761
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
762
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
763
+
764
+ return hidden_states
765
+
766
+
767
+ class SkipFFTransformerBlock(nn.Module):
768
+ def __init__(
769
+ self,
770
+ dim: int,
771
+ num_attention_heads: int,
772
+ attention_head_dim: int,
773
+ kv_input_dim: int,
774
+ kv_input_dim_proj_use_bias: bool,
775
+ dropout=0.0,
776
+ cross_attention_dim: Optional[int] = None,
777
+ attention_bias: bool = False,
778
+ attention_out_bias: bool = True,
779
+ ):
780
+ super().__init__()
781
+ if kv_input_dim != dim:
782
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
783
+ else:
784
+ self.kv_mapper = None
785
+
786
+ self.norm1 = RMSNorm(dim, 1e-06)
787
+
788
+ self.attn1 = Attention(
789
+ query_dim=dim,
790
+ heads=num_attention_heads,
791
+ dim_head=attention_head_dim,
792
+ dropout=dropout,
793
+ bias=attention_bias,
794
+ cross_attention_dim=cross_attention_dim,
795
+ out_bias=attention_out_bias,
796
+ )
797
+
798
+ self.norm2 = RMSNorm(dim, 1e-06)
799
+
800
+ self.attn2 = Attention(
801
+ query_dim=dim,
802
+ cross_attention_dim=cross_attention_dim,
803
+ heads=num_attention_heads,
804
+ dim_head=attention_head_dim,
805
+ dropout=dropout,
806
+ bias=attention_bias,
807
+ out_bias=attention_out_bias,
808
+ )
809
+
810
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
811
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
812
+
813
+ if self.kv_mapper is not None:
814
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
815
+
816
+ norm_hidden_states = self.norm1(hidden_states)
817
+
818
+ attn_output = self.attn1(
819
+ norm_hidden_states,
820
+ encoder_hidden_states=encoder_hidden_states,
821
+ **cross_attention_kwargs,
822
+ )
823
+
824
+ hidden_states = attn_output + hidden_states
825
+
826
+ norm_hidden_states = self.norm2(hidden_states)
827
+
828
+ attn_output = self.attn2(
829
+ norm_hidden_states,
830
+ encoder_hidden_states=encoder_hidden_states,
831
+ **cross_attention_kwargs,
832
+ )
833
+
834
+ hidden_states = attn_output + hidden_states
835
+
836
+ return hidden_states
837
+
838
+
839
+ @maybe_allow_in_graph
840
+ class FreeNoiseTransformerBlock(nn.Module):
841
+ r"""
842
+ A FreeNoise Transformer block.
843
+
844
+ Parameters:
845
+ dim (`int`):
846
+ The number of channels in the input and output.
847
+ num_attention_heads (`int`):
848
+ The number of heads to use for multi-head attention.
849
+ attention_head_dim (`int`):
850
+ The number of channels in each head.
851
+ dropout (`float`, *optional*, defaults to 0.0):
852
+ The dropout probability to use.
853
+ cross_attention_dim (`int`, *optional*):
854
+ The size of the encoder_hidden_states vector for cross attention.
855
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
856
+ Activation function to be used in feed-forward.
857
+ num_embeds_ada_norm (`int`, *optional*):
858
+ The number of diffusion steps used during training. See `Transformer2DModel`.
859
+ attention_bias (`bool`, defaults to `False`):
860
+ Configure if the attentions should contain a bias parameter.
861
+ only_cross_attention (`bool`, defaults to `False`):
862
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
863
+ double_self_attention (`bool`, defaults to `False`):
864
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
865
+ upcast_attention (`bool`, defaults to `False`):
866
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
867
+ norm_elementwise_affine (`bool`, defaults to `True`):
868
+ Whether to use learnable elementwise affine parameters for normalization.
869
+ norm_type (`str`, defaults to `"layer_norm"`):
870
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
871
+ final_dropout (`bool` defaults to `False`):
872
+ Whether to apply a final dropout after the last feed-forward layer.
873
+ attention_type (`str`, defaults to `"default"`):
874
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
875
+ positional_embeddings (`str`, *optional*):
876
+ The type of positional embeddings to apply to.
877
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
878
+ The maximum number of positional embeddings to apply.
879
+ ff_inner_dim (`int`, *optional*):
880
+ Hidden dimension of feed-forward MLP.
881
+ ff_bias (`bool`, defaults to `True`):
882
+ Whether or not to use bias in feed-forward MLP.
883
+ attention_out_bias (`bool`, defaults to `True`):
884
+ Whether or not to use bias in attention output project layer.
885
+ context_length (`int`, defaults to `16`):
886
+ The maximum number of frames that the FreeNoise block processes at once.
887
+ context_stride (`int`, defaults to `4`):
888
+ The number of frames to be skipped before starting to process a new batch of `context_length` frames.
889
+ weighting_scheme (`str`, defaults to `"pyramid"`):
890
+ The weighting scheme to use for weighting averaging of processed latent frames. As described in the
891
+ Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
892
+ used.
893
+ """
894
+
895
+ def __init__(
896
+ self,
897
+ dim: int,
898
+ num_attention_heads: int,
899
+ attention_head_dim: int,
900
+ dropout: float = 0.0,
901
+ cross_attention_dim: Optional[int] = None,
902
+ activation_fn: str = "geglu",
903
+ num_embeds_ada_norm: Optional[int] = None,
904
+ attention_bias: bool = False,
905
+ only_cross_attention: bool = False,
906
+ double_self_attention: bool = False,
907
+ upcast_attention: bool = False,
908
+ norm_elementwise_affine: bool = True,
909
+ norm_type: str = "layer_norm",
910
+ norm_eps: float = 1e-5,
911
+ final_dropout: bool = False,
912
+ positional_embeddings: Optional[str] = None,
913
+ num_positional_embeddings: Optional[int] = None,
914
+ ff_inner_dim: Optional[int] = None,
915
+ ff_bias: bool = True,
916
+ attention_out_bias: bool = True,
917
+ context_length: int = 16,
918
+ context_stride: int = 4,
919
+ weighting_scheme: str = "pyramid",
920
+ ):
921
+ super().__init__()
922
+ self.dim = dim
923
+ self.num_attention_heads = num_attention_heads
924
+ self.attention_head_dim = attention_head_dim
925
+ self.dropout = dropout
926
+ self.cross_attention_dim = cross_attention_dim
927
+ self.activation_fn = activation_fn
928
+ self.attention_bias = attention_bias
929
+ self.double_self_attention = double_self_attention
930
+ self.norm_elementwise_affine = norm_elementwise_affine
931
+ self.positional_embeddings = positional_embeddings
932
+ self.num_positional_embeddings = num_positional_embeddings
933
+ self.only_cross_attention = only_cross_attention
934
+
935
+ self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
936
+
937
+ # We keep these boolean flags for backward-compatibility.
938
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
939
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
940
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
941
+ self.use_layer_norm = norm_type == "layer_norm"
942
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
943
+
944
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
945
+ raise ValueError(
946
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
947
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
948
+ )
949
+
950
+ self.norm_type = norm_type
951
+ self.num_embeds_ada_norm = num_embeds_ada_norm
952
+
953
+ if positional_embeddings and (num_positional_embeddings is None):
954
+ raise ValueError(
955
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
956
+ )
957
+
958
+ if positional_embeddings == "sinusoidal":
959
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
960
+ else:
961
+ self.pos_embed = None
962
+
963
+ # Define 3 blocks. Each block has its own normalization layer.
964
+ # 1. Self-Attn
965
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
966
+
967
+ self.attn1 = Attention(
968
+ query_dim=dim,
969
+ heads=num_attention_heads,
970
+ dim_head=attention_head_dim,
971
+ dropout=dropout,
972
+ bias=attention_bias,
973
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
974
+ upcast_attention=upcast_attention,
975
+ out_bias=attention_out_bias,
976
+ )
977
+
978
+ # 2. Cross-Attn
979
+ if cross_attention_dim is not None or double_self_attention:
980
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
981
+
982
+ self.attn2 = Attention(
983
+ query_dim=dim,
984
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
985
+ heads=num_attention_heads,
986
+ dim_head=attention_head_dim,
987
+ dropout=dropout,
988
+ bias=attention_bias,
989
+ upcast_attention=upcast_attention,
990
+ out_bias=attention_out_bias,
991
+ ) # is self-attn if encoder_hidden_states is none
992
+
993
+ # 3. Feed-forward
994
+ self.ff = FeedForward(
995
+ dim,
996
+ dropout=dropout,
997
+ activation_fn=activation_fn,
998
+ final_dropout=final_dropout,
999
+ inner_dim=ff_inner_dim,
1000
+ bias=ff_bias,
1001
+ )
1002
+
1003
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1004
+
1005
+ # let chunk size default to None
1006
+ self._chunk_size = None
1007
+ self._chunk_dim = 0
1008
+
1009
+ def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1010
+ frame_indices = []
1011
+ for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1012
+ window_start = i
1013
+ window_end = min(num_frames, i + self.context_length)
1014
+ frame_indices.append((window_start, window_end))
1015
+ return frame_indices
1016
+
1017
+ def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1018
+ if weighting_scheme == "flat":
1019
+ weights = [1.0] * num_frames
1020
+
1021
+ elif weighting_scheme == "pyramid":
1022
+ if num_frames % 2 == 0:
1023
+ # num_frames = 4 => [1, 2, 2, 1]
1024
+ mid = num_frames // 2
1025
+ weights = list(range(1, mid + 1))
1026
+ weights = weights + weights[::-1]
1027
+ else:
1028
+ # num_frames = 5 => [1, 2, 3, 2, 1]
1029
+ mid = (num_frames + 1) // 2
1030
+ weights = list(range(1, mid))
1031
+ weights = weights + [mid] + weights[::-1]
1032
+
1033
+ elif weighting_scheme == "delayed_reverse_sawtooth":
1034
+ if num_frames % 2 == 0:
1035
+ # num_frames = 4 => [0.01, 2, 2, 1]
1036
+ mid = num_frames // 2
1037
+ weights = [0.01] * (mid - 1) + [mid]
1038
+ weights = weights + list(range(mid, 0, -1))
1039
+ else:
1040
+ # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1041
+ mid = (num_frames + 1) // 2
1042
+ weights = [0.01] * mid
1043
+ weights = weights + list(range(mid, 0, -1))
1044
+ else:
1045
+ raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1046
+
1047
+ return weights
1048
+
1049
+ def set_free_noise_properties(
1050
+ self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1051
+ ) -> None:
1052
+ self.context_length = context_length
1053
+ self.context_stride = context_stride
1054
+ self.weighting_scheme = weighting_scheme
1055
+
1056
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1057
+ # Sets chunk feed-forward
1058
+ self._chunk_size = chunk_size
1059
+ self._chunk_dim = dim
1060
+
1061
+ def forward(
1062
+ self,
1063
+ hidden_states: torch.Tensor,
1064
+ attention_mask: Optional[torch.Tensor] = None,
1065
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1066
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1067
+ cross_attention_kwargs: Dict[str, Any] = None,
1068
+ *args,
1069
+ **kwargs,
1070
+ ) -> torch.Tensor:
1071
+ if cross_attention_kwargs is not None:
1072
+ if cross_attention_kwargs.get("scale", None) is not None:
1073
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1074
+
1075
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1076
+
1077
+ # hidden_states: [B x H x W, F, C]
1078
+ device = hidden_states.device
1079
+ dtype = hidden_states.dtype
1080
+
1081
+ num_frames = hidden_states.size(1)
1082
+ frame_indices = self._get_frame_indices(num_frames)
1083
+ frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1084
+ frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1085
+ is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1086
+
1087
+ # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1088
+ # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1089
+ # [(0, 16), (4, 20), (8, 24), (10, 26)]
1090
+ if not is_last_frame_batch_complete:
1091
+ if num_frames < self.context_length:
1092
+ raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1093
+ last_frame_batch_length = num_frames - frame_indices[-1][1]
1094
+ frame_indices.append((num_frames - self.context_length, num_frames))
1095
+
1096
+ num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1097
+ accumulated_values = torch.zeros_like(hidden_states)
1098
+
1099
+ for i, (frame_start, frame_end) in enumerate(frame_indices):
1100
+ # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1101
+ # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1102
+ # essentially a non-multiple of `context_length`.
1103
+ weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1104
+ weights *= frame_weights
1105
+
1106
+ hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1107
+
1108
+ # Notice that normalization is always applied before the real computation in the following blocks.
1109
+ # 1. Self-Attention
1110
+ norm_hidden_states = self.norm1(hidden_states_chunk)
1111
+
1112
+ if self.pos_embed is not None:
1113
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1114
+
1115
+ attn_output = self.attn1(
1116
+ norm_hidden_states,
1117
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1118
+ attention_mask=attention_mask,
1119
+ **cross_attention_kwargs,
1120
+ )
1121
+
1122
+ hidden_states_chunk = attn_output + hidden_states_chunk
1123
+ if hidden_states_chunk.ndim == 4:
1124
+ hidden_states_chunk = hidden_states_chunk.squeeze(1)
1125
+
1126
+ # 2. Cross-Attention
1127
+ if self.attn2 is not None:
1128
+ norm_hidden_states = self.norm2(hidden_states_chunk)
1129
+
1130
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1131
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
1132
+
1133
+ attn_output = self.attn2(
1134
+ norm_hidden_states,
1135
+ encoder_hidden_states=encoder_hidden_states,
1136
+ attention_mask=encoder_attention_mask,
1137
+ **cross_attention_kwargs,
1138
+ )
1139
+ hidden_states_chunk = attn_output + hidden_states_chunk
1140
+
1141
+ if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1142
+ accumulated_values[:, -last_frame_batch_length:] += (
1143
+ hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1144
+ )
1145
+ num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1146
+ else:
1147
+ accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1148
+ num_times_accumulated[:, frame_start:frame_end] += weights
1149
+
1150
+ # TODO(aryan): Maybe this could be done in a better way.
1151
+ #
1152
+ # Previously, this was:
1153
+ # hidden_states = torch.where(
1154
+ # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1155
+ # )
1156
+ #
1157
+ # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1158
+ # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1159
+ # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1160
+ # looked into this deeply because other memory optimizations led to more pronounced reductions.
1161
+ hidden_states = torch.cat(
1162
+ [
1163
+ torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1164
+ for accumulated_split, num_times_split in zip(
1165
+ accumulated_values.split(self.context_length, dim=1),
1166
+ num_times_accumulated.split(self.context_length, dim=1),
1167
+ )
1168
+ ],
1169
+ dim=1,
1170
+ ).to(dtype)
1171
+
1172
+ # 3. Feed-forward
1173
+ norm_hidden_states = self.norm3(hidden_states)
1174
+
1175
+ if self._chunk_size is not None:
1176
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1177
+ else:
1178
+ ff_output = self.ff(norm_hidden_states)
1179
+
1180
+ hidden_states = ff_output + hidden_states
1181
+ if hidden_states.ndim == 4:
1182
+ hidden_states = hidden_states.squeeze(1)
1183
+
1184
+ return hidden_states
1185
+
1186
+
1187
+ class FeedForward(nn.Module):
1188
+ r"""
1189
+ A feed-forward layer.
1190
+
1191
+ Parameters:
1192
+ dim (`int`): The number of channels in the input.
1193
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1194
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1195
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1196
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1197
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1198
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1199
+ """
1200
+
1201
+ def __init__(
1202
+ self,
1203
+ dim: int,
1204
+ dim_out: Optional[int] = None,
1205
+ mult: int = 4,
1206
+ dropout: float = 0.0,
1207
+ activation_fn: str = "geglu",
1208
+ final_dropout: bool = False,
1209
+ inner_dim=None,
1210
+ bias: bool = True,
1211
+ ):
1212
+ super().__init__()
1213
+ if inner_dim is None:
1214
+ inner_dim = int(dim * mult)
1215
+ dim_out = dim_out if dim_out is not None else dim
1216
+
1217
+ if activation_fn == "gelu":
1218
+ act_fn = GELU(dim, inner_dim, bias=bias)
1219
+ if activation_fn == "gelu-approximate":
1220
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1221
+ elif activation_fn == "geglu":
1222
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1223
+ elif activation_fn == "geglu-approximate":
1224
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1225
+ elif activation_fn == "swiglu":
1226
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
1227
+
1228
+ self.net = nn.ModuleList([])
1229
+ # project in
1230
+ self.net.append(act_fn)
1231
+ # project dropout
1232
+ self.net.append(nn.Dropout(dropout))
1233
+ # project out
1234
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1235
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1236
+ if final_dropout:
1237
+ self.net.append(nn.Dropout(dropout))
1238
+
1239
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1240
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1241
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1242
+ deprecate("scale", "1.0.0", deprecation_message)
1243
+ for module in self.net:
1244
+ hidden_states = module(hidden_states)
1245
+ return hidden_states
models/resampler.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.models.embeddings import Timesteps, TimestepEmbedding
8
+
9
+ def get_timestep_embedding(
10
+ timesteps: torch.Tensor,
11
+ embedding_dim: int,
12
+ flip_sin_to_cos: bool = False,
13
+ downscale_freq_shift: float = 1,
14
+ scale: float = 1,
15
+ max_period: int = 10000,
16
+ ):
17
+ """
18
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
+
20
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
+ These may be fractional.
22
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
23
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
24
+ """
25
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
26
+
27
+ half_dim = embedding_dim // 2
28
+ exponent = -math.log(max_period) * torch.arange(
29
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
30
+ )
31
+ exponent = exponent / (half_dim - downscale_freq_shift)
32
+
33
+ emb = torch.exp(exponent)
34
+ emb = timesteps[:, None].float() * emb[None, :]
35
+
36
+ # scale embeddings
37
+ emb = scale * emb
38
+
39
+ # concat sine and cosine embeddings
40
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
41
+
42
+ # flip sine and cosine embeddings
43
+ if flip_sin_to_cos:
44
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
45
+
46
+ # zero pad
47
+ if embedding_dim % 2 == 1:
48
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
49
+ return emb
50
+
51
+
52
+ # FFN
53
+ def FeedForward(dim, mult=4):
54
+ inner_dim = int(dim * mult)
55
+ return nn.Sequential(
56
+ nn.LayerNorm(dim),
57
+ nn.Linear(dim, inner_dim, bias=False),
58
+ nn.GELU(),
59
+ nn.Linear(inner_dim, dim, bias=False),
60
+ )
61
+
62
+
63
+ def reshape_tensor(x, heads):
64
+ bs, length, width = x.shape
65
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
66
+ x = x.view(bs, length, heads, -1)
67
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
68
+ x = x.transpose(1, 2)
69
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
70
+ x = x.reshape(bs, heads, length, -1)
71
+ return x
72
+
73
+
74
+ class PerceiverAttention(nn.Module):
75
+ def __init__(self, *, dim, dim_head=64, heads=8):
76
+ super().__init__()
77
+ self.scale = dim_head**-0.5
78
+ self.dim_head = dim_head
79
+ self.heads = heads
80
+ inner_dim = dim_head * heads
81
+
82
+ self.norm1 = nn.LayerNorm(dim)
83
+ self.norm2 = nn.LayerNorm(dim)
84
+
85
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
86
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
87
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
88
+
89
+
90
+ def forward(self, x, latents, shift=None, scale=None):
91
+ """
92
+ Args:
93
+ x (torch.Tensor): image features
94
+ shape (b, n1, D)
95
+ latent (torch.Tensor): latent features
96
+ shape (b, n2, D)
97
+ """
98
+ x = self.norm1(x)
99
+ latents = self.norm2(latents)
100
+
101
+ if shift is not None and scale is not None:
102
+ latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
103
+
104
+ b, l, _ = latents.shape
105
+
106
+ q = self.to_q(latents)
107
+ kv_input = torch.cat((x, latents), dim=-2)
108
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
109
+
110
+ q = reshape_tensor(q, self.heads)
111
+ k = reshape_tensor(k, self.heads)
112
+ v = reshape_tensor(v, self.heads)
113
+
114
+ # attention
115
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
116
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
117
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
118
+ out = weight @ v
119
+
120
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
121
+
122
+ return self.to_out(out)
123
+
124
+
125
+ class Resampler(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim=1024,
129
+ depth=8,
130
+ dim_head=64,
131
+ heads=16,
132
+ num_queries=8,
133
+ embedding_dim=768,
134
+ output_dim=1024,
135
+ ff_mult=4,
136
+ *args,
137
+ **kwargs,
138
+ ):
139
+ super().__init__()
140
+
141
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
142
+
143
+ self.proj_in = nn.Linear(embedding_dim, dim)
144
+
145
+ self.proj_out = nn.Linear(dim, output_dim)
146
+ self.norm_out = nn.LayerNorm(output_dim)
147
+
148
+ self.layers = nn.ModuleList([])
149
+ for _ in range(depth):
150
+ self.layers.append(
151
+ nn.ModuleList(
152
+ [
153
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
154
+ FeedForward(dim=dim, mult=ff_mult),
155
+ ]
156
+ )
157
+ )
158
+
159
+ def forward(self, x):
160
+
161
+ latents = self.latents.repeat(x.size(0), 1, 1)
162
+
163
+ x = self.proj_in(x)
164
+
165
+ for attn, ff in self.layers:
166
+ latents = attn(x, latents) + latents
167
+ latents = ff(latents) + latents
168
+
169
+ latents = self.proj_out(latents)
170
+ return self.norm_out(latents)
171
+
172
+
173
+ class TimeResampler(nn.Module):
174
+ def __init__(
175
+ self,
176
+ dim=1024,
177
+ depth=8,
178
+ dim_head=64,
179
+ heads=16,
180
+ num_queries=8,
181
+ embedding_dim=768,
182
+ output_dim=1024,
183
+ ff_mult=4,
184
+ timestep_in_dim=320,
185
+ timestep_flip_sin_to_cos=True,
186
+ timestep_freq_shift=0,
187
+ ):
188
+ super().__init__()
189
+
190
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
191
+
192
+ self.proj_in = nn.Linear(embedding_dim, dim)
193
+
194
+ self.proj_out = nn.Linear(dim, output_dim)
195
+ self.norm_out = nn.LayerNorm(output_dim)
196
+
197
+ self.layers = nn.ModuleList([])
198
+ for _ in range(depth):
199
+ self.layers.append(
200
+ nn.ModuleList(
201
+ [
202
+ # msa
203
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
204
+ # ff
205
+ FeedForward(dim=dim, mult=ff_mult),
206
+ # adaLN
207
+ nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
208
+ ]
209
+ )
210
+ )
211
+
212
+ # time
213
+ self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
214
+ self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
215
+
216
+ # adaLN
217
+ # self.adaLN_modulation = nn.Sequential(
218
+ # nn.SiLU(),
219
+ # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
220
+ # )
221
+
222
+
223
+ def forward(self, x, timestep, need_temb=False):
224
+ timestep_emb = self.embedding_time(x, timestep) # bs, dim
225
+
226
+ latents = self.latents.repeat(x.size(0), 1, 1)
227
+
228
+ x = self.proj_in(x)
229
+ x = x + timestep_emb[:, None]
230
+
231
+ for attn, ff, adaLN_modulation in self.layers:
232
+ shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
233
+ latents = attn(x, latents, shift_msa, scale_msa) + latents
234
+
235
+ res = latents
236
+ for idx_ff in range(len(ff)):
237
+ layer_ff = ff[idx_ff]
238
+ latents = layer_ff(latents)
239
+ if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
240
+ latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
241
+ latents = latents + res
242
+
243
+ # latents = ff(latents) + latents
244
+
245
+ latents = self.proj_out(latents)
246
+ latents = self.norm_out(latents)
247
+
248
+ if need_temb:
249
+ return latents, timestep_emb
250
+ else:
251
+ return latents
252
+
253
+
254
+
255
+ def embedding_time(self, sample, timestep):
256
+
257
+ # 1. time
258
+ timesteps = timestep
259
+ if not torch.is_tensor(timesteps):
260
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
261
+ # This would be a good case for the `match` statement (Python 3.10+)
262
+ is_mps = sample.device.type == "mps"
263
+ if isinstance(timestep, float):
264
+ dtype = torch.float32 if is_mps else torch.float64
265
+ else:
266
+ dtype = torch.int32 if is_mps else torch.int64
267
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
268
+ elif len(timesteps.shape) == 0:
269
+ timesteps = timesteps[None].to(sample.device)
270
+
271
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
+ timesteps = timesteps.expand(sample.shape[0])
273
+
274
+ t_emb = self.time_proj(timesteps)
275
+
276
+ # timesteps does not contain any weights and will always return f32 tensors
277
+ # but time_embedding might actually be running in fp16. so we need to cast here.
278
+ # there might be better ways to encapsulate this.
279
+ t_emb = t_emb.to(dtype=sample.dtype)
280
+
281
+ emb = self.time_embedding(t_emb, None)
282
+ return emb
283
+
284
+
285
+
286
+
287
+
288
+ if __name__ == '__main__':
289
+ model = TimeResampler(
290
+ dim=1280,
291
+ depth=4,
292
+ dim_head=64,
293
+ heads=20,
294
+ num_queries=16,
295
+ embedding_dim=512,
296
+ output_dim=2048,
297
+ ff_mult=4,
298
+ timestep_in_dim=320,
299
+ timestep_flip_sin_to_cos=True,
300
+ timestep_freq_shift=0,
301
+ in_channel_extra_emb=2048,
302
+ )
303
+
304
+
models/transformer_sd3.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from .attention import JointTransformerBlock
24
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from diffusers.models.normalization import AdaLayerNormContinuous
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
36
+ """
37
+ The Transformer model introduced in Stable Diffusion 3.
38
+
39
+ Reference: https://arxiv.org/abs/2403.03206
40
+
41
+ Parameters:
42
+ sample_size (`int`): The width of the latent images. This is fixed during training since
43
+ it is used to learn a number of position embeddings.
44
+ patch_size (`int`): Patch size to turn the input data into small patches.
45
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
46
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
47
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
48
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
49
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
+ caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
51
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
52
+ out_channels (`int`, defaults to 16): Number of output channels.
53
+
54
+ """
55
+
56
+ _supports_gradient_checkpointing = True
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ sample_size: int = 128,
62
+ patch_size: int = 2,
63
+ in_channels: int = 16,
64
+ num_layers: int = 18,
65
+ attention_head_dim: int = 64,
66
+ num_attention_heads: int = 18,
67
+ joint_attention_dim: int = 4096,
68
+ caption_projection_dim: int = 1152,
69
+ pooled_projection_dim: int = 2048,
70
+ out_channels: int = 16,
71
+ pos_embed_max_size: int = 96,
72
+ dual_attention_layers: Tuple[
73
+ int, ...
74
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
+ qk_norm: Optional[str] = None,
76
+ ):
77
+ super().__init__()
78
+ default_out_channels = in_channels
79
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
80
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
81
+
82
+ self.pos_embed = PatchEmbed(
83
+ height=self.config.sample_size,
84
+ width=self.config.sample_size,
85
+ patch_size=self.config.patch_size,
86
+ in_channels=self.config.in_channels,
87
+ embed_dim=self.inner_dim,
88
+ pos_embed_max_size=pos_embed_max_size, # hard-code for now.
89
+ )
90
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91
+ embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
92
+ )
93
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94
+
95
+ # `attention_head_dim` is doubled to account for the mixing.
96
+ # It needs to crafted when we get the actual checkpoints.
97
+ self.transformer_blocks = nn.ModuleList(
98
+ [
99
+ JointTransformerBlock(
100
+ dim=self.inner_dim,
101
+ num_attention_heads=self.config.num_attention_heads,
102
+ attention_head_dim=self.config.attention_head_dim,
103
+ context_pre_only=i == num_layers - 1,
104
+ qk_norm=qk_norm,
105
+ use_dual_attention=True if i in dual_attention_layers else False,
106
+ )
107
+ for i in range(self.config.num_layers)
108
+ ]
109
+ )
110
+
111
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
112
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
113
+
114
+ self.gradient_checkpointing = False
115
+
116
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
117
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
118
+ """
119
+ Sets the attention processor to use [feed forward
120
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
121
+
122
+ Parameters:
123
+ chunk_size (`int`, *optional*):
124
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
125
+ over each tensor of dim=`dim`.
126
+ dim (`int`, *optional*, defaults to `0`):
127
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
128
+ or dim=1 (sequence length).
129
+ """
130
+ if dim not in [0, 1]:
131
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
132
+
133
+ # By default chunk size is 1
134
+ chunk_size = chunk_size or 1
135
+
136
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
137
+ if hasattr(module, "set_chunk_feed_forward"):
138
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
139
+
140
+ for child in module.children():
141
+ fn_recursive_feed_forward(child, chunk_size, dim)
142
+
143
+ for module in self.children():
144
+ fn_recursive_feed_forward(module, chunk_size, dim)
145
+
146
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
147
+ def disable_forward_chunking(self):
148
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
149
+ if hasattr(module, "set_chunk_feed_forward"):
150
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
151
+
152
+ for child in module.children():
153
+ fn_recursive_feed_forward(child, chunk_size, dim)
154
+
155
+ for module in self.children():
156
+ fn_recursive_feed_forward(module, None, 0)
157
+
158
+ @property
159
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
+ r"""
162
+ Returns:
163
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
+ indexed by its weight name.
165
+ """
166
+ # set recursively
167
+ processors = {}
168
+
169
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
+ if hasattr(module, "get_processor"):
171
+ processors[f"{name}.processor"] = module.get_processor()
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
+
176
+ return processors
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_add_processors(name, module, processors)
180
+
181
+ return processors
182
+
183
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
+ r"""
186
+ Sets the attention processor to use to compute attention.
187
+
188
+ Parameters:
189
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
+ for **all** `Attention` layers.
192
+
193
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
+ processor. This is strongly recommended when setting trainable attention processors.
195
+
196
+ """
197
+ count = len(self.attn_processors.keys())
198
+
199
+ if isinstance(processor, dict) and len(processor) != count:
200
+ raise ValueError(
201
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
+ )
204
+
205
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
+ if hasattr(module, "set_processor"):
207
+ if not isinstance(processor, dict):
208
+ module.set_processor(processor)
209
+ else:
210
+ module.set_processor(processor.pop(f"{name}.processor"))
211
+
212
+ for sub_name, child in module.named_children():
213
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
+
215
+ for name, module in self.named_children():
216
+ fn_recursive_attn_processor(name, module, processor)
217
+
218
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
219
+ def fuse_qkv_projections(self):
220
+ """
221
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
222
+ are fused. For cross-attention modules, key and value projection matrices are fused.
223
+
224
+ <Tip warning={true}>
225
+
226
+ This API is 🧪 experimental.
227
+
228
+ </Tip>
229
+ """
230
+ self.original_attn_processors = None
231
+
232
+ for _, attn_processor in self.attn_processors.items():
233
+ if "Added" in str(attn_processor.__class__.__name__):
234
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
235
+
236
+ self.original_attn_processors = self.attn_processors
237
+
238
+ for module in self.modules():
239
+ if isinstance(module, Attention):
240
+ module.fuse_projections(fuse=True)
241
+
242
+ self.set_attn_processor(FusedJointAttnProcessor2_0())
243
+
244
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
245
+ def unfuse_qkv_projections(self):
246
+ """Disables the fused QKV projection if enabled.
247
+
248
+ <Tip warning={true}>
249
+
250
+ This API is 🧪 experimental.
251
+
252
+ </Tip>
253
+
254
+ """
255
+ if self.original_attn_processors is not None:
256
+ self.set_attn_processor(self.original_attn_processors)
257
+
258
+ def _set_gradient_checkpointing(self, module, value=False):
259
+ if hasattr(module, "gradient_checkpointing"):
260
+ module.gradient_checkpointing = value
261
+
262
+ def forward(
263
+ self,
264
+ hidden_states: torch.FloatTensor,
265
+ encoder_hidden_states: torch.FloatTensor = None,
266
+ pooled_projections: torch.FloatTensor = None,
267
+ timestep: torch.LongTensor = None,
268
+ block_controlnet_hidden_states: List = None,
269
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
+ return_dict: bool = True,
271
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
+ """
273
+ The [`SD3Transformer2DModel`] forward method.
274
+
275
+ Args:
276
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
277
+ Input `hidden_states`.
278
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
+ from the embeddings of input conditions.
282
+ timestep ( `torch.LongTensor`):
283
+ Used to indicate denoising step.
284
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285
+ A list of tensors that if specified are added to the residuals of transformer blocks.
286
+ joint_attention_kwargs (`dict`, *optional*):
287
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
288
+ `self.processor` in
289
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
290
+ return_dict (`bool`, *optional*, defaults to `True`):
291
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
+ tuple.
293
+
294
+ Returns:
295
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
+ `tuple` where the first element is the sample tensor.
297
+ """
298
+ if joint_attention_kwargs is not None:
299
+ joint_attention_kwargs = joint_attention_kwargs.copy()
300
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
301
+ else:
302
+ lora_scale = 1.0
303
+
304
+ if USE_PEFT_BACKEND:
305
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
306
+ scale_lora_layers(self, lora_scale)
307
+ else:
308
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
309
+ logger.warning(
310
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
311
+ )
312
+
313
+ height, width = hidden_states.shape[-2:]
314
+
315
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
316
+ temb = self.time_text_embed(timestep, pooled_projections)
317
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
+
319
+ for index_block, block in enumerate(self.transformer_blocks):
320
+ if self.training and self.gradient_checkpointing:
321
+
322
+ def create_custom_forward(module, return_dict=None):
323
+ def custom_forward(*inputs):
324
+ if return_dict is not None:
325
+ return module(*inputs, return_dict=return_dict)
326
+ else:
327
+ return module(*inputs)
328
+
329
+ return custom_forward
330
+
331
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
332
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
333
+ create_custom_forward(block),
334
+ hidden_states,
335
+ encoder_hidden_states,
336
+ temb,
337
+ joint_attention_kwargs,
338
+ **ckpt_kwargs,
339
+ )
340
+
341
+ else:
342
+ encoder_hidden_states, hidden_states = block(
343
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
344
+ joint_attention_kwargs=joint_attention_kwargs,
345
+ )
346
+
347
+ # controlnet residual
348
+ if block_controlnet_hidden_states is not None and block.context_pre_only is False:
349
+ interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
350
+ hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
351
+
352
+ hidden_states = self.norm_out(hidden_states, temb)
353
+ hidden_states = self.proj_out(hidden_states)
354
+
355
+ # unpatchify
356
+ patch_size = self.config.patch_size
357
+ height = height // patch_size
358
+ width = width // patch_size
359
+
360
+ hidden_states = hidden_states.reshape(
361
+ shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
362
+ )
363
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
+ output = hidden_states.reshape(
365
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
366
+ )
367
+
368
+ if USE_PEFT_BACKEND:
369
+ # remove `lora_scale` from each PEFT layer
370
+ unscale_lora_layers(self, lora_scale)
371
+
372
+ if not return_dict:
373
+ return (output,)
374
+
375
+ return Transformer2DModelOutput(sample=output)
pipeline_stable_diffusion_3_ipa.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from transformers import (
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ T5EncoderModel,
25
+ T5TokenizerFast,
26
+ )
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30
+ from diffusers.models.autoencoders import AutoencoderKL
31
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
+ from diffusers.utils import (
33
+ USE_PEFT_BACKEND,
34
+ is_torch_xla_available,
35
+ logging,
36
+ replace_example_docstring,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import randn_tensor
41
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
+ from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
43
+
44
+ from models.resampler import TimeResampler
45
+ from models.transformer_sd3 import SD3Transformer2DModel
46
+ from diffusers.models.normalization import RMSNorm
47
+ from einops import rearrange
48
+
49
+
50
+ if is_torch_xla_available():
51
+ import torch_xla.core.xla_model as xm
52
+
53
+ XLA_AVAILABLE = True
54
+ else:
55
+ XLA_AVAILABLE = False
56
+
57
+
58
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
+
60
+ EXAMPLE_DOC_STRING = """
61
+ Examples:
62
+ ```py
63
+ >>> import torch
64
+ >>> from diffusers import StableDiffusion3Pipeline
65
+
66
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
67
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
68
+ ... )
69
+ >>> pipe.to("cuda")
70
+ >>> prompt = "A cat holding a sign that says hello world"
71
+ >>> image = pipe(prompt).images[0]
72
+ >>> image.save("sd3.png")
73
+ ```
74
+ """
75
+
76
+
77
+ class AdaLayerNorm(nn.Module):
78
+ """
79
+ Norm layer adaptive layer norm zero (adaLN-Zero).
80
+
81
+ Parameters:
82
+ embedding_dim (`int`): The size of each embedding vector.
83
+ num_embeddings (`int`): The size of the embeddings dictionary.
84
+ """
85
+
86
+ def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'):
87
+ super().__init__()
88
+
89
+ self.silu = nn.SiLU()
90
+ num_params_dict = dict(
91
+ zero=6,
92
+ normal=2,
93
+ )
94
+ num_params = num_params_dict[mode]
95
+ self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True)
96
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
97
+ self.mode = mode
98
+
99
+ def forward(
100
+ self,
101
+ x,
102
+ hidden_dtype = None,
103
+ emb = None,
104
+ ):
105
+ emb = self.linear(self.silu(emb))
106
+ if self.mode == 'normal':
107
+ shift_msa, scale_msa = emb.chunk(2, dim=1)
108
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
109
+ return x
110
+
111
+ elif self.mode == 'zero':
112
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
113
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
114
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
115
+
116
+
117
+ class JointIPAttnProcessor(torch.nn.Module):
118
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
119
+
120
+ def __init__(
121
+ self,
122
+ hidden_size=None,
123
+ cross_attention_dim=None,
124
+ ip_hidden_states_dim=None,
125
+ ip_encoder_hidden_states_dim=None,
126
+ head_dim=None,
127
+ timesteps_emb_dim=1280,
128
+ ):
129
+ super().__init__()
130
+
131
+ self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim)
132
+ self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
133
+ self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
134
+ self.norm_q = RMSNorm(head_dim, 1e-6)
135
+ self.norm_k = RMSNorm(head_dim, 1e-6)
136
+ self.norm_ip_k = RMSNorm(head_dim, 1e-6)
137
+
138
+
139
+ def __call__(
140
+ self,
141
+ attn,
142
+ hidden_states: torch.FloatTensor,
143
+ encoder_hidden_states: torch.FloatTensor = None,
144
+ attention_mask: Optional[torch.FloatTensor] = None,
145
+ emb_dict=None,
146
+ *args,
147
+ **kwargs,
148
+ ) -> torch.FloatTensor:
149
+ residual = hidden_states
150
+
151
+ batch_size = hidden_states.shape[0]
152
+
153
+ # `sample` projections.
154
+ query = attn.to_q(hidden_states)
155
+ key = attn.to_k(hidden_states)
156
+ value = attn.to_v(hidden_states)
157
+ img_query = query
158
+ img_key = key
159
+ img_value = value
160
+
161
+ inner_dim = key.shape[-1]
162
+ head_dim = inner_dim // attn.heads
163
+
164
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+
168
+ if attn.norm_q is not None:
169
+ query = attn.norm_q(query)
170
+ if attn.norm_k is not None:
171
+ key = attn.norm_k(key)
172
+
173
+ # `context` projections.
174
+ if encoder_hidden_states is not None:
175
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
176
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
177
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
178
+
179
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
180
+ batch_size, -1, attn.heads, head_dim
181
+ ).transpose(1, 2)
182
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
183
+ batch_size, -1, attn.heads, head_dim
184
+ ).transpose(1, 2)
185
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
186
+ batch_size, -1, attn.heads, head_dim
187
+ ).transpose(1, 2)
188
+
189
+ if attn.norm_added_q is not None:
190
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
191
+ if attn.norm_added_k is not None:
192
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
193
+
194
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
195
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
196
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
197
+
198
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
199
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
200
+ hidden_states = hidden_states.to(query.dtype)
201
+
202
+ if encoder_hidden_states is not None:
203
+ # Split the attention outputs.
204
+ hidden_states, encoder_hidden_states = (
205
+ hidden_states[:, : residual.shape[1]],
206
+ hidden_states[:, residual.shape[1] :],
207
+ )
208
+ if not attn.context_pre_only:
209
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
210
+
211
+
212
+ # IPadapter
213
+ ip_hidden_states = emb_dict.get('ip_hidden_states', None)
214
+ ip_hidden_states = self.get_ip_hidden_states(
215
+ attn,
216
+ img_query,
217
+ ip_hidden_states,
218
+ img_key,
219
+ img_value,
220
+ None,
221
+ None,
222
+ emb_dict['temb'],
223
+ )
224
+ if ip_hidden_states is not None:
225
+ hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0)
226
+
227
+
228
+ # linear proj
229
+ hidden_states = attn.to_out[0](hidden_states)
230
+ # dropout
231
+ hidden_states = attn.to_out[1](hidden_states)
232
+
233
+ if encoder_hidden_states is not None:
234
+ return hidden_states, encoder_hidden_states
235
+ else:
236
+ return hidden_states
237
+
238
+
239
+ def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None):
240
+ if ip_hidden_states is None:
241
+ return None
242
+
243
+ if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
244
+ return None
245
+
246
+ # norm ip input
247
+ norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb)
248
+
249
+ # to k and v
250
+ ip_key = self.to_k_ip(norm_ip_hidden_states)
251
+ ip_value = self.to_v_ip(norm_ip_hidden_states)
252
+
253
+ # reshape
254
+ query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads)
255
+ img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
256
+ img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
257
+ ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
258
+ ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
259
+
260
+ # norm
261
+ query = self.norm_q(query)
262
+ img_key = self.norm_k(img_key)
263
+ ip_key = self.norm_ip_k(ip_key)
264
+
265
+ # cat img
266
+ key = torch.cat([img_key, ip_key], dim=2)
267
+ value = torch.cat([img_value, ip_value], dim=2)
268
+
269
+ #
270
+ ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
271
+ ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
272
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
273
+ return ip_hidden_states
274
+
275
+
276
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
277
+ def retrieve_timesteps(
278
+ scheduler,
279
+ num_inference_steps: Optional[int] = None,
280
+ device: Optional[Union[str, torch.device]] = None,
281
+ timesteps: Optional[List[int]] = None,
282
+ sigmas: Optional[List[float]] = None,
283
+ **kwargs,
284
+ ):
285
+ """
286
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
287
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
288
+
289
+ Args:
290
+ scheduler (`SchedulerMixin`):
291
+ The scheduler to get timesteps from.
292
+ num_inference_steps (`int`):
293
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
294
+ must be `None`.
295
+ device (`str` or `torch.device`, *optional*):
296
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
297
+ timesteps (`List[int]`, *optional*):
298
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
299
+ `num_inference_steps` and `sigmas` must be `None`.
300
+ sigmas (`List[float]`, *optional*):
301
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
302
+ `num_inference_steps` and `timesteps` must be `None`.
303
+
304
+ Returns:
305
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
306
+ second element is the number of inference steps.
307
+ """
308
+ if timesteps is not None and sigmas is not None:
309
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
310
+ if timesteps is not None:
311
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
312
+ if not accepts_timesteps:
313
+ raise ValueError(
314
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
315
+ f" timestep schedules. Please check whether you are using the correct scheduler."
316
+ )
317
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
318
+ timesteps = scheduler.timesteps
319
+ num_inference_steps = len(timesteps)
320
+ elif sigmas is not None:
321
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
322
+ if not accept_sigmas:
323
+ raise ValueError(
324
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
325
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
326
+ )
327
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
328
+ timesteps = scheduler.timesteps
329
+ num_inference_steps = len(timesteps)
330
+ else:
331
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
332
+ timesteps = scheduler.timesteps
333
+ return timesteps, num_inference_steps
334
+
335
+
336
+ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
337
+ r"""
338
+ Args:
339
+ transformer ([`SD3Transformer2DModel`]):
340
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
341
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
342
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
343
+ vae ([`AutoencoderKL`]):
344
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
345
+ text_encoder ([`CLIPTextModelWithProjection`]):
346
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
347
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
348
+ with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
349
+ as its dimension.
350
+ text_encoder_2 ([`CLIPTextModelWithProjection`]):
351
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
352
+ specifically the
353
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
354
+ variant.
355
+ text_encoder_3 ([`T5EncoderModel`]):
356
+ Frozen text-encoder. Stable Diffusion 3 uses
357
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
358
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
359
+ tokenizer (`CLIPTokenizer`):
360
+ Tokenizer of class
361
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
362
+ tokenizer_2 (`CLIPTokenizer`):
363
+ Second Tokenizer of class
364
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
365
+ tokenizer_3 (`T5TokenizerFast`):
366
+ Tokenizer of class
367
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
368
+ """
369
+
370
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
371
+ _optional_components = []
372
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
373
+
374
+ def __init__(
375
+ self,
376
+ transformer: SD3Transformer2DModel,
377
+ scheduler: FlowMatchEulerDiscreteScheduler,
378
+ vae: AutoencoderKL,
379
+ text_encoder: CLIPTextModelWithProjection,
380
+ tokenizer: CLIPTokenizer,
381
+ text_encoder_2: CLIPTextModelWithProjection,
382
+ tokenizer_2: CLIPTokenizer,
383
+ text_encoder_3: T5EncoderModel,
384
+ tokenizer_3: T5TokenizerFast,
385
+ ):
386
+ super().__init__()
387
+
388
+ self.register_modules(
389
+ vae=vae,
390
+ text_encoder=text_encoder,
391
+ text_encoder_2=text_encoder_2,
392
+ text_encoder_3=text_encoder_3,
393
+ tokenizer=tokenizer,
394
+ tokenizer_2=tokenizer_2,
395
+ tokenizer_3=tokenizer_3,
396
+ transformer=transformer,
397
+ scheduler=scheduler,
398
+ )
399
+ self.vae_scale_factor = (
400
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
401
+ )
402
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
403
+ self.tokenizer_max_length = (
404
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
405
+ )
406
+ self.default_sample_size = (
407
+ self.transformer.config.sample_size
408
+ if hasattr(self, "transformer") and self.transformer is not None
409
+ else 128
410
+ )
411
+
412
+ def _get_t5_prompt_embeds(
413
+ self,
414
+ prompt: Union[str, List[str]] = None,
415
+ num_images_per_prompt: int = 1,
416
+ max_sequence_length: int = 256,
417
+ device: Optional[torch.device] = None,
418
+ dtype: Optional[torch.dtype] = None,
419
+ ):
420
+ device = device or self._execution_device
421
+ dtype = dtype or self.text_encoder.dtype
422
+
423
+ prompt = [prompt] if isinstance(prompt, str) else prompt
424
+ batch_size = len(prompt)
425
+
426
+ if self.text_encoder_3 is None:
427
+ return torch.zeros(
428
+ (
429
+ batch_size * num_images_per_prompt,
430
+ self.tokenizer_max_length,
431
+ self.transformer.config.joint_attention_dim,
432
+ ),
433
+ device=device,
434
+ dtype=dtype,
435
+ )
436
+
437
+ text_inputs = self.tokenizer_3(
438
+ prompt,
439
+ padding="max_length",
440
+ max_length=max_sequence_length,
441
+ truncation=True,
442
+ add_special_tokens=True,
443
+ return_tensors="pt",
444
+ )
445
+ text_input_ids = text_inputs.input_ids
446
+ untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
447
+
448
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
449
+ removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
450
+ logger.warning(
451
+ "The following part of your input was truncated because `max_sequence_length` is set to "
452
+ f" {max_sequence_length} tokens: {removed_text}"
453
+ )
454
+
455
+ prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
456
+
457
+ dtype = self.text_encoder_3.dtype
458
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
459
+
460
+ _, seq_len, _ = prompt_embeds.shape
461
+
462
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
463
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465
+
466
+ return prompt_embeds
467
+
468
+ def _get_clip_prompt_embeds(
469
+ self,
470
+ prompt: Union[str, List[str]],
471
+ num_images_per_prompt: int = 1,
472
+ device: Optional[torch.device] = None,
473
+ clip_skip: Optional[int] = None,
474
+ clip_model_index: int = 0,
475
+ ):
476
+ device = device or self._execution_device
477
+
478
+ clip_tokenizers = [self.tokenizer, self.tokenizer_2]
479
+ clip_text_encoders = [self.text_encoder, self.text_encoder_2]
480
+
481
+ tokenizer = clip_tokenizers[clip_model_index]
482
+ text_encoder = clip_text_encoders[clip_model_index]
483
+
484
+ prompt = [prompt] if isinstance(prompt, str) else prompt
485
+ batch_size = len(prompt)
486
+
487
+ text_inputs = tokenizer(
488
+ prompt,
489
+ padding="max_length",
490
+ max_length=self.tokenizer_max_length,
491
+ truncation=True,
492
+ return_tensors="pt",
493
+ )
494
+
495
+ text_input_ids = text_inputs.input_ids
496
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
497
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
498
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
499
+ logger.warning(
500
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
501
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
502
+ )
503
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
504
+ pooled_prompt_embeds = prompt_embeds[0]
505
+
506
+ if clip_skip is None:
507
+ prompt_embeds = prompt_embeds.hidden_states[-2]
508
+ else:
509
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
510
+
511
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
512
+
513
+ _, seq_len, _ = prompt_embeds.shape
514
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
515
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
516
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
517
+
518
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
519
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
520
+
521
+ return prompt_embeds, pooled_prompt_embeds
522
+
523
+ def encode_prompt(
524
+ self,
525
+ prompt: Union[str, List[str]],
526
+ prompt_2: Union[str, List[str]],
527
+ prompt_3: Union[str, List[str]],
528
+ device: Optional[torch.device] = None,
529
+ num_images_per_prompt: int = 1,
530
+ do_classifier_free_guidance: bool = True,
531
+ negative_prompt: Optional[Union[str, List[str]]] = None,
532
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
533
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
534
+ prompt_embeds: Optional[torch.FloatTensor] = None,
535
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
536
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
537
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
538
+ clip_skip: Optional[int] = None,
539
+ max_sequence_length: int = 256,
540
+ lora_scale: Optional[float] = None,
541
+ ):
542
+ r"""
543
+
544
+ Args:
545
+ prompt (`str` or `List[str]`, *optional*):
546
+ prompt to be encoded
547
+ prompt_2 (`str` or `List[str]`, *optional*):
548
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
549
+ used in all text-encoders
550
+ prompt_3 (`str` or `List[str]`, *optional*):
551
+ The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
552
+ used in all text-encoders
553
+ device: (`torch.device`):
554
+ torch device
555
+ num_images_per_prompt (`int`):
556
+ number of images that should be generated per prompt
557
+ do_classifier_free_guidance (`bool`):
558
+ whether to use classifier free guidance or not
559
+ negative_prompt (`str` or `List[str]`, *optional*):
560
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
561
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
562
+ less than `1`).
563
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
564
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
565
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
566
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
567
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
568
+ `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
569
+ prompt_embeds (`torch.FloatTensor`, *optional*):
570
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
571
+ provided, text embeddings will be generated from `prompt` input argument.
572
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
573
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
574
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
575
+ argument.
576
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
577
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
578
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
579
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
580
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
581
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
582
+ input argument.
583
+ clip_skip (`int`, *optional*):
584
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
585
+ the output of the pre-final layer will be used for computing the prompt embeddings.
586
+ lora_scale (`float`, *optional*):
587
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
588
+ """
589
+ device = device or self._execution_device
590
+
591
+ # set lora scale so that monkey patched LoRA
592
+ # function of text encoder can correctly access it
593
+ if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
594
+ self._lora_scale = lora_scale
595
+
596
+ # dynamically adjust the LoRA scale
597
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
598
+ scale_lora_layers(self.text_encoder, lora_scale)
599
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
600
+ scale_lora_layers(self.text_encoder_2, lora_scale)
601
+
602
+ prompt = [prompt] if isinstance(prompt, str) else prompt
603
+ if prompt is not None:
604
+ batch_size = len(prompt)
605
+ else:
606
+ batch_size = prompt_embeds.shape[0]
607
+
608
+ if prompt_embeds is None:
609
+ prompt_2 = prompt_2 or prompt
610
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
611
+
612
+ prompt_3 = prompt_3 or prompt
613
+ prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
614
+
615
+ prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
616
+ prompt=prompt,
617
+ device=device,
618
+ num_images_per_prompt=num_images_per_prompt,
619
+ clip_skip=clip_skip,
620
+ clip_model_index=0,
621
+ )
622
+ prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
623
+ prompt=prompt_2,
624
+ device=device,
625
+ num_images_per_prompt=num_images_per_prompt,
626
+ clip_skip=clip_skip,
627
+ clip_model_index=1,
628
+ )
629
+ clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
630
+
631
+ t5_prompt_embed = self._get_t5_prompt_embeds(
632
+ prompt=prompt_3,
633
+ num_images_per_prompt=num_images_per_prompt,
634
+ max_sequence_length=max_sequence_length,
635
+ device=device,
636
+ )
637
+
638
+ clip_prompt_embeds = torch.nn.functional.pad(
639
+ clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
640
+ )
641
+
642
+ prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
643
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
644
+
645
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
646
+ negative_prompt = negative_prompt or ""
647
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
648
+ negative_prompt_3 = negative_prompt_3 or negative_prompt
649
+
650
+ # normalize str to list
651
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
652
+ negative_prompt_2 = (
653
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
654
+ )
655
+ negative_prompt_3 = (
656
+ batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
657
+ )
658
+
659
+ if prompt is not None and type(prompt) is not type(negative_prompt):
660
+ raise TypeError(
661
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
662
+ f" {type(prompt)}."
663
+ )
664
+ elif batch_size != len(negative_prompt):
665
+ raise ValueError(
666
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
667
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
668
+ " the batch size of `prompt`."
669
+ )
670
+
671
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
672
+ negative_prompt,
673
+ device=device,
674
+ num_images_per_prompt=num_images_per_prompt,
675
+ clip_skip=None,
676
+ clip_model_index=0,
677
+ )
678
+ negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
679
+ negative_prompt_2,
680
+ device=device,
681
+ num_images_per_prompt=num_images_per_prompt,
682
+ clip_skip=None,
683
+ clip_model_index=1,
684
+ )
685
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
686
+
687
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
688
+ prompt=negative_prompt_3,
689
+ num_images_per_prompt=num_images_per_prompt,
690
+ max_sequence_length=max_sequence_length,
691
+ device=device,
692
+ )
693
+
694
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
695
+ negative_clip_prompt_embeds,
696
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
697
+ )
698
+
699
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
700
+ negative_pooled_prompt_embeds = torch.cat(
701
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
702
+ )
703
+
704
+ if self.text_encoder is not None:
705
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
706
+ # Retrieve the original scale by scaling back the LoRA layers
707
+ unscale_lora_layers(self.text_encoder, lora_scale)
708
+
709
+ if self.text_encoder_2 is not None:
710
+ if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
711
+ # Retrieve the original scale by scaling back the LoRA layers
712
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
713
+
714
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
715
+
716
+ def check_inputs(
717
+ self,
718
+ prompt,
719
+ prompt_2,
720
+ prompt_3,
721
+ height,
722
+ width,
723
+ negative_prompt=None,
724
+ negative_prompt_2=None,
725
+ negative_prompt_3=None,
726
+ prompt_embeds=None,
727
+ negative_prompt_embeds=None,
728
+ pooled_prompt_embeds=None,
729
+ negative_pooled_prompt_embeds=None,
730
+ callback_on_step_end_tensor_inputs=None,
731
+ max_sequence_length=None,
732
+ ):
733
+ if height % 8 != 0 or width % 8 != 0:
734
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
735
+
736
+ if callback_on_step_end_tensor_inputs is not None and not all(
737
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
738
+ ):
739
+ raise ValueError(
740
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
741
+ )
742
+
743
+ if prompt is not None and prompt_embeds is not None:
744
+ raise ValueError(
745
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
746
+ " only forward one of the two."
747
+ )
748
+ elif prompt_2 is not None and prompt_embeds is not None:
749
+ raise ValueError(
750
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
751
+ " only forward one of the two."
752
+ )
753
+ elif prompt_3 is not None and prompt_embeds is not None:
754
+ raise ValueError(
755
+ f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
756
+ " only forward one of the two."
757
+ )
758
+ elif prompt is None and prompt_embeds is None:
759
+ raise ValueError(
760
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
761
+ )
762
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
763
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
764
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
765
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
766
+ elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
767
+ raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
768
+
769
+ if negative_prompt is not None and negative_prompt_embeds is not None:
770
+ raise ValueError(
771
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
772
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
773
+ )
774
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
775
+ raise ValueError(
776
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
777
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
778
+ )
779
+ elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
780
+ raise ValueError(
781
+ f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
782
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
783
+ )
784
+
785
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
786
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
787
+ raise ValueError(
788
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
789
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
790
+ f" {negative_prompt_embeds.shape}."
791
+ )
792
+
793
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
794
+ raise ValueError(
795
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
796
+ )
797
+
798
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
799
+ raise ValueError(
800
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
801
+ )
802
+
803
+ if max_sequence_length is not None and max_sequence_length > 512:
804
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
805
+
806
+ def prepare_latents(
807
+ self,
808
+ batch_size,
809
+ num_channels_latents,
810
+ height,
811
+ width,
812
+ dtype,
813
+ device,
814
+ generator,
815
+ latents=None,
816
+ ):
817
+ if latents is not None:
818
+ return latents.to(device=device, dtype=dtype)
819
+
820
+ shape = (
821
+ batch_size,
822
+ num_channels_latents,
823
+ int(height) // self.vae_scale_factor,
824
+ int(width) // self.vae_scale_factor,
825
+ )
826
+
827
+ if isinstance(generator, list) and len(generator) != batch_size:
828
+ raise ValueError(
829
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
830
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
831
+ )
832
+
833
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
834
+
835
+ return latents
836
+
837
+ @property
838
+ def guidance_scale(self):
839
+ return self._guidance_scale
840
+
841
+ @property
842
+ def clip_skip(self):
843
+ return self._clip_skip
844
+
845
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
846
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
847
+ # corresponds to doing no classifier free guidance.
848
+ @property
849
+ def do_classifier_free_guidance(self):
850
+ return self._guidance_scale > 1
851
+
852
+ @property
853
+ def joint_attention_kwargs(self):
854
+ return self._joint_attention_kwargs
855
+
856
+ @property
857
+ def num_timesteps(self):
858
+ return self._num_timesteps
859
+
860
+ @property
861
+ def interrupt(self):
862
+ return self._interrupt
863
+
864
+
865
+ @torch.inference_mode()
866
+ def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432):
867
+ from transformers import SiglipVisionModel, SiglipImageProcessor
868
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
869
+
870
+ device, dtype = self.transformer.device, self.transformer.dtype
871
+ image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
872
+ image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path)
873
+ image_encoder.eval()
874
+ image_encoder.to(device, dtype=dtype)
875
+ self.image_encoder = image_encoder
876
+ self.clip_image_processor = image_processor
877
+
878
+ sample_class = TimeResampler
879
+ image_proj_model = sample_class(
880
+ dim=1280,
881
+ depth=4,
882
+ dim_head=64,
883
+ heads=20,
884
+ num_queries=nb_token,
885
+ embedding_dim=1152,
886
+ output_dim=output_dim,
887
+ ff_mult=4,
888
+ timestep_in_dim=320,
889
+ timestep_flip_sin_to_cos=True,
890
+ timestep_freq_shift=0,
891
+ )
892
+ image_proj_model.eval()
893
+ image_proj_model.to(device, dtype=dtype)
894
+ key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
895
+ print(f"=> loading image_proj_model: {key_name}")
896
+
897
+ self.image_proj_model = image_proj_model
898
+
899
+
900
+ attn_procs = {}
901
+ transformer = self.transformer
902
+ for idx_name, name in enumerate(transformer.attn_processors.keys()):
903
+ hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads
904
+ ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads
905
+ ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim
906
+
907
+ attn_procs[name] = JointIPAttnProcessor(
908
+ hidden_size=hidden_size,
909
+ cross_attention_dim=transformer.config.caption_projection_dim,
910
+ ip_hidden_states_dim=ip_hidden_states_dim,
911
+ ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim,
912
+ head_dim=transformer.config.attention_head_dim,
913
+ timesteps_emb_dim=1280,
914
+ ).to(device, dtype=dtype)
915
+
916
+ self.transformer.set_attn_processor(attn_procs)
917
+ tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values())
918
+
919
+ key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
920
+ print(f"=> loading ip_adapter: {key_name}")
921
+
922
+
923
+ @torch.inference_mode()
924
+ def encode_clip_image_emb(self, clip_image, device, dtype):
925
+
926
+ # clip
927
+ clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
928
+ clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
929
+ clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
930
+ clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
931
+
932
+ return clip_image_embeds
933
+
934
+
935
+
936
+ @torch.no_grad()
937
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
938
+ def __call__(
939
+ self,
940
+ prompt: Union[str, List[str]] = None,
941
+ prompt_2: Optional[Union[str, List[str]]] = None,
942
+ prompt_3: Optional[Union[str, List[str]]] = None,
943
+ height: Optional[int] = None,
944
+ width: Optional[int] = None,
945
+ num_inference_steps: int = 28,
946
+ timesteps: List[int] = None,
947
+ guidance_scale: float = 7.0,
948
+ negative_prompt: Optional[Union[str, List[str]]] = None,
949
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
950
+ negative_prompt_3: Optional[Union[str, List[str]]] = None,
951
+ num_images_per_prompt: Optional[int] = 1,
952
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
953
+ latents: Optional[torch.FloatTensor] = None,
954
+ prompt_embeds: Optional[torch.FloatTensor] = None,
955
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
956
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
957
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
958
+ output_type: Optional[str] = "pil",
959
+ return_dict: bool = True,
960
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
961
+ clip_skip: Optional[int] = None,
962
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
963
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
964
+ max_sequence_length: int = 256,
965
+
966
+ # ipa
967
+ clip_image=None,
968
+ ipadapter_scale=1.0,
969
+ ):
970
+ r"""
971
+ Function invoked when calling the pipeline for generation.
972
+
973
+ Args:
974
+ prompt (`str` or `List[str]`, *optional*):
975
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
976
+ instead.
977
+ prompt_2 (`str` or `List[str]`, *optional*):
978
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
979
+ will be used instead
980
+ prompt_3 (`str` or `List[str]`, *optional*):
981
+ The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
982
+ will be used instead
983
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
984
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
985
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
986
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
987
+ num_inference_steps (`int`, *optional*, defaults to 50):
988
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
989
+ expense of slower inference.
990
+ timesteps (`List[int]`, *optional*):
991
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
992
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
993
+ passed will be used. Must be in descending order.
994
+ guidance_scale (`float`, *optional*, defaults to 7.0):
995
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
996
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
997
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
998
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
999
+ usually at the expense of lower image quality.
1000
+ negative_prompt (`str` or `List[str]`, *optional*):
1001
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1002
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1003
+ less than `1`).
1004
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1005
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1006
+ `text_encoder_2`. If not defined, `negative_prompt` is used instead
1007
+ negative_prompt_3 (`str` or `List[str]`, *optional*):
1008
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
1009
+ `text_encoder_3`. If not defined, `negative_prompt` is used instead
1010
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1011
+ The number of images to generate per prompt.
1012
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1013
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1014
+ to make generation deterministic.
1015
+ latents (`torch.FloatTensor`, *optional*):
1016
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1017
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1018
+ tensor will ge generated by sampling using the supplied random `generator`.
1019
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1020
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1021
+ provided, text embeddings will be generated from `prompt` input argument.
1022
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1023
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1024
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1025
+ argument.
1026
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1027
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1028
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1029
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1030
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1031
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1032
+ input argument.
1033
+ output_type (`str`, *optional*, defaults to `"pil"`):
1034
+ The output format of the generate image. Choose between
1035
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1036
+ return_dict (`bool`, *optional*, defaults to `True`):
1037
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1038
+ of a plain tuple.
1039
+ joint_attention_kwargs (`dict`, *optional*):
1040
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1041
+ `self.processor` in
1042
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1043
+ callback_on_step_end (`Callable`, *optional*):
1044
+ A function that calls at the end of each denoising steps during the inference. The function is called
1045
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1046
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1047
+ `callback_on_step_end_tensor_inputs`.
1048
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1049
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1050
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1051
+ `._callback_tensor_inputs` attribute of your pipeline class.
1052
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1053
+
1054
+ Examples:
1055
+
1056
+ Returns:
1057
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1058
+ [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1059
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1060
+ """
1061
+
1062
+ height = height or self.default_sample_size * self.vae_scale_factor
1063
+ width = width or self.default_sample_size * self.vae_scale_factor
1064
+
1065
+ # 1. Check inputs. Raise error if not correct
1066
+ self.check_inputs(
1067
+ prompt,
1068
+ prompt_2,
1069
+ prompt_3,
1070
+ height,
1071
+ width,
1072
+ negative_prompt=negative_prompt,
1073
+ negative_prompt_2=negative_prompt_2,
1074
+ negative_prompt_3=negative_prompt_3,
1075
+ prompt_embeds=prompt_embeds,
1076
+ negative_prompt_embeds=negative_prompt_embeds,
1077
+ pooled_prompt_embeds=pooled_prompt_embeds,
1078
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1079
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1080
+ max_sequence_length=max_sequence_length,
1081
+ )
1082
+
1083
+ self._guidance_scale = guidance_scale
1084
+ self._clip_skip = clip_skip
1085
+ self._joint_attention_kwargs = joint_attention_kwargs
1086
+ self._interrupt = False
1087
+
1088
+ # 2. Define call parameters
1089
+ if prompt is not None and isinstance(prompt, str):
1090
+ batch_size = 1
1091
+ elif prompt is not None and isinstance(prompt, list):
1092
+ batch_size = len(prompt)
1093
+ else:
1094
+ batch_size = prompt_embeds.shape[0]
1095
+
1096
+ device = self._execution_device
1097
+ dtype = self.transformer.dtype
1098
+
1099
+ lora_scale = (
1100
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1101
+ )
1102
+ (
1103
+ prompt_embeds,
1104
+ negative_prompt_embeds,
1105
+ pooled_prompt_embeds,
1106
+ negative_pooled_prompt_embeds,
1107
+ ) = self.encode_prompt(
1108
+ prompt=prompt,
1109
+ prompt_2=prompt_2,
1110
+ prompt_3=prompt_3,
1111
+ negative_prompt=negative_prompt,
1112
+ negative_prompt_2=negative_prompt_2,
1113
+ negative_prompt_3=negative_prompt_3,
1114
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1115
+ prompt_embeds=prompt_embeds,
1116
+ negative_prompt_embeds=negative_prompt_embeds,
1117
+ pooled_prompt_embeds=pooled_prompt_embeds,
1118
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1119
+ device=device,
1120
+ clip_skip=self.clip_skip,
1121
+ num_images_per_prompt=num_images_per_prompt,
1122
+ max_sequence_length=max_sequence_length,
1123
+ lora_scale=lora_scale,
1124
+ )
1125
+
1126
+ if self.do_classifier_free_guidance:
1127
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1128
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1129
+
1130
+ # 3. prepare clip emb
1131
+ clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1132
+ clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
1133
+
1134
+ # 4. Prepare timesteps
1135
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1136
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1137
+ self._num_timesteps = len(timesteps)
1138
+
1139
+ # 5. Prepare latent variables
1140
+ num_channels_latents = self.transformer.config.in_channels
1141
+ latents = self.prepare_latents(
1142
+ batch_size * num_images_per_prompt,
1143
+ num_channels_latents,
1144
+ height,
1145
+ width,
1146
+ prompt_embeds.dtype,
1147
+ device,
1148
+ generator,
1149
+ latents,
1150
+ )
1151
+
1152
+ # 6. Denoising loop
1153
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1154
+ for i, t in enumerate(timesteps):
1155
+ if self.interrupt:
1156
+ continue
1157
+
1158
+ # expand the latents if we are doing classifier free guidance
1159
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1160
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1161
+ timestep = t.expand(latent_model_input.shape[0])
1162
+
1163
+ image_prompt_embeds, timestep_emb = self.image_proj_model(
1164
+ clip_image_embeds,
1165
+ timestep.to(dtype=latents.dtype),
1166
+ need_temb=True
1167
+ )
1168
+
1169
+ joint_attention_kwargs = dict(
1170
+ emb_dict=dict(
1171
+ ip_hidden_states=image_prompt_embeds,
1172
+ temb=timestep_emb,
1173
+ scale=ipadapter_scale,
1174
+ )
1175
+ )
1176
+
1177
+ noise_pred = self.transformer(
1178
+ hidden_states=latent_model_input,
1179
+ timestep=timestep,
1180
+ encoder_hidden_states=prompt_embeds,
1181
+ pooled_projections=pooled_prompt_embeds,
1182
+ joint_attention_kwargs=joint_attention_kwargs,
1183
+ return_dict=False,
1184
+ )[0]
1185
+
1186
+ # perform guidance
1187
+ if self.do_classifier_free_guidance:
1188
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1189
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1190
+
1191
+ # compute the previous noisy sample x_t -> x_t-1
1192
+ latents_dtype = latents.dtype
1193
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1194
+
1195
+ if latents.dtype != latents_dtype:
1196
+ if torch.backends.mps.is_available():
1197
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1198
+ latents = latents.to(latents_dtype)
1199
+
1200
+ if callback_on_step_end is not None:
1201
+ callback_kwargs = {}
1202
+ for k in callback_on_step_end_tensor_inputs:
1203
+ callback_kwargs[k] = locals()[k]
1204
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1205
+
1206
+ latents = callback_outputs.pop("latents", latents)
1207
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1208
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1209
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1210
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1211
+ )
1212
+
1213
+ # call the callback, if provided
1214
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1215
+ progress_bar.update()
1216
+
1217
+ if XLA_AVAILABLE:
1218
+ xm.mark_step()
1219
+
1220
+ if output_type == "latent":
1221
+ image = latents
1222
+
1223
+ else:
1224
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1225
+
1226
+ image = self.vae.decode(latents, return_dict=False)[0]
1227
+ image = self.image_processor.postprocess(image, output_type=output_type)
1228
+
1229
+ # Offload all models
1230
+ self.maybe_free_model_hooks()
1231
+
1232
+ if not return_dict:
1233
+ return (image,)
1234
+
1235
+ return StableDiffusion3PipelineOutput(images=image)
teasers/0.png ADDED

Git LFS Details

  • SHA256: 6325e12735c57a61449fc94330d6e1e744977994bedff1fe6a2f37588d0a448e
  • Pointer size: 132 Bytes
  • Size of remote file: 5.2 MB
teasers/1.png ADDED

Git LFS Details

  • SHA256: 6bdca1eae51d34f587bea5cc218e861f1c88678c9f69d12ba1931fbcc567e9db
  • Pointer size: 132 Bytes
  • Size of remote file: 5.32 MB