AngelBottomless commited on
Commit
0a4fc35
·
verified ·
1 Parent(s): 2da64d3

Upload 18 files

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
cache.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import hashlib
4
+ import functools
5
+ import json
6
+ import yaml
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from diffusers import AutoencoderKL
13
+ from torchvision import transforms
14
+ from tqdm import tqdm
15
+
16
+ from imgproc import (
17
+ generate_crop_size_list,
18
+ to_rgb_if_rgba,
19
+ var_center_crop,
20
+ )
21
+ from data import read_general
22
+
23
+ # ---- Flux VAE scaling parameters ----
24
+ VAE_SCALE = 0.3611
25
+ VAE_SHIFT = 0.1159
26
+
27
+ def handle_image(image: Image.Image) -> Image.Image:
28
+ """
29
+ Ensure the image is in RGB format, converting from RGBA, L, P, etc.
30
+ Raise ValueError if unrecognized mode.
31
+ """
32
+ mode = image.mode.upper()
33
+ if mode == "RGB":
34
+ return image
35
+ elif mode == "RGBA":
36
+ return to_rgb_if_rgba(image)
37
+ elif mode in ("L", "P"):
38
+ return image.convert("RGB")
39
+ else:
40
+ raise ValueError(f"Unsupported image mode: {mode}")
41
+
42
+ def encode(vae: AutoencoderKL, img_tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
43
+ """
44
+ Encode a normalized image tensor to latents using the Flux VAE, applying SHIFT+SCALE.
45
+ img_tensor shape: (C, H, W) or (1,C,H,W). We'll reshape to (1,C,H,W) if needed.
46
+ """
47
+ if img_tensor.dim() == 3:
48
+ img_tensor = img_tensor.unsqueeze(0) # (1,C,H,W)
49
+ img_tensor = img_tensor.to(device, non_blocking=True)
50
+ with torch.no_grad():
51
+ # bfloat16 casting for VAE encode
52
+ latent_dist = vae.encode(img_tensor).latent_dist
53
+ # use .mode()[0] or .sample() depending on whether you prefer the mode or random sample
54
+ latents = latent_dist.mode()[0]
55
+ latents = (latents - VAE_SHIFT) * VAE_SCALE
56
+ return latents.float()
57
+
58
+ def load_image_paths_from_yaml(yaml_path: str) -> list:
59
+ """
60
+ Parse a YAML containing a 'META' key with paths to .jsonl files.
61
+ For each .jsonl (with 'type' == 'image_text'), read lines of JSON
62
+ where we expect an 'image_path' field. Collect these paths in a list.
63
+ """
64
+ with open(yaml_path, "r", encoding="utf-8") as f:
65
+ data = yaml.safe_load(f)
66
+
67
+ image_files = []
68
+ meta_list = data.get("META", [])
69
+ for meta_item in meta_list:
70
+ # Example: path=/data0/DanbooruWebp/booru1116Webp.jsonl
71
+ # type=image_text
72
+ ftype = meta_item.get("type", "")
73
+ fpath = meta_item.get("path", "")
74
+ if ftype != "image_text":
75
+ # skip unknown types
76
+ continue
77
+ if not os.path.isfile(fpath):
78
+ print(f"[Warning] JSONL file not found: {fpath}")
79
+ continue
80
+
81
+ # Open .jsonl and parse lines
82
+ with open(fpath, "r", encoding="utf-8") as fin:
83
+ for line in fin:
84
+ line = line.strip()
85
+ if not line:
86
+ continue
87
+ try:
88
+ obj = json.loads(line)
89
+ if "image_path" in obj:
90
+ # This is the actual disk path for the image
91
+ image_files.append(obj["image_path"])
92
+ except Exception as e:
93
+ print(f"[Warning] JSON parse error in {fpath}: {e}")
94
+ continue
95
+
96
+ return image_files
97
+
98
+ def main():
99
+ parser = argparse.ArgumentParser(description="Cache image latents using Flux VAE")
100
+ parser.add_argument("--data_yaml", type=str, required=True,
101
+ help="Path to dataset YAML config (with META -> .jsonl paths)")
102
+ parser.add_argument("--resolution", type=int, required=True,
103
+ help="Target resolution (e.g., 256, 512, 1024) for center-crop/resize")
104
+ parser.add_argument("--total_split", type=int, default=1,
105
+ help="Total number of parallel splits/workers")
106
+ parser.add_argument("--current_worker_index", type=int, default=0,
107
+ help="Index of this worker (0-based)")
108
+ parser.add_argument("--patch_size", type=int, default=8,
109
+ help="Patch size used for generating potential crop sizes")
110
+ parser.add_argument("--random_top_k", type=int, default=1,
111
+ help="Number of top crop options from var_center_crop to randomly pick")
112
+ args = parser.parse_args()
113
+
114
+ # ------------------------------------------------------------------
115
+ # 1) Setup VAE model for encoding:
116
+ # ------------------------------------------------------------------
117
+ vae = AutoencoderKL.from_pretrained(
118
+ "black-forest-labs/FLUX.1-dev",
119
+ subfolder="vae",
120
+ torch_dtype=torch.float16
121
+ ).eval()
122
+
123
+ device = torch.device(
124
+ f"cuda:0" if torch.cuda.is_available() else "cpu"
125
+ )
126
+ vae.to(device)
127
+
128
+ # ------------------------------------------------------------------
129
+ # 2) Prepare your transform (crop -> tensor -> normalize).
130
+ # This must match how images are processed before training.
131
+ # ------------------------------------------------------------------
132
+ max_num_patches = round((args.resolution / (args.patch_size * 1.0)) ** 2)
133
+ crop_size_list = generate_crop_size_list(max_num_patches, args.patch_size)
134
+ image_transform = transforms.Compose([
135
+ transforms.Lambda(functools.partial(var_center_crop,
136
+ crop_size_list=crop_size_list,
137
+ random_top_k=args.random_top_k)),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
140
+ ])
141
+
142
+ # ------------------------------------------------------------------
143
+ # 3) Load image paths from YAML / JSONL references:
144
+ # ------------------------------------------------------------------
145
+ image_files = load_image_paths_from_yaml(args.data_yaml)
146
+ if not image_files:
147
+ print("[INFO] No image files found. Check your YAML & JSONL contents.")
148
+ return
149
+
150
+ # ------------------------------------------------------------------
151
+ # 4) Process each image => transform => encode => save .npz
152
+ # ------------------------------------------------------------------
153
+ worker_idx = args.current_worker_index
154
+ total_split = args.total_split
155
+ res = args.resolution
156
+
157
+ for image_path in tqdm(image_files, desc=f"Worker {worker_idx}"):
158
+ # 4.a) Determine if this file belongs to the current worker
159
+ hash_val = int(hashlib.sha1(image_path.encode("utf-8")).hexdigest(), 16)
160
+ if hash_val % total_split != worker_idx:
161
+ continue
162
+
163
+ # 4.b) Construct cache path
164
+ base, _ = os.path.splitext(image_path)
165
+ out_path = f"{base}_{res}.npz"
166
+ if os.path.exists(out_path):
167
+ continue
168
+
169
+ # 4.c) Read the image from disk & handle mode
170
+ try:
171
+ pil_image = Image.open(read_general(image_path))
172
+ pil_image = handle_image(pil_image) # ensure RGB
173
+ except Exception as e:
174
+ print(f"[Warning] Could not open image {image_path}: {e}")
175
+ continue
176
+
177
+ # Optionally, you can do a simple resize (if your training expects it).
178
+ # Otherwise, rely solely on var_center_crop to pick a final crop size.
179
+ pil_image = pil_image.resize((res, res), Image.Resampling.LANCZOS)
180
+
181
+ # 4.d) Apply var_center_crop -> toTensor -> normalize
182
+ try:
183
+ transformed_tensor = image_transform(pil_image) # shape=(3,H,W)
184
+ except Exception as e:
185
+ print(f"[Warning] Skipping {image_path} due to transform error: {e}")
186
+ continue
187
+ transformed_tensor = transformed_tensor.to(torch.float16)
188
+ # 4.e) Encode with Flux VAE (shift+scale) => latent
189
+ latents = encode(vae, transformed_tensor, device=device)
190
+ latents_np = latents.cpu().numpy() # shape=(C, H//8, W//8) typically
191
+
192
+ # 4.f) Save latents to .npz
193
+ try:
194
+ np.savez_compressed(out_path, latent=latents_np)
195
+ except Exception as e:
196
+ print(f"[Error] Saving .npz for {image_path} failed: {e}")
197
+
198
+ if __name__ == "__main__":
199
+ main()
data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .data_reader import *
2
+ from .dataset import *
data/data_reader.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import logging
5
+ from io import BytesIO
6
+ from typing import Union, Optional, Tuple, Dict, Any, Protocol, List
7
+
8
+ import requests
9
+ from PIL import Image
10
+
11
+ # Disable Pillow’s large image pixel limit.
12
+ Image.MAX_IMAGE_PIXELS = None
13
+
14
+ #####################################################
15
+ # Configure Logging with Level Argument
16
+ #####################################################
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def configure_logging(level: Union[str, int] = logging.INFO):
21
+ """
22
+ Configures the root logger (and thus 'logger') to a specific logging level.
23
+
24
+ :param level: Either a string like 'DEBUG'/'INFO'/'WARNING'
25
+ or an integer like logging.DEBUG/logging.INFO/etc.
26
+ """
27
+ if isinstance(level, str):
28
+ level = getattr(logging, level.upper(), logging.INFO)
29
+
30
+ logging.basicConfig(
31
+ level=level,
32
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
33
+ )
34
+
35
+
36
+ # Global Ceph/petrel client
37
+ client = None # type: ignore
38
+
39
+ # Cache for JSON data loaded from a repo
40
+ loaded_jsons: Dict[str, Any] = {}
41
+
42
+ #####################################################
43
+ # Helpers for Hugging Face Token & HTTP Session
44
+ #####################################################
45
+
46
+
47
+ def _get_hf_access_token() -> str:
48
+ """
49
+ Retrieves the Hugging Face access token from the environment or from 'env.json'.
50
+ Raises ValueError if not found.
51
+ """
52
+ hf_access_token = os.environ.get("HF_ACCESS_TOKEN")
53
+ if not hf_access_token and os.path.isfile("env.json"):
54
+ with open("env.json", "r", encoding="utf-8") as f:
55
+ env_data = json.load(f)
56
+ hf_access_token = env_data.get("HF_ACCESS_TOKEN")
57
+
58
+ if not hf_access_token:
59
+ return None
60
+
61
+ return hf_access_token
62
+
63
+
64
+ def get_hf_session() -> requests.Session:
65
+ """
66
+ Creates and returns a requests.Session object with the Hugging Face token in the headers.
67
+ """
68
+ token = _get_hf_access_token()
69
+ session = requests.Session()
70
+ if token:
71
+ session.headers.update({"Authorization": f"Bearer {token}"})
72
+ return session
73
+
74
+
75
+ #####################################################
76
+ # Ceph/Petrel Client Initialization
77
+ #####################################################
78
+
79
+
80
+ def init_ceph_client_if_needed():
81
+ """
82
+ Initializes the global Ceph/petrel `client` if it has not yet been set.
83
+ """
84
+ global client
85
+ if client is None:
86
+ logger.info("Initializing Ceph/petrel client...")
87
+ start_time = time.time()
88
+ from petrel_client.client import Client # noqa
89
+
90
+ client = Client("./petreloss.conf")
91
+ end_time = time.time()
92
+ logger.info(
93
+ f"Initialized Ceph/petrel client in {end_time - start_time:.2f} seconds"
94
+ )
95
+
96
+
97
+ #####################################################
98
+ # Reading & Caching JSON
99
+ #####################################################
100
+
101
+
102
+ def read_json_from_repo(
103
+ session: requests.Session, repo_addr: str, file_name: str, cache_dir: str
104
+ ) -> Optional[Dict[str, Any]]:
105
+ """
106
+ Reads JSON from a given repository address and file name, with caching:
107
+ 1. If cached in memory (loaded_jsons), returns it.
108
+ 2. Otherwise, checks local disk cache (cache_dir).
109
+ 3. If not found on disk, downloads and saves it locally.
110
+
111
+ :param session: requests.Session
112
+ :param repo_addr: URL base (e.g. "https://github.com/user/repo/tree/main")
113
+ :param file_name: Name of the JSON file
114
+ :param cache_dir: Local directory to store cache
115
+ :return: Parsed JSON object or None
116
+ """
117
+ unique_key = f"{repo_addr}/{file_name}"
118
+ if unique_key in loaded_jsons:
119
+ logger.debug(f"Found in-memory cache for {unique_key}")
120
+ return loaded_jsons[unique_key]
121
+
122
+ # Check local disk cache
123
+ cache_file = os.path.join(cache_dir, file_name)
124
+ if os.path.exists(cache_file):
125
+ logger.debug(f"Reading from local cache: {cache_file}")
126
+ with open(cache_file, "r", encoding="utf-8") as f:
127
+ result = json.load(f)
128
+ loaded_jsons[unique_key] = result
129
+ return result
130
+ else:
131
+ # Download and cache
132
+ url = f"{repo_addr}/{file_name}"
133
+ logger.debug(f"Downloading JSON from {url}")
134
+ response = session.get(url)
135
+ try:
136
+ response.raise_for_status()
137
+ except requests.HTTPError:
138
+ if response.status_code == 404:
139
+ loaded_jsons[unique_key] = None
140
+ return None
141
+ raise
142
+ data = response.json()
143
+ os.makedirs(cache_dir, exist_ok=True)
144
+ with open(cache_file, "w", encoding="utf-8") as f:
145
+ json.dump(data, f, indent=4)
146
+ loaded_jsons[unique_key] = data
147
+ return data
148
+
149
+
150
+ def load_json_index(
151
+ session: requests.Session,
152
+ json_url: str,
153
+ cache_path: Optional[str] = None,
154
+ ) -> Optional[Dict[str, Any]]:
155
+ """
156
+ Download (if needed) and cache a JSON file from `json_url`.
157
+ If `cache_path` is provided, data is saved/loaded from that path.
158
+
159
+ :param session: requests.Session
160
+ :param json_url: Direct URL to the JSON file
161
+ :param cache_path: Local path for caching the JSON
162
+ :return: Parsed JSON (dict) or None if 404
163
+ """
164
+ if cache_path is not None and os.path.isfile(cache_path):
165
+ logger.debug(f"Found cached JSON at {cache_path}")
166
+ with open(cache_path, "r", encoding="utf-8") as f:
167
+ return json.load(f)
168
+
169
+ logger.debug(f"Requesting JSON index from {json_url}")
170
+ resp = session.get(json_url)
171
+ if resp.status_code == 404:
172
+ logger.warning(f"JSON index not found (404): {json_url}")
173
+ return None
174
+ resp.raise_for_status()
175
+
176
+ data = resp.json()
177
+ if cache_path is not None:
178
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
179
+ with open(cache_path, "w", encoding="utf-8") as f:
180
+ json.dump(data, f)
181
+ logger.debug(f"Saved JSON index to {cache_path}")
182
+ return data
183
+
184
+
185
+ #####################################################
186
+ # Downloading Byte Ranges
187
+ #####################################################
188
+
189
+
190
+ def download_range(session: requests.Session, url: str, start: int, end: int) -> bytes:
191
+ """
192
+ Downloads the inclusive byte range [start, end] from the specified URL via
193
+ an HTTP Range request and returns the raw bytes.
194
+
195
+ :param session: A requests.Session with appropriate headers
196
+ :param url: The file URL to download
197
+ :param start: Start byte (inclusive)
198
+ :param end: End byte (inclusive)
199
+ :return: Raw bytes of the specified range
200
+ """
201
+ headers = {"Range": f"bytes={start}-{end}"}
202
+ logger.debug(f"Downloading range {start}-{end} from {url}")
203
+ response = session.get(url, headers=headers, stream=True)
204
+ response.raise_for_status()
205
+ return response.content
206
+
207
+
208
+ #####################################################
209
+ # Repository Protocol and Implementations
210
+ #####################################################
211
+
212
+
213
+ class BaseRepository(Protocol):
214
+ """
215
+ A Protocol that each repository must implement. Must have a method:
216
+ find_image(session, image_id) -> (tar_url, start_offset, end_offset, filename) or None
217
+ """
218
+
219
+ def find_image(
220
+ self, session: requests.Session, image_id: Union[int, str]
221
+ ) -> Optional[Tuple[str, int, int, str]]: ...
222
+
223
+
224
+ def primary_subfolder_from_id(x: int) -> str:
225
+ """
226
+ Given an integer image ID, return a subfolder name based on the ID mod 1000.
227
+ E.g., 7502245 -> '0245'.
228
+ """
229
+ if not isinstance(x, int):
230
+ raise ValueError(f"Primary subfolder requires an integer ID, given: {x}")
231
+ val = x % 1000
232
+ return f"{val:04d}"
233
+
234
+
235
+ def secondary_chunk_from_id(x: int, chunk_size: int = 1000) -> int:
236
+ """
237
+ Compute the chunk index for a 'secondary' dataset given an image ID.
238
+ """
239
+ return x % chunk_size
240
+
241
+
242
+ class PrimaryRepository(BaseRepository):
243
+ """
244
+ Example of a 'primary' dataset repository:
245
+ - .tar files named "NNNN.tar" where NNNN = image_id % 1000
246
+ - Each .tar file has a companion JSON index "NNNN.json"
247
+ - The JSON maps "7501000.jpg" -> [start_offset, end_offset]
248
+ """
249
+
250
+ def __init__(self, base_url: str, cache_dir: str, entry: Optional[str]=None):
251
+ self.base_url = base_url
252
+ self.cache_dir = cache_dir
253
+ self.entry = entry
254
+ os.makedirs(self.cache_dir, exist_ok=True)
255
+
256
+ def _build_primary_id_map(self, json_index: Dict[str, Any]) -> Dict[int, str]:
257
+ """
258
+ From a JSON index like { "7501000.jpg": [start, end], ... },
259
+ create a map of integer ID -> filename (e.g. 7501000 -> "7501000.jpg").
260
+ """
261
+ out = {}
262
+ for filename in json_index.keys():
263
+ root, _ = os.path.splitext(filename)
264
+ try:
265
+ num = int(root)
266
+ out[num] = filename
267
+ except ValueError:
268
+ continue
269
+ return out
270
+
271
+ def find_image(
272
+ self, session: requests.Session, image_id: Union[int, str]
273
+ ) -> Optional[Tuple[str, int, int, str]]:
274
+ if isinstance(image_id, str):
275
+ try:
276
+ image_id = int(image_id)
277
+ except ValueError:
278
+ logger.error(f"Invalid image ID: {image_id}")
279
+ return None
280
+ folder = primary_subfolder_from_id(image_id)
281
+ json_name = f"{folder}.json"
282
+ json_url = f"{self.base_url}/{json_name}"
283
+ cache_path = os.path.join(self.cache_dir, json_name)
284
+
285
+ logger.debug(f"Looking for image {image_id} in {json_name} (folder: {folder})")
286
+ json_index = load_json_index(session, json_url, cache_path)
287
+ if not json_index:
288
+ logger.debug(f"No JSON index found for folder {folder}")
289
+ return None
290
+
291
+ # Build a map integer_id -> filename
292
+ id_map = self._build_primary_id_map(json_index)
293
+ filename = id_map.get(image_id)
294
+ if not filename:
295
+ logger.debug(f"Image ID {image_id} not found in index for folder {folder}")
296
+ return None
297
+
298
+ start_offset, end_offset = json_index[filename]
299
+ tar_url = f"{self.base_url}/{folder}.tar"
300
+ logger.debug(
301
+ f"Found image {image_id} in {folder}.tar ({start_offset}-{end_offset})"
302
+ )
303
+ return tar_url, start_offset, end_offset, filename
304
+
305
+
306
+ class SecondaryRepository(BaseRepository):
307
+ """
308
+ Example for a 'secondary' dataset that:
309
+ - Uses chunk-based storage (each chunk is named data-XXXX.tar)
310
+ - For each chunk, there's a corresponding data-XXXX.json with a "files" mapping
311
+ """
312
+
313
+ def __init__(
314
+ self,
315
+ tar_base_url: str,
316
+ json_base_url: str,
317
+ cache_dir: str,
318
+ chunk_size: int = 1000,
319
+ entry: Optional[str]=None
320
+ ):
321
+ self.tar_base_url = tar_base_url
322
+ self.json_base_url = json_base_url
323
+ self.cache_dir = cache_dir
324
+ self.chunk_size = chunk_size
325
+ self.entry = entry
326
+ os.makedirs(self.cache_dir, exist_ok=True)
327
+
328
+ def find_image(
329
+ self, session: requests.Session, image_id: Union[int, str]
330
+ ) -> Optional[Tuple[str, int, int, str]]:
331
+ if isinstance(image_id, str):
332
+ try:
333
+ image_id = int(image_id)
334
+ except ValueError:
335
+ logger.error(f"Invalid image ID: {image_id}")
336
+ return None
337
+ chunk_index = secondary_chunk_from_id(image_id, self.chunk_size)
338
+ data_name = f"data-{chunk_index:04d}"
339
+
340
+ json_url = f"{self.json_base_url}/{data_name}.json"
341
+ cache_path = os.path.join(self.cache_dir, f"{data_name}.json")
342
+
343
+ logger.debug(f"Looking for image {image_id} in chunk {data_name}")
344
+ data = load_json_index(session, json_url, cache_path)
345
+ if not data or "files" not in data:
346
+ logger.debug(f"No file mapping found in {data_name}.json")
347
+ return None
348
+
349
+ filename_key = f"{image_id}.webp"
350
+ file_dict = data["files"].get(filename_key)
351
+ if not file_dict:
352
+ logger.debug(f"Image ID {image_id} not found in chunk {data_name}")
353
+ return None
354
+
355
+ offset = file_dict["offset"]
356
+ size = file_dict["size"]
357
+ start_offset = offset
358
+ end_offset = offset + size - 1 # inclusive
359
+
360
+ tar_url = f"{self.tar_base_url}/{data_name}.tar"
361
+ logger.info(
362
+ f"Found image {image_id} in {data_name}.tar ({start_offset}-{end_offset})"
363
+ )
364
+ return tar_url, start_offset, end_offset, filename_key
365
+
366
+
367
+ class CustomRepository(BaseRepository):
368
+ """
369
+ Repository that relies on a single 'all_indices.json' plus a structure:
370
+ key -> "tar_path#file_name"
371
+ and then a nested mapping for tar_path -> file_name -> [start_offset, end_offset]
372
+ """
373
+
374
+ def __init__(self, base_url: str, cache_dir: str, entry: Optional[str]=None):
375
+ self.base_url = base_url
376
+ self.cache_dir = cache_dir
377
+ self.entry = entry
378
+ os.makedirs(self.cache_dir, exist_ok=True)
379
+
380
+ def get_range_for_key(
381
+ self, session: requests.Session, key: Union[int, str]
382
+ ) -> Optional[Tuple[str, int, int, str]]:
383
+ # all_indices.json: { key: "tar_path#file_name", tar_path: {...} }
384
+ key = str(key)
385
+ key_index = read_json_from_repo(
386
+ session, self.base_url, "internal_map.json", self.cache_dir
387
+ )
388
+ if key_index is None:
389
+ logger.debug(f"No internal_map.json found in custom repo: {self.base_url}")
390
+ return None
391
+ real_key = key_index.get(key)
392
+ if not real_key:
393
+ logger.debug(f"Key {key} not found in custom repo index")
394
+ return None
395
+ repo_index = read_json_from_repo(
396
+ session, self.base_url, "all_indices.json", self.cache_dir
397
+ )
398
+ if repo_index is None:
399
+ logger.debug(f"No all_indices.json found in custom repo: {self.base_url}")
400
+ return None
401
+ tar_path, file_name = real_key.split("#", 1)
402
+ if tar_path not in repo_index:
403
+ logger.debug(f"Key {real_key} not found in custom repo index")
404
+ return None
405
+ tar_info = repo_index.get(tar_path, {}).get(file_name, None)
406
+ if not tar_info or len(tar_info) < 2:
407
+ return None
408
+
409
+ start, end = tar_info
410
+ tar_url = f"{self.base_url}/{tar_path}"
411
+ logger.info(
412
+ f"Found key '{key}' in custom repository {tar_path} ({start}-{end})"
413
+ )
414
+ return tar_url, start, end, file_name
415
+
416
+ def find_image(
417
+ self, session: requests.Session, image_id: str
418
+ ) -> Optional[Tuple[str, int, int, str]]:
419
+ return self.get_range_for_key(session, image_id)
420
+
421
+
422
+ #####################################################
423
+ # Repository Configuration
424
+ #####################################################
425
+
426
+ class RepositoryConfig:
427
+ """
428
+ Manages loading/storing repository configurations from a JSON file,
429
+ and instantiates the corresponding repository objects, including custom 'entry' prefixes.
430
+ """
431
+
432
+ def __init__(self, config_path: str):
433
+ """
434
+ :param config_path: Path to the JSON configuration file.
435
+ """
436
+ self.config_path = config_path
437
+ # Lists to hold instantiated repository objects
438
+ self.repositories: List[BaseRepository] = []
439
+ self.custom_repositories: List[CustomRepository] = []
440
+
441
+ # Map from entry string -> list of repositories that handle that entry
442
+ self.entry_map: Dict[str, List[BaseRepository]] = {}
443
+
444
+ def load(self):
445
+ """
446
+ Reads the config file from disk and populates repositories and entry_map.
447
+ """
448
+ if not os.path.isfile(self.config_path):
449
+ raise FileNotFoundError(f"Config file not found: {self.config_path}")
450
+
451
+ logger.debug(f"Loading repository configuration from {self.config_path}")
452
+ print(f"Loading repository configuration from {self.config_path}")
453
+ with open(self.config_path, "r", encoding="utf-8") as f:
454
+ data = json.load(f)
455
+
456
+ self.from_dict(data)
457
+
458
+ def from_dict(self, data: Dict[str, Any]):
459
+ """
460
+ Populates repositories/customs from a dictionary, building self.entry_map as well.
461
+
462
+ :param data: A dict corresponding to the structure of `repository.json`.
463
+ """
464
+ # Clear existing repos
465
+ self.repositories.clear()
466
+ self.custom_repositories.clear()
467
+ self.entry_map.clear()
468
+
469
+ # Load standard repositories
470
+ repos_config = data.get("repositories", [])
471
+ for repo_dict in repos_config:
472
+ repo_obj = self._create_repository(repo_dict)
473
+ if repo_obj is not None:
474
+ self.repositories.append(repo_obj)
475
+ # If there's an "entry", register in entry_map
476
+ entry_name = repo_dict.get("entry")
477
+ if entry_name:
478
+ self.entry_map.setdefault(entry_name, []).append(repo_obj)
479
+
480
+ # Load custom repositories
481
+ custom_config = data.get("customs", [])
482
+ for custom_dict in custom_config:
483
+ custom_obj = self._create_custom_repository(custom_dict)
484
+ if custom_obj is not None:
485
+ self.custom_repositories.append(custom_obj)
486
+ entry_name = custom_dict.get("entry")
487
+ if entry_name:
488
+ self.entry_map.setdefault(entry_name, []).append(custom_obj)
489
+ logger.info(
490
+ f"Loaded {len(self.repositories)} standard repositories, "
491
+ f"{len(self.custom_repositories)} custom repositories, "
492
+ f"with {len(self.entry_map)} distinct entries."
493
+ )
494
+
495
+ def _create_repository(self, config: Dict[str, Any]) -> Optional[BaseRepository]:
496
+ """
497
+ Internal helper to instantiate a standard repository based on 'type'.
498
+ """
499
+ repo_type = config.get("type")
500
+ entry = config.get("entry", None) # new field
501
+
502
+ if repo_type == "primary":
503
+ base_url = config.get("base_url")
504
+ cache_dir = config.get("cache_dir")
505
+ if base_url and cache_dir:
506
+ return PrimaryRepository(
507
+ base_url=base_url,
508
+ cache_dir=cache_dir,
509
+ entry=entry, # pass to constructor
510
+ )
511
+ else:
512
+ logger.warning(
513
+ "Invalid 'primary' repo config; missing base_url or cache_dir."
514
+ )
515
+ return None
516
+
517
+ elif repo_type == "secondary":
518
+ tar_base_url = config.get("tar_base_url")
519
+ json_base_url = config.get("json_base_url")
520
+ cache_dir = config.get("cache_dir")
521
+ chunk_size = config.get("chunk_size", 1000)
522
+ if tar_base_url and json_base_url and cache_dir:
523
+ return SecondaryRepository(
524
+ tar_base_url=tar_base_url,
525
+ json_base_url=json_base_url,
526
+ cache_dir=cache_dir,
527
+ chunk_size=chunk_size,
528
+ entry=entry,
529
+ )
530
+ else:
531
+ logger.warning(
532
+ "Invalid 'secondary' repo config; missing tar_base_url/json_base_url/cache_dir."
533
+ )
534
+ return None
535
+
536
+ else:
537
+ logger.warning(
538
+ f"Repository type '{repo_type}' is not recognized or not supported."
539
+ )
540
+ return None
541
+
542
+ def _create_custom_repository(
543
+ self, config: Dict[str, Any]
544
+ ) -> Optional[CustomRepository]:
545
+ """
546
+ Internal helper to instantiate a custom repository.
547
+ """
548
+ repo_type = config.get("type")
549
+ entry = config.get("entry", None)
550
+
551
+ if repo_type == "custom":
552
+ base_url = config.get("base_url")
553
+ cache_dir = config.get("cache_dir")
554
+ if base_url and cache_dir:
555
+ return CustomRepository(
556
+ base_url=base_url, cache_dir=cache_dir, entry=entry
557
+ )
558
+ else:
559
+ logger.warning(
560
+ "Invalid 'custom' repo config; missing base_url or cache_dir."
561
+ )
562
+ return None
563
+
564
+ else:
565
+ logger.warning(
566
+ f"Custom repository type '{repo_type}' is not recognized or not supported."
567
+ )
568
+ return None
569
+
570
+ def to_dict(self) -> Dict[str, Any]:
571
+ """
572
+ Reconstructs the config dictionary from the current repository objects.
573
+ """
574
+ return {
575
+ "repositories": [self._repo_to_dict(repo) for repo in self.repositories],
576
+ "customs": [
577
+ self._custom_repo_to_dict(crepo) for crepo in self.custom_repositories
578
+ ],
579
+ }
580
+
581
+ def _repo_to_dict(self, repo: BaseRepository) -> Dict[str, Any]:
582
+ """
583
+ Rebuilds the config dict for a standard repository from its attributes.
584
+ """
585
+ # We assume each repository has .entry
586
+ if hasattr(repo, "entry"):
587
+ entry_val = getattr(repo, "entry", None)
588
+ else:
589
+ entry_val = None
590
+
591
+ if isinstance(repo, PrimaryRepository):
592
+ return {
593
+ "type": "primary",
594
+ "base_url": repo.base_url,
595
+ "cache_dir": repo.cache_dir,
596
+ "entry": entry_val,
597
+ }
598
+ elif isinstance(repo, SecondaryRepository):
599
+ return {
600
+ "type": "secondary",
601
+ "tar_base_url": repo.tar_base_url,
602
+ "json_base_url": repo.json_base_url,
603
+ "cache_dir": repo.cache_dir,
604
+ "chunk_size": repo.chunk_size,
605
+ "entry": entry_val,
606
+ }
607
+ else:
608
+ return {"type": "unknown", "entry": entry_val}
609
+
610
+ def _custom_repo_to_dict(self, repo: CustomRepository) -> Dict[str, Any]:
611
+ """
612
+ Rebuilds the config dict for a CustomRepository from its attributes.
613
+ """
614
+ return {
615
+ "type": "custom",
616
+ "base_url": repo.base_url,
617
+ "cache_dir": repo.cache_dir,
618
+ "entry": getattr(repo, "entry", None),
619
+ }
620
+
621
+ def save(self, path: Optional[str] = None):
622
+ """
623
+ Saves the current config (based on the instantiated repo objects) back to a JSON file.
624
+ :param path: Optional; if None, uses self.config_path.
625
+ """
626
+ if path is None:
627
+ path = self.config_path
628
+
629
+ data = self.to_dict()
630
+ with open(path, "w", encoding="utf-8") as f:
631
+ json.dump(data, f, indent=4)
632
+ logger.info(f"Repository configuration saved to {path}")
633
+
634
+ def get_repositories_for_entry(self, entry: str) -> List[Union[BaseRepository, CustomRepository]]:
635
+ """
636
+ Retrieves the list of repositories (both standard and custom) that are mapped to a given entry prefix.
637
+ """
638
+ return self.entry_map.get(entry, [])
639
+
640
+ def search_entry_and_key(self, entry: str, key: str) -> Optional[BytesIO]:
641
+ """
642
+ Returns a RepositoryPool object that can be used to download images for a given entry.
643
+ """
644
+ repositories = self.get_repositories_for_entry(entry)
645
+ if not repositories:
646
+ logger.warning(f"No repositories found for entry: {entry}")
647
+ return None
648
+ base_repos = BaseRepositoryPool(repositories)
649
+ result = base_repos.download_by_id(key)
650
+ if result:
651
+ return result
652
+ return None
653
+
654
+
655
+ #####################################################
656
+ class RepositoryPool(Protocol):
657
+ """
658
+ A Protocol for a set of repositories that can be searched for a given image ID.
659
+ """
660
+ ### class to hold download_by_id method
661
+ def download_by_id(self, image_id: int) -> Optional[BytesIO]: ...
662
+
663
+
664
+ class BaseRepositoryPool(RepositoryPool):
665
+ """
666
+ A pool of BaseRepository objects, allowing for a unified download_by_id method.
667
+ """
668
+
669
+ def __init__(self, repositories: List[BaseRepository]):
670
+ self.repositories = repositories
671
+ ### class to hold download_by_id method
672
+ def download_by_id(self, image_id: int) -> Optional[BytesIO]:
673
+ session = get_hf_session()
674
+ for repo in self.repositories:
675
+ info = repo.find_image(session, image_id)
676
+ logger.debug(f"Searching for image {image_id} in {repo}, result: {info}")
677
+ if info:
678
+ break
679
+ if not info:
680
+ msg = f"Image ID {image_id} was not found in any repository. (Base)"
681
+ logger.info(msg)
682
+ return None
683
+ tar_url, start_offset, end_offset, _ = info
684
+ file_bytes = download_range(session, tar_url, start_offset, end_offset)
685
+ logger.debug(f"Successfully downloaded image {image_id} from {tar_url}")
686
+ return BytesIO(file_bytes)
687
+
688
+
689
+ #####################################################
690
+ # Universal Read Function
691
+ #####################################################
692
+ REPOSITORY_CONFIG: RepositoryConfig = RepositoryConfig(r"repository.json")
693
+ REPOSITORY_CONFIG.load()
694
+
695
+ def read_general(path: str) -> Union[str, BytesIO]:
696
+ """
697
+ A universal read function:
698
+ - If path starts with "danbooru://", parse out the integer ID and download
699
+ from configured repositories. Returns a BytesIO of the file content.
700
+ - If path starts with "s3://", uses Ceph/petrel client to retrieve data.
701
+ - Otherwise, if the path doesn't exist locally, tries custom repositories.
702
+ - If none of the above, returns the path string as-is (assuming it's local or standard).
703
+
704
+ :param path: The path or URI to read
705
+ :return: Either a local path string or an in-memory BytesIO
706
+ """
707
+ config = REPOSITORY_CONFIG
708
+ if path.startswith("s3://"):
709
+ init_ceph_client_if_needed()
710
+ logger.debug(f"Downloading from Ceph/petrel: {path}")
711
+ file_data = client.get(path) # type: ignore
712
+ return BytesIO(file_data)
713
+ if "://" in path:
714
+ parts = path.split("://", 1)
715
+ entry = parts[0]
716
+ result = config.search_entry_and_key(entry, parts[1])
717
+ if result:
718
+ return result
719
+ raise FileNotFoundError(f"Image ID not found in any repository: {path}")
720
+ # If the path isn't local, try custom repositories
721
+ if not os.path.exists(path):
722
+ raise FileNotFoundError(f"File not found: {path}")
723
+
724
+ # Otherwise, assume it's a normal local path
725
+ logger.debug(f"Returning local path: {path}")
726
+ return path
727
+
728
+
729
+ if __name__ == "__main__":
730
+ # 2) Configure logging at the desired level
731
+ configure_logging("DEBUG") # or "INFO", "WARNING", etc.
732
+
733
+ # 3) Example usage:
734
+ # try:
735
+ # data = read_general("danbooru://6706939")
736
+ # if isinstance(data, BytesIO):
737
+ # img = Image.open(data)
738
+ # img.show()
739
+ # except FileNotFoundError as e:
740
+ # logger.error(str(e))
741
+ # try:
742
+ # data = read_general("danbooru://8884993")
743
+ # if isinstance(data, BytesIO):
744
+ # img = Image.open(data)
745
+ # img.show()
746
+ # except FileNotFoundError as e:
747
+ # logger.error(str(e))
748
+ #
749
+ try:
750
+ data = read_general("anime://fancaps/8183457")
751
+ if isinstance(data, BytesIO):
752
+ img = Image.open(data)
753
+ img.show()
754
+ except FileNotFoundError as e:
755
+ logger.error(str(e))
756
+ # Other usage examples:
757
+ # data2 = read_general("s3://bucket_name/path/to/object.jpg")
758
+ # data3 = read_general("some/local/path.jpg")
data/dataset.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import copy
3
+ import json
4
+ import logging
5
+ import os
6
+ from pathlib import Path
7
+ import random
8
+ from time import sleep
9
+ import traceback
10
+ import warnings
11
+ import pandas as pd
12
+ from tqdm import tqdm
13
+ import h5py
14
+ import torch.distributed as dist
15
+ from torch.utils.data import Dataset
16
+ import yaml
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class DataBriefReportException(Exception):
21
+ def __init__(self, message=None):
22
+ self.message = message
23
+
24
+ def __str__(self):
25
+ return f"{self.__class__}: {self.message}"
26
+
27
+
28
+ class DataNoReportException(Exception):
29
+ def __init__(self, message=None):
30
+ self.message = message
31
+
32
+ def __str__(self):
33
+ return f"{self.__class__}: {self.message}"
34
+
35
+
36
+ class ItemProcessor(ABC):
37
+ @abstractmethod
38
+ def process_item(self, data_item, training_mode=False):
39
+ raise NotImplementedError
40
+ def is_huggingface_path(path: str) -> bool:
41
+ # Heuristic: Hugging Face dataset paths are in format "user/dataset"
42
+ # and not an existing local file or directory.
43
+ return ("/" in path and not os.path.exists(path) and not "booru" in path) or (os.path.exists(path) and os.path.isdir(path))
44
+
45
+ global_log_count = 0
46
+ def log_every_n(n, msg):
47
+ global global_log_count
48
+ if global_log_count % n == 0:
49
+ logger.warning(msg)
50
+ global_log_count += 1
51
+ class MyDataset(Dataset):
52
+ def __init__(self, config_path, item_processor: ItemProcessor, cache_on_disk=False):
53
+ logger.info(f"read dataset config from {config_path}")
54
+ with open(config_path, "r") as f:
55
+ self.config = yaml.load(f, Loader=yaml.FullLoader)
56
+ logger.info("DATASET CONFIG:")
57
+ logger.info(self.config)
58
+
59
+ self.cache_on_disk = cache_on_disk
60
+ if self.cache_on_disk:
61
+ cache_dir = self._get_cache_dir(config_path)
62
+ if int(os.environ["LOCAL_RANK"]) == 0:
63
+ local_rank = dist.get_rank()
64
+ print(f"Building cache on rank {local_rank}")
65
+ self._collect_annotations_and_save_to_cache(cache_dir)
66
+ dist.barrier()
67
+ ann, group_indice_range = self._load_annotations_from_cache(cache_dir)
68
+ else:
69
+ cache_dir = None
70
+ ann, group_indice_range = self._collect_annotations()
71
+
72
+ self.ann = ann
73
+ self.group_indices = {key: list(range(val[0], val[1])) for key, val in group_indice_range.items()}
74
+
75
+ logger.info(f"total length: {len(self)}")
76
+
77
+ self.item_processor = item_processor
78
+
79
+ def __len__(self):
80
+ return len(self.ann)
81
+
82
+ def _collect_annotations(self):
83
+ meta_type_to_caption_type = {
84
+ "image_text" : "prompt",
85
+ "image_nl_caption" : "sentence",
86
+ "image_alttext" : "alttext",
87
+ "default" : "prompt",
88
+ "super_high_quality_caption" : "super_high_quality_caption",
89
+ "image_tags" : "tags",
90
+ }
91
+ switchable_keys = ["prompt", "sentence", "alttext", "super_high_quality_caption", "tags"]
92
+ group_ann = {}
93
+ for meta in self.config["META"]:
94
+ meta_path, meta_type = meta["path"], meta.get("type", "default")
95
+ meta_key = meta_type_to_caption_type.get(meta_type, "prompt")
96
+ logger.info(f"Reading {meta_path} with type {meta_type} and key {meta_key}")
97
+ if is_huggingface_path(meta_path):
98
+ raise NotImplementedError("Hugging Face datasets are not supported in this minimal example.")
99
+ else:
100
+ meta_ext = os.path.splitext(meta_path)[-1]
101
+ if meta_ext == ".json":
102
+ # with open(meta_path) as f:
103
+ # meta_l = json.load(f)
104
+ with open(meta_path, 'r') as json_file:
105
+ f = json_file.read()
106
+ meta_l = json.loads(f)
107
+ elif meta_ext == ".jsonl":
108
+ meta_l = []
109
+ with open(meta_path) as f:
110
+ for i, line in tqdm(enumerate(f), desc=f"Reading {meta_path}"):
111
+ try:
112
+ read_result = json.loads(line)
113
+ if isinstance(read_result, dict):
114
+ for key in switchable_keys:
115
+ if key in read_result and meta_key != key:
116
+ read_result[meta_key] = read_result[key]
117
+ read_result.pop(key)
118
+ break
119
+ if read_result[meta_key].strip():
120
+ meta_l.append(read_result)
121
+ else:
122
+ logger.error(f"Empty prompt in {meta_path} line {i}, file: {meta_path}")
123
+ log_every_n(10000, f"line {i}: {read_result}")
124
+ else:
125
+ raise ValueError(f"Expected a dictionary, got {type(read_result)} for {meta_path} line {i}")
126
+ except json.decoder.JSONDecodeError as e:
127
+ logger.error(f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}")
128
+ raise e
129
+ elif meta_ext == ".parquet":
130
+ meta_l = []
131
+ df = pd.read_parquet(meta_path) # Read the Parquet file into a DataFrame
132
+ pq_cols = meta.get("pq_cols", None)
133
+ if pq_cols is not None:
134
+ cols = pq_cols.split(",")
135
+ else:
136
+ cols = None
137
+ if cols:
138
+ if "index" not in cols:
139
+ raise ValueError(f"The 'index' column must be included in the 'pq_cols' list., in {meta_path}")
140
+ if not all([col in df.columns for col in cols]):
141
+ raise ValueError(f"Columns in 'pq_cols' must be present in the Parquet file., in {meta_path}")
142
+ for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Reading {meta_path}"):
143
+ # Pull the 'index' column (whatever column indicates image index/id)
144
+ index_val = row["index"]
145
+
146
+ # For each *other* column in the row, if not None/NaN, use it as "prompt"
147
+ for col in df.columns:
148
+ if col == "index":
149
+ continue
150
+ if cols:
151
+ if col not in cols:
152
+ continue
153
+ # Skip if the value is None or NaN
154
+ if pd.notna(row[col]) and str(row[col]):
155
+ log_every_n(10000, f"{meta_key}: {row[col]}")
156
+ meta_l.append({
157
+ "image_path": f"danbooru://{index_val}" if not os.path.exists(index_val) and "://" not in str(index_val) else str(index_val),
158
+ meta_key: str(row[col]) # Cast to str in case it's not a string
159
+ })
160
+ else:
161
+ raise NotImplementedError(
162
+ f'Unknown meta file extension: "{meta_ext}". '
163
+ f"Currently, .json, .jsonl, .parquet (with index column + caption columns) are supported. "
164
+ "If you are using a supported format, please set the file extension so that the proper parsing "
165
+ "routine can be called."
166
+ )
167
+ logger.info(f"{meta_path}, type{meta_type}: len {len(meta_l)}")
168
+ if "ratio" in meta:
169
+ random.seed(0)
170
+ meta_l = random.sample(meta_l, int(len(meta_l) * meta["ratio"]))
171
+ logger.info(f"sample (ratio = {meta['ratio']}) {len(meta_l)} items")
172
+ if "root" in meta:
173
+ for item in meta_l:
174
+ for path_key in ["path", "image_url", "image", "image_path"]:
175
+ if path_key in item:
176
+ item[path_key] = os.path.join(meta["root"], item[path_key])
177
+ if meta_type not in group_ann:
178
+ group_ann[meta_type] = []
179
+ group_ann[meta_type] += meta_l
180
+
181
+ ann = sum(list(group_ann.values()), start=[])
182
+
183
+ group_indice_range = {}
184
+ start_pos = 0
185
+ for meta_type, meta_l in group_ann.items():
186
+ group_indice_range[meta_type] = [start_pos, start_pos + len(meta_l)]
187
+ start_pos = start_pos + len(meta_l)
188
+
189
+ return ann, group_indice_range
190
+
191
+ def _collect_annotations_and_save_to_cache(self, cache_dir):
192
+ if (Path(cache_dir) / "data.h5").exists() and (Path(cache_dir) / "ready").exists():
193
+ # off-the-shelf annotation cache exists
194
+ warnings.warn(
195
+ f"Use existing h5 data cache: {Path(cache_dir)}\n"
196
+ f"Note: if the actual data defined by the data config has changed since your last run, "
197
+ f"please delete the cache manually and re-run this experiment, or the data actually used "
198
+ f"will not be updated"
199
+ )
200
+ return
201
+
202
+ Path(cache_dir).mkdir(parents=True, exist_ok=True)
203
+ ann, group_indice_range = self._collect_annotations()
204
+
205
+ # when cache on disk, rank0 saves items to an h5 file
206
+ serialized_ann = [json.dumps(_) for _ in ann]
207
+ logger.info(f"start to build data cache to: {Path(cache_dir)}")
208
+ with h5py.File(Path(cache_dir) / "data.h5", "w") as file:
209
+ dt = h5py.vlen_dtype(str)
210
+ h5_ann = file.create_dataset("ann", (len(serialized_ann),), dtype=dt)
211
+ h5_ann[:] = serialized_ann
212
+ file.create_dataset("group_indice_range", data=json.dumps(group_indice_range))
213
+ with open(Path(cache_dir) / "ready", "w") as f:
214
+ f.write("ready")
215
+ logger.info(f"data cache built")
216
+
217
+ @staticmethod
218
+ def _get_cache_dir(config_path):
219
+ config_identifier = config_path
220
+ disallowed_chars = ["/", "\\", ".", "?", "!"]
221
+ for _ in disallowed_chars:
222
+ config_identifier = config_identifier.replace(_, "-")
223
+ cache_dir = f"./accessory_data_cache/{config_identifier}"
224
+ return cache_dir
225
+
226
+ @staticmethod
227
+ def _load_annotations_from_cache(cache_dir):
228
+ while not (Path(cache_dir) / "ready").exists():
229
+ # cache has not yet been completed by rank 0
230
+ assert int(os.environ["LOCAL_RANK"]) != 0
231
+ sleep(1)
232
+ cache_file = h5py.File(Path(cache_dir) / "data.h5", "r")
233
+ annotations = cache_file["ann"]
234
+ group_indice_range = json.loads(cache_file["group_indice_range"].asstr()[()])
235
+ return annotations, group_indice_range
236
+
237
+ def get_item_func(self, index):
238
+ data_item = self.ann[index]
239
+ if self.cache_on_disk:
240
+ data_item = json.loads(data_item)
241
+ else:
242
+ data_item = copy.deepcopy(data_item)
243
+
244
+ return self.item_processor.process_item(data_item, training_mode=True)
245
+
246
+ def __getitem__(self, index):
247
+ try:
248
+ return self.get_item_func(index)
249
+ except Exception as e:
250
+ if isinstance(e, DataNoReportException):
251
+ pass
252
+ elif isinstance(e, DataBriefReportException):
253
+ logger.info(e)
254
+ else:
255
+ logger.info(
256
+ f"Item {index} errored, annotation:\n"
257
+ f"{self.ann[index]}\n"
258
+ f"Error:\n"
259
+ f"{traceback.format_exc()}"
260
+ )
261
+ for group_name, indices_this_group in self.group_indices.items():
262
+ if indices_this_group[0] <= index <= indices_this_group[-1]:
263
+ if index == indices_this_group[0]:
264
+ new_index = indices_this_group[-1]
265
+ else:
266
+ new_index = index - 1
267
+ return self[new_index]
268
+ raise RuntimeError
269
+
270
+ def groups(self):
271
+ return list(self.group_indices.values())
grad_norm.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import fairscale.nn.model_parallel.initialize as fs_init
4
+ from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.nn as nn
8
+
9
+
10
+ def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]:
11
+ ret_dict = {}
12
+ for module_name, module in model.named_modules():
13
+
14
+ def param_fqn(param_name):
15
+ return param_name if module_name == "" else module_name + "." + param_name
16
+
17
+ if isinstance(module, ColumnParallelLinear):
18
+ ret_dict[param_fqn("weight")] = 0
19
+ if module.bias is not None:
20
+ ret_dict[param_fqn("bias")] = 0
21
+ elif isinstance(module, RowParallelLinear):
22
+ ret_dict[param_fqn("weight")] = 1
23
+ if module.bias is not None:
24
+ ret_dict[param_fqn("bias")] = -1
25
+ elif isinstance(module, ParallelEmbedding):
26
+ ret_dict[param_fqn("weight")] = 1
27
+ else:
28
+ for param_name, param in module.named_parameters(recurse=False):
29
+ ret_dict[param_fqn(param_name)] = -1
30
+ return ret_dict
31
+
32
+
33
+ def calculate_l2_grad_norm(
34
+ model: nn.Module,
35
+ model_parallel_dim_dict: Dict[str, int],
36
+ ) -> float:
37
+ mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
38
+ non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
39
+
40
+ for name, param in model.named_parameters():
41
+ if param.grad is None:
42
+ continue
43
+ name = ".".join(x for x in name.split(".") if not x.startswith("_"))
44
+ assert name in model_parallel_dim_dict
45
+ if model_parallel_dim_dict[name] < 0:
46
+ non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
47
+ else:
48
+ mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
49
+
50
+ dist.all_reduce(mp_norm_sq)
51
+ dist.all_reduce(non_mp_norm_sq)
52
+ non_mp_norm_sq /= fs_init.get_model_parallel_world_size()
53
+
54
+ return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5
55
+
56
+
57
+ def scale_grad(model: nn.Module, factor: float) -> None:
58
+ for param in model.parameters():
59
+ if param.grad is not None:
60
+ param.grad.mul_(factor)
imgproc.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+
7
+ def center_crop_arr(pil_image, image_size):
8
+ """
9
+ Center cropping implementation from ADM.
10
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
11
+ """
12
+ while min(*pil_image.size) >= 2 * image_size:
13
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
14
+
15
+ scale = image_size / min(*pil_image.size)
16
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
17
+
18
+ arr = np.array(pil_image)
19
+ crop_y = (arr.shape[0] - image_size) // 2
20
+ crop_x = (arr.shape[1] - image_size) // 2
21
+ return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
22
+
23
+
24
+ def center_crop(pil_image, crop_size):
25
+ while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
26
+ pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
27
+
28
+ scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
29
+ pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
30
+
31
+ # crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
32
+ # crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
33
+ crop_left = (pil_image.size[0] - crop_size[0]) // 2
34
+ crop_upper = (pil_image.size[1] - crop_size[1]) // 2
35
+ crop_right = crop_left + crop_size[0]
36
+ crop_lower = crop_upper + crop_size[1]
37
+ return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
38
+
39
+ def var_center_crop(pil_image, crop_size_list, random_top_k=4):
40
+ w, h = pil_image.size
41
+ rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
42
+ crop_size = random.choice(
43
+ sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
44
+ )[1]
45
+ return center_crop(pil_image, crop_size)
46
+
47
+ def var_center_crop_128(pil_image, crop_size_list, random_top_k=4):
48
+ w, h = pil_image.size
49
+ rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
50
+ crop_size = random.choice(
51
+ sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
52
+ )[1]
53
+ breakpoint()
54
+ return center_crop(pil_image, (((w//128)*128), ((h//128)*128)))
55
+
56
+
57
+ def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
58
+ assert max_ratio >= 1.0
59
+ crop_size_list = []
60
+ wp, hp = num_patches, 1
61
+ while wp > 0:
62
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
63
+ if ((wp * patch_size)//32) % 2 == 0 and ((hp * patch_size)//32) % 2 == 0:
64
+ crop_size_list.append((wp * patch_size, hp * patch_size))
65
+ if (hp + 1) * wp <= num_patches:
66
+ hp += 1
67
+ else:
68
+ wp -= 1
69
+ return crop_size_list
70
+
71
+
72
+ def to_rgb_if_rgba(img: Image.Image):
73
+ if img.mode.upper() == "RGBA":
74
+ rgb_img = Image.new("RGB", img.size, (255, 255, 255))
75
+ rgb_img.paste(img, mask=img.split()[3]) # 3 is the alpha channel
76
+ return rgb_img
77
+ elif img.mode.upper() == "P":
78
+ return img.convert('RGB')
79
+ else:
80
+ return img
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import NextDiT_2B_GQA_patch2_Adaln_Refiner, NextDiT_3B_GQA_patch2_Adaln_Refiner, NextDiT_4B_GQA_patch2_Adaln_Refiner, NextDiT_7B_GQA_patch2_Adaln_Refiner
models/components.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ try:
7
+ from apex.normalization import FusedRMSNorm as RMSNorm
8
+ except ImportError:
9
+ warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10
+
11
+ class RMSNorm(torch.nn.Module):
12
+ def __init__(self, dim: int, eps: float = 1e-6):
13
+ """
14
+ Initialize the RMSNorm normalization layer.
15
+
16
+ Args:
17
+ dim (int): The dimension of the input tensor.
18
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19
+
20
+ Attributes:
21
+ eps (float): A small value added to the denominator for numerical stability.
22
+ weight (nn.Parameter): Learnable scaling parameter.
23
+
24
+ """
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.weight = nn.Parameter(torch.ones(dim))
28
+
29
+ def _norm(self, x):
30
+ """
31
+ Apply the RMSNorm normalization to the input tensor.
32
+
33
+ Args:
34
+ x (torch.Tensor): The input tensor.
35
+
36
+ Returns:
37
+ torch.Tensor: The normalized tensor.
38
+
39
+ """
40
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41
+
42
+ def forward(self, x):
43
+ """
44
+ Forward pass through the RMSNorm layer.
45
+
46
+ Args:
47
+ x (torch.Tensor): The input tensor.
48
+
49
+ Returns:
50
+ torch.Tensor: The output tensor after applying RMSNorm.
51
+
52
+ """
53
+ output = self._norm(x.float()).type_as(x)
54
+ return output * self.weight
models/model.py ADDED
@@ -0,0 +1,930 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import math
13
+ from typing import List, Optional, Tuple
14
+
15
+ from flash_attn import flash_attn_varlen_func
16
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+ from .components import RMSNorm
22
+
23
+
24
+ def modulate(x, scale):
25
+ return x * (1 + scale.unsqueeze(1))
26
+
27
+
28
+ #############################################################################
29
+ # Embedding Layers for Timesteps and Class Labels #
30
+ #############################################################################
31
+
32
+
33
+ class TimestepEmbedder(nn.Module):
34
+ """
35
+ Embeds scalar timesteps into vector representations.
36
+ """
37
+
38
+ def __init__(self, hidden_size, frequency_embedding_size=256):
39
+ super().__init__()
40
+ self.mlp = nn.Sequential(
41
+ nn.Linear(
42
+ frequency_embedding_size,
43
+ hidden_size,
44
+ bias=True,
45
+ ),
46
+ nn.SiLU(),
47
+ nn.Linear(
48
+ hidden_size,
49
+ hidden_size,
50
+ bias=True,
51
+ ),
52
+ )
53
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
54
+ nn.init.zeros_(self.mlp[0].bias)
55
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
56
+ nn.init.zeros_(self.mlp[2].bias)
57
+
58
+ self.frequency_embedding_size = frequency_embedding_size
59
+
60
+ @staticmethod
61
+ def timestep_embedding(t, dim, max_period=10000):
62
+ """
63
+ Create sinusoidal timestep embeddings.
64
+ :param t: a 1-D Tensor of N indices, one per batch element.
65
+ These may be fractional.
66
+ :param dim: the dimension of the output.
67
+ :param max_period: controls the minimum frequency of the embeddings.
68
+ :return: an (N, D) Tensor of positional embeddings.
69
+ """
70
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
71
+ half = dim // 2
72
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
73
+ device=t.device
74
+ )
75
+ args = t[:, None].float() * freqs[None]
76
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
77
+ if dim % 2:
78
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
79
+ return embedding
80
+
81
+ def forward(self, t):
82
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
83
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
84
+ return t_emb
85
+
86
+
87
+ #############################################################################
88
+ # Core NextDiT Model #
89
+ #############################################################################
90
+
91
+
92
+ class JointAttention(nn.Module):
93
+ """Multi-head attention module."""
94
+
95
+ def __init__(
96
+ self,
97
+ dim: int,
98
+ n_heads: int,
99
+ n_kv_heads: Optional[int],
100
+ qk_norm: bool,
101
+ ):
102
+ """
103
+ Initialize the Attention module.
104
+
105
+ Args:
106
+ dim (int): Number of input dimensions.
107
+ n_heads (int): Number of heads.
108
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
109
+
110
+ """
111
+ super().__init__()
112
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
113
+ self.n_local_heads = n_heads
114
+ self.n_local_kv_heads = self.n_kv_heads
115
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
116
+ self.head_dim = dim // n_heads
117
+
118
+ self.qkv = nn.Linear(
119
+ dim,
120
+ (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
121
+ bias=False,
122
+ )
123
+ nn.init.xavier_uniform_(self.qkv.weight)
124
+
125
+ self.out = nn.Linear(
126
+ n_heads * self.head_dim,
127
+ dim,
128
+ bias=False,
129
+ )
130
+ nn.init.xavier_uniform_(self.out.weight)
131
+
132
+ if qk_norm:
133
+ self.q_norm = RMSNorm(self.head_dim)
134
+ self.k_norm = RMSNorm(self.head_dim)
135
+ else:
136
+ self.q_norm = self.k_norm = nn.Identity()
137
+
138
+ @staticmethod
139
+ def apply_rotary_emb(
140
+ x_in: torch.Tensor,
141
+ freqs_cis: torch.Tensor,
142
+ ) -> torch.Tensor:
143
+ """
144
+ Apply rotary embeddings to input tensors using the given frequency
145
+ tensor.
146
+
147
+ This function applies rotary embeddings to the given query 'xq' and
148
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
149
+ input tensors are reshaped as complex numbers, and the frequency tensor
150
+ is reshaped for broadcasting compatibility. The resulting tensors
151
+ contain rotary embeddings and are returned as real tensors.
152
+
153
+ Args:
154
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
155
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
156
+ exponentials.
157
+
158
+ Returns:
159
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
160
+ and key tensor with rotary embeddings.
161
+ """
162
+ with torch.cuda.amp.autocast(enabled=False):
163
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
164
+ freqs_cis = freqs_cis.unsqueeze(2)
165
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
166
+ return x_out.type_as(x_in)
167
+
168
+ # copied from huggingface modeling_llama.py
169
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
170
+ def _get_unpad_data(attention_mask):
171
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
172
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
173
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
174
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
175
+ return (
176
+ indices,
177
+ cu_seqlens,
178
+ max_seqlen_in_batch,
179
+ )
180
+
181
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
182
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
183
+
184
+ key_layer = index_first_axis(
185
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
186
+ indices_k,
187
+ )
188
+ value_layer = index_first_axis(
189
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
190
+ indices_k,
191
+ )
192
+ if query_length == kv_seq_len:
193
+ query_layer = index_first_axis(
194
+ query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
195
+ indices_k,
196
+ )
197
+ cu_seqlens_q = cu_seqlens_k
198
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
199
+ indices_q = indices_k
200
+ elif query_length == 1:
201
+ max_seqlen_in_batch_q = 1
202
+ cu_seqlens_q = torch.arange(
203
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
204
+ ) # There is a memcpy here, that is very bad.
205
+ indices_q = cu_seqlens_q[:-1]
206
+ query_layer = query_layer.squeeze(1)
207
+ else:
208
+ # The -q_len: slice assumes left padding.
209
+ attention_mask = attention_mask[:, -query_length:]
210
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
211
+
212
+ return (
213
+ query_layer,
214
+ key_layer,
215
+ value_layer,
216
+ indices_q,
217
+ (cu_seqlens_q, cu_seqlens_k),
218
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
219
+ )
220
+
221
+ def forward(
222
+ self,
223
+ x: torch.Tensor,
224
+ x_mask: torch.Tensor,
225
+ freqs_cis: torch.Tensor,
226
+ ) -> torch.Tensor:
227
+ """
228
+
229
+ Args:
230
+ x:
231
+ x_mask:
232
+ freqs_cis:
233
+
234
+ Returns:
235
+
236
+ """
237
+ bsz, seqlen, _ = x.shape
238
+ dtype = x.dtype
239
+
240
+ xq, xk, xv = torch.split(
241
+ self.qkv(x),
242
+ [
243
+ self.n_local_heads * self.head_dim,
244
+ self.n_local_kv_heads * self.head_dim,
245
+ self.n_local_kv_heads * self.head_dim,
246
+ ],
247
+ dim=-1,
248
+ )
249
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
250
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
251
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
252
+ xq = self.q_norm(xq)
253
+ xk = self.k_norm(xk)
254
+ xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
255
+ xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
256
+ xq, xk = xq.to(dtype), xk.to(dtype)
257
+
258
+ softmax_scale = math.sqrt(1 / self.head_dim)
259
+
260
+ if dtype in [torch.float16, torch.bfloat16]:
261
+ # begin var_len flash attn
262
+ (
263
+ query_states,
264
+ key_states,
265
+ value_states,
266
+ indices_q,
267
+ cu_seq_lens,
268
+ max_seq_lens,
269
+ ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
270
+
271
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
272
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
273
+
274
+ attn_output_unpad = flash_attn_varlen_func(
275
+ query_states,
276
+ key_states,
277
+ value_states,
278
+ cu_seqlens_q=cu_seqlens_q,
279
+ cu_seqlens_k=cu_seqlens_k,
280
+ max_seqlen_q=max_seqlen_in_batch_q,
281
+ max_seqlen_k=max_seqlen_in_batch_k,
282
+ dropout_p=0.0,
283
+ causal=False,
284
+ softmax_scale=softmax_scale,
285
+ )
286
+ output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
287
+ # end var_len_flash_attn
288
+
289
+ else:
290
+ n_rep = self.n_local_heads // self.n_local_kv_heads
291
+ if n_rep >= 1:
292
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
293
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
294
+ output = (
295
+ F.scaled_dot_product_attention(
296
+ xq.permute(0, 2, 1, 3),
297
+ xk.permute(0, 2, 1, 3),
298
+ xv.permute(0, 2, 1, 3),
299
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
300
+ scale=softmax_scale,
301
+ )
302
+ .permute(0, 2, 1, 3)
303
+ .to(dtype)
304
+ )
305
+
306
+ output = output.flatten(-2)
307
+
308
+ return self.out(output)
309
+
310
+
311
+ class FeedForward(nn.Module):
312
+ def __init__(
313
+ self,
314
+ dim: int,
315
+ hidden_dim: int,
316
+ multiple_of: int,
317
+ ffn_dim_multiplier: Optional[float],
318
+ ):
319
+ """
320
+ Initialize the FeedForward module.
321
+
322
+ Args:
323
+ dim (int): Input dimension.
324
+ hidden_dim (int): Hidden dimension of the feedforward layer.
325
+ multiple_of (int): Value to ensure hidden dimension is a multiple
326
+ of this value.
327
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
328
+ dimension. Defaults to None.
329
+
330
+ """
331
+ super().__init__()
332
+ # custom dim factor multiplier
333
+ if ffn_dim_multiplier is not None:
334
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
335
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
336
+
337
+ self.w1 = nn.Linear(
338
+ dim,
339
+ hidden_dim,
340
+ bias=False,
341
+ )
342
+ nn.init.xavier_uniform_(self.w1.weight)
343
+ self.w2 = nn.Linear(
344
+ hidden_dim,
345
+ dim,
346
+ bias=False,
347
+ )
348
+ nn.init.xavier_uniform_(self.w2.weight)
349
+ self.w3 = nn.Linear(
350
+ dim,
351
+ hidden_dim,
352
+ bias=False,
353
+ )
354
+ nn.init.xavier_uniform_(self.w3.weight)
355
+
356
+ # @torch.compile
357
+ def _forward_silu_gating(self, x1, x3):
358
+ return F.silu(x1) * x3
359
+
360
+ def forward(self, x):
361
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
362
+
363
+
364
+ class JointTransformerBlock(nn.Module):
365
+ def __init__(
366
+ self,
367
+ layer_id: int,
368
+ dim: int,
369
+ n_heads: int,
370
+ n_kv_heads: int,
371
+ multiple_of: int,
372
+ ffn_dim_multiplier: float,
373
+ norm_eps: float,
374
+ qk_norm: bool,
375
+ modulation=True
376
+ ) -> None:
377
+ """
378
+ Initialize a TransformerBlock.
379
+
380
+ Args:
381
+ layer_id (int): Identifier for the layer.
382
+ dim (int): Embedding dimension of the input features.
383
+ n_heads (int): Number of attention heads.
384
+ n_kv_heads (Optional[int]): Number of attention heads in key and
385
+ value features (if using GQA), or set to None for the same as
386
+ query.
387
+ multiple_of (int):
388
+ ffn_dim_multiplier (float):
389
+ norm_eps (float):
390
+
391
+ """
392
+ super().__init__()
393
+ self.dim = dim
394
+ self.head_dim = dim // n_heads
395
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm)
396
+ self.feed_forward = FeedForward(
397
+ dim=dim,
398
+ hidden_dim=4 * dim,
399
+ multiple_of=multiple_of,
400
+ ffn_dim_multiplier=ffn_dim_multiplier,
401
+ )
402
+ self.layer_id = layer_id
403
+ self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
404
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
405
+
406
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
407
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
408
+
409
+ self.modulation = modulation
410
+ if modulation:
411
+ self.adaLN_modulation = nn.Sequential(
412
+ nn.SiLU(),
413
+ nn.Linear(
414
+ min(dim, 1024),
415
+ 4 * dim,
416
+ bias=True,
417
+ ),
418
+ )
419
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
420
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
421
+
422
+ def forward(
423
+ self,
424
+ x: torch.Tensor,
425
+ x_mask: torch.Tensor,
426
+ freqs_cis: torch.Tensor,
427
+ adaln_input: Optional[torch.Tensor]=None,
428
+ ):
429
+ """
430
+ Perform a forward pass through the TransformerBlock.
431
+
432
+ Args:
433
+ x (torch.Tensor): Input tensor.
434
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
435
+
436
+ Returns:
437
+ torch.Tensor: Output tensor after applying attention and
438
+ feedforward layers.
439
+
440
+ """
441
+ if self.modulation:
442
+ assert adaln_input is not None
443
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
444
+
445
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
446
+ self.attention(
447
+ modulate(self.attention_norm1(x), scale_msa),
448
+ x_mask,
449
+ freqs_cis,
450
+ )
451
+ )
452
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
453
+ self.feed_forward(
454
+ modulate(self.ffn_norm1(x), scale_mlp),
455
+ )
456
+ )
457
+ else:
458
+ assert adaln_input is None
459
+ x = x + self.attention_norm2(
460
+ self.attention(
461
+ self.attention_norm1(x),
462
+ x_mask,
463
+ freqs_cis,
464
+ )
465
+ )
466
+ x = x + self.ffn_norm2(
467
+ self.feed_forward(
468
+ self.ffn_norm1(x),
469
+ )
470
+ )
471
+ return x
472
+
473
+
474
+ class FinalLayer(nn.Module):
475
+ """
476
+ The final layer of NextDiT.
477
+ """
478
+
479
+ def __init__(self, hidden_size, patch_size, out_channels):
480
+ super().__init__()
481
+ self.norm_final = nn.LayerNorm(
482
+ hidden_size,
483
+ elementwise_affine=False,
484
+ eps=1e-6,
485
+ )
486
+ self.linear = nn.Linear(
487
+ hidden_size,
488
+ patch_size * patch_size * out_channels,
489
+ bias=True,
490
+ )
491
+ nn.init.zeros_(self.linear.weight)
492
+ nn.init.zeros_(self.linear.bias)
493
+
494
+ self.adaLN_modulation = nn.Sequential(
495
+ nn.SiLU(),
496
+ nn.Linear(
497
+ min(hidden_size, 1024),
498
+ hidden_size,
499
+ bias=True,
500
+ ),
501
+ )
502
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
503
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
504
+
505
+ def forward(self, x, c):
506
+ scale = self.adaLN_modulation(c)
507
+ x = modulate(self.norm_final(x), scale)
508
+ x = self.linear(x)
509
+ return x
510
+
511
+
512
+ class RopeEmbedder:
513
+ def __init__(
514
+ self, theta: float = 10000.0, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512)
515
+ ):
516
+ super().__init__()
517
+ self.theta = theta
518
+ self.axes_dims = axes_dims
519
+ self.axes_lens = axes_lens
520
+ self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
521
+
522
+ def __call__(self, ids: torch.Tensor):
523
+ self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
524
+ result = []
525
+ for i in range(len(self.axes_dims)):
526
+ # import torch.distributed as dist
527
+ # if not dist.is_initialized() or dist.get_rank() == 0:
528
+ # import pdb
529
+ # pdb.set_trace()
530
+ index = ids[:, :, i:i+1].repeat(1, 1, self.freqs_cis[i].shape[-1]).to(torch.int64)
531
+ result.append(torch.gather(self.freqs_cis[i].unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
532
+ return torch.cat(result, dim=-1)
533
+
534
+
535
+ class NextDiT(nn.Module):
536
+ """
537
+ Diffusion model with a Transformer backbone.
538
+ """
539
+
540
+ def __init__(
541
+ self,
542
+ patch_size: int = 2,
543
+ in_channels: int = 4,
544
+ dim: int = 4096,
545
+ n_layers: int = 32,
546
+ n_refiner_layers: int = 2,
547
+ n_heads: int = 32,
548
+ n_kv_heads: Optional[int] = None,
549
+ multiple_of: int = 256,
550
+ ffn_dim_multiplier: Optional[float] = None,
551
+ norm_eps: float = 1e-5,
552
+ qk_norm: bool = False,
553
+ cap_feat_dim: int = 5120,
554
+ axes_dims: List[int] = (16, 56, 56),
555
+ axes_lens: List[int] = (1, 512, 512),
556
+ ) -> None:
557
+ super().__init__()
558
+ self.in_channels = in_channels
559
+ self.out_channels = in_channels
560
+ self.patch_size = patch_size
561
+
562
+ self.x_embedder = nn.Linear(
563
+ in_features=patch_size * patch_size * in_channels,
564
+ out_features=dim,
565
+ bias=True,
566
+ )
567
+ nn.init.xavier_uniform_(self.x_embedder.weight)
568
+ nn.init.constant_(self.x_embedder.bias, 0.0)
569
+
570
+ self.noise_refiner = nn.ModuleList(
571
+ [
572
+ JointTransformerBlock(
573
+ layer_id,
574
+ dim,
575
+ n_heads,
576
+ n_kv_heads,
577
+ multiple_of,
578
+ ffn_dim_multiplier,
579
+ norm_eps,
580
+ qk_norm,
581
+ modulation=True,
582
+ )
583
+ for layer_id in range(n_refiner_layers)
584
+ ]
585
+ )
586
+ self.context_refiner = nn.ModuleList(
587
+ [
588
+ JointTransformerBlock(
589
+ layer_id,
590
+ dim,
591
+ n_heads,
592
+ n_kv_heads,
593
+ multiple_of,
594
+ ffn_dim_multiplier,
595
+ norm_eps,
596
+ qk_norm,
597
+ modulation=False,
598
+ )
599
+ for layer_id in range(n_refiner_layers)
600
+ ]
601
+ )
602
+
603
+ self.t_embedder = TimestepEmbedder(min(dim, 1024))
604
+ self.cap_embedder = nn.Sequential(
605
+ RMSNorm(cap_feat_dim, eps=norm_eps),
606
+ nn.Linear(
607
+ cap_feat_dim,
608
+ dim,
609
+ bias=True,
610
+ ),
611
+ )
612
+ nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
613
+ # nn.init.zeros_(self.cap_embedder[1].weight)
614
+ nn.init.zeros_(self.cap_embedder[1].bias)
615
+
616
+ self.layers = nn.ModuleList(
617
+ [
618
+ JointTransformerBlock(
619
+ layer_id,
620
+ dim,
621
+ n_heads,
622
+ n_kv_heads,
623
+ multiple_of,
624
+ ffn_dim_multiplier,
625
+ norm_eps,
626
+ qk_norm,
627
+ )
628
+ for layer_id in range(n_layers)
629
+ ]
630
+ )
631
+ self.norm_final = RMSNorm(dim, eps=norm_eps)
632
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
633
+
634
+ assert (dim // n_heads) == sum(axes_dims)
635
+ self.axes_dims = axes_dims
636
+ self.axes_lens = axes_lens
637
+ self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
638
+ self.dim = dim
639
+ self.n_heads = n_heads
640
+
641
+ def unpatchify(
642
+ self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
643
+ ) -> List[torch.Tensor]:
644
+ """
645
+ x: (N, T, patch_size**2 * C)
646
+ imgs: (N, H, W, C)
647
+ """
648
+ pH = pW = self.patch_size
649
+ imgs = []
650
+ for i in range(x.size(0)):
651
+ H, W = img_size[i]
652
+ begin = cap_size[i]
653
+ end = begin + (H // pH) * (W // pW)
654
+ imgs.append(
655
+ x[i][begin:end]
656
+ .view(H // pH, W // pW, pH, pW, self.out_channels)
657
+ .permute(4, 0, 2, 1, 3)
658
+ .flatten(3, 4)
659
+ .flatten(1, 2)
660
+ )
661
+
662
+ if return_tensor:
663
+ imgs = torch.stack(imgs, dim=0)
664
+ return imgs
665
+
666
+ def patchify_and_embed(
667
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor
668
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
669
+ bsz = len(x)
670
+ pH = pW = self.patch_size
671
+ device = x[0].device
672
+
673
+ l_effective_cap_len = cap_mask.sum(dim=1).tolist()
674
+ img_sizes = [(img.size(1), img.size(2)) for img in x]
675
+ l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
676
+
677
+ max_seq_len = max(
678
+ (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
679
+ )
680
+ max_cap_len = max(l_effective_cap_len)
681
+ max_img_len = max(l_effective_img_len)
682
+
683
+ position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
684
+
685
+ for i in range(bsz):
686
+ cap_len = l_effective_cap_len[i]
687
+ img_len = l_effective_img_len[i]
688
+ H, W = img_sizes[i]
689
+ H_tokens, W_tokens = H // pH, W // pW
690
+ assert H_tokens * W_tokens == img_len
691
+
692
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
693
+ position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
694
+ row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
695
+ col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
696
+ position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
697
+ position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
698
+
699
+ freqs_cis = self.rope_embedder(position_ids)
700
+
701
+ # build freqs_cis for cap and image individually
702
+ cap_freqs_cis_shape = list(freqs_cis.shape)
703
+ # cap_freqs_cis_shape[1] = max_cap_len
704
+ cap_freqs_cis_shape[1] = cap_feats.shape[1]
705
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
706
+
707
+ img_freqs_cis_shape = list(freqs_cis.shape)
708
+ img_freqs_cis_shape[1] = max_img_len
709
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
710
+
711
+ for i in range(bsz):
712
+ cap_len = l_effective_cap_len[i]
713
+ img_len = l_effective_img_len[i]
714
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
715
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
716
+
717
+ # refine context
718
+ for layer in self.context_refiner:
719
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
720
+
721
+ # refine image
722
+ flat_x = []
723
+ for i in range(bsz):
724
+ img = x[i]
725
+ C, H, W = img.size()
726
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
727
+ flat_x.append(img)
728
+ x = flat_x
729
+ padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
730
+ padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device)
731
+ for i in range(bsz):
732
+ padded_img_embed[i, :l_effective_img_len[i]] = x[i]
733
+ padded_img_mask[i, :l_effective_img_len[i]] = True
734
+
735
+ padded_img_embed = self.x_embedder(padded_img_embed)
736
+ for layer in self.noise_refiner:
737
+ padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
738
+
739
+ mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
740
+ padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
741
+ for i in range(bsz):
742
+ cap_len = l_effective_cap_len[i]
743
+ img_len = l_effective_img_len[i]
744
+
745
+ mask[i, :cap_len+img_len] = True
746
+ padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
747
+ padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
748
+
749
+ return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
750
+
751
+
752
+ def forward(self, x, t, cap_feats, cap_mask):
753
+ """
754
+ Forward pass of NextDiT.
755
+ t: (N,) tensor of diffusion timesteps
756
+ y: (N,) tensor of text tokens/features
757
+ """
758
+
759
+ # import torch.distributed as dist
760
+ # if not dist.is_initialized() or dist.get_rank() == 0:
761
+ # import pdb
762
+ # pdb.set_trace()
763
+ # torch.save([x, t, cap_feats, cap_mask], "./fake_input.pt")
764
+ t = self.t_embedder(t) # (N, D)
765
+ adaln_input = t
766
+
767
+ cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
768
+
769
+ x_is_tensor = isinstance(x, torch.Tensor)
770
+ x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t)
771
+ freqs_cis = freqs_cis.to(x.device)
772
+
773
+ for layer in self.layers:
774
+ x = layer(x, mask, freqs_cis, adaln_input)
775
+
776
+ x = self.final_layer(x, adaln_input)
777
+ x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)
778
+
779
+ return x
780
+
781
+ def forward_with_cfg(
782
+ self,
783
+ x,
784
+ t,
785
+ cap_feats,
786
+ cap_mask,
787
+ cfg_scale,
788
+ cfg_trunc=100,
789
+ renorm_cfg=1
790
+ ):
791
+ """
792
+ Forward pass of NextDiT, but also batches the unconditional forward pass
793
+ for classifier-free guidance.
794
+ """
795
+ # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
796
+ half = x[: len(x) // 2]
797
+ if t[0] < cfg_trunc:
798
+ combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
799
+ model_out = self.forward(combined, t, cap_feats, cap_mask) # [2, 16, 128, 128]
800
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
801
+ # three channels by default. The standard approach to cfg applies it to all channels.
802
+ # This can be done by uncommenting the following line and commenting-out the line following that.
803
+ eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
804
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
805
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
806
+ if float(renorm_cfg) > 0.0:
807
+ ori_pos_norm = torch.linalg.vector_norm(cond_eps
808
+ , dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
809
+ )
810
+ max_new_norm = ori_pos_norm * float(renorm_cfg)
811
+ new_pos_norm = torch.linalg.vector_norm(
812
+ half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
813
+ )
814
+ if new_pos_norm >= max_new_norm:
815
+ half_eps = half_eps * (max_new_norm / new_pos_norm)
816
+ else:
817
+ combined = half
818
+ model_out = self.forward(combined, t[:len(x) // 2], cap_feats[:len(x) // 2], cap_mask[:len(x) // 2])
819
+ eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
820
+ half_eps = eps
821
+
822
+ output = torch.cat([half_eps, half_eps], dim=0)
823
+ return output
824
+
825
+ @staticmethod
826
+ def precompute_freqs_cis(
827
+ dim: List[int],
828
+ end: List[int],
829
+ theta: float = 10000.0,
830
+ ):
831
+ """
832
+ Precompute the frequency tensor for complex exponentials (cis) with
833
+ given dimensions.
834
+
835
+ This function calculates a frequency tensor with complex exponentials
836
+ using the given dimension 'dim' and the end index 'end'. The 'theta'
837
+ parameter scales the frequencies. The returned tensor contains complex
838
+ values in complex64 data type.
839
+
840
+ Args:
841
+ dim (list): Dimension of the frequency tensor.
842
+ end (list): End index for precomputing frequencies.
843
+ theta (float, optional): Scaling factor for frequency computation.
844
+ Defaults to 10000.0.
845
+
846
+ Returns:
847
+ torch.Tensor: Precomputed frequency tensor with complex
848
+ exponentials.
849
+ """
850
+ freqs_cis = []
851
+ for i, (d, e) in enumerate(zip(dim, end)):
852
+ freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
853
+ timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
854
+ freqs = torch.outer(timestep, freqs).float()
855
+ freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
856
+ freqs_cis.append(freqs_cis_i)
857
+
858
+ return freqs_cis
859
+
860
+ def parameter_count(self) -> int:
861
+ total_params = 0
862
+
863
+ def _recursive_count_params(module):
864
+ nonlocal total_params
865
+ for param in module.parameters(recurse=False):
866
+ total_params += param.numel()
867
+ for submodule in module.children():
868
+ _recursive_count_params(submodule)
869
+
870
+ _recursive_count_params(self)
871
+ return total_params
872
+
873
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
874
+ return list(self.layers)
875
+
876
+ def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
877
+ return list(self.layers)
878
+
879
+
880
+ #############################################################################
881
+ # NextDiT Configs #
882
+ #############################################################################
883
+
884
+ def NextDiT_2B_GQA_patch2_Adaln_Refiner(**kwargs):
885
+ return NextDiT(
886
+ patch_size=2,
887
+ dim=2304,
888
+ n_layers=26,
889
+ n_heads=24,
890
+ n_kv_heads=8,
891
+ axes_dims=[32, 32, 32],
892
+ axes_lens=[300, 512, 512],
893
+ **kwargs
894
+ )
895
+
896
+ def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs):
897
+ return NextDiT(
898
+ patch_size=2,
899
+ dim=2592,
900
+ n_layers=30,
901
+ n_heads=24,
902
+ n_kv_heads=8,
903
+ axes_dims=[36, 36, 36],
904
+ axes_lens=[300, 512, 512],
905
+ **kwargs,
906
+ )
907
+
908
+ def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs):
909
+ return NextDiT(
910
+ patch_size=2,
911
+ dim=2880,
912
+ n_layers=32,
913
+ n_heads=24,
914
+ n_kv_heads=8,
915
+ axes_dims=[40, 40, 40],
916
+ axes_lens=[300, 512, 512],
917
+ **kwargs,
918
+ )
919
+
920
+ def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
921
+ return NextDiT(
922
+ patch_size=2,
923
+ dim=3840,
924
+ n_layers=32,
925
+ n_heads=32,
926
+ n_kv_heads=8,
927
+ axes_dims=[40, 40, 40],
928
+ axes_lens=[300, 512, 512],
929
+ **kwargs,
930
+ )
parallel.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import subprocess
5
+ from time import sleep
6
+
7
+ import fairscale.nn.model_parallel.initialize as fs_init
8
+ import torch
9
+ import torch.distributed as dist
10
+ from datetime import timedelta
11
+
12
+
13
+ def _setup_dist_env_from_slurm(args):
14
+ while not os.environ.get("MASTER_ADDR", ""):
15
+ os.environ["MASTER_ADDR"] = (
16
+ subprocess.check_output(
17
+ "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
18
+ shell=True,
19
+ )
20
+ .decode()
21
+ .strip()
22
+ )
23
+ sleep(1)
24
+ if not os.environ.get("MASTER_PORT"):
25
+ os.environ["MASTER_PORT"] = str(args.master_port)
26
+ if not os.environ.get("WORLD_SIZE"):
27
+ os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
28
+ if not os.environ.get("RANK"):
29
+ os.environ["RANK"] = os.environ["SLURM_PROCID"]
30
+ if not os.environ.get("LOCAL_RANK"):
31
+ os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
32
+ if not os.environ.get("LOCAL_WORLD_SIZE"):
33
+ os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
34
+
35
+
36
+ _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
37
+ _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
38
+
39
+
40
+ def get_local_rank() -> int:
41
+ return _LOCAL_RANK
42
+
43
+
44
+ def get_local_world_size() -> int:
45
+ return _LOCAL_WORLD_SIZE
46
+
47
+
48
+ def distributed_init(args):
49
+ if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
50
+ _setup_dist_env_from_slurm(args)
51
+
52
+ dist.init_process_group("nccl", timeout=timedelta(hours=5))
53
+ fs_init.initialize_model_parallel(args.model_parallel_size)
54
+ torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
55
+
56
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
57
+ _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
58
+ _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
59
+
60
+ global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
61
+ local_ranks, local_world_sizes = [
62
+ torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
63
+ ]
64
+ dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
65
+ dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
66
+ local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
67
+ node_ranks = [[0]]
68
+ for i in range(1, dist.get_world_size()):
69
+ if len(node_ranks[-1]) == local_world_sizes[i - 1]:
70
+ node_ranks.append([])
71
+ else:
72
+ assert local_world_sizes[i] == local_world_sizes[i - 1]
73
+ node_ranks[-1].append(i)
74
+ for ranks in node_ranks:
75
+ group = dist.new_group(ranks)
76
+ if dist.get_rank() in ranks:
77
+ assert _INTRA_NODE_PROCESS_GROUP is None
78
+ _INTRA_NODE_PROCESS_GROUP = group
79
+ assert _INTRA_NODE_PROCESS_GROUP is not None
80
+
81
+ if min(local_world_sizes) == max(local_world_sizes):
82
+ for i in range(get_local_world_size()):
83
+ group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
84
+ if i == get_local_rank():
85
+ assert _INTER_NODE_PROCESS_GROUP is None
86
+ _INTER_NODE_PROCESS_GROUP = group
87
+ assert _INTER_NODE_PROCESS_GROUP is not None
88
+
89
+
90
+ def get_intra_node_process_group():
91
+ assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
92
+ return _INTRA_NODE_PROCESS_GROUP
93
+
94
+
95
+ def get_inter_node_process_group():
96
+ assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
97
+ return _INTER_NODE_PROCESS_GROUP
transport/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import ModelType, PathType, Sampler, Transport, WeightType
2
+
3
+
4
+ def create_transport(
5
+ path_type="Linear",
6
+ prediction="velocity",
7
+ loss_weight=None,
8
+ train_eps=None,
9
+ sample_eps=None,
10
+ snr_type="uniform",
11
+ do_shift=True,
12
+ seq_len=1024, # corresponding to 512x512
13
+ ):
14
+ """function for creating Transport object
15
+ **Note**: model prediction defaults to velocity
16
+ Args:
17
+ - path_type: type of path to use; default to linear
18
+ - learn_score: set model prediction to score
19
+ - learn_noise: set model prediction to noise
20
+ - velocity_weighted: weight loss by velocity weight
21
+ - likelihood_weighted: weight loss by likelihood weight
22
+ - train_eps: small epsilon for avoiding instability during training
23
+ - sample_eps: small epsilon for avoiding instability during sampling
24
+ """
25
+
26
+ if prediction == "noise":
27
+ model_type = ModelType.NOISE
28
+ elif prediction == "score":
29
+ model_type = ModelType.SCORE
30
+ else:
31
+ model_type = ModelType.VELOCITY
32
+
33
+ if loss_weight == "velocity":
34
+ loss_type = WeightType.VELOCITY
35
+ elif loss_weight == "likelihood":
36
+ loss_type = WeightType.LIKELIHOOD
37
+ else:
38
+ loss_type = WeightType.NONE
39
+
40
+ path_choice = {
41
+ "Linear": PathType.LINEAR,
42
+ "GVP": PathType.GVP,
43
+ "VP": PathType.VP,
44
+ }
45
+
46
+ path_type = path_choice[path_type]
47
+
48
+ if path_type in [PathType.VP]:
49
+ train_eps = 1e-5 if train_eps is None else train_eps
50
+ sample_eps = 1e-3 if train_eps is None else sample_eps
51
+ elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
52
+ train_eps = 1e-3 if train_eps is None else train_eps
53
+ sample_eps = 1e-3 if train_eps is None else sample_eps
54
+ else: # velocity & [GVP, LINEAR] is stable everywhere
55
+ train_eps = 0
56
+ sample_eps = 0
57
+
58
+ # create flow state
59
+ state = Transport(
60
+ model_type=model_type,
61
+ path_type=path_type,
62
+ loss_type=loss_type,
63
+ train_eps=train_eps,
64
+ sample_eps=sample_eps,
65
+ snr_type=snr_type,
66
+ do_shift=do_shift,
67
+ seq_len=seq_len,
68
+ )
69
+
70
+ return state
transport/dpm_solver.py ADDED
@@ -0,0 +1,1386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
18
+ import os
19
+
20
+ import torch
21
+ from tqdm import tqdm
22
+
23
+
24
+ class NoiseScheduleFlow:
25
+ def __init__(
26
+ self,
27
+ schedule="discrete_flow",
28
+ ):
29
+ """Create a wrapper class for the forward SDE (EDM type)."""
30
+ self.T = 1
31
+ self.t0 = 0.001
32
+ self.schedule = schedule # ['continuous', 'discrete_flow']
33
+ self.total_N = 1000
34
+
35
+ def marginal_log_mean_coeff(self, t):
36
+ """
37
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
38
+ """
39
+ return torch.log(self.marginal_alpha(t))
40
+
41
+ def marginal_alpha(self, t):
42
+ """
43
+ Compute alpha_t of a given continuous-time label t in [0, T].
44
+ """
45
+ return 1 - t
46
+
47
+ @staticmethod
48
+ def marginal_std(t):
49
+ """
50
+ Compute sigma_t of a given continuous-time label t in [0, T].
51
+ """
52
+ return t
53
+
54
+ def marginal_lambda(self, t):
55
+ """
56
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
57
+ """
58
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
59
+ log_std = torch.log(self.marginal_std(t))
60
+ return log_mean_coeff - log_std
61
+
62
+ @staticmethod
63
+ def inverse_lambda(lamb):
64
+ """
65
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
66
+ """
67
+ return torch.exp(-lamb)
68
+
69
+
70
+ def model_wrapper(
71
+ model,
72
+ noise_schedule,
73
+ model_type="noise",
74
+ model_kwargs={},
75
+ guidance_type="uncond",
76
+ condition=None,
77
+ unconditional_condition=None,
78
+ guidance_scale=1.0,
79
+ interval_guidance=[0, 1.0],
80
+ classifier_fn=None,
81
+ classifier_kwargs={},
82
+ ):
83
+ """Create a wrapper function for the noise prediction model.
84
+
85
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
86
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
87
+
88
+ We support four types of the diffusion model by setting `model_type`:
89
+
90
+ 1. "noise": noise prediction model. (Trained by predicting noise).
91
+
92
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
93
+
94
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
95
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
96
+
97
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
98
+ arXiv preprint arXiv:2202.00512 (2022).
99
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
100
+ arXiv preprint arXiv:2210.02303 (2022).
101
+
102
+ 4. "score": marginal score function. (Trained by denoising score matching).
103
+ Note that the score function and the noise prediction model follows a simple relationship:
104
+ ```
105
+ noise(x_t, t) = -sigma_t * score(x_t, t)
106
+ ```
107
+
108
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
109
+ 1. "uncond": unconditional sampling by DPMs.
110
+ The input `model` has the following format:
111
+ ``
112
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
113
+ ``
114
+
115
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
116
+ The input `model` has the following format:
117
+ ``
118
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
119
+ ``
120
+
121
+ The input `classifier_fn` has the following format:
122
+ ``
123
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
124
+ ``
125
+
126
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
127
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
128
+
129
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
130
+ The input `model` has the following format:
131
+ ``
132
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
133
+ ``
134
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
135
+
136
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
137
+ arXiv preprint arXiv:2207.12598 (2022).
138
+
139
+
140
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
141
+ or continuous-time labels (i.e. epsilon to T).
142
+
143
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
144
+ ``
145
+ def model_fn(x, t_continuous) -> noise:
146
+ t_input = get_model_input_time(t_continuous)
147
+ return noise_pred(model, x, t_input, **model_kwargs)
148
+ ``
149
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
150
+
151
+ ===============================================================
152
+
153
+ Args:
154
+ model: A diffusion model with the corresponding format described above.
155
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
156
+ model_type: A `str`. The parameterization type of the diffusion model.
157
+ "noise" or "x_start" or "v" or "score".
158
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
159
+ guidance_type: A `str`. The type of the guidance for sampling.
160
+ "uncond" or "classifier" or "classifier-free".
161
+ condition: A pytorch tensor. The condition for the guided sampling.
162
+ Only used for "classifier" or "classifier-free" guidance type.
163
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
164
+ Only used for "classifier-free" guidance type.
165
+ guidance_scale: A `float`. The scale for the guided sampling.
166
+ classifier_fn: A classifier function. Only used for the classifier guidance.
167
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
168
+ Returns:
169
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
170
+ """
171
+
172
+ def get_model_input_time(t_continuous):
173
+ """
174
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
175
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
176
+ For continuous-time DPMs, we just use `t_continuous`.
177
+ """
178
+ if noise_schedule.schedule == "discrete":
179
+ return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
180
+ elif noise_schedule.schedule == "discrete_flow":
181
+ return t_continuous * noise_schedule.total_N
182
+ else:
183
+ return t_continuous
184
+
185
+ def noise_pred_fn(x, t_continuous, cond=None):
186
+ t_input = get_model_input_time(t_continuous)
187
+ if cond is None:
188
+ output = model(x, t_input, **model_kwargs)
189
+ else:
190
+ output = model(x, t_input, cond, **model_kwargs)
191
+ if model_type == "noise":
192
+ return output
193
+ elif model_type == "x_start":
194
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
195
+ return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
196
+ elif model_type == "v":
197
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
198
+ return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
199
+ elif model_type == "score":
200
+ sigma_t = noise_schedule.marginal_std(t_continuous)
201
+ return -expand_dims(sigma_t, x.dim()) * output
202
+ elif model_type == "flow":
203
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
204
+ try:
205
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
206
+ except:
207
+ noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
208
+ return noise
209
+
210
+ def cond_grad_fn(x, t_input):
211
+ """
212
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
213
+ """
214
+ with torch.enable_grad():
215
+ x_in = x.detach().requires_grad_(True)
216
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
217
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
218
+
219
+ def model_fn(x, t_continuous):
220
+ """
221
+ The noise predicition model function that is used for DPM-Solver.
222
+ """
223
+ guidance_tp = guidance_type
224
+ if guidance_tp == "uncond":
225
+ return noise_pred_fn(x, t_continuous)
226
+ elif guidance_tp == "classifier":
227
+ assert classifier_fn is not None
228
+ t_input = get_model_input_time(t_continuous)
229
+ cond_grad = cond_grad_fn(x, t_input)
230
+ sigma_t = noise_schedule.marginal_std(t_continuous)
231
+ noise = noise_pred_fn(x, t_continuous)
232
+ return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
233
+ elif guidance_tp == "classifier-free":
234
+ if (
235
+ guidance_scale == 1.0
236
+ or unconditional_condition is None
237
+ or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
238
+ ):
239
+ return noise_pred_fn(x, t_continuous, cond=condition)
240
+ else:
241
+ x_in = torch.cat([x] * 2)
242
+ t_in = torch.cat([t_continuous] * 2)
243
+ c_in = torch.cat([unconditional_condition, condition])
244
+ try:
245
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
246
+ except:
247
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
248
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
249
+
250
+ assert model_type in ["noise", "x_start", "v", "score", "flow"]
251
+ assert guidance_type in [
252
+ "uncond",
253
+ "classifier",
254
+ "classifier-free",
255
+ ]
256
+ return model_fn
257
+
258
+
259
+ class DPM_Solver:
260
+ def __init__(
261
+ self,
262
+ model_fn,
263
+ noise_schedule,
264
+ algorithm_type="dpmsolver++",
265
+ correcting_x0_fn=None,
266
+ correcting_xt_fn=None,
267
+ thresholding_max_val=1.0,
268
+ dynamic_thresholding_ratio=0.995,
269
+ ):
270
+ """Construct a DPM-Solver.
271
+
272
+ We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
273
+
274
+ We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
275
+ can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
276
+ dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
277
+ DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
278
+ DPMs (such as stable-diffusion).
279
+
280
+ To support advanced algorithms in image-to-image applications, we also support corrector functions for
281
+ both x0 and xt.
282
+
283
+ Args:
284
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
285
+ ``
286
+ def model_fn(x, t_continuous):
287
+ return noise
288
+ ``
289
+ The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
290
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
291
+ algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
292
+ correcting_x0_fn: A `str` or a function with the following format:
293
+ ```
294
+ def correcting_x0_fn(x0, t):
295
+ x0_new = ...
296
+ return x0_new
297
+ ```
298
+ This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
299
+ ```
300
+ x0_pred = data_pred_model(xt, t)
301
+ if correcting_x0_fn is not None:
302
+ x0_pred = correcting_x0_fn(x0_pred, t)
303
+ xt_1 = update(x0_pred, xt, t)
304
+ ```
305
+ If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
306
+ correcting_xt_fn: A function with the following format:
307
+ ```
308
+ def correcting_xt_fn(xt, t, step):
309
+ x_new = ...
310
+ return x_new
311
+ ```
312
+ This function is to correct the intermediate samples xt at each sampling step. e.g.,
313
+ ```
314
+ xt = ...
315
+ xt = correcting_xt_fn(xt, t, step)
316
+ ```
317
+ thresholding_max_val: A `float`. The max value for thresholding.
318
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
319
+ dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
320
+ Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
321
+
322
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
323
+ Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
324
+ with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
325
+ """
326
+ self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
327
+ self.noise_schedule = noise_schedule
328
+ assert algorithm_type in ["dpmsolver", "dpmsolver++"]
329
+ self.algorithm_type = algorithm_type
330
+ if correcting_x0_fn == "dynamic_thresholding":
331
+ self.correcting_x0_fn = self.dynamic_thresholding_fn
332
+ else:
333
+ self.correcting_x0_fn = correcting_x0_fn
334
+ self.correcting_xt_fn = correcting_xt_fn
335
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
336
+ self.thresholding_max_val = thresholding_max_val
337
+ self.register_progress_bar()
338
+
339
+ def register_progress_bar(self, progress_fn=None):
340
+ """
341
+ Register a progress bar callback function
342
+
343
+ Args:
344
+ progress_fn: Callback function that takes current step and total steps as parameters
345
+ """
346
+ self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
347
+
348
+ def update_progress(self, step, total_steps):
349
+ """
350
+ Update sampling progress
351
+
352
+ Args:
353
+ step: Current step number
354
+ total_steps: Total number of steps
355
+ """
356
+ if hasattr(self, "progress_fn"):
357
+ try:
358
+ self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
359
+ except:
360
+ self.progress_fn(step, total_steps)
361
+
362
+ else:
363
+ # If no progress_fn registered, use default empty function
364
+ pass
365
+
366
+ def dynamic_thresholding_fn(self, x0, t):
367
+ """
368
+ The dynamic thresholding method.
369
+ """
370
+ dims = x0.dim()
371
+ p = self.dynamic_thresholding_ratio
372
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
373
+ s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
374
+ x0 = torch.clamp(x0, -s, s) / s
375
+ return x0
376
+
377
+ def noise_prediction_fn(self, x, t):
378
+ """
379
+ Return the noise prediction model.
380
+ """
381
+ return self.model(x, t)
382
+
383
+ def data_prediction_fn(self, x, t):
384
+ """
385
+ Return the data prediction model (with corrector).
386
+ """
387
+ noise = self.noise_prediction_fn(x, t)
388
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
389
+ x0 = (x - sigma_t * noise) / alpha_t
390
+ if self.correcting_x0_fn is not None:
391
+ x0 = self.correcting_x0_fn(x0, t)
392
+ return x0
393
+
394
+ def model_fn(self, x, t):
395
+ """
396
+ Convert the model to the noise prediction model or the data prediction model.
397
+ """
398
+ if self.algorithm_type == "dpmsolver++":
399
+ return self.data_prediction_fn(x, t)
400
+ else:
401
+ return self.noise_prediction_fn(x, t)
402
+
403
+ def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
404
+ """Compute the intermediate time steps for sampling.
405
+
406
+ Args:
407
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
408
+ - 'logSNR': uniform logSNR for the time steps.
409
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
410
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
411
+ t_T: A `float`. The starting time of the sampling (default is T).
412
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
413
+ N: A `int`. The total number of the spacing of the time steps.
414
+ device: A torch device.
415
+ Returns:
416
+ A pytorch tensor of the time steps, with the shape (N + 1,).
417
+ """
418
+ if skip_type == "logSNR":
419
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
420
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
421
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
422
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
423
+ elif skip_type == "time_uniform":
424
+ return torch.linspace(t_T, t_0, N + 1).to(device)
425
+ elif skip_type == "time_quadratic":
426
+ t_order = 2
427
+ t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
428
+ return t
429
+ elif skip_type == "time_uniform_flow":
430
+ betas = torch.linspace(t_T, t_0, N + 1).to(device)
431
+ sigmas = 1.0 - betas
432
+ sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
433
+ return sigmas
434
+ else:
435
+ raise ValueError(
436
+ f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
437
+ )
438
+
439
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
+ """
441
+ Get the order of each step for sampling by the singlestep DPM-Solver.
442
+
443
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
444
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
445
+ - If order == 1:
446
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
447
+ - If order == 2:
448
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
449
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
450
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
451
+ - If order == 3:
452
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
453
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
454
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
455
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
456
+
457
+ ============================================
458
+ Args:
459
+ order: A `int`. The max order for the solver (2 or 3).
460
+ steps: A `int`. The total number of function evaluations (NFE).
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ device: A torch device.
468
+ Returns:
469
+ orders: A list of the solver order of each step.
470
+ """
471
+ if order == 3:
472
+ K = steps // 3 + 1
473
+ if steps % 3 == 0:
474
+ orders = [3,] * (
475
+ K - 2
476
+ ) + [2, 1]
477
+ elif steps % 3 == 1:
478
+ orders = [3,] * (
479
+ K - 1
480
+ ) + [1]
481
+ else:
482
+ orders = [3,] * (
483
+ K - 1
484
+ ) + [2]
485
+ elif order == 2:
486
+ if steps % 2 == 0:
487
+ K = steps // 2
488
+ orders = [
489
+ 2,
490
+ ] * K
491
+ else:
492
+ K = steps // 2 + 1
493
+ orders = [2,] * (
494
+ K - 1
495
+ ) + [1]
496
+ elif order == 1:
497
+ K = 1
498
+ orders = [
499
+ 1,
500
+ ] * steps
501
+ else:
502
+ raise ValueError("'order' must be '1' or '2' or '3'.")
503
+ if skip_type == "logSNR":
504
+ # To reproduce the results in DPM-Solver paper
505
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
506
+ else:
507
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
508
+ torch.cumsum(
509
+ torch.tensor(
510
+ [
511
+ 0,
512
+ ]
513
+ + orders
514
+ ),
515
+ 0,
516
+ ).to(device)
517
+ ]
518
+ return timesteps_outer, orders
519
+
520
+ def denoise_to_zero_fn(self, x, s):
521
+ """
522
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
523
+ """
524
+ return self.data_prediction_fn(x, s)
525
+
526
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
527
+ """
528
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
529
+
530
+ Args:
531
+ x: A pytorch tensor. The initial value at time `s`.
532
+ s: A pytorch tensor. The starting time, with the shape (1,).
533
+ t: A pytorch tensor. The ending time, with the shape (1,).
534
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
535
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
536
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
537
+ Returns:
538
+ x_t: A pytorch tensor. The approximated solution at time `t`.
539
+ """
540
+ ns = self.noise_schedule
541
+ dims = x.dim()
542
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
543
+ h = lambda_t - lambda_s
544
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
545
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
546
+ alpha_t = torch.exp(log_alpha_t)
547
+
548
+ if self.algorithm_type == "dpmsolver++":
549
+ phi_1 = torch.expm1(-h)
550
+ if model_s is None:
551
+ model_s = self.model_fn(x, s)
552
+ x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
553
+ if return_intermediate:
554
+ return x_t, {"model_s": model_s}
555
+ else:
556
+ return x_t
557
+ else:
558
+ phi_1 = torch.expm1(h)
559
+ if model_s is None:
560
+ model_s = self.model_fn(x, s)
561
+ x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
562
+ if return_intermediate:
563
+ return x_t, {"model_s": model_s}
564
+ else:
565
+ return x_t
566
+
567
+ def singlestep_dpm_solver_second_update(
568
+ self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
569
+ ):
570
+ """
571
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
572
+
573
+ Args:
574
+ x: A pytorch tensor. The initial value at time `s`.
575
+ s: A pytorch tensor. The starting time, with the shape (1,).
576
+ t: A pytorch tensor. The ending time, with the shape (1,).
577
+ r1: A `float`. The hyperparameter of the second-order solver.
578
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
579
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
580
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
581
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
582
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
583
+ Returns:
584
+ x_t: A pytorch tensor. The approximated solution at time `t`.
585
+ """
586
+ if solver_type not in ["dpmsolver", "taylor"]:
587
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
588
+ if r1 is None:
589
+ r1 = 0.5
590
+ ns = self.noise_schedule
591
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
592
+ h = lambda_t - lambda_s
593
+ lambda_s1 = lambda_s + r1 * h
594
+ s1 = ns.inverse_lambda(lambda_s1)
595
+ log_alpha_s, log_alpha_s1, log_alpha_t = (
596
+ ns.marginal_log_mean_coeff(s),
597
+ ns.marginal_log_mean_coeff(s1),
598
+ ns.marginal_log_mean_coeff(t),
599
+ )
600
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
601
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
602
+
603
+ if self.algorithm_type == "dpmsolver++":
604
+ phi_11 = torch.expm1(-r1 * h)
605
+ phi_1 = torch.expm1(-h)
606
+
607
+ if model_s is None:
608
+ model_s = self.model_fn(x, s)
609
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
610
+ model_s1 = self.model_fn(x_s1, s1)
611
+ if solver_type == "dpmsolver":
612
+ x_t = (
613
+ (sigma_t / sigma_s) * x
614
+ - (alpha_t * phi_1) * model_s
615
+ - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
616
+ )
617
+ elif solver_type == "taylor":
618
+ x_t = (
619
+ (sigma_t / sigma_s) * x
620
+ - (alpha_t * phi_1) * model_s
621
+ + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
622
+ )
623
+ else:
624
+ phi_11 = torch.expm1(r1 * h)
625
+ phi_1 = torch.expm1(h)
626
+
627
+ if model_s is None:
628
+ model_s = self.model_fn(x, s)
629
+ x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
630
+ model_s1 = self.model_fn(x_s1, s1)
631
+ if solver_type == "dpmsolver":
632
+ x_t = (
633
+ torch.exp(log_alpha_t - log_alpha_s) * x
634
+ - (sigma_t * phi_1) * model_s
635
+ - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
636
+ )
637
+ elif solver_type == "taylor":
638
+ x_t = (
639
+ torch.exp(log_alpha_t - log_alpha_s) * x
640
+ - (sigma_t * phi_1) * model_s
641
+ - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
642
+ )
643
+ if return_intermediate:
644
+ return x_t, {"model_s": model_s, "model_s1": model_s1}
645
+ else:
646
+ return x_t
647
+
648
+ def singlestep_dpm_solver_third_update(
649
+ self,
650
+ x,
651
+ s,
652
+ t,
653
+ r1=1.0 / 3.0,
654
+ r2=2.0 / 3.0,
655
+ model_s=None,
656
+ model_s1=None,
657
+ return_intermediate=False,
658
+ solver_type="dpmsolver",
659
+ ):
660
+ """
661
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
662
+
663
+ Args:
664
+ x: A pytorch tensor. The initial value at time `s`.
665
+ s: A pytorch tensor. The starting time, with the shape (1,).
666
+ t: A pytorch tensor. The ending time, with the shape (1,).
667
+ r1: A `float`. The hyperparameter of the third-order solver.
668
+ r2: A `float`. The hyperparameter of the third-order solver.
669
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
670
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
671
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
672
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
673
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
674
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
675
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
676
+ Returns:
677
+ x_t: A pytorch tensor. The approximated solution at time `t`.
678
+ """
679
+ if solver_type not in ["dpmsolver", "taylor"]:
680
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
681
+ if r1 is None:
682
+ r1 = 1.0 / 3.0
683
+ if r2 is None:
684
+ r2 = 2.0 / 3.0
685
+ ns = self.noise_schedule
686
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
687
+ h = lambda_t - lambda_s
688
+ lambda_s1 = lambda_s + r1 * h
689
+ lambda_s2 = lambda_s + r2 * h
690
+ s1 = ns.inverse_lambda(lambda_s1)
691
+ s2 = ns.inverse_lambda(lambda_s2)
692
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
693
+ ns.marginal_log_mean_coeff(s),
694
+ ns.marginal_log_mean_coeff(s1),
695
+ ns.marginal_log_mean_coeff(s2),
696
+ ns.marginal_log_mean_coeff(t),
697
+ )
698
+ sigma_s, sigma_s1, sigma_s2, sigma_t = (
699
+ ns.marginal_std(s),
700
+ ns.marginal_std(s1),
701
+ ns.marginal_std(s2),
702
+ ns.marginal_std(t),
703
+ )
704
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
705
+
706
+ if self.algorithm_type == "dpmsolver++":
707
+ phi_11 = torch.expm1(-r1 * h)
708
+ phi_12 = torch.expm1(-r2 * h)
709
+ phi_1 = torch.expm1(-h)
710
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
711
+ phi_2 = phi_1 / h + 1.0
712
+ phi_3 = phi_2 / h - 0.5
713
+
714
+ if model_s is None:
715
+ model_s = self.model_fn(x, s)
716
+ if model_s1 is None:
717
+ x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
718
+ model_s1 = self.model_fn(x_s1, s1)
719
+ x_s2 = (
720
+ (sigma_s2 / sigma_s) * x
721
+ - (alpha_s2 * phi_12) * model_s
722
+ + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
723
+ )
724
+ model_s2 = self.model_fn(x_s2, s2)
725
+ if solver_type == "dpmsolver":
726
+ x_t = (
727
+ (sigma_t / sigma_s) * x
728
+ - (alpha_t * phi_1) * model_s
729
+ + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
730
+ )
731
+ elif solver_type == "taylor":
732
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
733
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
734
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
735
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
736
+ x_t = (
737
+ (sigma_t / sigma_s) * x
738
+ - (alpha_t * phi_1) * model_s
739
+ + (alpha_t * phi_2) * D1
740
+ - (alpha_t * phi_3) * D2
741
+ )
742
+ else:
743
+ phi_11 = torch.expm1(r1 * h)
744
+ phi_12 = torch.expm1(r2 * h)
745
+ phi_1 = torch.expm1(h)
746
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
747
+ phi_2 = phi_1 / h - 1.0
748
+ phi_3 = phi_2 / h - 0.5
749
+
750
+ if model_s is None:
751
+ model_s = self.model_fn(x, s)
752
+ if model_s1 is None:
753
+ x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
754
+ model_s1 = self.model_fn(x_s1, s1)
755
+ x_s2 = (
756
+ (torch.exp(log_alpha_s2 - log_alpha_s)) * x
757
+ - (sigma_s2 * phi_12) * model_s
758
+ - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
759
+ )
760
+ model_s2 = self.model_fn(x_s2, s2)
761
+ if solver_type == "dpmsolver":
762
+ x_t = (
763
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
764
+ - (sigma_t * phi_1) * model_s
765
+ - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
766
+ )
767
+ elif solver_type == "taylor":
768
+ D1_0 = (1.0 / r1) * (model_s1 - model_s)
769
+ D1_1 = (1.0 / r2) * (model_s2 - model_s)
770
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
771
+ D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
772
+ x_t = (
773
+ (torch.exp(log_alpha_t - log_alpha_s)) * x
774
+ - (sigma_t * phi_1) * model_s
775
+ - (sigma_t * phi_2) * D1
776
+ - (sigma_t * phi_3) * D2
777
+ )
778
+
779
+ if return_intermediate:
780
+ return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
781
+ else:
782
+ return x_t
783
+
784
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
785
+ """
786
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
787
+
788
+ Args:
789
+ x: A pytorch tensor. The initial value at time `s`.
790
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
791
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
792
+ t: A pytorch tensor. The ending time, with the shape (1,).
793
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
794
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
795
+ Returns:
796
+ x_t: A pytorch tensor. The approximated solution at time `t`.
797
+ """
798
+ if solver_type not in ["dpmsolver", "taylor"]:
799
+ raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
800
+ ns = self.noise_schedule
801
+ model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
802
+ t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
803
+ lambda_prev_1, lambda_prev_0, lambda_t = (
804
+ ns.marginal_lambda(t_prev_1),
805
+ ns.marginal_lambda(t_prev_0),
806
+ ns.marginal_lambda(t),
807
+ )
808
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
809
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
810
+ alpha_t = torch.exp(log_alpha_t)
811
+
812
+ h_0 = lambda_prev_0 - lambda_prev_1
813
+ h = lambda_t - lambda_prev_0
814
+ r0 = h_0 / h
815
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
816
+ if self.algorithm_type == "dpmsolver++":
817
+ phi_1 = torch.expm1(-h)
818
+ if solver_type == "dpmsolver":
819
+ x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
820
+ elif solver_type == "taylor":
821
+ x_t = (
822
+ (sigma_t / sigma_prev_0) * x
823
+ - (alpha_t * phi_1) * model_prev_0
824
+ + (alpha_t * (phi_1 / h + 1.0)) * D1_0
825
+ )
826
+ else:
827
+ phi_1 = torch.expm1(h)
828
+ if solver_type == "dpmsolver":
829
+ x_t = (
830
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
831
+ - (sigma_t * phi_1) * model_prev_0
832
+ - 0.5 * (sigma_t * phi_1) * D1_0
833
+ )
834
+ elif solver_type == "taylor":
835
+ x_t = (
836
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
837
+ - (sigma_t * phi_1) * model_prev_0
838
+ - (sigma_t * (phi_1 / h - 1.0)) * D1_0
839
+ )
840
+ return x_t
841
+
842
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
843
+ """
844
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
845
+
846
+ Args:
847
+ x: A pytorch tensor. The initial value at time `s`.
848
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
849
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
850
+ t: A pytorch tensor. The ending time, with the shape (1,).
851
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
852
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
853
+ Returns:
854
+ x_t: A pytorch tensor. The approximated solution at time `t`.
855
+ """
856
+ ns = self.noise_schedule
857
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
858
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
859
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
860
+ ns.marginal_lambda(t_prev_2),
861
+ ns.marginal_lambda(t_prev_1),
862
+ ns.marginal_lambda(t_prev_0),
863
+ ns.marginal_lambda(t),
864
+ )
865
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
866
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
867
+ alpha_t = torch.exp(log_alpha_t)
868
+
869
+ h_1 = lambda_prev_1 - lambda_prev_2
870
+ h_0 = lambda_prev_0 - lambda_prev_1
871
+ h = lambda_t - lambda_prev_0
872
+ r0, r1 = h_0 / h, h_1 / h
873
+ D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
874
+ D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
875
+ D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
876
+ D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
877
+ if self.algorithm_type == "dpmsolver++":
878
+ phi_1 = torch.expm1(-h)
879
+ phi_2 = phi_1 / h + 1.0
880
+ phi_3 = phi_2 / h - 0.5
881
+ x_t = (
882
+ (sigma_t / sigma_prev_0) * x
883
+ - (alpha_t * phi_1) * model_prev_0
884
+ + (alpha_t * phi_2) * D1
885
+ - (alpha_t * phi_3) * D2
886
+ )
887
+ else:
888
+ phi_1 = torch.expm1(h)
889
+ phi_2 = phi_1 / h - 1.0
890
+ phi_3 = phi_2 / h - 0.5
891
+ x_t = (
892
+ (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
893
+ - (sigma_t * phi_1) * model_prev_0
894
+ - (sigma_t * phi_2) * D1
895
+ - (sigma_t * phi_3) * D2
896
+ )
897
+ return x_t
898
+
899
+ def singlestep_dpm_solver_update(
900
+ self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
901
+ ):
902
+ """
903
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
904
+
905
+ Args:
906
+ x: A pytorch tensor. The initial value at time `s`.
907
+ s: A pytorch tensor. The starting time, with the shape (1,).
908
+ t: A pytorch tensor. The ending time, with the shape (1,).
909
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
910
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
911
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
912
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
913
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
914
+ r2: A `float`. The hyperparameter of the third-order solver.
915
+ Returns:
916
+ x_t: A pytorch tensor. The approximated solution at time `t`.
917
+ """
918
+ if order == 1:
919
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
920
+ elif order == 2:
921
+ return self.singlestep_dpm_solver_second_update(
922
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
923
+ )
924
+ elif order == 3:
925
+ return self.singlestep_dpm_solver_third_update(
926
+ x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
927
+ )
928
+ else:
929
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
930
+
931
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
932
+ """
933
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
934
+
935
+ Args:
936
+ x: A pytorch tensor. The initial value at time `s`.
937
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
938
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
939
+ t: A pytorch tensor. The ending time, with the shape (1,).
940
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
941
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
942
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
943
+ Returns:
944
+ x_t: A pytorch tensor. The approximated solution at time `t`.
945
+ """
946
+ if order == 1:
947
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
948
+ elif order == 2:
949
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
950
+ elif order == 3:
951
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
952
+ else:
953
+ raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
954
+
955
+ def dpm_solver_adaptive(
956
+ self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
957
+ ):
958
+ """
959
+ The adaptive step size solver based on singlestep DPM-Solver.
960
+
961
+ Args:
962
+ x: A pytorch tensor. The initial value at time `t_T`.
963
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
964
+ t_T: A `float`. The starting time of the sampling (default is T).
965
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
966
+ h_init: A `float`. The initial step size (for logSNR).
967
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
968
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
969
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
970
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
971
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
972
+ solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
973
+ The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
974
+ Returns:
975
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
976
+
977
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
978
+ """
979
+ ns = self.noise_schedule
980
+ s = t_T * torch.ones((1,)).to(x)
981
+ lambda_s = ns.marginal_lambda(s)
982
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
983
+ h = h_init * torch.ones_like(s).to(x)
984
+ x_prev = x
985
+ nfe = 0
986
+ if order == 2:
987
+ r1 = 0.5
988
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
989
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
990
+ x, s, t, r1=r1, solver_type=solver_type, **kwargs
991
+ )
992
+ elif order == 3:
993
+ r1, r2 = 1.0 / 3.0, 2.0 / 3.0
994
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
995
+ x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
996
+ )
997
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
998
+ x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
999
+ )
1000
+ else:
1001
+ raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
1002
+ while torch.abs(s - t_0).mean() > t_err:
1003
+ t = ns.inverse_lambda(lambda_s + h)
1004
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
1005
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1006
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1007
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1008
+ E = norm_fn((x_higher - x_lower) / delta).max()
1009
+ if torch.all(E <= 1.0):
1010
+ x = x_higher
1011
+ s = t
1012
+ x_prev = x_lower
1013
+ lambda_s = ns.marginal_lambda(s)
1014
+ h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
1015
+ nfe += order
1016
+ print("adaptive solver nfe", nfe)
1017
+ return x
1018
+
1019
+ def add_noise(self, x, t, noise=None):
1020
+ """
1021
+ Compute the noised input xt = alpha_t * x + sigma_t * noise.
1022
+
1023
+ Args:
1024
+ x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1025
+ t: A `torch.Tensor` with shape `(t_size,)`.
1026
+ Returns:
1027
+ xt with shape `(t_size, batch_size, *shape)`.
1028
+ """
1029
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1030
+ if noise is None:
1031
+ noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1032
+ x = x.reshape((-1, *x.shape))
1033
+ xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1034
+ if t.shape[0] == 1:
1035
+ return xt.squeeze(0)
1036
+ else:
1037
+ return xt
1038
+
1039
+ def inverse(
1040
+ self,
1041
+ x,
1042
+ steps=20,
1043
+ t_start=None,
1044
+ t_end=None,
1045
+ order=2,
1046
+ skip_type="time_uniform",
1047
+ method="multistep",
1048
+ lower_order_final=True,
1049
+ denoise_to_zero=False,
1050
+ solver_type="dpmsolver",
1051
+ atol=0.0078,
1052
+ rtol=0.05,
1053
+ return_intermediate=False,
1054
+ ):
1055
+ """
1056
+ Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1057
+ For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1058
+ """
1059
+ t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
1060
+ t_T = self.noise_schedule.T if t_end is None else t_end
1061
+ assert (
1062
+ t_0 > 0 and t_T > 0
1063
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1064
+ return self.sample(
1065
+ x,
1066
+ steps=steps,
1067
+ t_start=t_0,
1068
+ t_end=t_T,
1069
+ order=order,
1070
+ skip_type=skip_type,
1071
+ method=method,
1072
+ lower_order_final=lower_order_final,
1073
+ denoise_to_zero=denoise_to_zero,
1074
+ solver_type=solver_type,
1075
+ atol=atol,
1076
+ rtol=rtol,
1077
+ return_intermediate=return_intermediate,
1078
+ )
1079
+
1080
+ def sample(
1081
+ self,
1082
+ x,
1083
+ steps=20,
1084
+ t_start=None,
1085
+ t_end=None,
1086
+ order=2,
1087
+ skip_type="time_uniform",
1088
+ method="multistep",
1089
+ lower_order_final=True,
1090
+ denoise_to_zero=False,
1091
+ solver_type="dpmsolver",
1092
+ atol=0.0078,
1093
+ rtol=0.05,
1094
+ return_intermediate=False,
1095
+ flow_shift=1.0,
1096
+ ):
1097
+ """
1098
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1099
+
1100
+ =====================================================
1101
+
1102
+ We support the following algorithms for both noise prediction model and data prediction model:
1103
+ - 'singlestep':
1104
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1105
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1106
+ The total number of function evaluations (NFE) == `steps`.
1107
+ Given a fixed NFE == `steps`, the sampling procedure is:
1108
+ - If `order` == 1:
1109
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1110
+ - If `order` == 2:
1111
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1112
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1113
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1114
+ - If `order` == 3:
1115
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1116
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1117
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1118
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1119
+ - 'multistep':
1120
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1121
+ We initialize the first `order` values by lower order multistep solvers.
1122
+ Given a fixed NFE == `steps`, the sampling procedure is:
1123
+ Denote K = steps.
1124
+ - If `order` == 1:
1125
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
1126
+ - If `order` == 2:
1127
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1128
+ - If `order` == 3:
1129
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1130
+ - 'singlestep_fixed':
1131
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1132
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1133
+ - 'adaptive':
1134
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1135
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1136
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1137
+ (NFE) and the sample quality.
1138
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1139
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1140
+
1141
+ =====================================================
1142
+
1143
+ Some advices for choosing the algorithm:
1144
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1145
+ Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1146
+ e.g., DPM-Solver:
1147
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1148
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1149
+ skip_type='time_uniform', method='singlestep')
1150
+ e.g., DPM-Solver++:
1151
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1152
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1153
+ skip_type='time_uniform', method='singlestep')
1154
+ - For **guided sampling with large guidance scale** by DPMs:
1155
+ Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1156
+ e.g.
1157
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1158
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1159
+ skip_type='time_uniform', method='multistep')
1160
+
1161
+ We support three types of `skip_type`:
1162
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1163
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1164
+ - 'time_quadratic': quadratic time for the time steps.
1165
+
1166
+ =====================================================
1167
+ Args:
1168
+ x: A pytorch tensor. The initial value at time `t_start`
1169
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1170
+ steps: A `int`. The total number of function evaluations (NFE).
1171
+ t_start: A `float`. The starting time of the sampling.
1172
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1173
+ t_end: A `float`. The ending time of the sampling.
1174
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1175
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1176
+ For discrete-time DPMs:
1177
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1178
+ For continuous-time DPMs:
1179
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1180
+ order: A `int`. The order of DPM-Solver.
1181
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1182
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1183
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1184
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1185
+
1186
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1187
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1188
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1189
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1190
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1191
+ it for high-resolutional images.
1192
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1193
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1194
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1195
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1196
+ solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1197
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1198
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1199
+ return_intermediate: A `bool`. Whether to save the xt at each step.
1200
+ When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1201
+ Returns:
1202
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1203
+
1204
+ """
1205
+ t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
1206
+ t_T = self.noise_schedule.T if t_start is None else t_start
1207
+ assert (
1208
+ t_0 > 0 and t_T > 0
1209
+ ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1210
+ if return_intermediate:
1211
+ assert method in [
1212
+ "multistep",
1213
+ "singlestep",
1214
+ "singlestep_fixed",
1215
+ ], "Cannot use adaptive solver when saving intermediate values"
1216
+ if self.correcting_xt_fn is not None:
1217
+ assert method in [
1218
+ "multistep",
1219
+ "singlestep",
1220
+ "singlestep_fixed",
1221
+ ], "Cannot use adaptive solver when correcting_xt_fn is not None"
1222
+ device = x.device
1223
+ intermediates = []
1224
+ with torch.no_grad():
1225
+ if method == "adaptive":
1226
+ x = self.dpm_solver_adaptive(
1227
+ x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
1228
+ )
1229
+ elif method == "multistep":
1230
+ assert steps >= order
1231
+ timesteps = self.get_time_steps(
1232
+ skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
1233
+ )
1234
+ assert timesteps.shape[0] - 1 == steps
1235
+ # Init the initial values.
1236
+ step = 0
1237
+ t = timesteps[step]
1238
+ t_prev_list = [t]
1239
+ model_prev_list = [self.model_fn(x, t)]
1240
+ if self.correcting_xt_fn is not None:
1241
+ x = self.correcting_xt_fn(x, t, step)
1242
+ if return_intermediate:
1243
+ intermediates.append(x)
1244
+ self.update_progress(step + 1, len(timesteps))
1245
+ # Init the first `order` values by lower order multistep DPM-Solver.
1246
+ for step in range(1, order):
1247
+ t = timesteps[step]
1248
+ x = self.multistep_dpm_solver_update(
1249
+ x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
1250
+ )
1251
+ if self.correcting_xt_fn is not None:
1252
+ x = self.correcting_xt_fn(x, t, step)
1253
+ if return_intermediate:
1254
+ intermediates.append(x)
1255
+ t_prev_list.append(t)
1256
+ model_prev_list.append(self.model_fn(x, t))
1257
+ # update progress bar
1258
+ self.update_progress(step + 1, len(timesteps))
1259
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1260
+ for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
1261
+ t = timesteps[step]
1262
+ # We only use lower order for steps < 10
1263
+ # if lower_order_final and steps < 10:
1264
+ if lower_order_final: # recommended by Shuchen Xue
1265
+ step_order = min(order, steps + 1 - step)
1266
+ else:
1267
+ step_order = order
1268
+ x = self.multistep_dpm_solver_update(
1269
+ x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
1270
+ )
1271
+ if self.correcting_xt_fn is not None:
1272
+ x = self.correcting_xt_fn(x, t, step)
1273
+ if return_intermediate:
1274
+ intermediates.append(x)
1275
+ for i in range(order - 1):
1276
+ t_prev_list[i] = t_prev_list[i + 1]
1277
+ model_prev_list[i] = model_prev_list[i + 1]
1278
+ t_prev_list[-1] = t
1279
+ # We do not need to evaluate the final model value.
1280
+ if step < steps:
1281
+ model_prev_list[-1] = self.model_fn(x, t)
1282
+ # update progress bar
1283
+ self.update_progress(step + 1, len(timesteps))
1284
+ elif method in ["singlestep", "singlestep_fixed"]:
1285
+ if method == "singlestep":
1286
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
1287
+ steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
1288
+ )
1289
+ elif method == "singlestep_fixed":
1290
+ K = steps // order
1291
+ orders = [
1292
+ order,
1293
+ ] * K
1294
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1295
+ for step, order in enumerate(orders):
1296
+ s, t = timesteps_outer[step], timesteps_outer[step + 1]
1297
+ timesteps_inner = self.get_time_steps(
1298
+ skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
1299
+ )
1300
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1301
+ h = lambda_inner[-1] - lambda_inner[0]
1302
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1303
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1304
+ x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1305
+ if self.correcting_xt_fn is not None:
1306
+ x = self.correcting_xt_fn(x, t, step)
1307
+ if return_intermediate:
1308
+ intermediates.append(x)
1309
+ self.update_progress(step + 1, len(timesteps_outer))
1310
+ else:
1311
+ raise ValueError(f"Got wrong method {method}")
1312
+ if denoise_to_zero:
1313
+ t = torch.ones((1,)).to(device) * t_0
1314
+ x = self.denoise_to_zero_fn(x, t)
1315
+ if self.correcting_xt_fn is not None:
1316
+ x = self.correcting_xt_fn(x, t, step + 1)
1317
+ if return_intermediate:
1318
+ intermediates.append(x)
1319
+ if return_intermediate:
1320
+ return x, intermediates
1321
+ else:
1322
+ return x
1323
+
1324
+
1325
+ #############################################################
1326
+ # other utility functions
1327
+ #############################################################
1328
+
1329
+
1330
+ def interpolate_fn(x, xp, yp):
1331
+ """
1332
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1333
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1334
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1335
+
1336
+ Args:
1337
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1338
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1339
+ yp: PyTorch tensor with shape [C, K].
1340
+ Returns:
1341
+ The function values f(x), with shape [N, C].
1342
+ """
1343
+ N, K = x.shape[0], xp.shape[1]
1344
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1345
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1346
+ x_idx = torch.argmin(x_indices, dim=2)
1347
+ cand_start_idx = x_idx - 1
1348
+ start_idx = torch.where(
1349
+ torch.eq(x_idx, 0),
1350
+ torch.tensor(1, device=x.device),
1351
+ torch.where(
1352
+ torch.eq(x_idx, K),
1353
+ torch.tensor(K - 2, device=x.device),
1354
+ cand_start_idx,
1355
+ ),
1356
+ )
1357
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1358
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1359
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1360
+ start_idx2 = torch.where(
1361
+ torch.eq(x_idx, 0),
1362
+ torch.tensor(0, device=x.device),
1363
+ torch.where(
1364
+ torch.eq(x_idx, K),
1365
+ torch.tensor(K - 2, device=x.device),
1366
+ cand_start_idx,
1367
+ ),
1368
+ )
1369
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1370
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1371
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1372
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1373
+ return cand
1374
+
1375
+
1376
+ def expand_dims(v, dims):
1377
+ """
1378
+ Expand the tensor `v` to the dim `dims`.
1379
+
1380
+ Args:
1381
+ `v`: a PyTorch tensor with shape [N].
1382
+ `dim`: a `int`.
1383
+ Returns:
1384
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1385
+ """
1386
+ return v[(...,) + (None,) * (dims - 1)]
transport/integrators.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ from torchdiffeq import odeint
3
+ from .utils import time_shift, get_lin_function
4
+
5
+ class sde:
6
+ """SDE solver class"""
7
+
8
+ def __init__(
9
+ self,
10
+ drift,
11
+ diffusion,
12
+ *,
13
+ t0,
14
+ t1,
15
+ num_steps,
16
+ sampler_type,
17
+ ):
18
+ assert t0 < t1, "SDE sampler has to be in forward time"
19
+
20
+ self.num_timesteps = num_steps
21
+ self.t = th.linspace(t0, t1, num_steps)
22
+ self.dt = self.t[1] - self.t[0]
23
+ self.drift = drift
24
+ self.diffusion = diffusion
25
+ self.sampler_type = sampler_type
26
+
27
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
28
+ w_cur = th.randn(x.size()).to(x)
29
+ t = th.ones(x.size(0)).to(x) * t
30
+ dw = w_cur * th.sqrt(self.dt)
31
+ drift = self.drift(x, t, model, **model_kwargs)
32
+ diffusion = self.diffusion(x, t)
33
+ mean_x = x + drift * self.dt
34
+ x = mean_x + th.sqrt(2 * diffusion) * dw
35
+ return x, mean_x
36
+
37
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
38
+ w_cur = th.randn(x.size()).to(x)
39
+ dw = w_cur * th.sqrt(self.dt)
40
+ t_cur = th.ones(x.size(0)).to(x) * t
41
+ diffusion = self.diffusion(x, t_cur)
42
+ xhat = x + th.sqrt(2 * diffusion) * dw
43
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
44
+ xp = xhat + self.dt * K1
45
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
46
+ return (
47
+ xhat + 0.5 * self.dt * (K1 + K2),
48
+ xhat,
49
+ ) # at last time point we do not perform the heun step
50
+
51
+ def __forward_fn(self):
52
+ """TODO: generalize here by adding all private functions ending with steps to it"""
53
+ sampler_dict = {
54
+ "Euler": self.__Euler_Maruyama_step,
55
+ "Heun": self.__Heun_step,
56
+ }
57
+
58
+ try:
59
+ sampler = sampler_dict[self.sampler_type]
60
+ except:
61
+ raise NotImplementedError("Smapler type not implemented.")
62
+
63
+ return sampler
64
+
65
+ def sample(self, init, model, **model_kwargs):
66
+ """forward loop of sde"""
67
+ x = init
68
+ mean_x = init
69
+ samples = []
70
+ sampler = self.__forward_fn()
71
+ for ti in self.t[:-1]:
72
+ with th.no_grad():
73
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
74
+ samples.append(x)
75
+
76
+ return samples
77
+
78
+
79
+ class ode:
80
+ """ODE solver class"""
81
+
82
+ def __init__(
83
+ self,
84
+ drift,
85
+ *,
86
+ t0,
87
+ t1,
88
+ sampler_type,
89
+ num_steps,
90
+ atol,
91
+ rtol,
92
+ do_shift=False,
93
+ time_shifting_factor=None,
94
+ ):
95
+ assert t0 < t1, "ODE sampler has to be in forward time"
96
+
97
+ self.drift = drift
98
+ self.do_shift = do_shift
99
+ self.t = th.linspace(t0, t1, num_steps)
100
+ if time_shifting_factor:
101
+ self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
102
+ self.atol = atol
103
+ self.rtol = rtol
104
+ self.sampler_type = sampler_type
105
+
106
+ def sample(self, x, model, **model_kwargs):
107
+ x = x.float()
108
+ device = x[0].device if isinstance(x, tuple) else x.device
109
+
110
+ def _fn(t, x):
111
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
112
+ model_output = self.drift(x, t, model, **model_kwargs).float()
113
+ return model_output
114
+
115
+ t = self.t.to(device)
116
+ if self.do_shift:
117
+ mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
118
+ t = time_shift(mu, 1.0, t)
119
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
120
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
121
+ samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
122
+ return samples
transport/path.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * len(x[0].size())
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+
19
+ class ICPlan:
20
+ """Linear Coupling Plan"""
21
+
22
+ def __init__(self, sigma=0.0):
23
+ self.sigma = sigma
24
+
25
+ def compute_alpha_t(self, t):
26
+ """Compute the data coefficient along the path"""
27
+ return t, 1
28
+
29
+ def compute_sigma_t(self, t):
30
+ """Compute the noise coefficient along the path"""
31
+ return 1 - t, -1
32
+
33
+ def compute_d_alpha_alpha_ratio_t(self, t):
34
+ """Compute the ratio between d_alpha and alpha"""
35
+ return 1 / t
36
+
37
+ def compute_drift(self, x, t):
38
+ """We always output sde according to score parametrization;"""
39
+ t = expand_t_like_x(t, x)
40
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
41
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
42
+ drift = alpha_ratio * x
43
+ diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
44
+
45
+ return -drift, diffusion
46
+
47
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
48
+ """Compute the diffusion term of the SDE
49
+ Args:
50
+ x: [batch_dim, ...], data point
51
+ t: [batch_dim,], time vector
52
+ form: str, form of the diffusion term
53
+ norm: float, norm of the diffusion term
54
+ """
55
+ t = expand_t_like_x(t, x)
56
+ choices = {
57
+ "constant": norm,
58
+ "SBDM": norm * self.compute_drift(x, t)[1],
59
+ "sigma": norm * self.compute_sigma_t(t)[0],
60
+ "linear": norm * (1 - t),
61
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
62
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
63
+ }
64
+
65
+ try:
66
+ diffusion = choices[form]
67
+ except KeyError:
68
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
69
+
70
+ return diffusion
71
+
72
+ def get_score_from_velocity(self, velocity, x, t):
73
+ """Wrapper function: transfrom velocity prediction model to score
74
+ Args:
75
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
76
+ x: [batch_dim, ...] shaped tensor; x_t data point
77
+ t: [batch_dim,] time tensor
78
+ """
79
+ t = expand_t_like_x(t, x)
80
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
81
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
82
+ mean = x
83
+ reverse_alpha_ratio = alpha_t / d_alpha_t
84
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
85
+ score = (reverse_alpha_ratio * velocity - mean) / var
86
+ return score
87
+
88
+ def get_noise_from_velocity(self, velocity, x, t):
89
+ """Wrapper function: transfrom velocity prediction model to denoiser
90
+ Args:
91
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
92
+ x: [batch_dim, ...] shaped tensor; x_t data point
93
+ t: [batch_dim,] time tensor
94
+ """
95
+ t = expand_t_like_x(t, x)
96
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
97
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
98
+ mean = x
99
+ reverse_alpha_ratio = alpha_t / d_alpha_t
100
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
101
+ noise = (reverse_alpha_ratio * velocity - mean) / var
102
+ return noise
103
+
104
+ def get_velocity_from_score(self, score, x, t):
105
+ """Wrapper function: transfrom score prediction model to velocity
106
+ Args:
107
+ score: [batch_dim, ...] shaped tensor; score model output
108
+ x: [batch_dim, ...] shaped tensor; x_t data point
109
+ t: [batch_dim,] time tensor
110
+ """
111
+ t = expand_t_like_x(t, x)
112
+ drift, var = self.compute_drift(x, t)
113
+ velocity = var * score - drift
114
+ return velocity
115
+
116
+ def compute_mu_t(self, t, x0, x1):
117
+ """Compute the mean of time-dependent density p_t"""
118
+ t = expand_t_like_x(t, x1)
119
+ alpha_t, _ = self.compute_alpha_t(t)
120
+ sigma_t, _ = self.compute_sigma_t(t)
121
+ if isinstance(x1, (list, tuple)):
122
+ return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
123
+ else:
124
+ return alpha_t * x1 + sigma_t * x0
125
+
126
+ def compute_xt(self, t, x0, x1):
127
+ """Sample xt from time-dependent density p_t; rng is required"""
128
+ xt = self.compute_mu_t(t, x0, x1)
129
+ return xt
130
+
131
+ def compute_ut(self, t, x0, x1, xt):
132
+ """Compute the vector field corresponding to p_t"""
133
+ t = expand_t_like_x(t, x1)
134
+ _, d_alpha_t = self.compute_alpha_t(t)
135
+ _, d_sigma_t = self.compute_sigma_t(t)
136
+ if isinstance(x1, (list, tuple)):
137
+ return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
138
+ else:
139
+ return d_alpha_t * x1 + d_sigma_t * x0
140
+
141
+ def plan(self, t, x0, x1):
142
+ xt = self.compute_xt(t, x0, x1)
143
+ ut = self.compute_ut(t, x0, x1, xt)
144
+ return t, xt, ut
145
+
146
+
147
+ class VPCPlan(ICPlan):
148
+ """class for VP path flow matching"""
149
+
150
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
151
+ self.sigma_min = sigma_min
152
+ self.sigma_max = sigma_max
153
+ self.log_mean_coeff = (
154
+ lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
155
+ )
156
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
157
+
158
+ def compute_alpha_t(self, t):
159
+ """Compute coefficient of x1"""
160
+ alpha_t = self.log_mean_coeff(t)
161
+ alpha_t = th.exp(alpha_t)
162
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
163
+ return alpha_t, d_alpha_t
164
+
165
+ def compute_sigma_t(self, t):
166
+ """Compute coefficient of x0"""
167
+ p_sigma_t = 2 * self.log_mean_coeff(t)
168
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
169
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
170
+ return sigma_t, d_sigma_t
171
+
172
+ def compute_d_alpha_alpha_ratio_t(self, t):
173
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
174
+ return self.d_log_mean_coeff(t)
175
+
176
+ def compute_drift(self, x, t):
177
+ """Compute the drift term of the SDE"""
178
+ t = expand_t_like_x(t, x)
179
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
180
+ return -0.5 * beta_t * x, beta_t / 2
181
+
182
+
183
+ class GVPCPlan(ICPlan):
184
+ def __init__(self, sigma=0.0):
185
+ super().__init__(sigma)
186
+
187
+ def compute_alpha_t(self, t):
188
+ """Compute coefficient of x1"""
189
+ alpha_t = th.sin(t * np.pi / 2)
190
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
191
+ return alpha_t, d_alpha_t
192
+
193
+ def compute_sigma_t(self, t):
194
+ """Compute coefficient of x0"""
195
+ sigma_t = th.cos(t * np.pi / 2)
196
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
197
+ return sigma_t, d_sigma_t
198
+
199
+ def compute_d_alpha_alpha_ratio_t(self, t):
200
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
201
+ return np.pi / (2 * th.tan(t * np.pi / 2))
transport/transport.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import math
3
+ from typing import Callable
4
+
5
+ import numpy as np
6
+ import torch as th
7
+
8
+ from . import path
9
+ from .integrators import ode, sde
10
+ from .utils import mean_flat, expand_dims
11
+ from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
12
+
13
+
14
+ class ModelType(enum.Enum):
15
+ """
16
+ Which type of output the model predicts.
17
+ """
18
+
19
+ NOISE = enum.auto() # the model predicts epsilon
20
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
21
+ VELOCITY = enum.auto() # the model predicts v(x)
22
+
23
+
24
+ class PathType(enum.Enum):
25
+ """
26
+ Which type of path to use.
27
+ """
28
+
29
+ LINEAR = enum.auto()
30
+ GVP = enum.auto()
31
+ VP = enum.auto()
32
+
33
+
34
+ class WeightType(enum.Enum):
35
+ """
36
+ Which type of weighting to use.
37
+ """
38
+
39
+ NONE = enum.auto()
40
+ VELOCITY = enum.auto()
41
+ LIKELIHOOD = enum.auto()
42
+
43
+
44
+ class Transport:
45
+ def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len):
46
+ path_options = {
47
+ PathType.LINEAR: path.ICPlan,
48
+ PathType.GVP: path.GVPCPlan,
49
+ PathType.VP: path.VPCPlan,
50
+ }
51
+
52
+ self.loss_type = loss_type
53
+ self.model_type = model_type
54
+ self.path_sampler = path_options[path_type]()
55
+ self.train_eps = train_eps
56
+ self.sample_eps = sample_eps
57
+
58
+ self.snr_type = snr_type
59
+ self.do_shift = do_shift
60
+ self.seq_len = seq_len
61
+
62
+ def prior_logp(self, z):
63
+ """
64
+ Standard multivariate normal prior
65
+ Assume z is batched
66
+ """
67
+ shape = th.tensor(z.size())
68
+ N = th.prod(shape[1:])
69
+ _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
70
+ return th.vmap(_fn)(z)
71
+
72
+ def check_interval(
73
+ self,
74
+ train_eps,
75
+ sample_eps,
76
+ *,
77
+ diffusion_form="SBDM",
78
+ sde=False,
79
+ reverse=False,
80
+ eval=False,
81
+ last_step_size=0.0,
82
+ ):
83
+ t0 = 0
84
+ t1 = 1
85
+ eps = train_eps if not eval else sample_eps
86
+ if type(self.path_sampler) in [path.VPCPlan]:
87
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
88
+
89
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
90
+ self.model_type != ModelType.VELOCITY or sde
91
+ ): # avoid numerical issue by taking a first semi-implicit step
92
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
93
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
94
+
95
+ if reverse:
96
+ t0, t1 = 1 - t0, 1 - t1
97
+
98
+ return t0, t1
99
+
100
+ def sample(self, x1):
101
+ """Sampling x0 & t based on shape of x1 (if needed)
102
+ Args:
103
+ x1 - data point; [batch, *dim]
104
+ """
105
+ if isinstance(x1, (list, tuple)):
106
+ x0 = [th.randn_like(img_start) for img_start in x1]
107
+ else:
108
+ x0 = th.randn_like(x1)
109
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
110
+
111
+ if self.snr_type.startswith("uniform"):
112
+ assert t0 == 0.0 and t1 == 1.0, "not implemented."
113
+ if "_" in self.snr_type:
114
+ _, t0, t1 = self.snr_type.split("_")
115
+ t0, t1 = float(t0), float(t1)
116
+ t = th.rand((len(x1),)) * (t1 - t0) + t0
117
+ elif self.snr_type == "lognorm":
118
+ u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
119
+ t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
120
+ else:
121
+ raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
122
+
123
+ if self.do_shift:
124
+ base_shift: float = 0.5
125
+ max_shift: float = 1.15
126
+ mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
127
+ t = self.time_shift(mu, 1.0, t)
128
+ t = t.to(x1[0])
129
+ return t, x0, x1
130
+
131
+ def time_shift(self, mu: float, sigma: float, t: th.Tensor):
132
+ # the following implementation was original for t=0: clean / t=1: noise
133
+ # Since we adopt the reverse, the 1-t operations are needed
134
+ t = 1 - t
135
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
136
+ t = 1 - t
137
+ return t
138
+
139
+ def get_lin_function(
140
+ self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
141
+ ) -> Callable[[float], float]:
142
+ m = (y2 - y1) / (x2 - x1)
143
+ b = y1 - m * x1
144
+ return lambda x: m * x + b
145
+
146
+ def training_losses(self, model, x1, model_kwargs=None):
147
+ """Loss for training the score model
148
+ Args:
149
+ - model: backbone model; could be score, noise, or velocity
150
+ - x1: datapoint
151
+ - model_kwargs: additional arguments for the model
152
+ """
153
+ if model_kwargs == None:
154
+ model_kwargs = {}
155
+ t, x0, x1 = self.sample(x1)
156
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
157
+ if "cond" in model_kwargs:
158
+ conds = model_kwargs.pop("cond")
159
+ xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
160
+ model_output = model(xt, t, **model_kwargs)
161
+ B = len(x0)
162
+
163
+ terms = {}
164
+ # terms['pred'] = model_output
165
+ if self.model_type == ModelType.VELOCITY:
166
+ if isinstance(x1, (list, tuple)):
167
+ assert len(model_output) == len(ut) == len(x1)
168
+ for i in range(B):
169
+ assert (
170
+ model_output[i].shape == ut[i].shape == x1[i].shape
171
+ ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
172
+ terms["task_loss"] = th.stack(
173
+ [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
174
+ dim=0,
175
+ )
176
+ else:
177
+ terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
178
+ else:
179
+ raise NotImplementedError
180
+
181
+ terms["loss"] = terms["task_loss"]
182
+ terms["task_loss"] = terms["task_loss"].clone().detach()
183
+ terms["t"] = t
184
+ return terms
185
+
186
+ def get_drift(self):
187
+ """member function for obtaining the drift of the probability flow ODE"""
188
+
189
+ def score_ode(x, t, model, **model_kwargs):
190
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
191
+ model_output = model(x, t, **model_kwargs)
192
+ return -drift_mean + drift_var * model_output # by change of variable
193
+
194
+ def noise_ode(x, t, model, **model_kwargs):
195
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
196
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
197
+ model_output = model(x, t, **model_kwargs)
198
+ score = model_output / -sigma_t
199
+ return -drift_mean + drift_var * score
200
+
201
+ def velocity_ode(x, t, model, **model_kwargs):
202
+ model_output = model(x, t, **model_kwargs)
203
+ return model_output
204
+
205
+ if self.model_type == ModelType.NOISE:
206
+ drift_fn = noise_ode
207
+ elif self.model_type == ModelType.SCORE:
208
+ drift_fn = score_ode
209
+ else:
210
+ drift_fn = velocity_ode
211
+
212
+ def body_fn(x, t, model, **model_kwargs):
213
+ model_output = drift_fn(x, t, model, **model_kwargs)
214
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
215
+ return model_output
216
+
217
+ return body_fn
218
+
219
+ def get_score(
220
+ self,
221
+ ):
222
+ """member function for obtaining score of
223
+ x_t = alpha_t * x + sigma_t * eps"""
224
+ if self.model_type == ModelType.NOISE:
225
+ score_fn = (
226
+ lambda x, t, model, **kwargs: model(x, t, **kwargs)
227
+ / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
228
+ )
229
+ elif self.model_type == ModelType.SCORE:
230
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
231
+ elif self.model_type == ModelType.VELOCITY:
232
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
233
+ model(x, t, **kwargs), x, t
234
+ )
235
+ else:
236
+ raise NotImplementedError()
237
+
238
+ return score_fn
239
+
240
+
241
+ class Sampler:
242
+ """Sampler class for the transport model"""
243
+
244
+ def __init__(
245
+ self,
246
+ transport,
247
+ ):
248
+ """Constructor for a general sampler; supporting different sampling methods
249
+ Args:
250
+ - transport: an tranport object specify model prediction & interpolant type
251
+ """
252
+
253
+ self.transport = transport
254
+ self.drift = self.transport.get_drift()
255
+ self.score = self.transport.get_score()
256
+
257
+ def __get_sde_diffusion_and_drift(
258
+ self,
259
+ *,
260
+ diffusion_form="SBDM",
261
+ diffusion_norm=1.0,
262
+ ):
263
+ def diffusion_fn(x, t):
264
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
265
+ return diffusion
266
+
267
+ sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
268
+ x, t, model, **kwargs
269
+ )
270
+
271
+ sde_diffusion = diffusion_fn
272
+
273
+ return sde_drift, sde_diffusion
274
+
275
+ def __get_last_step(
276
+ self,
277
+ sde_drift,
278
+ *,
279
+ last_step,
280
+ last_step_size,
281
+ ):
282
+ """Get the last step function of the SDE solver"""
283
+
284
+ if last_step is None:
285
+ last_step_fn = lambda x, t, model, **model_kwargs: x
286
+ elif last_step == "Mean":
287
+ last_step_fn = (
288
+ lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
289
+ )
290
+ elif last_step == "Tweedie":
291
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
292
+ sigma = self.transport.path_sampler.compute_sigma_t
293
+ last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
294
+ 0
295
+ ] * self.score(x, t, model, **model_kwargs)
296
+ elif last_step == "Euler":
297
+ last_step_fn = (
298
+ lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
299
+ )
300
+ else:
301
+ raise NotImplementedError()
302
+
303
+ return last_step_fn
304
+
305
+ def sample_sde(
306
+ self,
307
+ *,
308
+ sampling_method="Euler",
309
+ diffusion_form="SBDM",
310
+ diffusion_norm=1.0,
311
+ last_step="Mean",
312
+ last_step_size=0.04,
313
+ num_steps=250,
314
+ ):
315
+ """returns a sampling function with given SDE settings
316
+ Args:
317
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
318
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
319
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
320
+ - last_step: type of the last step; default to identity
321
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
322
+ - num_steps: total integration step of SDE
323
+ """
324
+
325
+ if last_step is None:
326
+ last_step_size = 0.0
327
+
328
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
329
+ diffusion_form=diffusion_form,
330
+ diffusion_norm=diffusion_norm,
331
+ )
332
+
333
+ t0, t1 = self.transport.check_interval(
334
+ self.transport.train_eps,
335
+ self.transport.sample_eps,
336
+ diffusion_form=diffusion_form,
337
+ sde=True,
338
+ eval=True,
339
+ reverse=False,
340
+ last_step_size=last_step_size,
341
+ )
342
+
343
+ _sde = sde(
344
+ sde_drift,
345
+ sde_diffusion,
346
+ t0=t0,
347
+ t1=t1,
348
+ num_steps=num_steps,
349
+ sampler_type=sampling_method,
350
+ )
351
+
352
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
353
+
354
+ def _sample(init, model, **model_kwargs):
355
+ xs = _sde.sample(init, model, **model_kwargs)
356
+ ts = th.ones(init.size(0), device=init.device) * t1
357
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
358
+ xs.append(x)
359
+
360
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
361
+
362
+ return xs
363
+
364
+ return _sample
365
+
366
+ def sample_dpm(
367
+ self,
368
+ model,
369
+ model_kwargs=None,
370
+ ):
371
+
372
+ noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
373
+
374
+ def noise_pred_fn(x, t_continuous):
375
+ output = model(x, 1 - t_continuous, **model_kwargs)
376
+ _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
377
+ try:
378
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
379
+ except:
380
+ noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
381
+ return noise
382
+
383
+ return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
384
+
385
+
386
+ def sample_ode(
387
+ self,
388
+ *,
389
+ sampling_method="dopri5",
390
+ num_steps=50,
391
+ atol=1e-6,
392
+ rtol=1e-3,
393
+ reverse=False,
394
+ do_shift=False,
395
+ time_shifting_factor=None,
396
+ ):
397
+ """returns a sampling function with given ODE settings
398
+ Args:
399
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
400
+ - num_steps:
401
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
402
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
403
+ - atol: absolute error tolerance for the solver
404
+ - rtol: relative error tolerance for the solver
405
+ """
406
+
407
+ # for flux
408
+ drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
409
+
410
+ t0, t1 = self.transport.check_interval(
411
+ self.transport.train_eps,
412
+ self.transport.sample_eps,
413
+ sde=False,
414
+ eval=True,
415
+ reverse=reverse,
416
+ last_step_size=0.0,
417
+ )
418
+
419
+ _ode = ode(
420
+ drift=drift,
421
+ t0=t0,
422
+ t1=t1,
423
+ sampler_type=sampling_method,
424
+ num_steps=num_steps,
425
+ atol=atol,
426
+ rtol=rtol,
427
+ do_shift=do_shift,
428
+ time_shifting_factor=time_shifting_factor,
429
+ )
430
+
431
+ return _ode.sample
432
+
433
+ def sample_ode_likelihood(
434
+ self,
435
+ *,
436
+ sampling_method="dopri5",
437
+ num_steps=50,
438
+ atol=1e-6,
439
+ rtol=1e-3,
440
+ ):
441
+ """returns a sampling function for calculating likelihood with given ODE settings
442
+ Args:
443
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
444
+ - num_steps:
445
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
446
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
447
+ - atol: absolute error tolerance for the solver
448
+ - rtol: relative error tolerance for the solver
449
+ """
450
+
451
+ def _likelihood_drift(x, t, model, **model_kwargs):
452
+ x, _ = x
453
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
454
+ t = th.ones_like(t) * (1 - t)
455
+ with th.enable_grad():
456
+ x.requires_grad = True
457
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
458
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
459
+ drift = self.drift(x, t, model, **model_kwargs)
460
+ return (-drift, logp_grad)
461
+
462
+ t0, t1 = self.transport.check_interval(
463
+ self.transport.train_eps,
464
+ self.transport.sample_eps,
465
+ sde=False,
466
+ eval=True,
467
+ reverse=False,
468
+ last_step_size=0.0,
469
+ )
470
+
471
+ _ode = ode(
472
+ drift=_likelihood_drift,
473
+ t0=t0,
474
+ t1=t1,
475
+ sampler_type=sampling_method,
476
+ num_steps=num_steps,
477
+ atol=atol,
478
+ rtol=rtol,
479
+ )
480
+
481
+ def _sample_fn(x, model, **model_kwargs):
482
+ init_logp = th.zeros(x.size(0)).to(x)
483
+ input = (x, init_logp)
484
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
485
+ drift, delta_logp = drift[-1], delta_logp[-1]
486
+ prior_logp = self.transport.prior_logp(drift)
487
+ logp = prior_logp - delta_logp
488
+ return logp, drift
489
+
490
+ return _sample_fn
transport/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import math
3
+
4
+ class EasyDict:
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+
13
+ def mean_flat(x):
14
+ """
15
+ Take the mean over all non-batch dimensions.
16
+ """
17
+ return th.mean(x, dim=list(range(1, len(x.size()))))
18
+
19
+
20
+ def log_state(state):
21
+ result = []
22
+
23
+ sorted_state = dict(sorted(state.items()))
24
+ for key, value in sorted_state.items():
25
+ # Check if the value is an instance of a class
26
+ if "<object" in str(value) or "object at" in str(value):
27
+ result.append(f"{key}: [{value.__class__.__name__}]")
28
+ else:
29
+ result.append(f"{key}: {value}")
30
+
31
+ return "\n".join(result)
32
+
33
+ def time_shift(mu: float, sigma: float, t: th.Tensor):
34
+ # the following implementation was original for t=0: clean / t=1: noise
35
+ # Since we adopt the reverse, the 1-t operations are needed
36
+ t = 1 - t
37
+ t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
38
+ t = 1 - t
39
+ return t
40
+
41
+ def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
42
+ m = (y2 - y1) / (x2 - x1)
43
+ b = y1 - m * x1
44
+ return lambda x: m * x + b
45
+
46
+ def expand_dims(v, dims):
47
+ """
48
+ Expand the tensor `v` to the dim `dims`.
49
+
50
+ Args:
51
+ `v`: a PyTorch tensor with shape [N].
52
+ `dim`: a `int`.
53
+ Returns:
54
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
55
+ """
56
+ return v[(...,) + (None,) * (dims - 1)]
util/misc.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ import datetime
3
+ import logging
4
+ import random
5
+ import time
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.distributed as dist
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def random_seed(seed=0):
15
+ random.seed(seed)
16
+ torch.random.manual_seed(seed)
17
+ np.random.seed(seed)
18
+
19
+
20
+ class SmoothedValue(object):
21
+ """Track a series of values and provide access to smoothed values over a
22
+ window or the global series average.
23
+ """
24
+
25
+ def __init__(self, window_size=1000, fmt=None):
26
+ if fmt is None:
27
+ fmt = "{avg:.4f} ({global_avg:.4f})"
28
+ self.deque = deque(maxlen=window_size)
29
+ self.total = 0.0
30
+ self.count = 0
31
+ self.fmt = fmt
32
+
33
+ def update(self, value, n=1):
34
+ self.deque.append(value)
35
+ self.count += n
36
+ self.total += value * n
37
+
38
+ def synchronize_between_processes(self):
39
+ """
40
+ Warning: does not synchronize the deque!
41
+ """
42
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
43
+ dist.barrier()
44
+ dist.all_reduce(t)
45
+ t = t.tolist()
46
+ self.count = int(t[0])
47
+ self.total = t[1]
48
+
49
+ @property
50
+ def median(self):
51
+ d = torch.tensor(list(self.deque))
52
+ return d.median().item()
53
+
54
+ @property
55
+ def avg(self):
56
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
57
+ return d.mean().item()
58
+
59
+ @property
60
+ def global_avg(self):
61
+ return self.total / self.count
62
+
63
+ @property
64
+ def max(self):
65
+ return max(self.deque)
66
+
67
+ @property
68
+ def value(self):
69
+ return self.deque[-1]
70
+
71
+ def __str__(self):
72
+ return self.fmt.format(
73
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
74
+ )
75
+
76
+
77
+ class MetricLogger(object):
78
+ def __init__(self, delimiter="\t", window_size=1000, fmt=None):
79
+ self.meters = defaultdict(lambda: SmoothedValue(window_size, fmt))
80
+ self.delimiter = delimiter
81
+
82
+ def update(self, **kwargs):
83
+ for k, v in kwargs.items():
84
+ if v is None:
85
+ continue
86
+ elif isinstance(v, (torch.Tensor, float, int)):
87
+ self.meters[k].update(v.item() if isinstance(v, torch.Tensor) else v)
88
+ elif isinstance(v, list):
89
+ for i, sub_v in enumerate(v):
90
+ self.meters[f"{k}_{i}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v)
91
+ elif isinstance(v, dict):
92
+ for sub_key, sub_v in v.items():
93
+ self.meters[f"{k}_{sub_key}"].update(sub_v.item() if isinstance(sub_v, torch.Tensor) else sub_v)
94
+ else:
95
+ raise TypeError(f"Unsupported type {type(v)} for metric {k}")
96
+
97
+ def __str__(self):
98
+ loss_str = []
99
+ for name, meter in self.meters.items():
100
+ loss_str.append("{}: {}".format(name, str(meter)))
101
+ return self.delimiter.join(loss_str)
102
+
103
+ def synchronize_between_processes(self):
104
+ for meter in self.meters.values():
105
+ meter.synchronize_between_processes()
106
+
107
+ def add_meter(self, name, meter):
108
+ self.meters[name] = meter
109
+
110
+ def log_every(self, iterable, print_freq, header=None, start_iter=0, samples_per_iter=None):
111
+ i = start_iter
112
+ if not header:
113
+ header = ""
114
+ start_time = time.time()
115
+ end = time.time()
116
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
117
+ data_time = SmoothedValue(fmt="{avg:.4f}")
118
+ log_msg = [header, "[{0" + "}/{1}]", "{meters}", "time: {time}", "data: {data}"]
119
+ if samples_per_iter is not None:
120
+ log_msg.append("samples/sec: {samples_per_sec:.2f}")
121
+ if torch.cuda.is_available():
122
+ log_msg.append("max mem: {memory:.0f}")
123
+ log_msg = self.delimiter.join(log_msg)
124
+ MB = 1024.0 * 1024.0
125
+ for obj in iterable:
126
+ data_time.update(time.time() - end)
127
+ yield obj
128
+ iter_time.update(time.time() - end)
129
+ if i % print_freq == 0:
130
+ try:
131
+ total_len = len(iterable)
132
+ except:
133
+ total_len = "unknown"
134
+
135
+ msg_kwargs = {
136
+ "meters": str(self),
137
+ "time": str(iter_time),
138
+ "data": str(data_time),
139
+ }
140
+ if samples_per_iter is not None:
141
+ msg_kwargs["samples_per_sec"] = samples_per_iter / iter_time.avg
142
+ if torch.cuda.is_available():
143
+ msg_kwargs["memory"] = torch.cuda.max_memory_allocated() / MB
144
+
145
+ logger.info(log_msg.format(i, total_len, **msg_kwargs))
146
+ i += 1
147
+ end = time.time()
148
+ total_time = time.time() - start_time
149
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
150
+ logger.info("{} Total time: {}".format(header, total_time_str))