Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- README.md +96 -0
- assets/1.jpg +3 -0
- assets/2.jpg +3 -0
- infer_sd35_large_ipa.py +40 -0
- ip-adapter.bin +3 -0
- ip-adapter.safetensors +3 -0
- models/__init__.py +0 -0
- models/attention.py +1245 -0
- models/resampler.py +304 -0
- models/transformer_sd3.py +375 -0
- pipeline_stable_diffusion_3_ipa.py +1235 -0
- teasers/0.png +3 -0
- teasers/1.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
teasers/0.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
teasers/1.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
ip-adapter.safetensors filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/1.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/2.jpg filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: stabilityai-ai-community
|
4 |
+
license_link: >-
|
5 |
+
https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md
|
6 |
+
language:
|
7 |
+
- en
|
8 |
+
library_name: diffusers
|
9 |
+
pipeline_tag: text-to-image
|
10 |
+
tags:
|
11 |
+
- Text-to-Image
|
12 |
+
- IP-Adapter
|
13 |
+
- StableDiffusion3Pipeline
|
14 |
+
- image-generation
|
15 |
+
- Stable Diffusion
|
16 |
+
base_model:
|
17 |
+
- stabilityai/stable-diffusion-3.5-large
|
18 |
+
---
|
19 |
+
|
20 |
+
# SD3.5-Large-IP-Adapter
|
21 |
+
|
22 |
+
This repository contains a IP-Adapter for SD3.5-Large model released by researchers from [InstantX Team](https://huggingface.co/InstantX), where image work just like text, so it may not be responsive or interfere with other text, but we do hope you enjoy this model, have fun and share your creative works with us [on Twitter](https://x.com/instantx_ai).
|
23 |
+
|
24 |
+
# Model Card
|
25 |
+
This is a regular IP-Adapter, where the new layers are added into all 38 blocks. We use [google/siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) to encode image for its superior performance, and adopt a TimeResampler to project. The image token number is set to 64.
|
26 |
+
|
27 |
+
# Showcases
|
28 |
+
|
29 |
+
<div class="container">
|
30 |
+
<img src="./teasers/0.png" width="1024"/>
|
31 |
+
<img src="./teasers/1.png" width="1024"/>
|
32 |
+
</div>
|
33 |
+
|
34 |
+
# Inference
|
35 |
+
The code has not been integrated into diffusers yet, please use our local files at this moment.
|
36 |
+
```python
|
37 |
+
import torch
|
38 |
+
from PIL import Image
|
39 |
+
|
40 |
+
from models.transformer_sd3 import SD3Transformer2DModel
|
41 |
+
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
|
42 |
+
|
43 |
+
model_path = 'stabilityai/stable-diffusion-3.5-large'
|
44 |
+
ip_adapter_path = './ip-adapter.bin'
|
45 |
+
image_encoder_path = "google/siglip-so400m-patch14-384"
|
46 |
+
|
47 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
48 |
+
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
|
49 |
+
)
|
50 |
+
|
51 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(
|
52 |
+
model_path, transformer=transformer, torch_dtype=torch.bfloat16
|
53 |
+
).to("cuda")
|
54 |
+
|
55 |
+
pipe.init_ipadapter(
|
56 |
+
ip_adapter_path=ip_adapter_path,
|
57 |
+
image_encoder_path=image_encoder_path,
|
58 |
+
nb_token=64,
|
59 |
+
)
|
60 |
+
|
61 |
+
ref_img = Image.open('./assets/1.jpg').convert('RGB')
|
62 |
+
|
63 |
+
# please note that SD3.5 Large is sensitive to highres generation like 1536x1536
|
64 |
+
image = pipe(
|
65 |
+
width=1024,
|
66 |
+
height=1024,
|
67 |
+
prompt='a cat',
|
68 |
+
negative_prompt="lowres, low quality, worst quality",
|
69 |
+
num_inference_steps=24,
|
70 |
+
guidance_scale=5.0,
|
71 |
+
generator=torch.Generator("cuda").manual_seed(42),
|
72 |
+
clip_image=ref_img,
|
73 |
+
ipadapter_scale=0.5,
|
74 |
+
).images[0]
|
75 |
+
image.save('./result.jpg')
|
76 |
+
```
|
77 |
+
|
78 |
+
# Community ComfyUI Support
|
79 |
+
Please refer to [Slickytail/ComfyUI-InstantX-IPAdapter-SD3](https://github.com/Slickytail/ComfyUI-InstantX-IPAdapter-SD3).
|
80 |
+
|
81 |
+
|
82 |
+
# License
|
83 |
+
The model is released under [stabilityai-ai-community](https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md). All copyright reserved.
|
84 |
+
|
85 |
+
# Acknowledgements
|
86 |
+
This project is sponsored by [HuggingFace](https://huggingface.co/) and [fal.ai](https://fal.ai/). Thanks to [Slickytail](https://github.com/Slickytail) for supporting ComfyUI node.
|
87 |
+
|
88 |
+
# Citation
|
89 |
+
If you find this project useful in your research, please cite us via
|
90 |
+
```
|
91 |
+
@misc{sd35-large-ipa,
|
92 |
+
author = {InstantX Team},
|
93 |
+
title = {InstantX SD3.5-Large IP-Adapter Page},
|
94 |
+
year = {2024},
|
95 |
+
}
|
96 |
+
```
|
assets/1.jpg
ADDED
![]() |
Git LFS Details
|
assets/2.jpg
ADDED
![]() |
Git LFS Details
|
infer_sd35_large_ipa.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from models.transformer_sd3 import SD3Transformer2DModel
|
5 |
+
from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == '__main__':
|
9 |
+
|
10 |
+
model_path = 'stabilityai/stable-diffusion-3.5-large'
|
11 |
+
ip_adapter_path = './ip-adapter.bin'
|
12 |
+
image_encoder_path = "google/siglip-so400m-patch14-384"
|
13 |
+
|
14 |
+
transformer = SD3Transformer2DModel.from_pretrained(
|
15 |
+
model_path, subfolder="transformer", torch_dtype=torch.bfloat16
|
16 |
+
)
|
17 |
+
|
18 |
+
pipe = StableDiffusion3Pipeline.from_pretrained(
|
19 |
+
model_path, transformer=transformer, torch_dtype=torch.bfloat16
|
20 |
+
).to("cuda")
|
21 |
+
|
22 |
+
pipe.init_ipadapter(
|
23 |
+
ip_adapter_path=ip_adapter_path,
|
24 |
+
image_encoder_path=image_encoder_path,
|
25 |
+
nb_token=64,
|
26 |
+
)
|
27 |
+
|
28 |
+
ref_img = Image.open('./assets/1.jpg').convert('RGB')
|
29 |
+
image = pipe(
|
30 |
+
width=1024,
|
31 |
+
height=1024,
|
32 |
+
prompt='a cat',
|
33 |
+
negative_prompt="lowres, low quality, worst quality",
|
34 |
+
num_inference_steps=24,
|
35 |
+
guidance_scale=5.0,
|
36 |
+
generator=torch.Generator("cuda").manual_seed(42),
|
37 |
+
clip_image=ref_img,
|
38 |
+
ipadapter_scale=0.5,
|
39 |
+
).images[0]
|
40 |
+
image.save('./result.jpg')
|
ip-adapter.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9fe54774aa528e712d9145ff6a59dd93b1fcf1d5935304feffd980ae6d42ae03
|
3 |
+
size 1595970439
|
ip-adapter.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3c6d90e1e9efbdc9db81b28420a9a5e4d3a0d6f7e9ef9eed013825f54d3239ac
|
3 |
+
size 1372601256
|
models/__init__.py
ADDED
File without changes
|
models/attention.py
ADDED
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, List, Optional, Tuple
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
from diffusers.utils import deprecate, logging
|
21 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
22 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
23 |
+
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
|
24 |
+
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
|
25 |
+
from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.get_logger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
|
32 |
+
# "feed_forward_chunk_size" can be used to save memory
|
33 |
+
if hidden_states.shape[chunk_dim] % chunk_size != 0:
|
34 |
+
raise ValueError(
|
35 |
+
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
36 |
+
)
|
37 |
+
|
38 |
+
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
|
39 |
+
ff_output = torch.cat(
|
40 |
+
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
|
41 |
+
dim=chunk_dim,
|
42 |
+
)
|
43 |
+
return ff_output
|
44 |
+
|
45 |
+
|
46 |
+
@maybe_allow_in_graph
|
47 |
+
class GatedSelfAttentionDense(nn.Module):
|
48 |
+
r"""
|
49 |
+
A gated self-attention dense layer that combines visual features and object features.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
query_dim (`int`): The number of channels in the query.
|
53 |
+
context_dim (`int`): The number of channels in the context.
|
54 |
+
n_heads (`int`): The number of heads to use for attention.
|
55 |
+
d_head (`int`): The number of channels in each head.
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
59 |
+
super().__init__()
|
60 |
+
|
61 |
+
# we need a linear projection since we need cat visual feature and obj feature
|
62 |
+
self.linear = nn.Linear(context_dim, query_dim)
|
63 |
+
|
64 |
+
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
|
65 |
+
self.ff = FeedForward(query_dim, activation_fn="geglu")
|
66 |
+
|
67 |
+
self.norm1 = nn.LayerNorm(query_dim)
|
68 |
+
self.norm2 = nn.LayerNorm(query_dim)
|
69 |
+
|
70 |
+
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
|
71 |
+
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
|
72 |
+
|
73 |
+
self.enabled = True
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
76 |
+
if not self.enabled:
|
77 |
+
return x
|
78 |
+
|
79 |
+
n_visual = x.shape[1]
|
80 |
+
objs = self.linear(objs)
|
81 |
+
|
82 |
+
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
|
83 |
+
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
|
84 |
+
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
@maybe_allow_in_graph
|
89 |
+
class JointTransformerBlock(nn.Module):
|
90 |
+
r"""
|
91 |
+
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
92 |
+
|
93 |
+
Reference: https://arxiv.org/abs/2403.03206
|
94 |
+
|
95 |
+
Parameters:
|
96 |
+
dim (`int`): The number of channels in the input and output.
|
97 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
98 |
+
attention_head_dim (`int`): The number of channels in each head.
|
99 |
+
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
100 |
+
processing of `context` conditions.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
dim: int,
|
106 |
+
num_attention_heads: int,
|
107 |
+
attention_head_dim: int,
|
108 |
+
context_pre_only: bool = False,
|
109 |
+
qk_norm: Optional[str] = None,
|
110 |
+
use_dual_attention: bool = False,
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
|
114 |
+
self.use_dual_attention = use_dual_attention
|
115 |
+
self.context_pre_only = context_pre_only
|
116 |
+
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
117 |
+
|
118 |
+
if use_dual_attention:
|
119 |
+
self.norm1 = SD35AdaLayerNormZeroX(dim)
|
120 |
+
else:
|
121 |
+
self.norm1 = AdaLayerNormZero(dim)
|
122 |
+
|
123 |
+
if context_norm_type == "ada_norm_continous":
|
124 |
+
self.norm1_context = AdaLayerNormContinuous(
|
125 |
+
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
126 |
+
)
|
127 |
+
elif context_norm_type == "ada_norm_zero":
|
128 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
129 |
+
else:
|
130 |
+
raise ValueError(
|
131 |
+
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
132 |
+
)
|
133 |
+
|
134 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
135 |
+
processor = JointAttnProcessor2_0()
|
136 |
+
else:
|
137 |
+
raise ValueError(
|
138 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
139 |
+
)
|
140 |
+
|
141 |
+
self.attn = Attention(
|
142 |
+
query_dim=dim,
|
143 |
+
cross_attention_dim=None,
|
144 |
+
added_kv_proj_dim=dim,
|
145 |
+
dim_head=attention_head_dim,
|
146 |
+
heads=num_attention_heads,
|
147 |
+
out_dim=dim,
|
148 |
+
context_pre_only=context_pre_only,
|
149 |
+
bias=True,
|
150 |
+
processor=processor,
|
151 |
+
qk_norm=qk_norm,
|
152 |
+
eps=1e-6,
|
153 |
+
)
|
154 |
+
|
155 |
+
if use_dual_attention:
|
156 |
+
self.attn2 = Attention(
|
157 |
+
query_dim=dim,
|
158 |
+
cross_attention_dim=None,
|
159 |
+
dim_head=attention_head_dim,
|
160 |
+
heads=num_attention_heads,
|
161 |
+
out_dim=dim,
|
162 |
+
bias=True,
|
163 |
+
processor=processor,
|
164 |
+
qk_norm=qk_norm,
|
165 |
+
eps=1e-6,
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
self.attn2 = None
|
169 |
+
|
170 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
171 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
172 |
+
|
173 |
+
if not context_pre_only:
|
174 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
175 |
+
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
176 |
+
else:
|
177 |
+
self.norm2_context = None
|
178 |
+
self.ff_context = None
|
179 |
+
|
180 |
+
# let chunk size default to None
|
181 |
+
self._chunk_size = None
|
182 |
+
self._chunk_dim = 0
|
183 |
+
|
184 |
+
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
185 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
186 |
+
# Sets chunk feed-forward
|
187 |
+
self._chunk_size = chunk_size
|
188 |
+
self._chunk_dim = dim
|
189 |
+
|
190 |
+
def forward(
|
191 |
+
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
|
192 |
+
joint_attention_kwargs=None,
|
193 |
+
):
|
194 |
+
if self.use_dual_attention:
|
195 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
196 |
+
hidden_states, emb=temb
|
197 |
+
)
|
198 |
+
else:
|
199 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
200 |
+
|
201 |
+
if self.context_pre_only:
|
202 |
+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
203 |
+
else:
|
204 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
205 |
+
encoder_hidden_states, emb=temb
|
206 |
+
)
|
207 |
+
|
208 |
+
# Attention.
|
209 |
+
attn_output, context_attn_output = self.attn(
|
210 |
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
|
211 |
+
**({} if joint_attention_kwargs is None else joint_attention_kwargs),
|
212 |
+
)
|
213 |
+
|
214 |
+
# Process attention outputs for the `hidden_states`.
|
215 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
216 |
+
hidden_states = hidden_states + attn_output
|
217 |
+
|
218 |
+
if self.use_dual_attention:
|
219 |
+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
|
220 |
+
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
221 |
+
hidden_states = hidden_states + attn_output2
|
222 |
+
|
223 |
+
norm_hidden_states = self.norm2(hidden_states)
|
224 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
225 |
+
if self._chunk_size is not None:
|
226 |
+
# "feed_forward_chunk_size" can be used to save memory
|
227 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
228 |
+
else:
|
229 |
+
ff_output = self.ff(norm_hidden_states)
|
230 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
231 |
+
|
232 |
+
hidden_states = hidden_states + ff_output
|
233 |
+
|
234 |
+
# Process attention outputs for the `encoder_hidden_states`.
|
235 |
+
if self.context_pre_only:
|
236 |
+
encoder_hidden_states = None
|
237 |
+
else:
|
238 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
239 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
240 |
+
|
241 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
242 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
243 |
+
if self._chunk_size is not None:
|
244 |
+
# "feed_forward_chunk_size" can be used to save memory
|
245 |
+
context_ff_output = _chunked_feed_forward(
|
246 |
+
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
250 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
251 |
+
|
252 |
+
return encoder_hidden_states, hidden_states
|
253 |
+
|
254 |
+
|
255 |
+
@maybe_allow_in_graph
|
256 |
+
class BasicTransformerBlock(nn.Module):
|
257 |
+
r"""
|
258 |
+
A basic Transformer block.
|
259 |
+
|
260 |
+
Parameters:
|
261 |
+
dim (`int`): The number of channels in the input and output.
|
262 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
263 |
+
attention_head_dim (`int`): The number of channels in each head.
|
264 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
265 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
266 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
267 |
+
num_embeds_ada_norm (:
|
268 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
269 |
+
attention_bias (:
|
270 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
271 |
+
only_cross_attention (`bool`, *optional*):
|
272 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
273 |
+
double_self_attention (`bool`, *optional*):
|
274 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
275 |
+
upcast_attention (`bool`, *optional*):
|
276 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
277 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
278 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
279 |
+
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
280 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
281 |
+
final_dropout (`bool` *optional*, defaults to False):
|
282 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
283 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
284 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
285 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
286 |
+
The type of positional embeddings to apply to.
|
287 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
288 |
+
The maximum number of positional embeddings to apply.
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
dim: int,
|
294 |
+
num_attention_heads: int,
|
295 |
+
attention_head_dim: int,
|
296 |
+
dropout=0.0,
|
297 |
+
cross_attention_dim: Optional[int] = None,
|
298 |
+
activation_fn: str = "geglu",
|
299 |
+
num_embeds_ada_norm: Optional[int] = None,
|
300 |
+
attention_bias: bool = False,
|
301 |
+
only_cross_attention: bool = False,
|
302 |
+
double_self_attention: bool = False,
|
303 |
+
upcast_attention: bool = False,
|
304 |
+
norm_elementwise_affine: bool = True,
|
305 |
+
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
|
306 |
+
norm_eps: float = 1e-5,
|
307 |
+
final_dropout: bool = False,
|
308 |
+
attention_type: str = "default",
|
309 |
+
positional_embeddings: Optional[str] = None,
|
310 |
+
num_positional_embeddings: Optional[int] = None,
|
311 |
+
ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
|
312 |
+
ada_norm_bias: Optional[int] = None,
|
313 |
+
ff_inner_dim: Optional[int] = None,
|
314 |
+
ff_bias: bool = True,
|
315 |
+
attention_out_bias: bool = True,
|
316 |
+
):
|
317 |
+
super().__init__()
|
318 |
+
self.dim = dim
|
319 |
+
self.num_attention_heads = num_attention_heads
|
320 |
+
self.attention_head_dim = attention_head_dim
|
321 |
+
self.dropout = dropout
|
322 |
+
self.cross_attention_dim = cross_attention_dim
|
323 |
+
self.activation_fn = activation_fn
|
324 |
+
self.attention_bias = attention_bias
|
325 |
+
self.double_self_attention = double_self_attention
|
326 |
+
self.norm_elementwise_affine = norm_elementwise_affine
|
327 |
+
self.positional_embeddings = positional_embeddings
|
328 |
+
self.num_positional_embeddings = num_positional_embeddings
|
329 |
+
self.only_cross_attention = only_cross_attention
|
330 |
+
|
331 |
+
# We keep these boolean flags for backward-compatibility.
|
332 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
333 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
334 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
335 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
336 |
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
337 |
+
|
338 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
339 |
+
raise ValueError(
|
340 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
341 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
342 |
+
)
|
343 |
+
|
344 |
+
self.norm_type = norm_type
|
345 |
+
self.num_embeds_ada_norm = num_embeds_ada_norm
|
346 |
+
|
347 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
348 |
+
raise ValueError(
|
349 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
350 |
+
)
|
351 |
+
|
352 |
+
if positional_embeddings == "sinusoidal":
|
353 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
354 |
+
else:
|
355 |
+
self.pos_embed = None
|
356 |
+
|
357 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
358 |
+
# 1. Self-Attn
|
359 |
+
if norm_type == "ada_norm":
|
360 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
361 |
+
elif norm_type == "ada_norm_zero":
|
362 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
363 |
+
elif norm_type == "ada_norm_continuous":
|
364 |
+
self.norm1 = AdaLayerNormContinuous(
|
365 |
+
dim,
|
366 |
+
ada_norm_continous_conditioning_embedding_dim,
|
367 |
+
norm_elementwise_affine,
|
368 |
+
norm_eps,
|
369 |
+
ada_norm_bias,
|
370 |
+
"rms_norm",
|
371 |
+
)
|
372 |
+
else:
|
373 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
374 |
+
|
375 |
+
self.attn1 = Attention(
|
376 |
+
query_dim=dim,
|
377 |
+
heads=num_attention_heads,
|
378 |
+
dim_head=attention_head_dim,
|
379 |
+
dropout=dropout,
|
380 |
+
bias=attention_bias,
|
381 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
382 |
+
upcast_attention=upcast_attention,
|
383 |
+
out_bias=attention_out_bias,
|
384 |
+
)
|
385 |
+
|
386 |
+
# 2. Cross-Attn
|
387 |
+
if cross_attention_dim is not None or double_self_attention:
|
388 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
389 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
390 |
+
# the second cross attention block.
|
391 |
+
if norm_type == "ada_norm":
|
392 |
+
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
393 |
+
elif norm_type == "ada_norm_continuous":
|
394 |
+
self.norm2 = AdaLayerNormContinuous(
|
395 |
+
dim,
|
396 |
+
ada_norm_continous_conditioning_embedding_dim,
|
397 |
+
norm_elementwise_affine,
|
398 |
+
norm_eps,
|
399 |
+
ada_norm_bias,
|
400 |
+
"rms_norm",
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
404 |
+
|
405 |
+
self.attn2 = Attention(
|
406 |
+
query_dim=dim,
|
407 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
408 |
+
heads=num_attention_heads,
|
409 |
+
dim_head=attention_head_dim,
|
410 |
+
dropout=dropout,
|
411 |
+
bias=attention_bias,
|
412 |
+
upcast_attention=upcast_attention,
|
413 |
+
out_bias=attention_out_bias,
|
414 |
+
) # is self-attn if encoder_hidden_states is none
|
415 |
+
else:
|
416 |
+
if norm_type == "ada_norm_single": # For Latte
|
417 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
418 |
+
else:
|
419 |
+
self.norm2 = None
|
420 |
+
self.attn2 = None
|
421 |
+
|
422 |
+
# 3. Feed-forward
|
423 |
+
if norm_type == "ada_norm_continuous":
|
424 |
+
self.norm3 = AdaLayerNormContinuous(
|
425 |
+
dim,
|
426 |
+
ada_norm_continous_conditioning_embedding_dim,
|
427 |
+
norm_elementwise_affine,
|
428 |
+
norm_eps,
|
429 |
+
ada_norm_bias,
|
430 |
+
"layer_norm",
|
431 |
+
)
|
432 |
+
|
433 |
+
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
|
434 |
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
435 |
+
elif norm_type == "layer_norm_i2vgen":
|
436 |
+
self.norm3 = None
|
437 |
+
|
438 |
+
self.ff = FeedForward(
|
439 |
+
dim,
|
440 |
+
dropout=dropout,
|
441 |
+
activation_fn=activation_fn,
|
442 |
+
final_dropout=final_dropout,
|
443 |
+
inner_dim=ff_inner_dim,
|
444 |
+
bias=ff_bias,
|
445 |
+
)
|
446 |
+
|
447 |
+
# 4. Fuser
|
448 |
+
if attention_type == "gated" or attention_type == "gated-text-image":
|
449 |
+
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
|
450 |
+
|
451 |
+
# 5. Scale-shift for PixArt-Alpha.
|
452 |
+
if norm_type == "ada_norm_single":
|
453 |
+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
454 |
+
|
455 |
+
# let chunk size default to None
|
456 |
+
self._chunk_size = None
|
457 |
+
self._chunk_dim = 0
|
458 |
+
|
459 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
460 |
+
# Sets chunk feed-forward
|
461 |
+
self._chunk_size = chunk_size
|
462 |
+
self._chunk_dim = dim
|
463 |
+
|
464 |
+
def forward(
|
465 |
+
self,
|
466 |
+
hidden_states: torch.Tensor,
|
467 |
+
attention_mask: Optional[torch.Tensor] = None,
|
468 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
469 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
470 |
+
timestep: Optional[torch.LongTensor] = None,
|
471 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
472 |
+
class_labels: Optional[torch.LongTensor] = None,
|
473 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
474 |
+
) -> torch.Tensor:
|
475 |
+
if cross_attention_kwargs is not None:
|
476 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
477 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
478 |
+
|
479 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
480 |
+
# 0. Self-Attention
|
481 |
+
batch_size = hidden_states.shape[0]
|
482 |
+
|
483 |
+
if self.norm_type == "ada_norm":
|
484 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
485 |
+
elif self.norm_type == "ada_norm_zero":
|
486 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
487 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
488 |
+
)
|
489 |
+
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
|
490 |
+
norm_hidden_states = self.norm1(hidden_states)
|
491 |
+
elif self.norm_type == "ada_norm_continuous":
|
492 |
+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
493 |
+
elif self.norm_type == "ada_norm_single":
|
494 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
495 |
+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
496 |
+
).chunk(6, dim=1)
|
497 |
+
norm_hidden_states = self.norm1(hidden_states)
|
498 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
499 |
+
else:
|
500 |
+
raise ValueError("Incorrect norm used")
|
501 |
+
|
502 |
+
if self.pos_embed is not None:
|
503 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
504 |
+
|
505 |
+
# 1. Prepare GLIGEN inputs
|
506 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
507 |
+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
508 |
+
|
509 |
+
attn_output = self.attn1(
|
510 |
+
norm_hidden_states,
|
511 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
512 |
+
attention_mask=attention_mask,
|
513 |
+
**cross_attention_kwargs,
|
514 |
+
)
|
515 |
+
|
516 |
+
if self.norm_type == "ada_norm_zero":
|
517 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
518 |
+
elif self.norm_type == "ada_norm_single":
|
519 |
+
attn_output = gate_msa * attn_output
|
520 |
+
|
521 |
+
hidden_states = attn_output + hidden_states
|
522 |
+
if hidden_states.ndim == 4:
|
523 |
+
hidden_states = hidden_states.squeeze(1)
|
524 |
+
|
525 |
+
# 1.2 GLIGEN Control
|
526 |
+
if gligen_kwargs is not None:
|
527 |
+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
528 |
+
|
529 |
+
# 3. Cross-Attention
|
530 |
+
if self.attn2 is not None:
|
531 |
+
if self.norm_type == "ada_norm":
|
532 |
+
norm_hidden_states = self.norm2(hidden_states, timestep)
|
533 |
+
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
|
534 |
+
norm_hidden_states = self.norm2(hidden_states)
|
535 |
+
elif self.norm_type == "ada_norm_single":
|
536 |
+
# For PixArt norm2 isn't applied here:
|
537 |
+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
538 |
+
norm_hidden_states = hidden_states
|
539 |
+
elif self.norm_type == "ada_norm_continuous":
|
540 |
+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
541 |
+
else:
|
542 |
+
raise ValueError("Incorrect norm")
|
543 |
+
|
544 |
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
545 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
546 |
+
|
547 |
+
attn_output = self.attn2(
|
548 |
+
norm_hidden_states,
|
549 |
+
encoder_hidden_states=encoder_hidden_states,
|
550 |
+
attention_mask=encoder_attention_mask,
|
551 |
+
**cross_attention_kwargs,
|
552 |
+
)
|
553 |
+
hidden_states = attn_output + hidden_states
|
554 |
+
|
555 |
+
# 4. Feed-forward
|
556 |
+
# i2vgen doesn't have this norm 🤷♂️
|
557 |
+
if self.norm_type == "ada_norm_continuous":
|
558 |
+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
|
559 |
+
elif not self.norm_type == "ada_norm_single":
|
560 |
+
norm_hidden_states = self.norm3(hidden_states)
|
561 |
+
|
562 |
+
if self.norm_type == "ada_norm_zero":
|
563 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
564 |
+
|
565 |
+
if self.norm_type == "ada_norm_single":
|
566 |
+
norm_hidden_states = self.norm2(hidden_states)
|
567 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
568 |
+
|
569 |
+
if self._chunk_size is not None:
|
570 |
+
# "feed_forward_chunk_size" can be used to save memory
|
571 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
572 |
+
else:
|
573 |
+
ff_output = self.ff(norm_hidden_states)
|
574 |
+
|
575 |
+
if self.norm_type == "ada_norm_zero":
|
576 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
577 |
+
elif self.norm_type == "ada_norm_single":
|
578 |
+
ff_output = gate_mlp * ff_output
|
579 |
+
|
580 |
+
hidden_states = ff_output + hidden_states
|
581 |
+
if hidden_states.ndim == 4:
|
582 |
+
hidden_states = hidden_states.squeeze(1)
|
583 |
+
|
584 |
+
return hidden_states
|
585 |
+
|
586 |
+
|
587 |
+
class LuminaFeedForward(nn.Module):
|
588 |
+
r"""
|
589 |
+
A feed-forward layer.
|
590 |
+
|
591 |
+
Parameters:
|
592 |
+
hidden_size (`int`):
|
593 |
+
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
594 |
+
hidden representations.
|
595 |
+
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
596 |
+
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
597 |
+
of this value.
|
598 |
+
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
599 |
+
dimension. Defaults to None.
|
600 |
+
"""
|
601 |
+
|
602 |
+
def __init__(
|
603 |
+
self,
|
604 |
+
dim: int,
|
605 |
+
inner_dim: int,
|
606 |
+
multiple_of: Optional[int] = 256,
|
607 |
+
ffn_dim_multiplier: Optional[float] = None,
|
608 |
+
):
|
609 |
+
super().__init__()
|
610 |
+
inner_dim = int(2 * inner_dim / 3)
|
611 |
+
# custom hidden_size factor multiplier
|
612 |
+
if ffn_dim_multiplier is not None:
|
613 |
+
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
614 |
+
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
615 |
+
|
616 |
+
self.linear_1 = nn.Linear(
|
617 |
+
dim,
|
618 |
+
inner_dim,
|
619 |
+
bias=False,
|
620 |
+
)
|
621 |
+
self.linear_2 = nn.Linear(
|
622 |
+
inner_dim,
|
623 |
+
dim,
|
624 |
+
bias=False,
|
625 |
+
)
|
626 |
+
self.linear_3 = nn.Linear(
|
627 |
+
dim,
|
628 |
+
inner_dim,
|
629 |
+
bias=False,
|
630 |
+
)
|
631 |
+
self.silu = FP32SiLU()
|
632 |
+
|
633 |
+
def forward(self, x):
|
634 |
+
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
635 |
+
|
636 |
+
|
637 |
+
@maybe_allow_in_graph
|
638 |
+
class TemporalBasicTransformerBlock(nn.Module):
|
639 |
+
r"""
|
640 |
+
A basic Transformer block for video like data.
|
641 |
+
|
642 |
+
Parameters:
|
643 |
+
dim (`int`): The number of channels in the input and output.
|
644 |
+
time_mix_inner_dim (`int`): The number of channels for temporal attention.
|
645 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
646 |
+
attention_head_dim (`int`): The number of channels in each head.
|
647 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
648 |
+
"""
|
649 |
+
|
650 |
+
def __init__(
|
651 |
+
self,
|
652 |
+
dim: int,
|
653 |
+
time_mix_inner_dim: int,
|
654 |
+
num_attention_heads: int,
|
655 |
+
attention_head_dim: int,
|
656 |
+
cross_attention_dim: Optional[int] = None,
|
657 |
+
):
|
658 |
+
super().__init__()
|
659 |
+
self.is_res = dim == time_mix_inner_dim
|
660 |
+
|
661 |
+
self.norm_in = nn.LayerNorm(dim)
|
662 |
+
|
663 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
664 |
+
# 1. Self-Attn
|
665 |
+
self.ff_in = FeedForward(
|
666 |
+
dim,
|
667 |
+
dim_out=time_mix_inner_dim,
|
668 |
+
activation_fn="geglu",
|
669 |
+
)
|
670 |
+
|
671 |
+
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
|
672 |
+
self.attn1 = Attention(
|
673 |
+
query_dim=time_mix_inner_dim,
|
674 |
+
heads=num_attention_heads,
|
675 |
+
dim_head=attention_head_dim,
|
676 |
+
cross_attention_dim=None,
|
677 |
+
)
|
678 |
+
|
679 |
+
# 2. Cross-Attn
|
680 |
+
if cross_attention_dim is not None:
|
681 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
682 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
683 |
+
# the second cross attention block.
|
684 |
+
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
|
685 |
+
self.attn2 = Attention(
|
686 |
+
query_dim=time_mix_inner_dim,
|
687 |
+
cross_attention_dim=cross_attention_dim,
|
688 |
+
heads=num_attention_heads,
|
689 |
+
dim_head=attention_head_dim,
|
690 |
+
) # is self-attn if encoder_hidden_states is none
|
691 |
+
else:
|
692 |
+
self.norm2 = None
|
693 |
+
self.attn2 = None
|
694 |
+
|
695 |
+
# 3. Feed-forward
|
696 |
+
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
|
697 |
+
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
|
698 |
+
|
699 |
+
# let chunk size default to None
|
700 |
+
self._chunk_size = None
|
701 |
+
self._chunk_dim = None
|
702 |
+
|
703 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
|
704 |
+
# Sets chunk feed-forward
|
705 |
+
self._chunk_size = chunk_size
|
706 |
+
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
|
707 |
+
self._chunk_dim = 1
|
708 |
+
|
709 |
+
def forward(
|
710 |
+
self,
|
711 |
+
hidden_states: torch.Tensor,
|
712 |
+
num_frames: int,
|
713 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
714 |
+
) -> torch.Tensor:
|
715 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
716 |
+
# 0. Self-Attention
|
717 |
+
batch_size = hidden_states.shape[0]
|
718 |
+
|
719 |
+
batch_frames, seq_length, channels = hidden_states.shape
|
720 |
+
batch_size = batch_frames // num_frames
|
721 |
+
|
722 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
|
723 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
724 |
+
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
|
725 |
+
|
726 |
+
residual = hidden_states
|
727 |
+
hidden_states = self.norm_in(hidden_states)
|
728 |
+
|
729 |
+
if self._chunk_size is not None:
|
730 |
+
hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
|
731 |
+
else:
|
732 |
+
hidden_states = self.ff_in(hidden_states)
|
733 |
+
|
734 |
+
if self.is_res:
|
735 |
+
hidden_states = hidden_states + residual
|
736 |
+
|
737 |
+
norm_hidden_states = self.norm1(hidden_states)
|
738 |
+
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
|
739 |
+
hidden_states = attn_output + hidden_states
|
740 |
+
|
741 |
+
# 3. Cross-Attention
|
742 |
+
if self.attn2 is not None:
|
743 |
+
norm_hidden_states = self.norm2(hidden_states)
|
744 |
+
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
745 |
+
hidden_states = attn_output + hidden_states
|
746 |
+
|
747 |
+
# 4. Feed-forward
|
748 |
+
norm_hidden_states = self.norm3(hidden_states)
|
749 |
+
|
750 |
+
if self._chunk_size is not None:
|
751 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
752 |
+
else:
|
753 |
+
ff_output = self.ff(norm_hidden_states)
|
754 |
+
|
755 |
+
if self.is_res:
|
756 |
+
hidden_states = ff_output + hidden_states
|
757 |
+
else:
|
758 |
+
hidden_states = ff_output
|
759 |
+
|
760 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
|
761 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3)
|
762 |
+
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
|
763 |
+
|
764 |
+
return hidden_states
|
765 |
+
|
766 |
+
|
767 |
+
class SkipFFTransformerBlock(nn.Module):
|
768 |
+
def __init__(
|
769 |
+
self,
|
770 |
+
dim: int,
|
771 |
+
num_attention_heads: int,
|
772 |
+
attention_head_dim: int,
|
773 |
+
kv_input_dim: int,
|
774 |
+
kv_input_dim_proj_use_bias: bool,
|
775 |
+
dropout=0.0,
|
776 |
+
cross_attention_dim: Optional[int] = None,
|
777 |
+
attention_bias: bool = False,
|
778 |
+
attention_out_bias: bool = True,
|
779 |
+
):
|
780 |
+
super().__init__()
|
781 |
+
if kv_input_dim != dim:
|
782 |
+
self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
|
783 |
+
else:
|
784 |
+
self.kv_mapper = None
|
785 |
+
|
786 |
+
self.norm1 = RMSNorm(dim, 1e-06)
|
787 |
+
|
788 |
+
self.attn1 = Attention(
|
789 |
+
query_dim=dim,
|
790 |
+
heads=num_attention_heads,
|
791 |
+
dim_head=attention_head_dim,
|
792 |
+
dropout=dropout,
|
793 |
+
bias=attention_bias,
|
794 |
+
cross_attention_dim=cross_attention_dim,
|
795 |
+
out_bias=attention_out_bias,
|
796 |
+
)
|
797 |
+
|
798 |
+
self.norm2 = RMSNorm(dim, 1e-06)
|
799 |
+
|
800 |
+
self.attn2 = Attention(
|
801 |
+
query_dim=dim,
|
802 |
+
cross_attention_dim=cross_attention_dim,
|
803 |
+
heads=num_attention_heads,
|
804 |
+
dim_head=attention_head_dim,
|
805 |
+
dropout=dropout,
|
806 |
+
bias=attention_bias,
|
807 |
+
out_bias=attention_out_bias,
|
808 |
+
)
|
809 |
+
|
810 |
+
def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
|
811 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
812 |
+
|
813 |
+
if self.kv_mapper is not None:
|
814 |
+
encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
|
815 |
+
|
816 |
+
norm_hidden_states = self.norm1(hidden_states)
|
817 |
+
|
818 |
+
attn_output = self.attn1(
|
819 |
+
norm_hidden_states,
|
820 |
+
encoder_hidden_states=encoder_hidden_states,
|
821 |
+
**cross_attention_kwargs,
|
822 |
+
)
|
823 |
+
|
824 |
+
hidden_states = attn_output + hidden_states
|
825 |
+
|
826 |
+
norm_hidden_states = self.norm2(hidden_states)
|
827 |
+
|
828 |
+
attn_output = self.attn2(
|
829 |
+
norm_hidden_states,
|
830 |
+
encoder_hidden_states=encoder_hidden_states,
|
831 |
+
**cross_attention_kwargs,
|
832 |
+
)
|
833 |
+
|
834 |
+
hidden_states = attn_output + hidden_states
|
835 |
+
|
836 |
+
return hidden_states
|
837 |
+
|
838 |
+
|
839 |
+
@maybe_allow_in_graph
|
840 |
+
class FreeNoiseTransformerBlock(nn.Module):
|
841 |
+
r"""
|
842 |
+
A FreeNoise Transformer block.
|
843 |
+
|
844 |
+
Parameters:
|
845 |
+
dim (`int`):
|
846 |
+
The number of channels in the input and output.
|
847 |
+
num_attention_heads (`int`):
|
848 |
+
The number of heads to use for multi-head attention.
|
849 |
+
attention_head_dim (`int`):
|
850 |
+
The number of channels in each head.
|
851 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
852 |
+
The dropout probability to use.
|
853 |
+
cross_attention_dim (`int`, *optional*):
|
854 |
+
The size of the encoder_hidden_states vector for cross attention.
|
855 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
856 |
+
Activation function to be used in feed-forward.
|
857 |
+
num_embeds_ada_norm (`int`, *optional*):
|
858 |
+
The number of diffusion steps used during training. See `Transformer2DModel`.
|
859 |
+
attention_bias (`bool`, defaults to `False`):
|
860 |
+
Configure if the attentions should contain a bias parameter.
|
861 |
+
only_cross_attention (`bool`, defaults to `False`):
|
862 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
863 |
+
double_self_attention (`bool`, defaults to `False`):
|
864 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
865 |
+
upcast_attention (`bool`, defaults to `False`):
|
866 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
867 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
868 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
869 |
+
norm_type (`str`, defaults to `"layer_norm"`):
|
870 |
+
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
871 |
+
final_dropout (`bool` defaults to `False`):
|
872 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
873 |
+
attention_type (`str`, defaults to `"default"`):
|
874 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
875 |
+
positional_embeddings (`str`, *optional*):
|
876 |
+
The type of positional embeddings to apply to.
|
877 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
878 |
+
The maximum number of positional embeddings to apply.
|
879 |
+
ff_inner_dim (`int`, *optional*):
|
880 |
+
Hidden dimension of feed-forward MLP.
|
881 |
+
ff_bias (`bool`, defaults to `True`):
|
882 |
+
Whether or not to use bias in feed-forward MLP.
|
883 |
+
attention_out_bias (`bool`, defaults to `True`):
|
884 |
+
Whether or not to use bias in attention output project layer.
|
885 |
+
context_length (`int`, defaults to `16`):
|
886 |
+
The maximum number of frames that the FreeNoise block processes at once.
|
887 |
+
context_stride (`int`, defaults to `4`):
|
888 |
+
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
|
889 |
+
weighting_scheme (`str`, defaults to `"pyramid"`):
|
890 |
+
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
|
891 |
+
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
|
892 |
+
used.
|
893 |
+
"""
|
894 |
+
|
895 |
+
def __init__(
|
896 |
+
self,
|
897 |
+
dim: int,
|
898 |
+
num_attention_heads: int,
|
899 |
+
attention_head_dim: int,
|
900 |
+
dropout: float = 0.0,
|
901 |
+
cross_attention_dim: Optional[int] = None,
|
902 |
+
activation_fn: str = "geglu",
|
903 |
+
num_embeds_ada_norm: Optional[int] = None,
|
904 |
+
attention_bias: bool = False,
|
905 |
+
only_cross_attention: bool = False,
|
906 |
+
double_self_attention: bool = False,
|
907 |
+
upcast_attention: bool = False,
|
908 |
+
norm_elementwise_affine: bool = True,
|
909 |
+
norm_type: str = "layer_norm",
|
910 |
+
norm_eps: float = 1e-5,
|
911 |
+
final_dropout: bool = False,
|
912 |
+
positional_embeddings: Optional[str] = None,
|
913 |
+
num_positional_embeddings: Optional[int] = None,
|
914 |
+
ff_inner_dim: Optional[int] = None,
|
915 |
+
ff_bias: bool = True,
|
916 |
+
attention_out_bias: bool = True,
|
917 |
+
context_length: int = 16,
|
918 |
+
context_stride: int = 4,
|
919 |
+
weighting_scheme: str = "pyramid",
|
920 |
+
):
|
921 |
+
super().__init__()
|
922 |
+
self.dim = dim
|
923 |
+
self.num_attention_heads = num_attention_heads
|
924 |
+
self.attention_head_dim = attention_head_dim
|
925 |
+
self.dropout = dropout
|
926 |
+
self.cross_attention_dim = cross_attention_dim
|
927 |
+
self.activation_fn = activation_fn
|
928 |
+
self.attention_bias = attention_bias
|
929 |
+
self.double_self_attention = double_self_attention
|
930 |
+
self.norm_elementwise_affine = norm_elementwise_affine
|
931 |
+
self.positional_embeddings = positional_embeddings
|
932 |
+
self.num_positional_embeddings = num_positional_embeddings
|
933 |
+
self.only_cross_attention = only_cross_attention
|
934 |
+
|
935 |
+
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
|
936 |
+
|
937 |
+
# We keep these boolean flags for backward-compatibility.
|
938 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
939 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
940 |
+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
|
941 |
+
self.use_layer_norm = norm_type == "layer_norm"
|
942 |
+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
|
943 |
+
|
944 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
945 |
+
raise ValueError(
|
946 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
947 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
948 |
+
)
|
949 |
+
|
950 |
+
self.norm_type = norm_type
|
951 |
+
self.num_embeds_ada_norm = num_embeds_ada_norm
|
952 |
+
|
953 |
+
if positional_embeddings and (num_positional_embeddings is None):
|
954 |
+
raise ValueError(
|
955 |
+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
|
956 |
+
)
|
957 |
+
|
958 |
+
if positional_embeddings == "sinusoidal":
|
959 |
+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
|
960 |
+
else:
|
961 |
+
self.pos_embed = None
|
962 |
+
|
963 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
964 |
+
# 1. Self-Attn
|
965 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
966 |
+
|
967 |
+
self.attn1 = Attention(
|
968 |
+
query_dim=dim,
|
969 |
+
heads=num_attention_heads,
|
970 |
+
dim_head=attention_head_dim,
|
971 |
+
dropout=dropout,
|
972 |
+
bias=attention_bias,
|
973 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
974 |
+
upcast_attention=upcast_attention,
|
975 |
+
out_bias=attention_out_bias,
|
976 |
+
)
|
977 |
+
|
978 |
+
# 2. Cross-Attn
|
979 |
+
if cross_attention_dim is not None or double_self_attention:
|
980 |
+
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
981 |
+
|
982 |
+
self.attn2 = Attention(
|
983 |
+
query_dim=dim,
|
984 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
985 |
+
heads=num_attention_heads,
|
986 |
+
dim_head=attention_head_dim,
|
987 |
+
dropout=dropout,
|
988 |
+
bias=attention_bias,
|
989 |
+
upcast_attention=upcast_attention,
|
990 |
+
out_bias=attention_out_bias,
|
991 |
+
) # is self-attn if encoder_hidden_states is none
|
992 |
+
|
993 |
+
# 3. Feed-forward
|
994 |
+
self.ff = FeedForward(
|
995 |
+
dim,
|
996 |
+
dropout=dropout,
|
997 |
+
activation_fn=activation_fn,
|
998 |
+
final_dropout=final_dropout,
|
999 |
+
inner_dim=ff_inner_dim,
|
1000 |
+
bias=ff_bias,
|
1001 |
+
)
|
1002 |
+
|
1003 |
+
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
1004 |
+
|
1005 |
+
# let chunk size default to None
|
1006 |
+
self._chunk_size = None
|
1007 |
+
self._chunk_dim = 0
|
1008 |
+
|
1009 |
+
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
|
1010 |
+
frame_indices = []
|
1011 |
+
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
|
1012 |
+
window_start = i
|
1013 |
+
window_end = min(num_frames, i + self.context_length)
|
1014 |
+
frame_indices.append((window_start, window_end))
|
1015 |
+
return frame_indices
|
1016 |
+
|
1017 |
+
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
|
1018 |
+
if weighting_scheme == "flat":
|
1019 |
+
weights = [1.0] * num_frames
|
1020 |
+
|
1021 |
+
elif weighting_scheme == "pyramid":
|
1022 |
+
if num_frames % 2 == 0:
|
1023 |
+
# num_frames = 4 => [1, 2, 2, 1]
|
1024 |
+
mid = num_frames // 2
|
1025 |
+
weights = list(range(1, mid + 1))
|
1026 |
+
weights = weights + weights[::-1]
|
1027 |
+
else:
|
1028 |
+
# num_frames = 5 => [1, 2, 3, 2, 1]
|
1029 |
+
mid = (num_frames + 1) // 2
|
1030 |
+
weights = list(range(1, mid))
|
1031 |
+
weights = weights + [mid] + weights[::-1]
|
1032 |
+
|
1033 |
+
elif weighting_scheme == "delayed_reverse_sawtooth":
|
1034 |
+
if num_frames % 2 == 0:
|
1035 |
+
# num_frames = 4 => [0.01, 2, 2, 1]
|
1036 |
+
mid = num_frames // 2
|
1037 |
+
weights = [0.01] * (mid - 1) + [mid]
|
1038 |
+
weights = weights + list(range(mid, 0, -1))
|
1039 |
+
else:
|
1040 |
+
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
|
1041 |
+
mid = (num_frames + 1) // 2
|
1042 |
+
weights = [0.01] * mid
|
1043 |
+
weights = weights + list(range(mid, 0, -1))
|
1044 |
+
else:
|
1045 |
+
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
|
1046 |
+
|
1047 |
+
return weights
|
1048 |
+
|
1049 |
+
def set_free_noise_properties(
|
1050 |
+
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
|
1051 |
+
) -> None:
|
1052 |
+
self.context_length = context_length
|
1053 |
+
self.context_stride = context_stride
|
1054 |
+
self.weighting_scheme = weighting_scheme
|
1055 |
+
|
1056 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
|
1057 |
+
# Sets chunk feed-forward
|
1058 |
+
self._chunk_size = chunk_size
|
1059 |
+
self._chunk_dim = dim
|
1060 |
+
|
1061 |
+
def forward(
|
1062 |
+
self,
|
1063 |
+
hidden_states: torch.Tensor,
|
1064 |
+
attention_mask: Optional[torch.Tensor] = None,
|
1065 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1066 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
1067 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
1068 |
+
*args,
|
1069 |
+
**kwargs,
|
1070 |
+
) -> torch.Tensor:
|
1071 |
+
if cross_attention_kwargs is not None:
|
1072 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
1073 |
+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1074 |
+
|
1075 |
+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
1076 |
+
|
1077 |
+
# hidden_states: [B x H x W, F, C]
|
1078 |
+
device = hidden_states.device
|
1079 |
+
dtype = hidden_states.dtype
|
1080 |
+
|
1081 |
+
num_frames = hidden_states.size(1)
|
1082 |
+
frame_indices = self._get_frame_indices(num_frames)
|
1083 |
+
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
|
1084 |
+
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
|
1085 |
+
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
|
1086 |
+
|
1087 |
+
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
|
1088 |
+
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
|
1089 |
+
# [(0, 16), (4, 20), (8, 24), (10, 26)]
|
1090 |
+
if not is_last_frame_batch_complete:
|
1091 |
+
if num_frames < self.context_length:
|
1092 |
+
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
|
1093 |
+
last_frame_batch_length = num_frames - frame_indices[-1][1]
|
1094 |
+
frame_indices.append((num_frames - self.context_length, num_frames))
|
1095 |
+
|
1096 |
+
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
|
1097 |
+
accumulated_values = torch.zeros_like(hidden_states)
|
1098 |
+
|
1099 |
+
for i, (frame_start, frame_end) in enumerate(frame_indices):
|
1100 |
+
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
|
1101 |
+
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
|
1102 |
+
# essentially a non-multiple of `context_length`.
|
1103 |
+
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
|
1104 |
+
weights *= frame_weights
|
1105 |
+
|
1106 |
+
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
|
1107 |
+
|
1108 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
1109 |
+
# 1. Self-Attention
|
1110 |
+
norm_hidden_states = self.norm1(hidden_states_chunk)
|
1111 |
+
|
1112 |
+
if self.pos_embed is not None:
|
1113 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1114 |
+
|
1115 |
+
attn_output = self.attn1(
|
1116 |
+
norm_hidden_states,
|
1117 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
1118 |
+
attention_mask=attention_mask,
|
1119 |
+
**cross_attention_kwargs,
|
1120 |
+
)
|
1121 |
+
|
1122 |
+
hidden_states_chunk = attn_output + hidden_states_chunk
|
1123 |
+
if hidden_states_chunk.ndim == 4:
|
1124 |
+
hidden_states_chunk = hidden_states_chunk.squeeze(1)
|
1125 |
+
|
1126 |
+
# 2. Cross-Attention
|
1127 |
+
if self.attn2 is not None:
|
1128 |
+
norm_hidden_states = self.norm2(hidden_states_chunk)
|
1129 |
+
|
1130 |
+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
|
1131 |
+
norm_hidden_states = self.pos_embed(norm_hidden_states)
|
1132 |
+
|
1133 |
+
attn_output = self.attn2(
|
1134 |
+
norm_hidden_states,
|
1135 |
+
encoder_hidden_states=encoder_hidden_states,
|
1136 |
+
attention_mask=encoder_attention_mask,
|
1137 |
+
**cross_attention_kwargs,
|
1138 |
+
)
|
1139 |
+
hidden_states_chunk = attn_output + hidden_states_chunk
|
1140 |
+
|
1141 |
+
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
|
1142 |
+
accumulated_values[:, -last_frame_batch_length:] += (
|
1143 |
+
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
|
1144 |
+
)
|
1145 |
+
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
|
1146 |
+
else:
|
1147 |
+
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
|
1148 |
+
num_times_accumulated[:, frame_start:frame_end] += weights
|
1149 |
+
|
1150 |
+
# TODO(aryan): Maybe this could be done in a better way.
|
1151 |
+
#
|
1152 |
+
# Previously, this was:
|
1153 |
+
# hidden_states = torch.where(
|
1154 |
+
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
|
1155 |
+
# )
|
1156 |
+
#
|
1157 |
+
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
|
1158 |
+
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
|
1159 |
+
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
|
1160 |
+
# looked into this deeply because other memory optimizations led to more pronounced reductions.
|
1161 |
+
hidden_states = torch.cat(
|
1162 |
+
[
|
1163 |
+
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
|
1164 |
+
for accumulated_split, num_times_split in zip(
|
1165 |
+
accumulated_values.split(self.context_length, dim=1),
|
1166 |
+
num_times_accumulated.split(self.context_length, dim=1),
|
1167 |
+
)
|
1168 |
+
],
|
1169 |
+
dim=1,
|
1170 |
+
).to(dtype)
|
1171 |
+
|
1172 |
+
# 3. Feed-forward
|
1173 |
+
norm_hidden_states = self.norm3(hidden_states)
|
1174 |
+
|
1175 |
+
if self._chunk_size is not None:
|
1176 |
+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
1177 |
+
else:
|
1178 |
+
ff_output = self.ff(norm_hidden_states)
|
1179 |
+
|
1180 |
+
hidden_states = ff_output + hidden_states
|
1181 |
+
if hidden_states.ndim == 4:
|
1182 |
+
hidden_states = hidden_states.squeeze(1)
|
1183 |
+
|
1184 |
+
return hidden_states
|
1185 |
+
|
1186 |
+
|
1187 |
+
class FeedForward(nn.Module):
|
1188 |
+
r"""
|
1189 |
+
A feed-forward layer.
|
1190 |
+
|
1191 |
+
Parameters:
|
1192 |
+
dim (`int`): The number of channels in the input.
|
1193 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
1194 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
1195 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
1196 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
1197 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
1198 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
1199 |
+
"""
|
1200 |
+
|
1201 |
+
def __init__(
|
1202 |
+
self,
|
1203 |
+
dim: int,
|
1204 |
+
dim_out: Optional[int] = None,
|
1205 |
+
mult: int = 4,
|
1206 |
+
dropout: float = 0.0,
|
1207 |
+
activation_fn: str = "geglu",
|
1208 |
+
final_dropout: bool = False,
|
1209 |
+
inner_dim=None,
|
1210 |
+
bias: bool = True,
|
1211 |
+
):
|
1212 |
+
super().__init__()
|
1213 |
+
if inner_dim is None:
|
1214 |
+
inner_dim = int(dim * mult)
|
1215 |
+
dim_out = dim_out if dim_out is not None else dim
|
1216 |
+
|
1217 |
+
if activation_fn == "gelu":
|
1218 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
1219 |
+
if activation_fn == "gelu-approximate":
|
1220 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
1221 |
+
elif activation_fn == "geglu":
|
1222 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
1223 |
+
elif activation_fn == "geglu-approximate":
|
1224 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
1225 |
+
elif activation_fn == "swiglu":
|
1226 |
+
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
1227 |
+
|
1228 |
+
self.net = nn.ModuleList([])
|
1229 |
+
# project in
|
1230 |
+
self.net.append(act_fn)
|
1231 |
+
# project dropout
|
1232 |
+
self.net.append(nn.Dropout(dropout))
|
1233 |
+
# project out
|
1234 |
+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
1235 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
1236 |
+
if final_dropout:
|
1237 |
+
self.net.append(nn.Dropout(dropout))
|
1238 |
+
|
1239 |
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
1240 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
1241 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
1242 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
1243 |
+
for module in self.net:
|
1244 |
+
hidden_states = module(hidden_states)
|
1245 |
+
return hidden_states
|
models/resampler.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
|
8 |
+
|
9 |
+
def get_timestep_embedding(
|
10 |
+
timesteps: torch.Tensor,
|
11 |
+
embedding_dim: int,
|
12 |
+
flip_sin_to_cos: bool = False,
|
13 |
+
downscale_freq_shift: float = 1,
|
14 |
+
scale: float = 1,
|
15 |
+
max_period: int = 10000,
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
19 |
+
|
20 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
21 |
+
These may be fractional.
|
22 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
23 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
24 |
+
"""
|
25 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
26 |
+
|
27 |
+
half_dim = embedding_dim // 2
|
28 |
+
exponent = -math.log(max_period) * torch.arange(
|
29 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
30 |
+
)
|
31 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
32 |
+
|
33 |
+
emb = torch.exp(exponent)
|
34 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
35 |
+
|
36 |
+
# scale embeddings
|
37 |
+
emb = scale * emb
|
38 |
+
|
39 |
+
# concat sine and cosine embeddings
|
40 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
41 |
+
|
42 |
+
# flip sine and cosine embeddings
|
43 |
+
if flip_sin_to_cos:
|
44 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
45 |
+
|
46 |
+
# zero pad
|
47 |
+
if embedding_dim % 2 == 1:
|
48 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
49 |
+
return emb
|
50 |
+
|
51 |
+
|
52 |
+
# FFN
|
53 |
+
def FeedForward(dim, mult=4):
|
54 |
+
inner_dim = int(dim * mult)
|
55 |
+
return nn.Sequential(
|
56 |
+
nn.LayerNorm(dim),
|
57 |
+
nn.Linear(dim, inner_dim, bias=False),
|
58 |
+
nn.GELU(),
|
59 |
+
nn.Linear(inner_dim, dim, bias=False),
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
def reshape_tensor(x, heads):
|
64 |
+
bs, length, width = x.shape
|
65 |
+
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
66 |
+
x = x.view(bs, length, heads, -1)
|
67 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
68 |
+
x = x.transpose(1, 2)
|
69 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
70 |
+
x = x.reshape(bs, heads, length, -1)
|
71 |
+
return x
|
72 |
+
|
73 |
+
|
74 |
+
class PerceiverAttention(nn.Module):
|
75 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
76 |
+
super().__init__()
|
77 |
+
self.scale = dim_head**-0.5
|
78 |
+
self.dim_head = dim_head
|
79 |
+
self.heads = heads
|
80 |
+
inner_dim = dim_head * heads
|
81 |
+
|
82 |
+
self.norm1 = nn.LayerNorm(dim)
|
83 |
+
self.norm2 = nn.LayerNorm(dim)
|
84 |
+
|
85 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
86 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
87 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, x, latents, shift=None, scale=None):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
x (torch.Tensor): image features
|
94 |
+
shape (b, n1, D)
|
95 |
+
latent (torch.Tensor): latent features
|
96 |
+
shape (b, n2, D)
|
97 |
+
"""
|
98 |
+
x = self.norm1(x)
|
99 |
+
latents = self.norm2(latents)
|
100 |
+
|
101 |
+
if shift is not None and scale is not None:
|
102 |
+
latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
103 |
+
|
104 |
+
b, l, _ = latents.shape
|
105 |
+
|
106 |
+
q = self.to_q(latents)
|
107 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
108 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
109 |
+
|
110 |
+
q = reshape_tensor(q, self.heads)
|
111 |
+
k = reshape_tensor(k, self.heads)
|
112 |
+
v = reshape_tensor(v, self.heads)
|
113 |
+
|
114 |
+
# attention
|
115 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
116 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
117 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
118 |
+
out = weight @ v
|
119 |
+
|
120 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
121 |
+
|
122 |
+
return self.to_out(out)
|
123 |
+
|
124 |
+
|
125 |
+
class Resampler(nn.Module):
|
126 |
+
def __init__(
|
127 |
+
self,
|
128 |
+
dim=1024,
|
129 |
+
depth=8,
|
130 |
+
dim_head=64,
|
131 |
+
heads=16,
|
132 |
+
num_queries=8,
|
133 |
+
embedding_dim=768,
|
134 |
+
output_dim=1024,
|
135 |
+
ff_mult=4,
|
136 |
+
*args,
|
137 |
+
**kwargs,
|
138 |
+
):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
142 |
+
|
143 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
144 |
+
|
145 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
146 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
147 |
+
|
148 |
+
self.layers = nn.ModuleList([])
|
149 |
+
for _ in range(depth):
|
150 |
+
self.layers.append(
|
151 |
+
nn.ModuleList(
|
152 |
+
[
|
153 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
154 |
+
FeedForward(dim=dim, mult=ff_mult),
|
155 |
+
]
|
156 |
+
)
|
157 |
+
)
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
|
161 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
162 |
+
|
163 |
+
x = self.proj_in(x)
|
164 |
+
|
165 |
+
for attn, ff in self.layers:
|
166 |
+
latents = attn(x, latents) + latents
|
167 |
+
latents = ff(latents) + latents
|
168 |
+
|
169 |
+
latents = self.proj_out(latents)
|
170 |
+
return self.norm_out(latents)
|
171 |
+
|
172 |
+
|
173 |
+
class TimeResampler(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
dim=1024,
|
177 |
+
depth=8,
|
178 |
+
dim_head=64,
|
179 |
+
heads=16,
|
180 |
+
num_queries=8,
|
181 |
+
embedding_dim=768,
|
182 |
+
output_dim=1024,
|
183 |
+
ff_mult=4,
|
184 |
+
timestep_in_dim=320,
|
185 |
+
timestep_flip_sin_to_cos=True,
|
186 |
+
timestep_freq_shift=0,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
|
190 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
191 |
+
|
192 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
193 |
+
|
194 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
195 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
196 |
+
|
197 |
+
self.layers = nn.ModuleList([])
|
198 |
+
for _ in range(depth):
|
199 |
+
self.layers.append(
|
200 |
+
nn.ModuleList(
|
201 |
+
[
|
202 |
+
# msa
|
203 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
204 |
+
# ff
|
205 |
+
FeedForward(dim=dim, mult=ff_mult),
|
206 |
+
# adaLN
|
207 |
+
nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
|
208 |
+
]
|
209 |
+
)
|
210 |
+
)
|
211 |
+
|
212 |
+
# time
|
213 |
+
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
214 |
+
self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
|
215 |
+
|
216 |
+
# adaLN
|
217 |
+
# self.adaLN_modulation = nn.Sequential(
|
218 |
+
# nn.SiLU(),
|
219 |
+
# nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
|
220 |
+
# )
|
221 |
+
|
222 |
+
|
223 |
+
def forward(self, x, timestep, need_temb=False):
|
224 |
+
timestep_emb = self.embedding_time(x, timestep) # bs, dim
|
225 |
+
|
226 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
227 |
+
|
228 |
+
x = self.proj_in(x)
|
229 |
+
x = x + timestep_emb[:, None]
|
230 |
+
|
231 |
+
for attn, ff, adaLN_modulation in self.layers:
|
232 |
+
shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
|
233 |
+
latents = attn(x, latents, shift_msa, scale_msa) + latents
|
234 |
+
|
235 |
+
res = latents
|
236 |
+
for idx_ff in range(len(ff)):
|
237 |
+
layer_ff = ff[idx_ff]
|
238 |
+
latents = layer_ff(latents)
|
239 |
+
if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
|
240 |
+
latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
241 |
+
latents = latents + res
|
242 |
+
|
243 |
+
# latents = ff(latents) + latents
|
244 |
+
|
245 |
+
latents = self.proj_out(latents)
|
246 |
+
latents = self.norm_out(latents)
|
247 |
+
|
248 |
+
if need_temb:
|
249 |
+
return latents, timestep_emb
|
250 |
+
else:
|
251 |
+
return latents
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
def embedding_time(self, sample, timestep):
|
256 |
+
|
257 |
+
# 1. time
|
258 |
+
timesteps = timestep
|
259 |
+
if not torch.is_tensor(timesteps):
|
260 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
261 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
262 |
+
is_mps = sample.device.type == "mps"
|
263 |
+
if isinstance(timestep, float):
|
264 |
+
dtype = torch.float32 if is_mps else torch.float64
|
265 |
+
else:
|
266 |
+
dtype = torch.int32 if is_mps else torch.int64
|
267 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
268 |
+
elif len(timesteps.shape) == 0:
|
269 |
+
timesteps = timesteps[None].to(sample.device)
|
270 |
+
|
271 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
272 |
+
timesteps = timesteps.expand(sample.shape[0])
|
273 |
+
|
274 |
+
t_emb = self.time_proj(timesteps)
|
275 |
+
|
276 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
277 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
278 |
+
# there might be better ways to encapsulate this.
|
279 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
280 |
+
|
281 |
+
emb = self.time_embedding(t_emb, None)
|
282 |
+
return emb
|
283 |
+
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
if __name__ == '__main__':
|
289 |
+
model = TimeResampler(
|
290 |
+
dim=1280,
|
291 |
+
depth=4,
|
292 |
+
dim_head=64,
|
293 |
+
heads=20,
|
294 |
+
num_queries=16,
|
295 |
+
embedding_dim=512,
|
296 |
+
output_dim=2048,
|
297 |
+
ff_mult=4,
|
298 |
+
timestep_in_dim=320,
|
299 |
+
timestep_flip_sin_to_cos=True,
|
300 |
+
timestep_freq_shift=0,
|
301 |
+
in_channel_extra_emb=2048,
|
302 |
+
)
|
303 |
+
|
304 |
+
|
models/transformer_sd3.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23 |
+
from .attention import JointTransformerBlock
|
24 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
|
25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
26 |
+
from diffusers.models.normalization import AdaLayerNormContinuous
|
27 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
28 |
+
from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
29 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
30 |
+
|
31 |
+
|
32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
33 |
+
|
34 |
+
|
35 |
+
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
36 |
+
"""
|
37 |
+
The Transformer model introduced in Stable Diffusion 3.
|
38 |
+
|
39 |
+
Reference: https://arxiv.org/abs/2403.03206
|
40 |
+
|
41 |
+
Parameters:
|
42 |
+
sample_size (`int`): The width of the latent images. This is fixed during training since
|
43 |
+
it is used to learn a number of position embeddings.
|
44 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
45 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
46 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
|
47 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
48 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
49 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
50 |
+
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
51 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
52 |
+
out_channels (`int`, defaults to 16): Number of output channels.
|
53 |
+
|
54 |
+
"""
|
55 |
+
|
56 |
+
_supports_gradient_checkpointing = True
|
57 |
+
|
58 |
+
@register_to_config
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
sample_size: int = 128,
|
62 |
+
patch_size: int = 2,
|
63 |
+
in_channels: int = 16,
|
64 |
+
num_layers: int = 18,
|
65 |
+
attention_head_dim: int = 64,
|
66 |
+
num_attention_heads: int = 18,
|
67 |
+
joint_attention_dim: int = 4096,
|
68 |
+
caption_projection_dim: int = 1152,
|
69 |
+
pooled_projection_dim: int = 2048,
|
70 |
+
out_channels: int = 16,
|
71 |
+
pos_embed_max_size: int = 96,
|
72 |
+
dual_attention_layers: Tuple[
|
73 |
+
int, ...
|
74 |
+
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
75 |
+
qk_norm: Optional[str] = None,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
default_out_channels = in_channels
|
79 |
+
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
80 |
+
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
81 |
+
|
82 |
+
self.pos_embed = PatchEmbed(
|
83 |
+
height=self.config.sample_size,
|
84 |
+
width=self.config.sample_size,
|
85 |
+
patch_size=self.config.patch_size,
|
86 |
+
in_channels=self.config.in_channels,
|
87 |
+
embed_dim=self.inner_dim,
|
88 |
+
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
89 |
+
)
|
90 |
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
91 |
+
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
92 |
+
)
|
93 |
+
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
|
94 |
+
|
95 |
+
# `attention_head_dim` is doubled to account for the mixing.
|
96 |
+
# It needs to crafted when we get the actual checkpoints.
|
97 |
+
self.transformer_blocks = nn.ModuleList(
|
98 |
+
[
|
99 |
+
JointTransformerBlock(
|
100 |
+
dim=self.inner_dim,
|
101 |
+
num_attention_heads=self.config.num_attention_heads,
|
102 |
+
attention_head_dim=self.config.attention_head_dim,
|
103 |
+
context_pre_only=i == num_layers - 1,
|
104 |
+
qk_norm=qk_norm,
|
105 |
+
use_dual_attention=True if i in dual_attention_layers else False,
|
106 |
+
)
|
107 |
+
for i in range(self.config.num_layers)
|
108 |
+
]
|
109 |
+
)
|
110 |
+
|
111 |
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
112 |
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
113 |
+
|
114 |
+
self.gradient_checkpointing = False
|
115 |
+
|
116 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
117 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
118 |
+
"""
|
119 |
+
Sets the attention processor to use [feed forward
|
120 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
121 |
+
|
122 |
+
Parameters:
|
123 |
+
chunk_size (`int`, *optional*):
|
124 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
125 |
+
over each tensor of dim=`dim`.
|
126 |
+
dim (`int`, *optional*, defaults to `0`):
|
127 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
128 |
+
or dim=1 (sequence length).
|
129 |
+
"""
|
130 |
+
if dim not in [0, 1]:
|
131 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
132 |
+
|
133 |
+
# By default chunk size is 1
|
134 |
+
chunk_size = chunk_size or 1
|
135 |
+
|
136 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
137 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
138 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
139 |
+
|
140 |
+
for child in module.children():
|
141 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
142 |
+
|
143 |
+
for module in self.children():
|
144 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
145 |
+
|
146 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
|
147 |
+
def disable_forward_chunking(self):
|
148 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
149 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
150 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
151 |
+
|
152 |
+
for child in module.children():
|
153 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
154 |
+
|
155 |
+
for module in self.children():
|
156 |
+
fn_recursive_feed_forward(module, None, 0)
|
157 |
+
|
158 |
+
@property
|
159 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
160 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
161 |
+
r"""
|
162 |
+
Returns:
|
163 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
164 |
+
indexed by its weight name.
|
165 |
+
"""
|
166 |
+
# set recursively
|
167 |
+
processors = {}
|
168 |
+
|
169 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
170 |
+
if hasattr(module, "get_processor"):
|
171 |
+
processors[f"{name}.processor"] = module.get_processor()
|
172 |
+
|
173 |
+
for sub_name, child in module.named_children():
|
174 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
175 |
+
|
176 |
+
return processors
|
177 |
+
|
178 |
+
for name, module in self.named_children():
|
179 |
+
fn_recursive_add_processors(name, module, processors)
|
180 |
+
|
181 |
+
return processors
|
182 |
+
|
183 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
184 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
185 |
+
r"""
|
186 |
+
Sets the attention processor to use to compute attention.
|
187 |
+
|
188 |
+
Parameters:
|
189 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
190 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
191 |
+
for **all** `Attention` layers.
|
192 |
+
|
193 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
194 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
195 |
+
|
196 |
+
"""
|
197 |
+
count = len(self.attn_processors.keys())
|
198 |
+
|
199 |
+
if isinstance(processor, dict) and len(processor) != count:
|
200 |
+
raise ValueError(
|
201 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
202 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
203 |
+
)
|
204 |
+
|
205 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
206 |
+
if hasattr(module, "set_processor"):
|
207 |
+
if not isinstance(processor, dict):
|
208 |
+
module.set_processor(processor)
|
209 |
+
else:
|
210 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
211 |
+
|
212 |
+
for sub_name, child in module.named_children():
|
213 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
214 |
+
|
215 |
+
for name, module in self.named_children():
|
216 |
+
fn_recursive_attn_processor(name, module, processor)
|
217 |
+
|
218 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
|
219 |
+
def fuse_qkv_projections(self):
|
220 |
+
"""
|
221 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
222 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
223 |
+
|
224 |
+
<Tip warning={true}>
|
225 |
+
|
226 |
+
This API is 🧪 experimental.
|
227 |
+
|
228 |
+
</Tip>
|
229 |
+
"""
|
230 |
+
self.original_attn_processors = None
|
231 |
+
|
232 |
+
for _, attn_processor in self.attn_processors.items():
|
233 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
234 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
235 |
+
|
236 |
+
self.original_attn_processors = self.attn_processors
|
237 |
+
|
238 |
+
for module in self.modules():
|
239 |
+
if isinstance(module, Attention):
|
240 |
+
module.fuse_projections(fuse=True)
|
241 |
+
|
242 |
+
self.set_attn_processor(FusedJointAttnProcessor2_0())
|
243 |
+
|
244 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
245 |
+
def unfuse_qkv_projections(self):
|
246 |
+
"""Disables the fused QKV projection if enabled.
|
247 |
+
|
248 |
+
<Tip warning={true}>
|
249 |
+
|
250 |
+
This API is 🧪 experimental.
|
251 |
+
|
252 |
+
</Tip>
|
253 |
+
|
254 |
+
"""
|
255 |
+
if self.original_attn_processors is not None:
|
256 |
+
self.set_attn_processor(self.original_attn_processors)
|
257 |
+
|
258 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
259 |
+
if hasattr(module, "gradient_checkpointing"):
|
260 |
+
module.gradient_checkpointing = value
|
261 |
+
|
262 |
+
def forward(
|
263 |
+
self,
|
264 |
+
hidden_states: torch.FloatTensor,
|
265 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
266 |
+
pooled_projections: torch.FloatTensor = None,
|
267 |
+
timestep: torch.LongTensor = None,
|
268 |
+
block_controlnet_hidden_states: List = None,
|
269 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
270 |
+
return_dict: bool = True,
|
271 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
272 |
+
"""
|
273 |
+
The [`SD3Transformer2DModel`] forward method.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
277 |
+
Input `hidden_states`.
|
278 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
279 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
280 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
281 |
+
from the embeddings of input conditions.
|
282 |
+
timestep ( `torch.LongTensor`):
|
283 |
+
Used to indicate denoising step.
|
284 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
285 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
286 |
+
joint_attention_kwargs (`dict`, *optional*):
|
287 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
288 |
+
`self.processor` in
|
289 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
290 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
291 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
292 |
+
tuple.
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
296 |
+
`tuple` where the first element is the sample tensor.
|
297 |
+
"""
|
298 |
+
if joint_attention_kwargs is not None:
|
299 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
300 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
301 |
+
else:
|
302 |
+
lora_scale = 1.0
|
303 |
+
|
304 |
+
if USE_PEFT_BACKEND:
|
305 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
306 |
+
scale_lora_layers(self, lora_scale)
|
307 |
+
else:
|
308 |
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
309 |
+
logger.warning(
|
310 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
311 |
+
)
|
312 |
+
|
313 |
+
height, width = hidden_states.shape[-2:]
|
314 |
+
|
315 |
+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
316 |
+
temb = self.time_text_embed(timestep, pooled_projections)
|
317 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
318 |
+
|
319 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
320 |
+
if self.training and self.gradient_checkpointing:
|
321 |
+
|
322 |
+
def create_custom_forward(module, return_dict=None):
|
323 |
+
def custom_forward(*inputs):
|
324 |
+
if return_dict is not None:
|
325 |
+
return module(*inputs, return_dict=return_dict)
|
326 |
+
else:
|
327 |
+
return module(*inputs)
|
328 |
+
|
329 |
+
return custom_forward
|
330 |
+
|
331 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
332 |
+
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
333 |
+
create_custom_forward(block),
|
334 |
+
hidden_states,
|
335 |
+
encoder_hidden_states,
|
336 |
+
temb,
|
337 |
+
joint_attention_kwargs,
|
338 |
+
**ckpt_kwargs,
|
339 |
+
)
|
340 |
+
|
341 |
+
else:
|
342 |
+
encoder_hidden_states, hidden_states = block(
|
343 |
+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
|
344 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
345 |
+
)
|
346 |
+
|
347 |
+
# controlnet residual
|
348 |
+
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
349 |
+
interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
|
350 |
+
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
|
351 |
+
|
352 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
353 |
+
hidden_states = self.proj_out(hidden_states)
|
354 |
+
|
355 |
+
# unpatchify
|
356 |
+
patch_size = self.config.patch_size
|
357 |
+
height = height // patch_size
|
358 |
+
width = width // patch_size
|
359 |
+
|
360 |
+
hidden_states = hidden_states.reshape(
|
361 |
+
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
362 |
+
)
|
363 |
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
364 |
+
output = hidden_states.reshape(
|
365 |
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
366 |
+
)
|
367 |
+
|
368 |
+
if USE_PEFT_BACKEND:
|
369 |
+
# remove `lora_scale` from each PEFT layer
|
370 |
+
unscale_lora_layers(self, lora_scale)
|
371 |
+
|
372 |
+
if not return_dict:
|
373 |
+
return (output,)
|
374 |
+
|
375 |
+
return Transformer2DModelOutput(sample=output)
|
pipeline_stable_diffusion_3_ipa.py
ADDED
@@ -0,0 +1,1235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from transformers import (
|
22 |
+
CLIPTextModelWithProjection,
|
23 |
+
CLIPTokenizer,
|
24 |
+
T5EncoderModel,
|
25 |
+
T5TokenizerFast,
|
26 |
+
)
|
27 |
+
|
28 |
+
from diffusers.image_processor import VaeImageProcessor
|
29 |
+
from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
30 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
31 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
32 |
+
from diffusers.utils import (
|
33 |
+
USE_PEFT_BACKEND,
|
34 |
+
is_torch_xla_available,
|
35 |
+
logging,
|
36 |
+
replace_example_docstring,
|
37 |
+
scale_lora_layers,
|
38 |
+
unscale_lora_layers,
|
39 |
+
)
|
40 |
+
from diffusers.utils.torch_utils import randn_tensor
|
41 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
42 |
+
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
43 |
+
|
44 |
+
from models.resampler import TimeResampler
|
45 |
+
from models.transformer_sd3 import SD3Transformer2DModel
|
46 |
+
from diffusers.models.normalization import RMSNorm
|
47 |
+
from einops import rearrange
|
48 |
+
|
49 |
+
|
50 |
+
if is_torch_xla_available():
|
51 |
+
import torch_xla.core.xla_model as xm
|
52 |
+
|
53 |
+
XLA_AVAILABLE = True
|
54 |
+
else:
|
55 |
+
XLA_AVAILABLE = False
|
56 |
+
|
57 |
+
|
58 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
59 |
+
|
60 |
+
EXAMPLE_DOC_STRING = """
|
61 |
+
Examples:
|
62 |
+
```py
|
63 |
+
>>> import torch
|
64 |
+
>>> from diffusers import StableDiffusion3Pipeline
|
65 |
+
|
66 |
+
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
67 |
+
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
68 |
+
... )
|
69 |
+
>>> pipe.to("cuda")
|
70 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
71 |
+
>>> image = pipe(prompt).images[0]
|
72 |
+
>>> image.save("sd3.png")
|
73 |
+
```
|
74 |
+
"""
|
75 |
+
|
76 |
+
|
77 |
+
class AdaLayerNorm(nn.Module):
|
78 |
+
"""
|
79 |
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
embedding_dim (`int`): The size of each embedding vector.
|
83 |
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.silu = nn.SiLU()
|
90 |
+
num_params_dict = dict(
|
91 |
+
zero=6,
|
92 |
+
normal=2,
|
93 |
+
)
|
94 |
+
num_params = num_params_dict[mode]
|
95 |
+
self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True)
|
96 |
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
97 |
+
self.mode = mode
|
98 |
+
|
99 |
+
def forward(
|
100 |
+
self,
|
101 |
+
x,
|
102 |
+
hidden_dtype = None,
|
103 |
+
emb = None,
|
104 |
+
):
|
105 |
+
emb = self.linear(self.silu(emb))
|
106 |
+
if self.mode == 'normal':
|
107 |
+
shift_msa, scale_msa = emb.chunk(2, dim=1)
|
108 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
109 |
+
return x
|
110 |
+
|
111 |
+
elif self.mode == 'zero':
|
112 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
113 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
114 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
115 |
+
|
116 |
+
|
117 |
+
class JointIPAttnProcessor(torch.nn.Module):
|
118 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
119 |
+
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
hidden_size=None,
|
123 |
+
cross_attention_dim=None,
|
124 |
+
ip_hidden_states_dim=None,
|
125 |
+
ip_encoder_hidden_states_dim=None,
|
126 |
+
head_dim=None,
|
127 |
+
timesteps_emb_dim=1280,
|
128 |
+
):
|
129 |
+
super().__init__()
|
130 |
+
|
131 |
+
self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim)
|
132 |
+
self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
133 |
+
self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
|
134 |
+
self.norm_q = RMSNorm(head_dim, 1e-6)
|
135 |
+
self.norm_k = RMSNorm(head_dim, 1e-6)
|
136 |
+
self.norm_ip_k = RMSNorm(head_dim, 1e-6)
|
137 |
+
|
138 |
+
|
139 |
+
def __call__(
|
140 |
+
self,
|
141 |
+
attn,
|
142 |
+
hidden_states: torch.FloatTensor,
|
143 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
144 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
145 |
+
emb_dict=None,
|
146 |
+
*args,
|
147 |
+
**kwargs,
|
148 |
+
) -> torch.FloatTensor:
|
149 |
+
residual = hidden_states
|
150 |
+
|
151 |
+
batch_size = hidden_states.shape[0]
|
152 |
+
|
153 |
+
# `sample` projections.
|
154 |
+
query = attn.to_q(hidden_states)
|
155 |
+
key = attn.to_k(hidden_states)
|
156 |
+
value = attn.to_v(hidden_states)
|
157 |
+
img_query = query
|
158 |
+
img_key = key
|
159 |
+
img_value = value
|
160 |
+
|
161 |
+
inner_dim = key.shape[-1]
|
162 |
+
head_dim = inner_dim // attn.heads
|
163 |
+
|
164 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
165 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
166 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
167 |
+
|
168 |
+
if attn.norm_q is not None:
|
169 |
+
query = attn.norm_q(query)
|
170 |
+
if attn.norm_k is not None:
|
171 |
+
key = attn.norm_k(key)
|
172 |
+
|
173 |
+
# `context` projections.
|
174 |
+
if encoder_hidden_states is not None:
|
175 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
176 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
177 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
178 |
+
|
179 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
180 |
+
batch_size, -1, attn.heads, head_dim
|
181 |
+
).transpose(1, 2)
|
182 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
183 |
+
batch_size, -1, attn.heads, head_dim
|
184 |
+
).transpose(1, 2)
|
185 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
186 |
+
batch_size, -1, attn.heads, head_dim
|
187 |
+
).transpose(1, 2)
|
188 |
+
|
189 |
+
if attn.norm_added_q is not None:
|
190 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
191 |
+
if attn.norm_added_k is not None:
|
192 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
193 |
+
|
194 |
+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
|
195 |
+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
|
196 |
+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
|
197 |
+
|
198 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
199 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
200 |
+
hidden_states = hidden_states.to(query.dtype)
|
201 |
+
|
202 |
+
if encoder_hidden_states is not None:
|
203 |
+
# Split the attention outputs.
|
204 |
+
hidden_states, encoder_hidden_states = (
|
205 |
+
hidden_states[:, : residual.shape[1]],
|
206 |
+
hidden_states[:, residual.shape[1] :],
|
207 |
+
)
|
208 |
+
if not attn.context_pre_only:
|
209 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
210 |
+
|
211 |
+
|
212 |
+
# IPadapter
|
213 |
+
ip_hidden_states = emb_dict.get('ip_hidden_states', None)
|
214 |
+
ip_hidden_states = self.get_ip_hidden_states(
|
215 |
+
attn,
|
216 |
+
img_query,
|
217 |
+
ip_hidden_states,
|
218 |
+
img_key,
|
219 |
+
img_value,
|
220 |
+
None,
|
221 |
+
None,
|
222 |
+
emb_dict['temb'],
|
223 |
+
)
|
224 |
+
if ip_hidden_states is not None:
|
225 |
+
hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0)
|
226 |
+
|
227 |
+
|
228 |
+
# linear proj
|
229 |
+
hidden_states = attn.to_out[0](hidden_states)
|
230 |
+
# dropout
|
231 |
+
hidden_states = attn.to_out[1](hidden_states)
|
232 |
+
|
233 |
+
if encoder_hidden_states is not None:
|
234 |
+
return hidden_states, encoder_hidden_states
|
235 |
+
else:
|
236 |
+
return hidden_states
|
237 |
+
|
238 |
+
|
239 |
+
def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None):
|
240 |
+
if ip_hidden_states is None:
|
241 |
+
return None
|
242 |
+
|
243 |
+
if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
|
244 |
+
return None
|
245 |
+
|
246 |
+
# norm ip input
|
247 |
+
norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb)
|
248 |
+
|
249 |
+
# to k and v
|
250 |
+
ip_key = self.to_k_ip(norm_ip_hidden_states)
|
251 |
+
ip_value = self.to_v_ip(norm_ip_hidden_states)
|
252 |
+
|
253 |
+
# reshape
|
254 |
+
query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads)
|
255 |
+
img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
|
256 |
+
img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
|
257 |
+
ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
|
258 |
+
ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
|
259 |
+
|
260 |
+
# norm
|
261 |
+
query = self.norm_q(query)
|
262 |
+
img_key = self.norm_k(img_key)
|
263 |
+
ip_key = self.norm_ip_k(ip_key)
|
264 |
+
|
265 |
+
# cat img
|
266 |
+
key = torch.cat([img_key, ip_key], dim=2)
|
267 |
+
value = torch.cat([img_value, ip_value], dim=2)
|
268 |
+
|
269 |
+
#
|
270 |
+
ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
271 |
+
ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
|
272 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
273 |
+
return ip_hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
277 |
+
def retrieve_timesteps(
|
278 |
+
scheduler,
|
279 |
+
num_inference_steps: Optional[int] = None,
|
280 |
+
device: Optional[Union[str, torch.device]] = None,
|
281 |
+
timesteps: Optional[List[int]] = None,
|
282 |
+
sigmas: Optional[List[float]] = None,
|
283 |
+
**kwargs,
|
284 |
+
):
|
285 |
+
"""
|
286 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
287 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
scheduler (`SchedulerMixin`):
|
291 |
+
The scheduler to get timesteps from.
|
292 |
+
num_inference_steps (`int`):
|
293 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
294 |
+
must be `None`.
|
295 |
+
device (`str` or `torch.device`, *optional*):
|
296 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
297 |
+
timesteps (`List[int]`, *optional*):
|
298 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
299 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
300 |
+
sigmas (`List[float]`, *optional*):
|
301 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
302 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
306 |
+
second element is the number of inference steps.
|
307 |
+
"""
|
308 |
+
if timesteps is not None and sigmas is not None:
|
309 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
310 |
+
if timesteps is not None:
|
311 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
312 |
+
if not accepts_timesteps:
|
313 |
+
raise ValueError(
|
314 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
315 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
316 |
+
)
|
317 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
318 |
+
timesteps = scheduler.timesteps
|
319 |
+
num_inference_steps = len(timesteps)
|
320 |
+
elif sigmas is not None:
|
321 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
322 |
+
if not accept_sigmas:
|
323 |
+
raise ValueError(
|
324 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
325 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
326 |
+
)
|
327 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
328 |
+
timesteps = scheduler.timesteps
|
329 |
+
num_inference_steps = len(timesteps)
|
330 |
+
else:
|
331 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
332 |
+
timesteps = scheduler.timesteps
|
333 |
+
return timesteps, num_inference_steps
|
334 |
+
|
335 |
+
|
336 |
+
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
|
337 |
+
r"""
|
338 |
+
Args:
|
339 |
+
transformer ([`SD3Transformer2DModel`]):
|
340 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
341 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
342 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
343 |
+
vae ([`AutoencoderKL`]):
|
344 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
345 |
+
text_encoder ([`CLIPTextModelWithProjection`]):
|
346 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
347 |
+
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
348 |
+
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
349 |
+
as its dimension.
|
350 |
+
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
351 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
352 |
+
specifically the
|
353 |
+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
354 |
+
variant.
|
355 |
+
text_encoder_3 ([`T5EncoderModel`]):
|
356 |
+
Frozen text-encoder. Stable Diffusion 3 uses
|
357 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
358 |
+
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
359 |
+
tokenizer (`CLIPTokenizer`):
|
360 |
+
Tokenizer of class
|
361 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
362 |
+
tokenizer_2 (`CLIPTokenizer`):
|
363 |
+
Second Tokenizer of class
|
364 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
365 |
+
tokenizer_3 (`T5TokenizerFast`):
|
366 |
+
Tokenizer of class
|
367 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
368 |
+
"""
|
369 |
+
|
370 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
371 |
+
_optional_components = []
|
372 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
373 |
+
|
374 |
+
def __init__(
|
375 |
+
self,
|
376 |
+
transformer: SD3Transformer2DModel,
|
377 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
378 |
+
vae: AutoencoderKL,
|
379 |
+
text_encoder: CLIPTextModelWithProjection,
|
380 |
+
tokenizer: CLIPTokenizer,
|
381 |
+
text_encoder_2: CLIPTextModelWithProjection,
|
382 |
+
tokenizer_2: CLIPTokenizer,
|
383 |
+
text_encoder_3: T5EncoderModel,
|
384 |
+
tokenizer_3: T5TokenizerFast,
|
385 |
+
):
|
386 |
+
super().__init__()
|
387 |
+
|
388 |
+
self.register_modules(
|
389 |
+
vae=vae,
|
390 |
+
text_encoder=text_encoder,
|
391 |
+
text_encoder_2=text_encoder_2,
|
392 |
+
text_encoder_3=text_encoder_3,
|
393 |
+
tokenizer=tokenizer,
|
394 |
+
tokenizer_2=tokenizer_2,
|
395 |
+
tokenizer_3=tokenizer_3,
|
396 |
+
transformer=transformer,
|
397 |
+
scheduler=scheduler,
|
398 |
+
)
|
399 |
+
self.vae_scale_factor = (
|
400 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
401 |
+
)
|
402 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
403 |
+
self.tokenizer_max_length = (
|
404 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
405 |
+
)
|
406 |
+
self.default_sample_size = (
|
407 |
+
self.transformer.config.sample_size
|
408 |
+
if hasattr(self, "transformer") and self.transformer is not None
|
409 |
+
else 128
|
410 |
+
)
|
411 |
+
|
412 |
+
def _get_t5_prompt_embeds(
|
413 |
+
self,
|
414 |
+
prompt: Union[str, List[str]] = None,
|
415 |
+
num_images_per_prompt: int = 1,
|
416 |
+
max_sequence_length: int = 256,
|
417 |
+
device: Optional[torch.device] = None,
|
418 |
+
dtype: Optional[torch.dtype] = None,
|
419 |
+
):
|
420 |
+
device = device or self._execution_device
|
421 |
+
dtype = dtype or self.text_encoder.dtype
|
422 |
+
|
423 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
424 |
+
batch_size = len(prompt)
|
425 |
+
|
426 |
+
if self.text_encoder_3 is None:
|
427 |
+
return torch.zeros(
|
428 |
+
(
|
429 |
+
batch_size * num_images_per_prompt,
|
430 |
+
self.tokenizer_max_length,
|
431 |
+
self.transformer.config.joint_attention_dim,
|
432 |
+
),
|
433 |
+
device=device,
|
434 |
+
dtype=dtype,
|
435 |
+
)
|
436 |
+
|
437 |
+
text_inputs = self.tokenizer_3(
|
438 |
+
prompt,
|
439 |
+
padding="max_length",
|
440 |
+
max_length=max_sequence_length,
|
441 |
+
truncation=True,
|
442 |
+
add_special_tokens=True,
|
443 |
+
return_tensors="pt",
|
444 |
+
)
|
445 |
+
text_input_ids = text_inputs.input_ids
|
446 |
+
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
447 |
+
|
448 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
449 |
+
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
450 |
+
logger.warning(
|
451 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
452 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
453 |
+
)
|
454 |
+
|
455 |
+
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
456 |
+
|
457 |
+
dtype = self.text_encoder_3.dtype
|
458 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
459 |
+
|
460 |
+
_, seq_len, _ = prompt_embeds.shape
|
461 |
+
|
462 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
463 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
464 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
465 |
+
|
466 |
+
return prompt_embeds
|
467 |
+
|
468 |
+
def _get_clip_prompt_embeds(
|
469 |
+
self,
|
470 |
+
prompt: Union[str, List[str]],
|
471 |
+
num_images_per_prompt: int = 1,
|
472 |
+
device: Optional[torch.device] = None,
|
473 |
+
clip_skip: Optional[int] = None,
|
474 |
+
clip_model_index: int = 0,
|
475 |
+
):
|
476 |
+
device = device or self._execution_device
|
477 |
+
|
478 |
+
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
479 |
+
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
480 |
+
|
481 |
+
tokenizer = clip_tokenizers[clip_model_index]
|
482 |
+
text_encoder = clip_text_encoders[clip_model_index]
|
483 |
+
|
484 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
485 |
+
batch_size = len(prompt)
|
486 |
+
|
487 |
+
text_inputs = tokenizer(
|
488 |
+
prompt,
|
489 |
+
padding="max_length",
|
490 |
+
max_length=self.tokenizer_max_length,
|
491 |
+
truncation=True,
|
492 |
+
return_tensors="pt",
|
493 |
+
)
|
494 |
+
|
495 |
+
text_input_ids = text_inputs.input_ids
|
496 |
+
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
497 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
498 |
+
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
499 |
+
logger.warning(
|
500 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
501 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
502 |
+
)
|
503 |
+
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
504 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
505 |
+
|
506 |
+
if clip_skip is None:
|
507 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
508 |
+
else:
|
509 |
+
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
510 |
+
|
511 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
512 |
+
|
513 |
+
_, seq_len, _ = prompt_embeds.shape
|
514 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
515 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
516 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
517 |
+
|
518 |
+
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
519 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
520 |
+
|
521 |
+
return prompt_embeds, pooled_prompt_embeds
|
522 |
+
|
523 |
+
def encode_prompt(
|
524 |
+
self,
|
525 |
+
prompt: Union[str, List[str]],
|
526 |
+
prompt_2: Union[str, List[str]],
|
527 |
+
prompt_3: Union[str, List[str]],
|
528 |
+
device: Optional[torch.device] = None,
|
529 |
+
num_images_per_prompt: int = 1,
|
530 |
+
do_classifier_free_guidance: bool = True,
|
531 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
532 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
533 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
534 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
535 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
536 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
537 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
538 |
+
clip_skip: Optional[int] = None,
|
539 |
+
max_sequence_length: int = 256,
|
540 |
+
lora_scale: Optional[float] = None,
|
541 |
+
):
|
542 |
+
r"""
|
543 |
+
|
544 |
+
Args:
|
545 |
+
prompt (`str` or `List[str]`, *optional*):
|
546 |
+
prompt to be encoded
|
547 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
548 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
549 |
+
used in all text-encoders
|
550 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
551 |
+
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
552 |
+
used in all text-encoders
|
553 |
+
device: (`torch.device`):
|
554 |
+
torch device
|
555 |
+
num_images_per_prompt (`int`):
|
556 |
+
number of images that should be generated per prompt
|
557 |
+
do_classifier_free_guidance (`bool`):
|
558 |
+
whether to use classifier free guidance or not
|
559 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
560 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
561 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
562 |
+
less than `1`).
|
563 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
564 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
565 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
566 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
567 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
568 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
569 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
570 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
571 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
572 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
573 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
574 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
575 |
+
argument.
|
576 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
577 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
578 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
579 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
580 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
581 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
582 |
+
input argument.
|
583 |
+
clip_skip (`int`, *optional*):
|
584 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
585 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
586 |
+
lora_scale (`float`, *optional*):
|
587 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
588 |
+
"""
|
589 |
+
device = device or self._execution_device
|
590 |
+
|
591 |
+
# set lora scale so that monkey patched LoRA
|
592 |
+
# function of text encoder can correctly access it
|
593 |
+
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
594 |
+
self._lora_scale = lora_scale
|
595 |
+
|
596 |
+
# dynamically adjust the LoRA scale
|
597 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
598 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
599 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
600 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
601 |
+
|
602 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
603 |
+
if prompt is not None:
|
604 |
+
batch_size = len(prompt)
|
605 |
+
else:
|
606 |
+
batch_size = prompt_embeds.shape[0]
|
607 |
+
|
608 |
+
if prompt_embeds is None:
|
609 |
+
prompt_2 = prompt_2 or prompt
|
610 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
611 |
+
|
612 |
+
prompt_3 = prompt_3 or prompt
|
613 |
+
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
614 |
+
|
615 |
+
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
616 |
+
prompt=prompt,
|
617 |
+
device=device,
|
618 |
+
num_images_per_prompt=num_images_per_prompt,
|
619 |
+
clip_skip=clip_skip,
|
620 |
+
clip_model_index=0,
|
621 |
+
)
|
622 |
+
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
623 |
+
prompt=prompt_2,
|
624 |
+
device=device,
|
625 |
+
num_images_per_prompt=num_images_per_prompt,
|
626 |
+
clip_skip=clip_skip,
|
627 |
+
clip_model_index=1,
|
628 |
+
)
|
629 |
+
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
630 |
+
|
631 |
+
t5_prompt_embed = self._get_t5_prompt_embeds(
|
632 |
+
prompt=prompt_3,
|
633 |
+
num_images_per_prompt=num_images_per_prompt,
|
634 |
+
max_sequence_length=max_sequence_length,
|
635 |
+
device=device,
|
636 |
+
)
|
637 |
+
|
638 |
+
clip_prompt_embeds = torch.nn.functional.pad(
|
639 |
+
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
640 |
+
)
|
641 |
+
|
642 |
+
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
643 |
+
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
644 |
+
|
645 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
646 |
+
negative_prompt = negative_prompt or ""
|
647 |
+
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
648 |
+
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
649 |
+
|
650 |
+
# normalize str to list
|
651 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
652 |
+
negative_prompt_2 = (
|
653 |
+
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
654 |
+
)
|
655 |
+
negative_prompt_3 = (
|
656 |
+
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
657 |
+
)
|
658 |
+
|
659 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
660 |
+
raise TypeError(
|
661 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
662 |
+
f" {type(prompt)}."
|
663 |
+
)
|
664 |
+
elif batch_size != len(negative_prompt):
|
665 |
+
raise ValueError(
|
666 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
667 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
668 |
+
" the batch size of `prompt`."
|
669 |
+
)
|
670 |
+
|
671 |
+
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
672 |
+
negative_prompt,
|
673 |
+
device=device,
|
674 |
+
num_images_per_prompt=num_images_per_prompt,
|
675 |
+
clip_skip=None,
|
676 |
+
clip_model_index=0,
|
677 |
+
)
|
678 |
+
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
679 |
+
negative_prompt_2,
|
680 |
+
device=device,
|
681 |
+
num_images_per_prompt=num_images_per_prompt,
|
682 |
+
clip_skip=None,
|
683 |
+
clip_model_index=1,
|
684 |
+
)
|
685 |
+
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
686 |
+
|
687 |
+
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
688 |
+
prompt=negative_prompt_3,
|
689 |
+
num_images_per_prompt=num_images_per_prompt,
|
690 |
+
max_sequence_length=max_sequence_length,
|
691 |
+
device=device,
|
692 |
+
)
|
693 |
+
|
694 |
+
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
695 |
+
negative_clip_prompt_embeds,
|
696 |
+
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
697 |
+
)
|
698 |
+
|
699 |
+
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
700 |
+
negative_pooled_prompt_embeds = torch.cat(
|
701 |
+
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
702 |
+
)
|
703 |
+
|
704 |
+
if self.text_encoder is not None:
|
705 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
706 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
707 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
708 |
+
|
709 |
+
if self.text_encoder_2 is not None:
|
710 |
+
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
711 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
712 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
713 |
+
|
714 |
+
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
715 |
+
|
716 |
+
def check_inputs(
|
717 |
+
self,
|
718 |
+
prompt,
|
719 |
+
prompt_2,
|
720 |
+
prompt_3,
|
721 |
+
height,
|
722 |
+
width,
|
723 |
+
negative_prompt=None,
|
724 |
+
negative_prompt_2=None,
|
725 |
+
negative_prompt_3=None,
|
726 |
+
prompt_embeds=None,
|
727 |
+
negative_prompt_embeds=None,
|
728 |
+
pooled_prompt_embeds=None,
|
729 |
+
negative_pooled_prompt_embeds=None,
|
730 |
+
callback_on_step_end_tensor_inputs=None,
|
731 |
+
max_sequence_length=None,
|
732 |
+
):
|
733 |
+
if height % 8 != 0 or width % 8 != 0:
|
734 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
735 |
+
|
736 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
737 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
738 |
+
):
|
739 |
+
raise ValueError(
|
740 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
741 |
+
)
|
742 |
+
|
743 |
+
if prompt is not None and prompt_embeds is not None:
|
744 |
+
raise ValueError(
|
745 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
746 |
+
" only forward one of the two."
|
747 |
+
)
|
748 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
749 |
+
raise ValueError(
|
750 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
751 |
+
" only forward one of the two."
|
752 |
+
)
|
753 |
+
elif prompt_3 is not None and prompt_embeds is not None:
|
754 |
+
raise ValueError(
|
755 |
+
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
756 |
+
" only forward one of the two."
|
757 |
+
)
|
758 |
+
elif prompt is None and prompt_embeds is None:
|
759 |
+
raise ValueError(
|
760 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
761 |
+
)
|
762 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
763 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
764 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
765 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
766 |
+
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
767 |
+
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
768 |
+
|
769 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
770 |
+
raise ValueError(
|
771 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
772 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
773 |
+
)
|
774 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
775 |
+
raise ValueError(
|
776 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
777 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
778 |
+
)
|
779 |
+
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
780 |
+
raise ValueError(
|
781 |
+
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
782 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
783 |
+
)
|
784 |
+
|
785 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
786 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
787 |
+
raise ValueError(
|
788 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
789 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
790 |
+
f" {negative_prompt_embeds.shape}."
|
791 |
+
)
|
792 |
+
|
793 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
794 |
+
raise ValueError(
|
795 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
796 |
+
)
|
797 |
+
|
798 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
799 |
+
raise ValueError(
|
800 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
801 |
+
)
|
802 |
+
|
803 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
804 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
805 |
+
|
806 |
+
def prepare_latents(
|
807 |
+
self,
|
808 |
+
batch_size,
|
809 |
+
num_channels_latents,
|
810 |
+
height,
|
811 |
+
width,
|
812 |
+
dtype,
|
813 |
+
device,
|
814 |
+
generator,
|
815 |
+
latents=None,
|
816 |
+
):
|
817 |
+
if latents is not None:
|
818 |
+
return latents.to(device=device, dtype=dtype)
|
819 |
+
|
820 |
+
shape = (
|
821 |
+
batch_size,
|
822 |
+
num_channels_latents,
|
823 |
+
int(height) // self.vae_scale_factor,
|
824 |
+
int(width) // self.vae_scale_factor,
|
825 |
+
)
|
826 |
+
|
827 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
828 |
+
raise ValueError(
|
829 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
830 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
831 |
+
)
|
832 |
+
|
833 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
834 |
+
|
835 |
+
return latents
|
836 |
+
|
837 |
+
@property
|
838 |
+
def guidance_scale(self):
|
839 |
+
return self._guidance_scale
|
840 |
+
|
841 |
+
@property
|
842 |
+
def clip_skip(self):
|
843 |
+
return self._clip_skip
|
844 |
+
|
845 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
846 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
847 |
+
# corresponds to doing no classifier free guidance.
|
848 |
+
@property
|
849 |
+
def do_classifier_free_guidance(self):
|
850 |
+
return self._guidance_scale > 1
|
851 |
+
|
852 |
+
@property
|
853 |
+
def joint_attention_kwargs(self):
|
854 |
+
return self._joint_attention_kwargs
|
855 |
+
|
856 |
+
@property
|
857 |
+
def num_timesteps(self):
|
858 |
+
return self._num_timesteps
|
859 |
+
|
860 |
+
@property
|
861 |
+
def interrupt(self):
|
862 |
+
return self._interrupt
|
863 |
+
|
864 |
+
|
865 |
+
@torch.inference_mode()
|
866 |
+
def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432):
|
867 |
+
from transformers import SiglipVisionModel, SiglipImageProcessor
|
868 |
+
state_dict = torch.load(ip_adapter_path, map_location="cpu")
|
869 |
+
|
870 |
+
device, dtype = self.transformer.device, self.transformer.dtype
|
871 |
+
image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
|
872 |
+
image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path)
|
873 |
+
image_encoder.eval()
|
874 |
+
image_encoder.to(device, dtype=dtype)
|
875 |
+
self.image_encoder = image_encoder
|
876 |
+
self.clip_image_processor = image_processor
|
877 |
+
|
878 |
+
sample_class = TimeResampler
|
879 |
+
image_proj_model = sample_class(
|
880 |
+
dim=1280,
|
881 |
+
depth=4,
|
882 |
+
dim_head=64,
|
883 |
+
heads=20,
|
884 |
+
num_queries=nb_token,
|
885 |
+
embedding_dim=1152,
|
886 |
+
output_dim=output_dim,
|
887 |
+
ff_mult=4,
|
888 |
+
timestep_in_dim=320,
|
889 |
+
timestep_flip_sin_to_cos=True,
|
890 |
+
timestep_freq_shift=0,
|
891 |
+
)
|
892 |
+
image_proj_model.eval()
|
893 |
+
image_proj_model.to(device, dtype=dtype)
|
894 |
+
key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
|
895 |
+
print(f"=> loading image_proj_model: {key_name}")
|
896 |
+
|
897 |
+
self.image_proj_model = image_proj_model
|
898 |
+
|
899 |
+
|
900 |
+
attn_procs = {}
|
901 |
+
transformer = self.transformer
|
902 |
+
for idx_name, name in enumerate(transformer.attn_processors.keys()):
|
903 |
+
hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads
|
904 |
+
ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads
|
905 |
+
ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim
|
906 |
+
|
907 |
+
attn_procs[name] = JointIPAttnProcessor(
|
908 |
+
hidden_size=hidden_size,
|
909 |
+
cross_attention_dim=transformer.config.caption_projection_dim,
|
910 |
+
ip_hidden_states_dim=ip_hidden_states_dim,
|
911 |
+
ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim,
|
912 |
+
head_dim=transformer.config.attention_head_dim,
|
913 |
+
timesteps_emb_dim=1280,
|
914 |
+
).to(device, dtype=dtype)
|
915 |
+
|
916 |
+
self.transformer.set_attn_processor(attn_procs)
|
917 |
+
tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values())
|
918 |
+
|
919 |
+
key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
|
920 |
+
print(f"=> loading ip_adapter: {key_name}")
|
921 |
+
|
922 |
+
|
923 |
+
@torch.inference_mode()
|
924 |
+
def encode_clip_image_emb(self, clip_image, device, dtype):
|
925 |
+
|
926 |
+
# clip
|
927 |
+
clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
|
928 |
+
clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
|
929 |
+
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
|
930 |
+
clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
|
931 |
+
|
932 |
+
return clip_image_embeds
|
933 |
+
|
934 |
+
|
935 |
+
|
936 |
+
@torch.no_grad()
|
937 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
938 |
+
def __call__(
|
939 |
+
self,
|
940 |
+
prompt: Union[str, List[str]] = None,
|
941 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
942 |
+
prompt_3: Optional[Union[str, List[str]]] = None,
|
943 |
+
height: Optional[int] = None,
|
944 |
+
width: Optional[int] = None,
|
945 |
+
num_inference_steps: int = 28,
|
946 |
+
timesteps: List[int] = None,
|
947 |
+
guidance_scale: float = 7.0,
|
948 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
949 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
950 |
+
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
951 |
+
num_images_per_prompt: Optional[int] = 1,
|
952 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
953 |
+
latents: Optional[torch.FloatTensor] = None,
|
954 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
955 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
956 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
957 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
958 |
+
output_type: Optional[str] = "pil",
|
959 |
+
return_dict: bool = True,
|
960 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
961 |
+
clip_skip: Optional[int] = None,
|
962 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
963 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
964 |
+
max_sequence_length: int = 256,
|
965 |
+
|
966 |
+
# ipa
|
967 |
+
clip_image=None,
|
968 |
+
ipadapter_scale=1.0,
|
969 |
+
):
|
970 |
+
r"""
|
971 |
+
Function invoked when calling the pipeline for generation.
|
972 |
+
|
973 |
+
Args:
|
974 |
+
prompt (`str` or `List[str]`, *optional*):
|
975 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
976 |
+
instead.
|
977 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
978 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
979 |
+
will be used instead
|
980 |
+
prompt_3 (`str` or `List[str]`, *optional*):
|
981 |
+
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
982 |
+
will be used instead
|
983 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
984 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
985 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
986 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
987 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
988 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
989 |
+
expense of slower inference.
|
990 |
+
timesteps (`List[int]`, *optional*):
|
991 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
992 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
993 |
+
passed will be used. Must be in descending order.
|
994 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
995 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
996 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
997 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
998 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
999 |
+
usually at the expense of lower image quality.
|
1000 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
1001 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
1002 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
1003 |
+
less than `1`).
|
1004 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
1005 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
1006 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
1007 |
+
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
1008 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
1009 |
+
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
1010 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
1011 |
+
The number of images to generate per prompt.
|
1012 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
1013 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
1014 |
+
to make generation deterministic.
|
1015 |
+
latents (`torch.FloatTensor`, *optional*):
|
1016 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
1017 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
1018 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
1019 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
1020 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
1021 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
1022 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1023 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1024 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
1025 |
+
argument.
|
1026 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1027 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
1028 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
1029 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
1030 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
1031 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
1032 |
+
input argument.
|
1033 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
1034 |
+
The output format of the generate image. Choose between
|
1035 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
1036 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
1037 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
1038 |
+
of a plain tuple.
|
1039 |
+
joint_attention_kwargs (`dict`, *optional*):
|
1040 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
1041 |
+
`self.processor` in
|
1042 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
1043 |
+
callback_on_step_end (`Callable`, *optional*):
|
1044 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
1045 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
1046 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
1047 |
+
`callback_on_step_end_tensor_inputs`.
|
1048 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
1049 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
1050 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
1051 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
1052 |
+
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
1053 |
+
|
1054 |
+
Examples:
|
1055 |
+
|
1056 |
+
Returns:
|
1057 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
1058 |
+
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
1059 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
1060 |
+
"""
|
1061 |
+
|
1062 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
1063 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
1064 |
+
|
1065 |
+
# 1. Check inputs. Raise error if not correct
|
1066 |
+
self.check_inputs(
|
1067 |
+
prompt,
|
1068 |
+
prompt_2,
|
1069 |
+
prompt_3,
|
1070 |
+
height,
|
1071 |
+
width,
|
1072 |
+
negative_prompt=negative_prompt,
|
1073 |
+
negative_prompt_2=negative_prompt_2,
|
1074 |
+
negative_prompt_3=negative_prompt_3,
|
1075 |
+
prompt_embeds=prompt_embeds,
|
1076 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1077 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1078 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1079 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
1080 |
+
max_sequence_length=max_sequence_length,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
self._guidance_scale = guidance_scale
|
1084 |
+
self._clip_skip = clip_skip
|
1085 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
1086 |
+
self._interrupt = False
|
1087 |
+
|
1088 |
+
# 2. Define call parameters
|
1089 |
+
if prompt is not None and isinstance(prompt, str):
|
1090 |
+
batch_size = 1
|
1091 |
+
elif prompt is not None and isinstance(prompt, list):
|
1092 |
+
batch_size = len(prompt)
|
1093 |
+
else:
|
1094 |
+
batch_size = prompt_embeds.shape[0]
|
1095 |
+
|
1096 |
+
device = self._execution_device
|
1097 |
+
dtype = self.transformer.dtype
|
1098 |
+
|
1099 |
+
lora_scale = (
|
1100 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
1101 |
+
)
|
1102 |
+
(
|
1103 |
+
prompt_embeds,
|
1104 |
+
negative_prompt_embeds,
|
1105 |
+
pooled_prompt_embeds,
|
1106 |
+
negative_pooled_prompt_embeds,
|
1107 |
+
) = self.encode_prompt(
|
1108 |
+
prompt=prompt,
|
1109 |
+
prompt_2=prompt_2,
|
1110 |
+
prompt_3=prompt_3,
|
1111 |
+
negative_prompt=negative_prompt,
|
1112 |
+
negative_prompt_2=negative_prompt_2,
|
1113 |
+
negative_prompt_3=negative_prompt_3,
|
1114 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1115 |
+
prompt_embeds=prompt_embeds,
|
1116 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
1117 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
1118 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
1119 |
+
device=device,
|
1120 |
+
clip_skip=self.clip_skip,
|
1121 |
+
num_images_per_prompt=num_images_per_prompt,
|
1122 |
+
max_sequence_length=max_sequence_length,
|
1123 |
+
lora_scale=lora_scale,
|
1124 |
+
)
|
1125 |
+
|
1126 |
+
if self.do_classifier_free_guidance:
|
1127 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
1128 |
+
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
1129 |
+
|
1130 |
+
# 3. prepare clip emb
|
1131 |
+
clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
|
1132 |
+
clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
|
1133 |
+
|
1134 |
+
# 4. Prepare timesteps
|
1135 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
1136 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
1137 |
+
self._num_timesteps = len(timesteps)
|
1138 |
+
|
1139 |
+
# 5. Prepare latent variables
|
1140 |
+
num_channels_latents = self.transformer.config.in_channels
|
1141 |
+
latents = self.prepare_latents(
|
1142 |
+
batch_size * num_images_per_prompt,
|
1143 |
+
num_channels_latents,
|
1144 |
+
height,
|
1145 |
+
width,
|
1146 |
+
prompt_embeds.dtype,
|
1147 |
+
device,
|
1148 |
+
generator,
|
1149 |
+
latents,
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
# 6. Denoising loop
|
1153 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1154 |
+
for i, t in enumerate(timesteps):
|
1155 |
+
if self.interrupt:
|
1156 |
+
continue
|
1157 |
+
|
1158 |
+
# expand the latents if we are doing classifier free guidance
|
1159 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1160 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
1161 |
+
timestep = t.expand(latent_model_input.shape[0])
|
1162 |
+
|
1163 |
+
image_prompt_embeds, timestep_emb = self.image_proj_model(
|
1164 |
+
clip_image_embeds,
|
1165 |
+
timestep.to(dtype=latents.dtype),
|
1166 |
+
need_temb=True
|
1167 |
+
)
|
1168 |
+
|
1169 |
+
joint_attention_kwargs = dict(
|
1170 |
+
emb_dict=dict(
|
1171 |
+
ip_hidden_states=image_prompt_embeds,
|
1172 |
+
temb=timestep_emb,
|
1173 |
+
scale=ipadapter_scale,
|
1174 |
+
)
|
1175 |
+
)
|
1176 |
+
|
1177 |
+
noise_pred = self.transformer(
|
1178 |
+
hidden_states=latent_model_input,
|
1179 |
+
timestep=timestep,
|
1180 |
+
encoder_hidden_states=prompt_embeds,
|
1181 |
+
pooled_projections=pooled_prompt_embeds,
|
1182 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
1183 |
+
return_dict=False,
|
1184 |
+
)[0]
|
1185 |
+
|
1186 |
+
# perform guidance
|
1187 |
+
if self.do_classifier_free_guidance:
|
1188 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1189 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1190 |
+
|
1191 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1192 |
+
latents_dtype = latents.dtype
|
1193 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
1194 |
+
|
1195 |
+
if latents.dtype != latents_dtype:
|
1196 |
+
if torch.backends.mps.is_available():
|
1197 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
1198 |
+
latents = latents.to(latents_dtype)
|
1199 |
+
|
1200 |
+
if callback_on_step_end is not None:
|
1201 |
+
callback_kwargs = {}
|
1202 |
+
for k in callback_on_step_end_tensor_inputs:
|
1203 |
+
callback_kwargs[k] = locals()[k]
|
1204 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1205 |
+
|
1206 |
+
latents = callback_outputs.pop("latents", latents)
|
1207 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1208 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1209 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1210 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1211 |
+
)
|
1212 |
+
|
1213 |
+
# call the callback, if provided
|
1214 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1215 |
+
progress_bar.update()
|
1216 |
+
|
1217 |
+
if XLA_AVAILABLE:
|
1218 |
+
xm.mark_step()
|
1219 |
+
|
1220 |
+
if output_type == "latent":
|
1221 |
+
image = latents
|
1222 |
+
|
1223 |
+
else:
|
1224 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
1225 |
+
|
1226 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
1227 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
1228 |
+
|
1229 |
+
# Offload all models
|
1230 |
+
self.maybe_free_model_hooks()
|
1231 |
+
|
1232 |
+
if not return_dict:
|
1233 |
+
return (image,)
|
1234 |
+
|
1235 |
+
return StableDiffusion3PipelineOutput(images=image)
|
teasers/0.png
ADDED
![]() |
Git LFS Details
|
teasers/1.png
ADDED
![]() |
Git LFS Details
|