Jahnavibh commited on
Commit
4124fa1
·
1 Parent(s): 7c4259c

Upload 123 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. One-2-3-45-master 2/.DS_Store +0 -0
  2. One-2-3-45-master 2/.gitattributes +35 -0
  3. One-2-3-45-master 2/.gitignore +11 -0
  4. One-2-3-45-master 2/LICENSE +201 -0
  5. One-2-3-45-master 2/README.md +221 -0
  6. One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml +117 -0
  7. One-2-3-45-master 2/download_ckpt.py +30 -0
  8. One-2-3-45-master 2/elevation_estimate/.gitignore +3 -0
  9. One-2-3-45-master 2/elevation_estimate/__init__.py +0 -0
  10. One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py +10 -0
  11. One-2-3-45-master 2/elevation_estimate/loftr/__init__.py +2 -0
  12. One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py +11 -0
  13. One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py +199 -0
  14. One-2-3-45-master 2/elevation_estimate/loftr/loftr.py +81 -0
  15. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py +2 -0
  16. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py +59 -0
  17. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py +81 -0
  18. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py +101 -0
  19. One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py +261 -0
  20. One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py +50 -0
  21. One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py +74 -0
  22. One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py +54 -0
  23. One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py +42 -0
  24. One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py +151 -0
  25. One-2-3-45-master 2/elevation_estimate/pyproject.toml +7 -0
  26. One-2-3-45-master 2/elevation_estimate/utils/__init__.py +0 -0
  27. One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py +205 -0
  28. One-2-3-45-master 2/elevation_estimate/utils/plotting.py +154 -0
  29. One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py +318 -0
  30. One-2-3-45-master 2/elevation_estimate/utils/utils3d.py +62 -0
  31. One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep +0 -0
  32. One-2-3-45-master 2/example.ipynb +0 -0
  33. One-2-3-45-master 2/ldm/data/__init__.py +0 -0
  34. One-2-3-45-master 2/ldm/data/base.py +40 -0
  35. One-2-3-45-master 2/ldm/data/coco.py +253 -0
  36. One-2-3-45-master 2/ldm/data/dummy.py +34 -0
  37. One-2-3-45-master 2/ldm/data/imagenet.py +394 -0
  38. One-2-3-45-master 2/ldm/data/inpainting/__init__.py +0 -0
  39. One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py +166 -0
  40. One-2-3-45-master 2/ldm/data/laion.py +537 -0
  41. One-2-3-45-master 2/ldm/data/lsun.py +92 -0
  42. One-2-3-45-master 2/ldm/data/nerf_like.py +165 -0
  43. One-2-3-45-master 2/ldm/data/simple.py +526 -0
  44. One-2-3-45-master 2/ldm/extras.py +77 -0
  45. One-2-3-45-master 2/ldm/guidance.py +96 -0
  46. One-2-3-45-master 2/ldm/lr_scheduler.py +98 -0
  47. One-2-3-45-master 2/ldm/models/autoencoder.py +443 -0
  48. One-2-3-45-master 2/ldm/models/diffusion/__init__.py +0 -0
  49. One-2-3-45-master 2/ldm/models/diffusion/classifier.py +267 -0
  50. One-2-3-45-master 2/ldm/models/diffusion/ddim.py +326 -0
One-2-3-45-master 2/.DS_Store ADDED
Binary file (6.15 kB). View file
 
One-2-3-45-master 2/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
One-2-3-45-master 2/.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ exp/
3
+ src/
4
+ *.DS_Store
5
+ *.ipynb
6
+ *.egg-info/
7
+ *.ckpt
8
+ *.pth
9
+
10
+ !example.ipynb
11
+ !reconstruction/exp
One-2-3-45-master 2/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
One-2-3-45-master 2/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center" width="100%">
2
+ <img src="https://github.com/Dustinpro/Dustinpro/assets/23076389/0fbdb69a-0fb4-4b42-b9da-e0b28532bdfd" width="80%" height="80%">
3
+ </p>
4
+
5
+
6
+ <p align="center">
7
+ [<a href="https://arxiv.org/pdf/2306.16928.pdf"><strong>Paper</strong></a>]
8
+ [<a href="http://one-2-3-45.com"><strong>Project</strong></a>]
9
+ [<a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45"><strong>Demo</strong></a>]
10
+ [<a href="#citation"><strong>BibTeX</strong></a>]
11
+ </p>
12
+
13
+ <p align="center">
14
+ <a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45">
15
+ <img alt="Hugging Face Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space_of_the_Week_%F0%9F%94%A5-blue">
16
+ </a>
17
+ </p>
18
+
19
+ One-2-3-45 rethinks how to leverage 2D diffusion models for 3D AIGC and introduces a novel forward-only paradigm that avoids the time-consuming optimization.
20
+
21
+ https://github.com/One-2-3-45/One-2-3-45/assets/16759292/a81d6e32-8d29-43a5-b044-b5112b9f9664
22
+
23
+
24
+
25
+ https://github.com/One-2-3-45/One-2-3-45/assets/16759292/5ecd45ef-8fd3-4643-af4c-fac3050a0428
26
+
27
+
28
+ ## News
29
+ **[09/21/2023]**
30
+ One-2-3-45 is accepted by NeurIPS 2023. See you in New Orleans!
31
+
32
+ **[09/11/2023]**
33
+ Training code released.
34
+
35
+ **[08/18/2023]**
36
+ Inference code released.
37
+
38
+ **[07/24/2023]**
39
+ Our demo reached the HuggingFace top 4 trending and was featured in 🤗 Spaces of the Week 🔥! Special thanks to HuggingFace 🤗 for sponsoring this demo!!
40
+
41
+ **[07/11/2023]**
42
+ [Online interactive demo](https://huggingface.co/spaces/One-2-3-45/One-2-3-45) released! Explore it and create your own 3D models in just 45 seconds!
43
+
44
+ **[06/29/2023]**
45
+ Check out our [paper](https://arxiv.org/pdf/2306.16928.pdf). [[X](https://twitter.com/_akhaliq/status/1674617785119305728)]
46
+
47
+ ## Installation
48
+ Hardware requirement: an NVIDIA GPU with memory >=18GB (_e.g._, RTX 3090 or A10). Tested on Ubuntu.
49
+
50
+ We offer two ways to setup the environment:
51
+
52
+ ### Traditional Installation
53
+ <details>
54
+ <summary>Step 1: Install Debian packages. </summary>
55
+
56
+ ```bash
57
+ sudo apt update && sudo apt install git-lfs libsparsehash-dev build-essential
58
+ ```
59
+ </details>
60
+
61
+ <details>
62
+ <summary>Step 2: Create and activate a conda environment. </summary>
63
+
64
+ ```bash
65
+ conda create -n One2345 python=3.10
66
+ conda activate One2345
67
+ ```
68
+ </details>
69
+
70
+ <details>
71
+ <summary>Step 3: Clone the repository to the local machine. </summary>
72
+
73
+ ```bash
74
+ # Make sure you have git-lfs installed.
75
+ git lfs install
76
+ git clone https://github.com/One-2-3-45/One-2-3-45
77
+ cd One-2-3-45
78
+ ```
79
+ </details>
80
+
81
+ <details>
82
+ <summary>Step 4: Install project dependencies using pip. </summary>
83
+
84
+ ```bash
85
+ # Ensure that the installed CUDA version matches the torch's cuda version.
86
+ # Example: CUDA 11.8 installation
87
+ wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
88
+ sudo sh cuda_11.8.0_520.61.05_linux.run
89
+ export PATH="/usr/local/cuda-11.8/bin:$PATH"
90
+ export LD_LIBRARY_PATH="/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH"
91
+ # Install PyTorch 2.0
92
+ pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
93
+ # Install dependencies
94
+ pip install -r requirements.txt
95
+ # Install inplace_abn and torchsparse
96
+ export TORCH_CUDA_ARCH_LIST="7.0;7.2;8.0;8.6+PTX" # CUDA architectures. Modify according to your hardware.
97
+ export IABN_FORCE_CUDA=1
98
+ pip install inplace_abn
99
+ FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
100
+ ```
101
+ </details>
102
+
103
+ <details>
104
+ <summary>Step 5: Download model checkpoints. </summary>
105
+
106
+ ```bash
107
+ python download_ckpt.py
108
+ ```
109
+ </details>
110
+
111
+
112
+ ### Installation by Docker Images
113
+ <details>
114
+ <summary>Option 1: Pull and Play (environment and checkpoints). (~22.3G)</summary>
115
+
116
+ ```bash
117
+ # Pull the Docker image that contains the full repository.
118
+ docker pull chaoxu98/one2345:demo_1.0
119
+ # An interactive demo will be launched automatically upon running the container.
120
+ # This will provide a public URL like XXXXXXX.gradio.live
121
+ docker run --name One-2-3-45_demo --gpus all -it chaoxu98/one2345:demo_1.0
122
+ ```
123
+ </details>
124
+
125
+ <details>
126
+ <summary>Option 2: Environment Only. (~7.3G)</summary>
127
+
128
+ ```bash
129
+ # Pull the Docker image that installed all project dependencies.
130
+ docker pull chaoxu98/one2345:1.0
131
+ # Start a Docker container named One2345.
132
+ docker run --name One-2-3-45 --gpus all -it chaoxu98/one2345:1.0
133
+ # Get a bash shell in the container.
134
+ docker exec -it One-2-3-45 /bin/bash
135
+ # Clone the repository to the local machine.
136
+ git clone https://github.com/One-2-3-45/One-2-3-45
137
+ cd One-2-3-45
138
+ # Download model checkpoints.
139
+ python download_ckpt.py
140
+ # Refer to getting started for inference.
141
+ ```
142
+ </details>
143
+
144
+ ## Getting Started (Inference)
145
+
146
+ First-time running will take longer time to compile the models.
147
+
148
+ Expected time cost per image: 40s on an NVIDIA A6000.
149
+ ```bash
150
+ # 1. Script
151
+ python run.py --img_path PATH_TO_INPUT_IMG --half_precision
152
+
153
+ # 2. Interactive demo (Gradio) with a friendly web interface
154
+ # An URL will be provided in the output
155
+ # (Local: 127.0.0.1:7860; Public: XXXXXXX.gradio.live)
156
+ cd demo/
157
+ python app.py
158
+
159
+ # 3. Jupyter Notebook
160
+ example.ipynb
161
+ ```
162
+
163
+ ## Training Your Own Model
164
+
165
+ ### Data Preparation
166
+ We use Objaverse-LVIS dataset for training and render the selected shapes (with CC-BY license) into 2D images with Blender.
167
+ #### Download the training images.
168
+ Download all One2345.zip.part-* files (5 files in total) from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and then cat them into a single .zip file using the following command:
169
+ ```bash
170
+ cat One2345.zip.part-* > One2345.zip
171
+ ```
172
+
173
+ #### Unzip the training images zip file.
174
+ Unzip the zip file into a folder specified by yourself (`YOUR_BASE_FOLDER`) with the following command:
175
+
176
+ ```bash
177
+ unzip One2345.zip -d YOUR_BASE_FOLDER
178
+ ```
179
+
180
+ #### Download meta files.
181
+
182
+ Download `One2345_training_pose.json` and `lvis_split_cc_by.json` from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and put them into the same folder as the training images (`YOUR_BASE_FOLDER`).
183
+
184
+ Your file structure should look like this:
185
+ ```
186
+ # One2345 is your base folder used in the previous steps
187
+
188
+ One2345
189
+ ├── One2345_training_pose.json
190
+ ├── lvis_split_cc_by.json
191
+ └── zero12345_narrow
192
+ ├── 000-000
193
+ ├── 000-001
194
+ ├── 000-002
195
+ ...
196
+ └── 000-159
197
+
198
+ ```
199
+
200
+ ### Training
201
+ Specify the `trainpath`, `valpath`, and `testpath` in the config file `./reconstruction/confs/one2345_lod_train.conf` to be `YOUR_BASE_FOLDER` used in data preparation steps and run the following command:
202
+ ```bash
203
+ cd reconstruction
204
+ python exp_runner_generic_blender_train.py --mode train --conf confs/one2345_lod_train.conf
205
+ ```
206
+ Experiment logs and checkpoints will be saved in `./reconstruction/exp/`.
207
+
208
+ ## Citation
209
+
210
+ If you find our code helpful, please cite our paper:
211
+
212
+ ```
213
+ @misc{liu2023one2345,
214
+ title={One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization},
215
+ author={Minghua Liu and Chao Xu and Haian Jin and Linghao Chen and Mukund Varma T and Zexiang Xu and Hao Su},
216
+ year={2023},
217
+ eprint={2306.16928},
218
+ archivePrefix={arXiv},
219
+ primaryClass={cs.CV}
220
+ }
221
+ ```
One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "image_target"
11
+ cond_stage_key: "image_cond"
12
+ image_size: 32
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+
19
+ scheduler_config: # 10000 warmup steps
20
+ target: ldm.lr_scheduler.LambdaLinearScheduler
21
+ params:
22
+ warm_up_steps: [ 100 ]
23
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
24
+ f_start: [ 1.e-6 ]
25
+ f_max: [ 1. ]
26
+ f_min: [ 1. ]
27
+
28
+ unet_config:
29
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
30
+ params:
31
+ image_size: 32 # unused
32
+ in_channels: 8
33
+ out_channels: 4
34
+ model_channels: 320
35
+ attention_resolutions: [ 4, 2, 1 ]
36
+ num_res_blocks: 2
37
+ channel_mult: [ 1, 2, 4, 4 ]
38
+ num_heads: 8
39
+ use_spatial_transformer: True
40
+ transformer_depth: 1
41
+ context_dim: 768
42
+ use_checkpoint: True
43
+ legacy: False
44
+
45
+ first_stage_config:
46
+ target: ldm.models.autoencoder.AutoencoderKL
47
+ params:
48
+ embed_dim: 4
49
+ monitor: val/rec_loss
50
+ ddconfig:
51
+ double_z: true
52
+ z_channels: 4
53
+ resolution: 256
54
+ in_channels: 3
55
+ out_ch: 3
56
+ ch: 128
57
+ ch_mult:
58
+ - 1
59
+ - 2
60
+ - 4
61
+ - 4
62
+ num_res_blocks: 2
63
+ attn_resolutions: []
64
+ dropout: 0.0
65
+ lossconfig:
66
+ target: torch.nn.Identity
67
+
68
+ cond_stage_config:
69
+ target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
70
+
71
+
72
+ data:
73
+ target: ldm.data.simple.ObjaverseDataModuleFromConfig
74
+ params:
75
+ root_dir: 'views_whole_sphere'
76
+ batch_size: 192
77
+ num_workers: 16
78
+ total_view: 4
79
+ train:
80
+ validation: False
81
+ image_transforms:
82
+ size: 256
83
+
84
+ validation:
85
+ validation: True
86
+ image_transforms:
87
+ size: 256
88
+
89
+
90
+ lightning:
91
+ find_unused_parameters: false
92
+ metrics_over_trainsteps_checkpoint: True
93
+ modelcheckpoint:
94
+ params:
95
+ every_n_train_steps: 5000
96
+ callbacks:
97
+ image_logger:
98
+ target: main.ImageLogger
99
+ params:
100
+ batch_frequency: 500
101
+ max_images: 32
102
+ increase_log_steps: False
103
+ log_first_step: True
104
+ log_images_kwargs:
105
+ use_ema_scope: False
106
+ inpaint: False
107
+ plot_progressive_rows: False
108
+ plot_diffusion_rows: False
109
+ N: 32
110
+ unconditional_guidance_scale: 3.0
111
+ unconditional_guidance_label: [""]
112
+
113
+ trainer:
114
+ benchmark: True
115
+ val_check_interval: 5000000 # really sorry
116
+ num_sanity_val_steps: 0
117
+ accumulate_grad_batches: 1
One-2-3-45-master 2/download_ckpt.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from tqdm import tqdm
3
+
4
+ def download_checkpoint(url, save_path):
5
+ try:
6
+ with urllib.request.urlopen(url) as response, open(save_path, 'wb') as file:
7
+ file_size = int(response.info().get('Content-Length', -1))
8
+ chunk_size = 8192
9
+ num_chunks = file_size // chunk_size if file_size > chunk_size else 1
10
+
11
+ with tqdm(total=file_size, unit='B', unit_scale=True, desc='Downloading', ncols=100) as pbar:
12
+ for chunk in iter(lambda: response.read(chunk_size), b''):
13
+ file.write(chunk)
14
+ pbar.update(len(chunk))
15
+
16
+ print(f"Checkpoint downloaded and saved to: {save_path}")
17
+ except Exception as e:
18
+ print(f"Error downloading checkpoint: {e}")
19
+
20
+ if __name__ == "__main__":
21
+ ckpts = {
22
+ "sam_vit_h_4b8939.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/sam_vit_h_4b8939.pth",
23
+ "zero123-xl.ckpt": "https://huggingface.co/One-2-3-45/code/resolve/main/zero123-xl.ckpt",
24
+ "elevation_estimate/utils/weights/indoor_ds_new.ckpt" : "https://huggingface.co/One-2-3-45/code/resolve/main/one2345_elev_est/tools/weights/indoor_ds_new.ckpt",
25
+ "reconstruction/exp/lod0/checkpoints/ckpt_215000.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth"
26
+ }
27
+ for ckpt_name, ckpt_url in ckpts.items():
28
+ print(f"Downloading checkpoint: {ckpt_name}")
29
+ download_checkpoint(ckpt_url, ckpt_name)
30
+
One-2-3-45-master 2/elevation_estimate/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ build/
2
+ .idea/
3
+ *.egg-info/
One-2-3-45-master 2/elevation_estimate/__init__.py ADDED
File without changes
One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from .utils.elev_est_api import elev_est_api
3
+
4
+ def estimate_elev(root_dir):
5
+ img_dir = osp.join(root_dir, "stage2_8")
6
+ img_paths = []
7
+ for i in range(4):
8
+ img_paths.append(f"{img_dir}/0_{i}.png")
9
+ elev = elev_est_api(img_paths)
10
+ return elev
One-2-3-45-master 2/elevation_estimate/loftr/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .loftr import LoFTR
2
+ from .utils.cvpr_ds_config import default_cfg
One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
2
+
3
+
4
+ def build_backbone(config):
5
+ if config['backbone_type'] == 'ResNetFPN':
6
+ if config['resolution'] == (8, 2):
7
+ return ResNetFPN_8_2(config['resnetfpn'])
8
+ elif config['resolution'] == (16, 4):
9
+ return ResNetFPN_16_4(config['resnetfpn'])
10
+ else:
11
+ raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def conv1x1(in_planes, out_planes, stride=1):
6
+ """1x1 convolution without padding"""
7
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
8
+
9
+
10
+ def conv3x3(in_planes, out_planes, stride=1):
11
+ """3x3 convolution with padding"""
12
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
13
+
14
+
15
+ class BasicBlock(nn.Module):
16
+ def __init__(self, in_planes, planes, stride=1):
17
+ super().__init__()
18
+ self.conv1 = conv3x3(in_planes, planes, stride)
19
+ self.conv2 = conv3x3(planes, planes)
20
+ self.bn1 = nn.BatchNorm2d(planes)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+ self.relu = nn.ReLU(inplace=True)
23
+
24
+ if stride == 1:
25
+ self.downsample = None
26
+ else:
27
+ self.downsample = nn.Sequential(
28
+ conv1x1(in_planes, planes, stride=stride),
29
+ nn.BatchNorm2d(planes)
30
+ )
31
+
32
+ def forward(self, x):
33
+ y = x
34
+ y = self.relu(self.bn1(self.conv1(y)))
35
+ y = self.bn2(self.conv2(y))
36
+
37
+ if self.downsample is not None:
38
+ x = self.downsample(x)
39
+
40
+ return self.relu(x+y)
41
+
42
+
43
+ class ResNetFPN_8_2(nn.Module):
44
+ """
45
+ ResNet+FPN, output resolution are 1/8 and 1/2.
46
+ Each block has 2 layers.
47
+ """
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ # Config
52
+ block = BasicBlock
53
+ initial_dim = config['initial_dim']
54
+ block_dims = config['block_dims']
55
+
56
+ # Class Variable
57
+ self.in_planes = initial_dim
58
+
59
+ # Networks
60
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
61
+ self.bn1 = nn.BatchNorm2d(initial_dim)
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
65
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
66
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
67
+
68
+ # 3. FPN upsample
69
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
70
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
71
+ self.layer2_outconv2 = nn.Sequential(
72
+ conv3x3(block_dims[2], block_dims[2]),
73
+ nn.BatchNorm2d(block_dims[2]),
74
+ nn.LeakyReLU(),
75
+ conv3x3(block_dims[2], block_dims[1]),
76
+ )
77
+ self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
78
+ self.layer1_outconv2 = nn.Sequential(
79
+ conv3x3(block_dims[1], block_dims[1]),
80
+ nn.BatchNorm2d(block_dims[1]),
81
+ nn.LeakyReLU(),
82
+ conv3x3(block_dims[1], block_dims[0]),
83
+ )
84
+
85
+ for m in self.modules():
86
+ if isinstance(m, nn.Conv2d):
87
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
89
+ nn.init.constant_(m.weight, 1)
90
+ nn.init.constant_(m.bias, 0)
91
+
92
+ def _make_layer(self, block, dim, stride=1):
93
+ layer1 = block(self.in_planes, dim, stride=stride)
94
+ layer2 = block(dim, dim, stride=1)
95
+ layers = (layer1, layer2)
96
+
97
+ self.in_planes = dim
98
+ return nn.Sequential(*layers)
99
+
100
+ def forward(self, x):
101
+ # ResNet Backbone
102
+ x0 = self.relu(self.bn1(self.conv1(x)))
103
+ x1 = self.layer1(x0) # 1/2
104
+ x2 = self.layer2(x1) # 1/4
105
+ x3 = self.layer3(x2) # 1/8
106
+
107
+ # FPN
108
+ x3_out = self.layer3_outconv(x3)
109
+
110
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
111
+ x2_out = self.layer2_outconv(x2)
112
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
113
+
114
+ x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
115
+ x1_out = self.layer1_outconv(x1)
116
+ x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
117
+
118
+ return [x3_out, x1_out]
119
+
120
+
121
+ class ResNetFPN_16_4(nn.Module):
122
+ """
123
+ ResNet+FPN, output resolution are 1/16 and 1/4.
124
+ Each block has 2 layers.
125
+ """
126
+
127
+ def __init__(self, config):
128
+ super().__init__()
129
+ # Config
130
+ block = BasicBlock
131
+ initial_dim = config['initial_dim']
132
+ block_dims = config['block_dims']
133
+
134
+ # Class Variable
135
+ self.in_planes = initial_dim
136
+
137
+ # Networks
138
+ self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
139
+ self.bn1 = nn.BatchNorm2d(initial_dim)
140
+ self.relu = nn.ReLU(inplace=True)
141
+
142
+ self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
143
+ self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
144
+ self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
145
+ self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
146
+
147
+ # 3. FPN upsample
148
+ self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
149
+ self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
150
+ self.layer3_outconv2 = nn.Sequential(
151
+ conv3x3(block_dims[3], block_dims[3]),
152
+ nn.BatchNorm2d(block_dims[3]),
153
+ nn.LeakyReLU(),
154
+ conv3x3(block_dims[3], block_dims[2]),
155
+ )
156
+
157
+ self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
158
+ self.layer2_outconv2 = nn.Sequential(
159
+ conv3x3(block_dims[2], block_dims[2]),
160
+ nn.BatchNorm2d(block_dims[2]),
161
+ nn.LeakyReLU(),
162
+ conv3x3(block_dims[2], block_dims[1]),
163
+ )
164
+
165
+ for m in self.modules():
166
+ if isinstance(m, nn.Conv2d):
167
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
+ nn.init.constant_(m.weight, 1)
170
+ nn.init.constant_(m.bias, 0)
171
+
172
+ def _make_layer(self, block, dim, stride=1):
173
+ layer1 = block(self.in_planes, dim, stride=stride)
174
+ layer2 = block(dim, dim, stride=1)
175
+ layers = (layer1, layer2)
176
+
177
+ self.in_planes = dim
178
+ return nn.Sequential(*layers)
179
+
180
+ def forward(self, x):
181
+ # ResNet Backbone
182
+ x0 = self.relu(self.bn1(self.conv1(x)))
183
+ x1 = self.layer1(x0) # 1/2
184
+ x2 = self.layer2(x1) # 1/4
185
+ x3 = self.layer3(x2) # 1/8
186
+ x4 = self.layer4(x3) # 1/16
187
+
188
+ # FPN
189
+ x4_out = self.layer4_outconv(x4)
190
+
191
+ x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
192
+ x3_out = self.layer3_outconv(x3)
193
+ x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
194
+
195
+ x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
196
+ x2_out = self.layer2_outconv(x2)
197
+ x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
198
+
199
+ return [x4_out, x2_out]
One-2-3-45-master 2/elevation_estimate/loftr/loftr.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops.einops import rearrange
4
+
5
+ from .backbone import build_backbone
6
+ from .utils.position_encoding import PositionEncodingSine
7
+ from .loftr_module import LocalFeatureTransformer, FinePreprocess
8
+ from .utils.coarse_matching import CoarseMatching
9
+ from .utils.fine_matching import FineMatching
10
+
11
+
12
+ class LoFTR(nn.Module):
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ # Misc
16
+ self.config = config
17
+
18
+ # Modules
19
+ self.backbone = build_backbone(config)
20
+ self.pos_encoding = PositionEncodingSine(
21
+ config['coarse']['d_model'],
22
+ temp_bug_fix=config['coarse']['temp_bug_fix'])
23
+ self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
24
+ self.coarse_matching = CoarseMatching(config['match_coarse'])
25
+ self.fine_preprocess = FinePreprocess(config)
26
+ self.loftr_fine = LocalFeatureTransformer(config["fine"])
27
+ self.fine_matching = FineMatching()
28
+
29
+ def forward(self, data):
30
+ """
31
+ Update:
32
+ data (dict): {
33
+ 'image0': (torch.Tensor): (N, 1, H, W)
34
+ 'image1': (torch.Tensor): (N, 1, H, W)
35
+ 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
36
+ 'mask1'(optional) : (torch.Tensor): (N, H, W)
37
+ }
38
+ """
39
+ # 1. Local Feature CNN
40
+ data.update({
41
+ 'bs': data['image0'].size(0),
42
+ 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
43
+ })
44
+
45
+ if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
46
+ feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
47
+ (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
48
+ else: # handle different input shapes
49
+ (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
50
+
51
+ data.update({
52
+ 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
53
+ 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
54
+ })
55
+
56
+ # 2. coarse-level loftr module
57
+ # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
58
+ feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
59
+ feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
60
+
61
+ mask_c0 = mask_c1 = None # mask is useful in training
62
+ if 'mask0' in data:
63
+ mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
64
+ feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
65
+
66
+ # 3. match coarse-level
67
+ self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
68
+
69
+ # 4. fine-level refinement
70
+ feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
71
+ if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
72
+ feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
73
+
74
+ # 5. match fine-level
75
+ self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
76
+
77
+ def load_state_dict(self, state_dict, *args, **kwargs):
78
+ for k in list(state_dict.keys()):
79
+ if k.startswith('matcher.'):
80
+ state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
81
+ return super().load_state_dict(state_dict, *args, **kwargs)
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .transformer import LocalFeatureTransformer
2
+ from .fine_preprocess import FinePreprocess
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange, repeat
5
+
6
+
7
+ class FinePreprocess(nn.Module):
8
+ def __init__(self, config):
9
+ super().__init__()
10
+
11
+ self.config = config
12
+ self.cat_c_feat = config['fine_concat_coarse_feat']
13
+ self.W = self.config['fine_window_size']
14
+
15
+ d_model_c = self.config['coarse']['d_model']
16
+ d_model_f = self.config['fine']['d_model']
17
+ self.d_model_f = d_model_f
18
+ if self.cat_c_feat:
19
+ self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
20
+ self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
21
+
22
+ self._reset_parameters()
23
+
24
+ def _reset_parameters(self):
25
+ for p in self.parameters():
26
+ if p.dim() > 1:
27
+ nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
28
+
29
+ def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
30
+ W = self.W
31
+ stride = data['hw0_f'][0] // data['hw0_c'][0]
32
+
33
+ data.update({'W': W})
34
+ if data['b_ids'].shape[0] == 0:
35
+ feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
36
+ feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
37
+ return feat0, feat1
38
+
39
+ # 1. unfold(crop) all local windows
40
+ feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
41
+ feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
42
+ feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
43
+ feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
44
+
45
+ # 2. select only the predicted matches
46
+ feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
47
+ feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
48
+
49
+ # option: use coarse-level loftr feature as context: concat and linear
50
+ if self.cat_c_feat:
51
+ feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
52
+ feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
53
+ feat_cf_win = self.merge_feat(torch.cat([
54
+ torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
55
+ repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
56
+ ], -1))
57
+ feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
58
+
59
+ return feat_f0_unfold, feat_f1_unfold
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
+ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
+ """
5
+
6
+ import torch
7
+ from torch.nn import Module, Dropout
8
+
9
+
10
+ def elu_feature_map(x):
11
+ return torch.nn.functional.elu(x) + 1
12
+
13
+
14
+ class LinearAttention(Module):
15
+ def __init__(self, eps=1e-6):
16
+ super().__init__()
17
+ self.feature_map = elu_feature_map
18
+ self.eps = eps
19
+
20
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21
+ """ Multi-Head linear attention proposed in "Transformers are RNNs"
22
+ Args:
23
+ queries: [N, L, H, D]
24
+ keys: [N, S, H, D]
25
+ values: [N, S, H, D]
26
+ q_mask: [N, L]
27
+ kv_mask: [N, S]
28
+ Returns:
29
+ queried_values: (N, L, H, D)
30
+ """
31
+ Q = self.feature_map(queries)
32
+ K = self.feature_map(keys)
33
+
34
+ # set padded position to zero
35
+ if q_mask is not None:
36
+ Q = Q * q_mask[:, :, None, None]
37
+ if kv_mask is not None:
38
+ K = K * kv_mask[:, :, None, None]
39
+ values = values * kv_mask[:, :, None, None]
40
+
41
+ v_length = values.size(1)
42
+ values = values / v_length # prevent fp16 overflow
43
+ KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44
+ Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45
+ queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46
+
47
+ return queried_values.contiguous()
48
+
49
+
50
+ class FullAttention(Module):
51
+ def __init__(self, use_dropout=False, attention_dropout=0.1):
52
+ super().__init__()
53
+ self.use_dropout = use_dropout
54
+ self.dropout = Dropout(attention_dropout)
55
+
56
+ def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57
+ """ Multi-head scaled dot-product attention, a.k.a full attention.
58
+ Args:
59
+ queries: [N, L, H, D]
60
+ keys: [N, S, H, D]
61
+ values: [N, S, H, D]
62
+ q_mask: [N, L]
63
+ kv_mask: [N, S]
64
+ Returns:
65
+ queried_values: (N, L, H, D)
66
+ """
67
+
68
+ # Compute the unnormalized attention and apply the masks
69
+ QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70
+ if kv_mask is not None:
71
+ QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72
+
73
+ # Compute the attention and the weighted average
74
+ softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75
+ A = torch.softmax(softmax_temp * QK, dim=2)
76
+ if self.use_dropout:
77
+ A = self.dropout(A)
78
+
79
+ queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80
+
81
+ return queried_values.contiguous()
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ from .linear_attention import LinearAttention, FullAttention
5
+
6
+
7
+ class LoFTREncoderLayer(nn.Module):
8
+ def __init__(self,
9
+ d_model,
10
+ nhead,
11
+ attention='linear'):
12
+ super(LoFTREncoderLayer, self).__init__()
13
+
14
+ self.dim = d_model // nhead
15
+ self.nhead = nhead
16
+
17
+ # multi-head attention
18
+ self.q_proj = nn.Linear(d_model, d_model, bias=False)
19
+ self.k_proj = nn.Linear(d_model, d_model, bias=False)
20
+ self.v_proj = nn.Linear(d_model, d_model, bias=False)
21
+ self.attention = LinearAttention() if attention == 'linear' else FullAttention()
22
+ self.merge = nn.Linear(d_model, d_model, bias=False)
23
+
24
+ # feed-forward network
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(d_model*2, d_model*2, bias=False),
27
+ nn.ReLU(True),
28
+ nn.Linear(d_model*2, d_model, bias=False),
29
+ )
30
+
31
+ # norm and dropout
32
+ self.norm1 = nn.LayerNorm(d_model)
33
+ self.norm2 = nn.LayerNorm(d_model)
34
+
35
+ def forward(self, x, source, x_mask=None, source_mask=None):
36
+ """
37
+ Args:
38
+ x (torch.Tensor): [N, L, C]
39
+ source (torch.Tensor): [N, S, C]
40
+ x_mask (torch.Tensor): [N, L] (optional)
41
+ source_mask (torch.Tensor): [N, S] (optional)
42
+ """
43
+ bs = x.size(0)
44
+ query, key, value = x, source, source
45
+
46
+ # multi-head attention
47
+ query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
48
+ key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
49
+ value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
50
+ message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
51
+ message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
52
+ message = self.norm1(message)
53
+
54
+ # feed-forward network
55
+ message = self.mlp(torch.cat([x, message], dim=2))
56
+ message = self.norm2(message)
57
+
58
+ return x + message
59
+
60
+
61
+ class LocalFeatureTransformer(nn.Module):
62
+ """A Local Feature Transformer (LoFTR) module."""
63
+
64
+ def __init__(self, config):
65
+ super(LocalFeatureTransformer, self).__init__()
66
+
67
+ self.config = config
68
+ self.d_model = config['d_model']
69
+ self.nhead = config['nhead']
70
+ self.layer_names = config['layer_names']
71
+ encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
72
+ self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
73
+ self._reset_parameters()
74
+
75
+ def _reset_parameters(self):
76
+ for p in self.parameters():
77
+ if p.dim() > 1:
78
+ nn.init.xavier_uniform_(p)
79
+
80
+ def forward(self, feat0, feat1, mask0=None, mask1=None):
81
+ """
82
+ Args:
83
+ feat0 (torch.Tensor): [N, L, C]
84
+ feat1 (torch.Tensor): [N, S, C]
85
+ mask0 (torch.Tensor): [N, L] (optional)
86
+ mask1 (torch.Tensor): [N, S] (optional)
87
+ """
88
+
89
+ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
90
+
91
+ for layer, name in zip(self.layers, self.layer_names):
92
+ if name == 'self':
93
+ feat0 = layer(feat0, feat0, mask0, mask0)
94
+ feat1 = layer(feat1, feat1, mask1, mask1)
95
+ elif name == 'cross':
96
+ feat0 = layer(feat0, feat1, mask0, mask1)
97
+ feat1 = layer(feat1, feat0, mask1, mask0)
98
+ else:
99
+ raise KeyError
100
+
101
+ return feat0, feat1
One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops.einops import rearrange
5
+
6
+ INF = 1e9
7
+
8
+ def mask_border(m, b: int, v):
9
+ """ Mask borders with value
10
+ Args:
11
+ m (torch.Tensor): [N, H0, W0, H1, W1]
12
+ b (int)
13
+ v (m.dtype)
14
+ """
15
+ if b <= 0:
16
+ return
17
+
18
+ m[:, :b] = v
19
+ m[:, :, :b] = v
20
+ m[:, :, :, :b] = v
21
+ m[:, :, :, :, :b] = v
22
+ m[:, -b:] = v
23
+ m[:, :, -b:] = v
24
+ m[:, :, :, -b:] = v
25
+ m[:, :, :, :, -b:] = v
26
+
27
+
28
+ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
29
+ if bd <= 0:
30
+ return
31
+
32
+ m[:, :bd] = v
33
+ m[:, :, :bd] = v
34
+ m[:, :, :, :bd] = v
35
+ m[:, :, :, :, :bd] = v
36
+
37
+ h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
38
+ h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
39
+ for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
40
+ m[b_idx, h0 - bd:] = v
41
+ m[b_idx, :, w0 - bd:] = v
42
+ m[b_idx, :, :, h1 - bd:] = v
43
+ m[b_idx, :, :, :, w1 - bd:] = v
44
+
45
+
46
+ def compute_max_candidates(p_m0, p_m1):
47
+ """Compute the max candidates of all pairs within a batch
48
+
49
+ Args:
50
+ p_m0, p_m1 (torch.Tensor): padded masks
51
+ """
52
+ h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
53
+ h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
54
+ max_cand = torch.sum(
55
+ torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
56
+ return max_cand
57
+
58
+
59
+ class CoarseMatching(nn.Module):
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.config = config
63
+ # general config
64
+ self.thr = config['thr']
65
+ self.border_rm = config['border_rm']
66
+ # -- # for trainig fine-level LoFTR
67
+ self.train_coarse_percent = config['train_coarse_percent']
68
+ self.train_pad_num_gt_min = config['train_pad_num_gt_min']
69
+
70
+ # we provide 2 options for differentiable matching
71
+ self.match_type = config['match_type']
72
+ if self.match_type == 'dual_softmax':
73
+ self.temperature = config['dsmax_temperature']
74
+ elif self.match_type == 'sinkhorn':
75
+ try:
76
+ from .superglue import log_optimal_transport
77
+ except ImportError:
78
+ raise ImportError("download superglue.py first!")
79
+ self.log_optimal_transport = log_optimal_transport
80
+ self.bin_score = nn.Parameter(
81
+ torch.tensor(config['skh_init_bin_score'], requires_grad=True))
82
+ self.skh_iters = config['skh_iters']
83
+ self.skh_prefilter = config['skh_prefilter']
84
+ else:
85
+ raise NotImplementedError()
86
+
87
+ def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
88
+ """
89
+ Args:
90
+ feat0 (torch.Tensor): [N, L, C]
91
+ feat1 (torch.Tensor): [N, S, C]
92
+ data (dict)
93
+ mask_c0 (torch.Tensor): [N, L] (optional)
94
+ mask_c1 (torch.Tensor): [N, S] (optional)
95
+ Update:
96
+ data (dict): {
97
+ 'b_ids' (torch.Tensor): [M'],
98
+ 'i_ids' (torch.Tensor): [M'],
99
+ 'j_ids' (torch.Tensor): [M'],
100
+ 'gt_mask' (torch.Tensor): [M'],
101
+ 'mkpts0_c' (torch.Tensor): [M, 2],
102
+ 'mkpts1_c' (torch.Tensor): [M, 2],
103
+ 'mconf' (torch.Tensor): [M]}
104
+ NOTE: M' != M during training.
105
+ """
106
+ N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
107
+
108
+ # normalize
109
+ feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
110
+ [feat_c0, feat_c1])
111
+
112
+ if self.match_type == 'dual_softmax':
113
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
114
+ feat_c1) / self.temperature
115
+ if mask_c0 is not None:
116
+ sim_matrix.masked_fill_(
117
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
118
+ -INF)
119
+ conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
120
+
121
+ elif self.match_type == 'sinkhorn':
122
+ # sinkhorn, dustbin included
123
+ sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
124
+ if mask_c0 is not None:
125
+ sim_matrix[:, :L, :S].masked_fill_(
126
+ ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
127
+ -INF)
128
+
129
+ # build uniform prior & use sinkhorn
130
+ log_assign_matrix = self.log_optimal_transport(
131
+ sim_matrix, self.bin_score, self.skh_iters)
132
+ assign_matrix = log_assign_matrix.exp()
133
+ conf_matrix = assign_matrix[:, :-1, :-1]
134
+
135
+ # filter prediction with dustbin score (only in evaluation mode)
136
+ if not self.training and self.skh_prefilter:
137
+ filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
138
+ filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
139
+ conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
140
+ conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
141
+
142
+ if self.config['sparse_spvs']:
143
+ data.update({'conf_matrix_with_bin': assign_matrix.clone()})
144
+
145
+ data.update({'conf_matrix': conf_matrix})
146
+
147
+ # predict coarse matches from conf_matrix
148
+ data.update(**self.get_coarse_match(conf_matrix, data))
149
+
150
+ @torch.no_grad()
151
+ def get_coarse_match(self, conf_matrix, data):
152
+ """
153
+ Args:
154
+ conf_matrix (torch.Tensor): [N, L, S]
155
+ data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
156
+ Returns:
157
+ coarse_matches (dict): {
158
+ 'b_ids' (torch.Tensor): [M'],
159
+ 'i_ids' (torch.Tensor): [M'],
160
+ 'j_ids' (torch.Tensor): [M'],
161
+ 'gt_mask' (torch.Tensor): [M'],
162
+ 'm_bids' (torch.Tensor): [M],
163
+ 'mkpts0_c' (torch.Tensor): [M, 2],
164
+ 'mkpts1_c' (torch.Tensor): [M, 2],
165
+ 'mconf' (torch.Tensor): [M]}
166
+ """
167
+ axes_lengths = {
168
+ 'h0c': data['hw0_c'][0],
169
+ 'w0c': data['hw0_c'][1],
170
+ 'h1c': data['hw1_c'][0],
171
+ 'w1c': data['hw1_c'][1]
172
+ }
173
+ _device = conf_matrix.device
174
+ # 1. confidence thresholding
175
+ mask = conf_matrix > self.thr
176
+ mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
177
+ **axes_lengths)
178
+ if 'mask0' not in data:
179
+ mask_border(mask, self.border_rm, False)
180
+ else:
181
+ mask_border_with_padding(mask, self.border_rm, False,
182
+ data['mask0'], data['mask1'])
183
+ mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
184
+ **axes_lengths)
185
+
186
+ # 2. mutual nearest
187
+ mask = mask \
188
+ * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
189
+ * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
190
+
191
+ # 3. find all valid coarse matches
192
+ # this only works when at most one `True` in each row
193
+ mask_v, all_j_ids = mask.max(dim=2)
194
+ b_ids, i_ids = torch.where(mask_v)
195
+ j_ids = all_j_ids[b_ids, i_ids]
196
+ mconf = conf_matrix[b_ids, i_ids, j_ids]
197
+
198
+ # 4. Random sampling of training samples for fine-level LoFTR
199
+ # (optional) pad samples with gt coarse-level matches
200
+ if self.training:
201
+ # NOTE:
202
+ # The sampling is performed across all pairs in a batch without manually balancing
203
+ # #samples for fine-level increases w.r.t. batch_size
204
+ if 'mask0' not in data:
205
+ num_candidates_max = mask.size(0) * max(
206
+ mask.size(1), mask.size(2))
207
+ else:
208
+ num_candidates_max = compute_max_candidates(
209
+ data['mask0'], data['mask1'])
210
+ num_matches_train = int(num_candidates_max *
211
+ self.train_coarse_percent)
212
+ num_matches_pred = len(b_ids)
213
+ assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
214
+
215
+ # pred_indices is to select from prediction
216
+ if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
217
+ pred_indices = torch.arange(num_matches_pred, device=_device)
218
+ else:
219
+ pred_indices = torch.randint(
220
+ num_matches_pred,
221
+ (num_matches_train - self.train_pad_num_gt_min, ),
222
+ device=_device)
223
+
224
+ # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
225
+ gt_pad_indices = torch.randint(
226
+ len(data['spv_b_ids']),
227
+ (max(num_matches_train - num_matches_pred,
228
+ self.train_pad_num_gt_min), ),
229
+ device=_device)
230
+ mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
231
+
232
+ b_ids, i_ids, j_ids, mconf = map(
233
+ lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
234
+ dim=0),
235
+ *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
236
+ [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
237
+
238
+ # These matches select patches that feed into fine-level network
239
+ coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
240
+
241
+ # 4. Update with matches in original image resolution
242
+ scale = data['hw0_i'][0] / data['hw0_c'][0]
243
+ scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
244
+ scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
245
+ mkpts0_c = torch.stack(
246
+ [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
247
+ dim=1) * scale0
248
+ mkpts1_c = torch.stack(
249
+ [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
250
+ dim=1) * scale1
251
+
252
+ # These matches is the current prediction (for visualization)
253
+ coarse_matches.update({
254
+ 'gt_mask': mconf == 0,
255
+ 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
256
+ 'mkpts0_c': mkpts0_c[mconf != 0],
257
+ 'mkpts1_c': mkpts1_c[mconf != 0],
258
+ 'mconf': mconf[mconf != 0]
259
+ })
260
+
261
+ return coarse_matches
One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+
4
+ def lower_config(yacs_cfg):
5
+ if not isinstance(yacs_cfg, CN):
6
+ return yacs_cfg
7
+ return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
8
+
9
+
10
+ _CN = CN()
11
+ _CN.BACKBONE_TYPE = 'ResNetFPN'
12
+ _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
13
+ _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
14
+ _CN.FINE_CONCAT_COARSE_FEAT = True
15
+
16
+ # 1. LoFTR-backbone (local feature CNN) config
17
+ _CN.RESNETFPN = CN()
18
+ _CN.RESNETFPN.INITIAL_DIM = 128
19
+ _CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
20
+
21
+ # 2. LoFTR-coarse module config
22
+ _CN.COARSE = CN()
23
+ _CN.COARSE.D_MODEL = 256
24
+ _CN.COARSE.D_FFN = 256
25
+ _CN.COARSE.NHEAD = 8
26
+ _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
27
+ _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
28
+ _CN.COARSE.TEMP_BUG_FIX = False
29
+
30
+ # 3. Coarse-Matching config
31
+ _CN.MATCH_COARSE = CN()
32
+ _CN.MATCH_COARSE.THR = 0.2
33
+ _CN.MATCH_COARSE.BORDER_RM = 2
34
+ _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
35
+ _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
36
+ _CN.MATCH_COARSE.SKH_ITERS = 3
37
+ _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
38
+ _CN.MATCH_COARSE.SKH_PREFILTER = True
39
+ _CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
40
+ _CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
41
+
42
+ # 4. LoFTR-fine module config
43
+ _CN.FINE = CN()
44
+ _CN.FINE.D_MODEL = 128
45
+ _CN.FINE.D_FFN = 128
46
+ _CN.FINE.NHEAD = 8
47
+ _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
48
+ _CN.FINE.ATTENTION = 'linear'
49
+
50
+ default_cfg = lower_config(_CN)
One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from kornia.geometry.subpix import dsnt
6
+ from kornia.utils.grid import create_meshgrid
7
+
8
+
9
+ class FineMatching(nn.Module):
10
+ """FineMatching with s2d paradigm"""
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ def forward(self, feat_f0, feat_f1, data):
16
+ """
17
+ Args:
18
+ feat0 (torch.Tensor): [M, WW, C]
19
+ feat1 (torch.Tensor): [M, WW, C]
20
+ data (dict)
21
+ Update:
22
+ data (dict):{
23
+ 'expec_f' (torch.Tensor): [M, 3],
24
+ 'mkpts0_f' (torch.Tensor): [M, 2],
25
+ 'mkpts1_f' (torch.Tensor): [M, 2]}
26
+ """
27
+ M, WW, C = feat_f0.shape
28
+ W = int(math.sqrt(WW))
29
+ scale = data['hw0_i'][0] / data['hw0_f'][0]
30
+ self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
31
+
32
+ # corner case: if no coarse matches found
33
+ if M == 0:
34
+ assert self.training == False, "M is always >0, when training, see coarse_matching.py"
35
+ # logger.warning('No matches found in coarse-level.')
36
+ data.update({
37
+ 'expec_f': torch.empty(0, 3, device=feat_f0.device),
38
+ 'mkpts0_f': data['mkpts0_c'],
39
+ 'mkpts1_f': data['mkpts1_c'],
40
+ })
41
+ return
42
+
43
+ feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
44
+ sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
45
+ softmax_temp = 1. / C**.5
46
+ heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
47
+
48
+ # compute coordinates from heatmap
49
+ coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
50
+ grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
51
+
52
+ # compute std over <x, y>
53
+ var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
54
+ std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
55
+
56
+ # for fine-level supervision
57
+ data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
58
+
59
+ # compute absolute kpt coords
60
+ self.get_fine_match(coords_normalized, data)
61
+
62
+ @torch.no_grad()
63
+ def get_fine_match(self, coords_normed, data):
64
+ W, WW, C, scale = self.W, self.WW, self.C, self.scale
65
+
66
+ # mkpts0_f and mkpts1_f
67
+ mkpts0_f = data['mkpts0_c']
68
+ scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
69
+ mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
70
+
71
+ data.update({
72
+ "mkpts0_f": mkpts0_f,
73
+ "mkpts1_f": mkpts1_f
74
+ })
One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ @torch.no_grad()
5
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
6
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
7
+ Also check covisibility and depth consistency.
8
+ Depth is consistent if relative error < 0.2 (hard-coded).
9
+
10
+ Args:
11
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
12
+ depth0 (torch.Tensor): [N, H, W],
13
+ depth1 (torch.Tensor): [N, H, W],
14
+ T_0to1 (torch.Tensor): [N, 3, 4],
15
+ K0 (torch.Tensor): [N, 3, 3],
16
+ K1 (torch.Tensor): [N, 3, 3],
17
+ Returns:
18
+ calculable_mask (torch.Tensor): [N, L]
19
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
20
+ """
21
+ kpts0_long = kpts0.round().long()
22
+
23
+ # Sample depth, get calculable_mask on depth != 0
24
+ kpts0_depth = torch.stack(
25
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
26
+ ) # (N, L)
27
+ nonzero_mask = kpts0_depth != 0
28
+
29
+ # Unproject
30
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
31
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
32
+
33
+ # Rigid Transform
34
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
35
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
36
+
37
+ # Project
38
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
39
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
40
+
41
+ # Covisible Check
42
+ h, w = depth1.shape[1:3]
43
+ covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
44
+ (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
45
+ w_kpts0_long = w_kpts0.long()
46
+ w_kpts0_long[~covisible_mask, :] = 0
47
+
48
+ w_kpts0_depth = torch.stack(
49
+ [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
50
+ ) # (N, L)
51
+ consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
52
+ valid_mask = nonzero_mask * covisible_mask * consistent_mask
53
+
54
+ return valid_mask, w_kpts0
One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PositionEncodingSine(nn.Module):
7
+ """
8
+ This is a sinusoidal position encoding that generalized to 2-dimensional images
9
+ """
10
+
11
+ def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
12
+ """
13
+ Args:
14
+ max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
15
+ temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
16
+ the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
17
+ on the final performance. For now, we keep both impls for backward compatability.
18
+ We will remove the buggy impl after re-training all variants of our released models.
19
+ """
20
+ super().__init__()
21
+
22
+ pe = torch.zeros((d_model, *max_shape))
23
+ y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
24
+ x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
25
+ if temp_bug_fix:
26
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
27
+ else: # a buggy implementation (for backward compatability only)
28
+ div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
29
+ div_term = div_term[:, None, None] # [C//4, 1, 1]
30
+ pe[0::4, :, :] = torch.sin(x_position * div_term)
31
+ pe[1::4, :, :] = torch.cos(x_position * div_term)
32
+ pe[2::4, :, :] = torch.sin(y_position * div_term)
33
+ pe[3::4, :, :] = torch.cos(y_position * div_term)
34
+
35
+ self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
36
+
37
+ def forward(self, x):
38
+ """
39
+ Args:
40
+ x: [N, C, H, W]
41
+ """
42
+ return x + self.pe[:, :, :x.size(2), :x.size(3)]
One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import log
2
+ from loguru import logger
3
+
4
+ import torch
5
+ from einops import repeat
6
+ from kornia.utils import create_meshgrid
7
+
8
+ from .geometry import warp_kpts
9
+
10
+ ############## ↓ Coarse-Level supervision ↓ ##############
11
+
12
+
13
+ @torch.no_grad()
14
+ def mask_pts_at_padded_regions(grid_pt, mask):
15
+ """For megadepth dataset, zero-padding exists in images"""
16
+ mask = repeat(mask, 'n h w -> n (h w) c', c=2)
17
+ grid_pt[~mask.bool()] = 0
18
+ return grid_pt
19
+
20
+
21
+ @torch.no_grad()
22
+ def spvs_coarse(data, config):
23
+ """
24
+ Update:
25
+ data (dict): {
26
+ "conf_matrix_gt": [N, hw0, hw1],
27
+ 'spv_b_ids': [M]
28
+ 'spv_i_ids': [M]
29
+ 'spv_j_ids': [M]
30
+ 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
31
+ 'spv_pt1_i': [N, hw1, 2], in original image resolution
32
+ }
33
+
34
+ NOTE:
35
+ - for scannet dataset, there're 3 kinds of resolution {i, c, f}
36
+ - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
37
+ """
38
+ # 1. misc
39
+ device = data['image0'].device
40
+ N, _, H0, W0 = data['image0'].shape
41
+ _, _, H1, W1 = data['image1'].shape
42
+ scale = config['LOFTR']['RESOLUTION'][0]
43
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
44
+ scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
45
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
46
+
47
+ # 2. warp grids
48
+ # create kpts in meshgrid and resize them to image resolution
49
+ grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
50
+ grid_pt0_i = scale0 * grid_pt0_c
51
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
52
+ grid_pt1_i = scale1 * grid_pt1_c
53
+
54
+ # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
55
+ if 'mask0' in data:
56
+ grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
57
+ grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
58
+
59
+ # warp kpts bi-directionally and resize them to coarse-level resolution
60
+ # (no depth consistency check, since it leads to worse results experimentally)
61
+ # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
62
+ _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
63
+ _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
64
+ w_pt0_c = w_pt0_i / scale1
65
+ w_pt1_c = w_pt1_i / scale0
66
+
67
+ # 3. check if mutual nearest neighbor
68
+ w_pt0_c_round = w_pt0_c[:, :, :].round().long()
69
+ nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
70
+ w_pt1_c_round = w_pt1_c[:, :, :].round().long()
71
+ nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
72
+
73
+ # corner case: out of boundary
74
+ def out_bound_mask(pt, w, h):
75
+ return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
76
+ nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
77
+ nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
78
+
79
+ loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
80
+ correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
81
+ correct_0to1[:, 0] = False # ignore the top-left corner
82
+
83
+ # 4. construct a gt conf_matrix
84
+ conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
85
+ b_ids, i_ids = torch.where(correct_0to1 != 0)
86
+ j_ids = nearest_index1[b_ids, i_ids]
87
+
88
+ conf_matrix_gt[b_ids, i_ids, j_ids] = 1
89
+ data.update({'conf_matrix_gt': conf_matrix_gt})
90
+
91
+ # 5. save coarse matches(gt) for training fine level
92
+ if len(b_ids) == 0:
93
+ logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
94
+ # this won't affect fine-level loss calculation
95
+ b_ids = torch.tensor([0], device=device)
96
+ i_ids = torch.tensor([0], device=device)
97
+ j_ids = torch.tensor([0], device=device)
98
+
99
+ data.update({
100
+ 'spv_b_ids': b_ids,
101
+ 'spv_i_ids': i_ids,
102
+ 'spv_j_ids': j_ids
103
+ })
104
+
105
+ # 6. save intermediate results (for fast fine-level computation)
106
+ data.update({
107
+ 'spv_w_pt0_i': w_pt0_i,
108
+ 'spv_pt1_i': grid_pt1_i
109
+ })
110
+
111
+
112
+ def compute_supervision_coarse(data, config):
113
+ assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
114
+ data_source = data['dataset_name'][0]
115
+ if data_source.lower() in ['scannet', 'megadepth']:
116
+ spvs_coarse(data, config)
117
+ else:
118
+ raise ValueError(f'Unknown data source: {data_source}')
119
+
120
+
121
+ ############## ↓ Fine-Level supervision ↓ ##############
122
+
123
+ @torch.no_grad()
124
+ def spvs_fine(data, config):
125
+ """
126
+ Update:
127
+ data (dict):{
128
+ "expec_f_gt": [M, 2]}
129
+ """
130
+ # 1. misc
131
+ # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
132
+ w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
133
+ scale = config['LOFTR']['RESOLUTION'][1]
134
+ radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
135
+
136
+ # 2. get coarse prediction
137
+ b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
138
+
139
+ # 3. compute gt
140
+ scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
141
+ # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
142
+ expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
143
+ data.update({"expec_f_gt": expec_f_gt})
144
+
145
+
146
+ def compute_supervision_fine(data, config):
147
+ data_source = data['dataset_name'][0]
148
+ if data_source.lower() in ['scannet', 'megadepth']:
149
+ spvs_fine(data, config)
150
+ else:
151
+ raise NotImplementedError
One-2-3-45-master 2/elevation_estimate/pyproject.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "elevation_estimate"
3
+ version = "0.1"
4
+
5
+ [tool.setuptools.packages.find]
6
+ exclude = ["configs", "tests"] # empty by default
7
+ namespaces = false # true by default
One-2-3-45-master 2/elevation_estimate/utils/__init__.py ADDED
File without changes
One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import os.path as osp
5
+ import imageio
6
+ from copy import deepcopy
7
+
8
+ import loguru
9
+ import torch
10
+ import matplotlib.cm as cm
11
+ import matplotlib.pyplot as plt
12
+
13
+ from ..loftr import LoFTR, default_cfg
14
+ from . import plt_utils
15
+ from .plotting import make_matching_figure
16
+ from .utils3d import rect_to_img, canonical_to_camera, calc_pose
17
+
18
+
19
+ class ElevEstHelper:
20
+ _feature_matcher = None
21
+
22
+ @classmethod
23
+ def get_feature_matcher(cls):
24
+ if cls._feature_matcher is None:
25
+ loguru.logger.info("Loading feature matcher...")
26
+ _default_cfg = deepcopy(default_cfg)
27
+ _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt
28
+ matcher = LoFTR(config=_default_cfg)
29
+ current_dir = os.path.dirname(os.path.abspath(__file__))
30
+ ckpt_path = os.path.join(current_dir, "weights/indoor_ds_new.ckpt")
31
+ if not osp.exists(ckpt_path):
32
+ loguru.logger.info("Downloading feature matcher...")
33
+ os.makedirs("weights", exist_ok=True)
34
+ import gdown
35
+ gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS",
36
+ path=ckpt_path)
37
+ matcher.load_state_dict(torch.load(ckpt_path)['state_dict'])
38
+ matcher = matcher.eval().cuda()
39
+ cls._feature_matcher = matcher
40
+ return cls._feature_matcher
41
+
42
+
43
+ def mask_out_bkgd(img_path, dbg=False):
44
+ img = imageio.imread_v2(img_path)
45
+ if img.shape[-1] == 4:
46
+ fg_mask = img[:, :, :3]
47
+ else:
48
+ loguru.logger.info("Image has no alpha channel, using thresholding to mask out background")
49
+ fg_mask = ~(img > 245).all(axis=-1)
50
+ if dbg:
51
+ plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0]))
52
+ plt.show()
53
+ return fg_mask
54
+
55
+
56
+ def get_feature_matching(img_paths, dbg=False):
57
+ assert len(img_paths) == 4
58
+ matcher = ElevEstHelper.get_feature_matcher()
59
+ feature_matching = {}
60
+ masks = []
61
+ for i in range(4):
62
+ mask = mask_out_bkgd(img_paths[i], dbg=dbg)
63
+ masks.append(mask)
64
+ for i in range(0, 4):
65
+ for j in range(i + 1, 4):
66
+ img0_pth = img_paths[i]
67
+ img1_pth = img_paths[j]
68
+ mask0 = masks[i]
69
+ mask1 = masks[j]
70
+ img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
71
+ img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
72
+ original_shape = img0_raw.shape
73
+ img0_raw_resized = cv2.resize(img0_raw, (480, 480))
74
+ img1_raw_resized = cv2.resize(img1_raw, (480, 480))
75
+
76
+ img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255.
77
+ img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255.
78
+ batch = {'image0': img0, 'image1': img1}
79
+
80
+ # Inference with LoFTR and get prediction
81
+ with torch.no_grad():
82
+ matcher(batch)
83
+ mkpts0 = batch['mkpts0_f'].cpu().numpy()
84
+ mkpts1 = batch['mkpts1_f'].cpu().numpy()
85
+ mconf = batch['mconf'].cpu().numpy()
86
+ mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480
87
+ mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480
88
+ mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480
89
+ mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480
90
+ keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)]
91
+ keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]
92
+ keep = np.logical_and(keep0, keep1)
93
+ mkpts0 = mkpts0[keep]
94
+ mkpts1 = mkpts1[keep]
95
+ mconf = mconf[keep]
96
+ if dbg:
97
+ # Draw visualization
98
+ color = cm.jet(mconf)
99
+ text = [
100
+ 'LoFTR',
101
+ 'Matches: {}'.format(len(mkpts0)),
102
+ ]
103
+ fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
104
+ fig.show()
105
+ feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1)
106
+
107
+ return feature_matching
108
+
109
+
110
+ def gen_pose_hypothesis(center_elevation):
111
+ elevations = np.radians(
112
+ [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120
113
+ azimuths = np.radians([30, 30, 30, 20, 40])
114
+ input_poses = calc_pose(elevations, azimuths, len(azimuths))
115
+ input_poses = input_poses[1:]
116
+ input_poses[..., 1] *= -1
117
+ input_poses[..., 2] *= -1
118
+ return input_poses
119
+
120
+
121
+ def ba_error_general(K, matches, poses):
122
+ projmat0 = K @ poses[0].inverse()[:3, :4]
123
+ projmat1 = K @ poses[1].inverse()[:3, :4]
124
+ match_01 = matches[0]
125
+ pts0 = match_01[:, :2]
126
+ pts1 = match_01[:, 2:4]
127
+ Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(),
128
+ pts0.cpu().numpy().T, pts1.cpu().numpy().T)
129
+ Xref = Xref[:3] / Xref[3:]
130
+ Xref = Xref.T
131
+ Xref = torch.from_numpy(Xref).cuda().float()
132
+ reproj_error = 0
133
+ for match, cp in zip(matches[1:], poses[2:]):
134
+ dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1))
135
+ if dist.numel() > 0:
136
+ # print("dist.shape", dist.shape)
137
+ m0to2_index = dist.argmin(1)
138
+ keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1
139
+ if keep.sum() > 0:
140
+ xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse()))
141
+ reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1)
142
+ conf02 = match[m0to2_index][keep][:, -1]
143
+ reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum())
144
+
145
+ return reproj_error
146
+
147
+
148
+ def find_optim_elev(elevs, nimgs, matches, K, dbg=False):
149
+ errs = []
150
+ for elev in elevs:
151
+ err = 0
152
+ cam_poses = gen_pose_hypothesis(elev)
153
+ for start in range(nimgs - 1):
154
+ batch_matches, batch_poses = [], []
155
+ for i in range(start, nimgs + start):
156
+ ci = i % nimgs
157
+ batch_poses.append(cam_poses[ci])
158
+ for j in range(nimgs - 1):
159
+ key = f"{start}_{(start + j + 1) % nimgs}"
160
+ match = matches[key]
161
+ batch_matches.append(match)
162
+ err += ba_error_general(K, batch_matches, batch_poses)
163
+ errs.append(err)
164
+ errs = torch.tensor(errs)
165
+ if dbg:
166
+ plt.plot(elevs, errs)
167
+ plt.show()
168
+ optim_elev = elevs[torch.argmin(errs)].item()
169
+ return optim_elev
170
+
171
+
172
+ def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False):
173
+ flag = True
174
+ matches = {}
175
+ for i in range(4):
176
+ for j in range(i + 1, 4):
177
+ match_ij = feature_matching[f"{i}_{j}"]
178
+ if len(match_ij) == 0:
179
+ flag = False
180
+ match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1)
181
+ matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda()
182
+ matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda()
183
+ if not flag:
184
+ loguru.logger.info("0 matches, could not estimate elevation")
185
+ return None
186
+ interval = 10
187
+ elevs = np.arange(min_elev, max_elev, interval)
188
+ optim_elev1 = find_optim_elev(elevs, 4, matches, K)
189
+
190
+ elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1)
191
+ optim_elev2 = find_optim_elev(elevs, 4, matches, K)
192
+
193
+ return optim_elev2
194
+
195
+
196
+ def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False):
197
+ feature_matching = get_feature_matching(img_paths, dbg=dbg)
198
+ if K is None:
199
+ loguru.logger.warning("K is not provided, using default K")
200
+ K = np.array([[280.0, 0, 128.0],
201
+ [0, 280.0, 128.0],
202
+ [0, 0, 1]])
203
+ K = torch.from_numpy(K).cuda().float()
204
+ elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg)
205
+ return elev
One-2-3-45-master 2/elevation_estimate/utils/plotting.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib
5
+
6
+
7
+ def _compute_conf_thresh(data):
8
+ dataset_name = data['dataset_name'][0].lower()
9
+ if dataset_name == 'scannet':
10
+ thr = 5e-4
11
+ elif dataset_name == 'megadepth':
12
+ thr = 1e-4
13
+ else:
14
+ raise ValueError(f'Unknown dataset: {dataset_name}')
15
+ return thr
16
+
17
+
18
+ # --- VISUALIZATION --- #
19
+
20
+ def make_matching_figure(
21
+ img0, img1, mkpts0, mkpts1, color,
22
+ kpts0=None, kpts1=None, text=[], dpi=75, path=None):
23
+ # draw image pair
24
+ assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
25
+ fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
26
+ axes[0].imshow(img0, cmap='gray')
27
+ axes[1].imshow(img1, cmap='gray')
28
+ for i in range(2): # clear all frames
29
+ axes[i].get_yaxis().set_ticks([])
30
+ axes[i].get_xaxis().set_ticks([])
31
+ for spine in axes[i].spines.values():
32
+ spine.set_visible(False)
33
+ plt.tight_layout(pad=1)
34
+
35
+ if kpts0 is not None:
36
+ assert kpts1 is not None
37
+ axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
38
+ axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
39
+
40
+ # draw matches
41
+ if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
42
+ fig.canvas.draw()
43
+ transFigure = fig.transFigure.inverted()
44
+ fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
45
+ fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
46
+ fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
47
+ (fkpts0[i, 1], fkpts1[i, 1]),
48
+ transform=fig.transFigure, c=color[i], linewidth=1)
49
+ for i in range(len(mkpts0))]
50
+
51
+ axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
52
+ axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
53
+
54
+ # put txts
55
+ txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
56
+ fig.text(
57
+ 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
58
+ fontsize=15, va='top', ha='left', color=txt_color)
59
+
60
+ # save or return figure
61
+ if path:
62
+ plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
63
+ plt.close()
64
+ else:
65
+ return fig
66
+
67
+
68
+ def _make_evaluation_figure(data, b_id, alpha='dynamic'):
69
+ b_mask = data['m_bids'] == b_id
70
+ conf_thr = _compute_conf_thresh(data)
71
+
72
+ img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
73
+ img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
74
+ kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
75
+ kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
76
+
77
+ # for megadepth, we visualize matches on the resized image
78
+ if 'scale0' in data:
79
+ kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
80
+ kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
81
+
82
+ epi_errs = data['epi_errs'][b_mask].cpu().numpy()
83
+ correct_mask = epi_errs < conf_thr
84
+ precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
85
+ n_correct = np.sum(correct_mask)
86
+ n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
87
+ recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
88
+ # recall might be larger than 1, since the calculation of conf_matrix_gt
89
+ # uses groundtruth depths and camera poses, but epipolar distance is used here.
90
+
91
+ # matching info
92
+ if alpha == 'dynamic':
93
+ alpha = dynamic_alpha(len(correct_mask))
94
+ color = error_colormap(epi_errs, conf_thr, alpha=alpha)
95
+
96
+ text = [
97
+ f'#Matches {len(kpts0)}',
98
+ f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
99
+ f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
100
+ ]
101
+
102
+ # make the figure
103
+ figure = make_matching_figure(img0, img1, kpts0, kpts1,
104
+ color, text=text)
105
+ return figure
106
+
107
+ def _make_confidence_figure(data, b_id):
108
+ # TODO: Implement confidence figure
109
+ raise NotImplementedError()
110
+
111
+
112
+ def make_matching_figures(data, config, mode='evaluation'):
113
+ """ Make matching figures for a batch.
114
+
115
+ Args:
116
+ data (Dict): a batch updated by PL_LoFTR.
117
+ config (Dict): matcher config
118
+ Returns:
119
+ figures (Dict[str, List[plt.figure]]
120
+ """
121
+ assert mode in ['evaluation', 'confidence'] # 'confidence'
122
+ figures = {mode: []}
123
+ for b_id in range(data['image0'].size(0)):
124
+ if mode == 'evaluation':
125
+ fig = _make_evaluation_figure(
126
+ data, b_id,
127
+ alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
128
+ elif mode == 'confidence':
129
+ fig = _make_confidence_figure(data, b_id)
130
+ else:
131
+ raise ValueError(f'Unknown plot mode: {mode}')
132
+ figures[mode].append(fig)
133
+ return figures
134
+
135
+
136
+ def dynamic_alpha(n_matches,
137
+ milestones=[0, 300, 1000, 2000],
138
+ alphas=[1.0, 0.8, 0.4, 0.2]):
139
+ if n_matches == 0:
140
+ return 1.0
141
+ ranges = list(zip(alphas, alphas[1:] + [None]))
142
+ loc = bisect.bisect_right(milestones, n_matches) - 1
143
+ _range = ranges[loc]
144
+ if _range[1] is None:
145
+ return _range[0]
146
+ return _range[1] + (milestones[loc + 1] - n_matches) / (
147
+ milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
148
+
149
+
150
+ def error_colormap(err, thr, alpha=1.0):
151
+ assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
152
+ x = 1 - np.clip(err / (thr * 2), 0, 1)
153
+ return np.clip(
154
+ np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import os
3
+ import matplotlib.pyplot as plt
4
+ import torch
5
+ import cv2
6
+ import math
7
+
8
+ import numpy as np
9
+ import tqdm
10
+ from cv2 import findContours
11
+ from dl_ext.primitive import safe_zip
12
+ from dl_ext.timer import EvalTime
13
+
14
+
15
+ def plot_confidence(confidence):
16
+ n = len(confidence)
17
+ plt.plot(np.arange(n), confidence)
18
+ plt.show()
19
+
20
+
21
+ def image_grid(
22
+ images,
23
+ rows=None,
24
+ cols=None,
25
+ fill: bool = True,
26
+ show_axes: bool = False,
27
+ rgb=None,
28
+ show=True,
29
+ label=None,
30
+ **kwargs
31
+ ):
32
+ """
33
+ A util function for plotting a grid of images.
34
+ Args:
35
+ images: (N, H, W, 4) array of RGBA images
36
+ rows: number of rows in the grid
37
+ cols: number of columns in the grid
38
+ fill: boolean indicating if the space between images should be filled
39
+ show_axes: boolean indicating if the axes of the plots should be visible
40
+ rgb: boolean, If True, only RGB channels are plotted.
41
+ If False, only the alpha channel is plotted.
42
+ Returns:
43
+ None
44
+ """
45
+ evaltime = EvalTime(disable=True)
46
+ evaltime('')
47
+ if isinstance(images, torch.Tensor):
48
+ images = images.detach().cpu()
49
+ if len(images[0].shape) == 2:
50
+ rgb = False
51
+ if images[0].shape[-1] == 2:
52
+ # flow
53
+ images = [flow_to_image(im) for im in images]
54
+ if (rows is None) != (cols is None):
55
+ raise ValueError("Specify either both rows and cols or neither.")
56
+
57
+ if rows is None:
58
+ rows = int(len(images) ** 0.5)
59
+ cols = math.ceil(len(images) / rows)
60
+
61
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
62
+ if len(images) < 50:
63
+ figsize = (10, 10)
64
+ else:
65
+ figsize = (15, 15)
66
+ evaltime('0.5')
67
+ plt.figure(figsize=figsize)
68
+ # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize)
69
+ if label:
70
+ # fig.suptitle(label, fontsize=30)
71
+ plt.suptitle(label, fontsize=30)
72
+ # bleed = 0
73
+ # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
74
+ evaltime('subplots')
75
+
76
+ # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))):
77
+ for i in range(len(images)):
78
+ # evaltime(f'{i} begin')
79
+ plt.subplot(rows, cols, i + 1)
80
+ if rgb:
81
+ # only render RGB channels
82
+ plt.imshow(images[i][..., :3], **kwargs)
83
+ # ax.imshow(im[..., :3], **kwargs)
84
+ else:
85
+ # only render Alpha channel
86
+ plt.imshow(images[i], **kwargs)
87
+ # ax.imshow(im, **kwargs)
88
+ if not show_axes:
89
+ plt.axis('off')
90
+ # ax.set_axis_off()
91
+ # ax.set_title(f'{i}')
92
+ plt.title(f'{i}')
93
+ # evaltime(f'{i} end')
94
+ evaltime('2')
95
+ if show:
96
+ plt.show()
97
+ # return fig
98
+
99
+
100
+ def depth_grid(
101
+ depths,
102
+ rows=None,
103
+ cols=None,
104
+ fill: bool = True,
105
+ show_axes: bool = False,
106
+ ):
107
+ """
108
+ A util function for plotting a grid of images.
109
+ Args:
110
+ images: (N, H, W, 4) array of RGBA images
111
+ rows: number of rows in the grid
112
+ cols: number of columns in the grid
113
+ fill: boolean indicating if the space between images should be filled
114
+ show_axes: boolean indicating if the axes of the plots should be visible
115
+ rgb: boolean, If True, only RGB channels are plotted.
116
+ If False, only the alpha channel is plotted.
117
+ Returns:
118
+ None
119
+ """
120
+ if (rows is None) != (cols is None):
121
+ raise ValueError("Specify either both rows and cols or neither.")
122
+
123
+ if rows is None:
124
+ rows = len(depths)
125
+ cols = 1
126
+
127
+ gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
128
+ fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
129
+ bleed = 0
130
+ fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
131
+
132
+ for ax, im in zip(axarr.ravel(), depths):
133
+ ax.imshow(im)
134
+ if not show_axes:
135
+ ax.set_axis_off()
136
+ plt.show()
137
+
138
+
139
+ def hover_masks_on_imgs(images, masks):
140
+ masks = np.array(masks)
141
+ new_imgs = []
142
+ tids = list(range(1, masks.max() + 1))
143
+ colors = colormap(rgb=True, lighten=True)
144
+ for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)):
145
+ for tid in tids:
146
+ im = vis_mask(
147
+ im,
148
+ (mask == tid).astype(np.uint8),
149
+ color=colors[tid],
150
+ alpha=0.5,
151
+ border_alpha=0.5,
152
+ border_color=[255, 255, 255],
153
+ border_thick=3)
154
+ new_imgs.append(im)
155
+ return new_imgs
156
+
157
+
158
+ def vis_mask(img,
159
+ mask,
160
+ color=[255, 255, 255],
161
+ alpha=0.4,
162
+ show_border=True,
163
+ border_alpha=0.5,
164
+ border_thick=1,
165
+ border_color=None):
166
+ """Visualizes a single binary mask."""
167
+ if isinstance(mask, torch.Tensor):
168
+ from anypose.utils.pn_utils import to_array
169
+ mask = to_array(mask > 0).astype(np.uint8)
170
+ img = img.astype(np.float32)
171
+ idx = np.nonzero(mask)
172
+
173
+ img[idx[0], idx[1], :] *= 1.0 - alpha
174
+ img[idx[0], idx[1], :] += [alpha * x for x in color]
175
+
176
+ if show_border:
177
+ contours, _ = findContours(
178
+ mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
179
+ # contours = [c for c in contours if c.shape[0] > 10]
180
+ if border_color is None:
181
+ border_color = color
182
+ if not isinstance(border_color, list):
183
+ border_color = border_color.tolist()
184
+ if border_alpha < 1:
185
+ with_border = img.copy()
186
+ cv2.drawContours(with_border, contours, -1, border_color,
187
+ border_thick, cv2.LINE_AA)
188
+ img = (1 - border_alpha) * img + border_alpha * with_border
189
+ else:
190
+ cv2.drawContours(img, contours, -1, border_color, border_thick,
191
+ cv2.LINE_AA)
192
+
193
+ return img.astype(np.uint8)
194
+
195
+
196
+ def colormap(rgb=False, lighten=True):
197
+ """Copied from Detectron codebase."""
198
+ color_list = np.array(
199
+ [
200
+ 0.000, 0.447, 0.741,
201
+ 0.850, 0.325, 0.098,
202
+ 0.929, 0.694, 0.125,
203
+ 0.494, 0.184, 0.556,
204
+ 0.466, 0.674, 0.188,
205
+ 0.301, 0.745, 0.933,
206
+ 0.635, 0.078, 0.184,
207
+ 0.300, 0.300, 0.300,
208
+ 0.600, 0.600, 0.600,
209
+ 1.000, 0.000, 0.000,
210
+ 1.000, 0.500, 0.000,
211
+ 0.749, 0.749, 0.000,
212
+ 0.000, 1.000, 0.000,
213
+ 0.000, 0.000, 1.000,
214
+ 0.667, 0.000, 1.000,
215
+ 0.333, 0.333, 0.000,
216
+ 0.333, 0.667, 0.000,
217
+ 0.333, 1.000, 0.000,
218
+ 0.667, 0.333, 0.000,
219
+ 0.667, 0.667, 0.000,
220
+ 0.667, 1.000, 0.000,
221
+ 1.000, 0.333, 0.000,
222
+ 1.000, 0.667, 0.000,
223
+ 1.000, 1.000, 0.000,
224
+ 0.000, 0.333, 0.500,
225
+ 0.000, 0.667, 0.500,
226
+ 0.000, 1.000, 0.500,
227
+ 0.333, 0.000, 0.500,
228
+ 0.333, 0.333, 0.500,
229
+ 0.333, 0.667, 0.500,
230
+ 0.333, 1.000, 0.500,
231
+ 0.667, 0.000, 0.500,
232
+ 0.667, 0.333, 0.500,
233
+ 0.667, 0.667, 0.500,
234
+ 0.667, 1.000, 0.500,
235
+ 1.000, 0.000, 0.500,
236
+ 1.000, 0.333, 0.500,
237
+ 1.000, 0.667, 0.500,
238
+ 1.000, 1.000, 0.500,
239
+ 0.000, 0.333, 1.000,
240
+ 0.000, 0.667, 1.000,
241
+ 0.000, 1.000, 1.000,
242
+ 0.333, 0.000, 1.000,
243
+ 0.333, 0.333, 1.000,
244
+ 0.333, 0.667, 1.000,
245
+ 0.333, 1.000, 1.000,
246
+ 0.667, 0.000, 1.000,
247
+ 0.667, 0.333, 1.000,
248
+ 0.667, 0.667, 1.000,
249
+ 0.667, 1.000, 1.000,
250
+ 1.000, 0.000, 1.000,
251
+ 1.000, 0.333, 1.000,
252
+ 1.000, 0.667, 1.000,
253
+ 0.167, 0.000, 0.000,
254
+ 0.333, 0.000, 0.000,
255
+ 0.500, 0.000, 0.000,
256
+ 0.667, 0.000, 0.000,
257
+ 0.833, 0.000, 0.000,
258
+ 1.000, 0.000, 0.000,
259
+ 0.000, 0.167, 0.000,
260
+ 0.000, 0.333, 0.000,
261
+ 0.000, 0.500, 0.000,
262
+ 0.000, 0.667, 0.000,
263
+ 0.000, 0.833, 0.000,
264
+ 0.000, 1.000, 0.000,
265
+ 0.000, 0.000, 0.167,
266
+ 0.000, 0.000, 0.333,
267
+ 0.000, 0.000, 0.500,
268
+ 0.000, 0.000, 0.667,
269
+ 0.000, 0.000, 0.833,
270
+ 0.000, 0.000, 1.000,
271
+ 0.000, 0.000, 0.000,
272
+ 0.143, 0.143, 0.143,
273
+ 0.286, 0.286, 0.286,
274
+ 0.429, 0.429, 0.429,
275
+ 0.571, 0.571, 0.571,
276
+ 0.714, 0.714, 0.714,
277
+ 0.857, 0.857, 0.857,
278
+ 1.000, 1.000, 1.000
279
+ ]
280
+ ).astype(np.float32)
281
+ color_list = color_list.reshape((-1, 3))
282
+ if not rgb:
283
+ color_list = color_list[:, ::-1]
284
+
285
+ if lighten:
286
+ # Make all the colors a little lighter / whiter. This is copied
287
+ # from the detectron visualization code (search for 'w_ratio').
288
+ w_ratio = 0.4
289
+ color_list = (color_list * (1 - w_ratio) + w_ratio)
290
+ return color_list * 255
291
+
292
+
293
+ def vis_layer_mask(masks, save_path=None):
294
+ masks = torch.as_tensor(masks)
295
+ tids = masks.unique().tolist()
296
+ tids.remove(0)
297
+ for tid in tqdm.tqdm(tids):
298
+ show = save_path is None
299
+ image_grid(masks == tid, label=f'{tid}', show=show)
300
+ if save_path:
301
+ os.makedirs(osp.dirname(save_path), exist_ok=True)
302
+ plt.savefig(save_path % tid)
303
+ plt.close('all')
304
+
305
+
306
+ def show(x, **kwargs):
307
+ if isinstance(x, torch.Tensor):
308
+ x = x.detach().cpu()
309
+ plt.imshow(x, **kwargs)
310
+ plt.show()
311
+
312
+
313
+ def vis_title(rgb, text, shift_y=30):
314
+ tmp = rgb.copy()
315
+ shift_x = rgb.shape[1] // 2
316
+ cv2.putText(tmp, text,
317
+ (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
318
+ return tmp
One-2-3-45-master 2/elevation_estimate/utils/utils3d.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def cart_to_hom(pts):
6
+ """
7
+ :param pts: (N, 3 or 2)
8
+ :return pts_hom: (N, 4 or 3)
9
+ """
10
+ if isinstance(pts, np.ndarray):
11
+ pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1)
12
+ else:
13
+ ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device)
14
+ pts_hom = torch.cat((pts, ones), dim=-1)
15
+ return pts_hom
16
+
17
+
18
+ def hom_to_cart(pts):
19
+ return pts[..., :-1] / pts[..., -1:]
20
+
21
+
22
+ def canonical_to_camera(pts, pose):
23
+ pts = cart_to_hom(pts)
24
+ pts = pts @ pose.transpose(-1, -2)
25
+ pts = hom_to_cart(pts)
26
+ return pts
27
+
28
+
29
+ def rect_to_img(K, pts_rect):
30
+ from dl_ext.vision_ext.datasets.kitti.structures import Calibration
31
+ pts_2d_hom = pts_rect @ K.t()
32
+ pts_img = Calibration.hom_to_cart(pts_2d_hom)
33
+ return pts_img
34
+
35
+
36
+ def calc_pose(phis, thetas, size, radius=1.2):
37
+ import torch
38
+ def normalize(vectors):
39
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
40
+
41
+ device = torch.device('cuda')
42
+ thetas = torch.FloatTensor(thetas).to(device)
43
+ phis = torch.FloatTensor(phis).to(device)
44
+
45
+ centers = torch.stack([
46
+ radius * torch.sin(thetas) * torch.sin(phis),
47
+ -radius * torch.cos(thetas) * torch.sin(phis),
48
+ radius * torch.cos(phis),
49
+ ], dim=-1) # [B, 3]
50
+
51
+ # lookat
52
+ forward_vector = normalize(centers).squeeze(0)
53
+ up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1)
54
+ right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
55
+ if right_vector.pow(2).sum() < 0.01:
56
+ right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
57
+ up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
58
+
59
+ poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
60
+ poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
61
+ poses[:, :3, 3] = centers
62
+ return poses
One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep ADDED
File without changes
One-2-3-45-master 2/example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
One-2-3-45-master 2/ldm/data/__init__.py ADDED
File without changes
One-2-3-45-master 2/ldm/data/base.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from abc import abstractmethod
4
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
5
+
6
+
7
+ class Txt2ImgIterableBaseDataset(IterableDataset):
8
+ '''
9
+ Define an interface to make the IterableDatasets for text2img data chainable
10
+ '''
11
+ def __init__(self, num_records=0, valid_ids=None, size=256):
12
+ super().__init__()
13
+ self.num_records = num_records
14
+ self.valid_ids = valid_ids
15
+ self.sample_ids = valid_ids
16
+ self.size = size
17
+
18
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
19
+
20
+ def __len__(self):
21
+ return self.num_records
22
+
23
+ @abstractmethod
24
+ def __iter__(self):
25
+ pass
26
+
27
+
28
+ class PRNGMixin(object):
29
+ """
30
+ Adds a prng property which is a numpy RandomState which gets
31
+ reinitialized whenever the pid changes to avoid synchronized sampling
32
+ behavior when used in conjunction with multiprocessing.
33
+ """
34
+ @property
35
+ def prng(self):
36
+ currentpid = os.getpid()
37
+ if getattr(self, "_initpid", None) != currentpid:
38
+ self._initpid = currentpid
39
+ self._prng = np.random.RandomState()
40
+ return self._prng
One-2-3-45-master 2/ldm/data/coco.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import albumentations
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ from torch.utils.data import Dataset
8
+ from abc import abstractmethod
9
+
10
+
11
+ class CocoBase(Dataset):
12
+ """needed for (image, caption, segmentation) pairs"""
13
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
14
+ crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
15
+ self.split = self.get_split()
16
+ self.size = size
17
+ if crop_size is None:
18
+ self.crop_size = size
19
+ else:
20
+ self.crop_size = crop_size
21
+
22
+ assert crop_type in [None, 'random', 'center']
23
+ self.crop_type = crop_type
24
+ self.use_segmenation = use_segmentation
25
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
26
+ self.stuffthing = use_stuffthing # include thing in segmentation
27
+ if self.onehot and not self.stuffthing:
28
+ raise NotImplemented("One hot mode is only supported for the "
29
+ "stuffthings version because labels are stored "
30
+ "a bit different.")
31
+
32
+ data_json = datajson
33
+ with open(data_json) as json_file:
34
+ self.json_data = json.load(json_file)
35
+ self.img_id_to_captions = dict()
36
+ self.img_id_to_filepath = dict()
37
+ self.img_id_to_segmentation_filepath = dict()
38
+
39
+ assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
40
+ f"captions_val{self.year()}.json"]
41
+ # TODO currently hardcoded paths, would be better to follow logic in
42
+ # cocstuff pixelmaps
43
+ if self.use_segmenation:
44
+ if self.stuffthing:
45
+ self.segmentation_prefix = (
46
+ f"data/cocostuffthings/val{self.year()}" if
47
+ data_json.endswith(f"captions_val{self.year()}.json") else
48
+ f"data/cocostuffthings/train{self.year()}")
49
+ else:
50
+ self.segmentation_prefix = (
51
+ f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
52
+ data_json.endswith(f"captions_val{self.year()}.json") else
53
+ f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
54
+
55
+ imagedirs = self.json_data["images"]
56
+ self.labels = {"image_ids": list()}
57
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
58
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
59
+ self.img_id_to_captions[imgdir["id"]] = list()
60
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
61
+ if self.use_segmenation:
62
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
63
+ self.segmentation_prefix, pngfilename)
64
+ if given_files is not None:
65
+ if pngfilename in given_files:
66
+ self.labels["image_ids"].append(imgdir["id"])
67
+ else:
68
+ self.labels["image_ids"].append(imgdir["id"])
69
+
70
+ capdirs = self.json_data["annotations"]
71
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
72
+ # there are in average 5 captions per image
73
+ #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
74
+ self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
75
+
76
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
77
+ if self.split=="validation":
78
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
79
+ else:
80
+ # default option for train is random crop
81
+ if self.crop_type in [None, 'random']:
82
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
83
+ else:
84
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
85
+ self.preprocessor = albumentations.Compose(
86
+ [self.rescaler, self.cropper],
87
+ additional_targets={"segmentation": "image"})
88
+ if force_no_crop:
89
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
90
+ self.preprocessor = albumentations.Compose(
91
+ [self.rescaler],
92
+ additional_targets={"segmentation": "image"})
93
+
94
+ @abstractmethod
95
+ def year(self):
96
+ raise NotImplementedError()
97
+
98
+ def __len__(self):
99
+ return len(self.labels["image_ids"])
100
+
101
+ def preprocess_image(self, image_path, segmentation_path=None):
102
+ image = Image.open(image_path)
103
+ if not image.mode == "RGB":
104
+ image = image.convert("RGB")
105
+ image = np.array(image).astype(np.uint8)
106
+ if segmentation_path:
107
+ segmentation = Image.open(segmentation_path)
108
+ if not self.onehot and not segmentation.mode == "RGB":
109
+ segmentation = segmentation.convert("RGB")
110
+ segmentation = np.array(segmentation).astype(np.uint8)
111
+ if self.onehot:
112
+ assert self.stuffthing
113
+ # stored in caffe format: unlabeled==255. stuff and thing from
114
+ # 0-181. to be compatible with the labels in
115
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
116
+ # we shift stuffthing one to the right and put unlabeled in zero
117
+ # as long as segmentation is uint8 shifting to right handles the
118
+ # latter too
119
+ assert segmentation.dtype == np.uint8
120
+ segmentation = segmentation + 1
121
+
122
+ processed = self.preprocessor(image=image, segmentation=segmentation)
123
+
124
+ image, segmentation = processed["image"], processed["segmentation"]
125
+ else:
126
+ image = self.preprocessor(image=image,)['image']
127
+
128
+ image = (image / 127.5 - 1.0).astype(np.float32)
129
+ if segmentation_path:
130
+ if self.onehot:
131
+ assert segmentation.dtype == np.uint8
132
+ # make it one hot
133
+ n_labels = 183
134
+ flatseg = np.ravel(segmentation)
135
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
136
+ onehot[np.arange(flatseg.size), flatseg] = True
137
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
138
+ segmentation = onehot
139
+ else:
140
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
141
+ return image, segmentation
142
+ else:
143
+ return image
144
+
145
+ def __getitem__(self, i):
146
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
147
+ if self.use_segmenation:
148
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
149
+ image, segmentation = self.preprocess_image(img_path, seg_path)
150
+ else:
151
+ image = self.preprocess_image(img_path)
152
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
153
+ # randomly draw one of all available captions per image
154
+ caption = captions[np.random.randint(0, len(captions))]
155
+ example = {"image": image,
156
+ #"caption": [str(caption[0])],
157
+ "caption": caption,
158
+ "img_path": img_path,
159
+ "filename_": img_path.split(os.sep)[-1]
160
+ }
161
+ if self.use_segmenation:
162
+ example.update({"seg_path": seg_path, 'segmentation': segmentation})
163
+ return example
164
+
165
+
166
+ class CocoImagesAndCaptionsTrain2017(CocoBase):
167
+ """returns a pair of (image, caption)"""
168
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
169
+ super().__init__(size=size,
170
+ dataroot="data/coco/train2017",
171
+ datajson="data/coco/annotations/captions_train2017.json",
172
+ onehot_segmentation=onehot_segmentation,
173
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
174
+
175
+ def get_split(self):
176
+ return "train"
177
+
178
+ def year(self):
179
+ return '2017'
180
+
181
+
182
+ class CocoImagesAndCaptionsValidation2017(CocoBase):
183
+ """returns a pair of (image, caption)"""
184
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
185
+ given_files=None):
186
+ super().__init__(size=size,
187
+ dataroot="data/coco/val2017",
188
+ datajson="data/coco/annotations/captions_val2017.json",
189
+ onehot_segmentation=onehot_segmentation,
190
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
191
+ given_files=given_files)
192
+
193
+ def get_split(self):
194
+ return "validation"
195
+
196
+ def year(self):
197
+ return '2017'
198
+
199
+
200
+
201
+ class CocoImagesAndCaptionsTrain2014(CocoBase):
202
+ """returns a pair of (image, caption)"""
203
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
204
+ super().__init__(size=size,
205
+ dataroot="data/coco/train2014",
206
+ datajson="data/coco/annotations2014/annotations/captions_train2014.json",
207
+ onehot_segmentation=onehot_segmentation,
208
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
209
+ use_segmentation=False,
210
+ crop_type=crop_type)
211
+
212
+ def get_split(self):
213
+ return "train"
214
+
215
+ def year(self):
216
+ return '2014'
217
+
218
+ class CocoImagesAndCaptionsValidation2014(CocoBase):
219
+ """returns a pair of (image, caption)"""
220
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
221
+ given_files=None,crop_type='center',**kwargs):
222
+ super().__init__(size=size,
223
+ dataroot="data/coco/val2014",
224
+ datajson="data/coco/annotations2014/annotations/captions_val2014.json",
225
+ onehot_segmentation=onehot_segmentation,
226
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
227
+ given_files=given_files,
228
+ use_segmentation=False,
229
+ crop_type=crop_type)
230
+
231
+ def get_split(self):
232
+ return "validation"
233
+
234
+ def year(self):
235
+ return '2014'
236
+
237
+ if __name__ == '__main__':
238
+ with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
239
+ json_data = json.load(json_file)
240
+ capdirs = json_data["annotations"]
241
+ import pudb; pudb.set_trace()
242
+ #d2 = CocoImagesAndCaptionsTrain2014(size=256)
243
+ d2 = CocoImagesAndCaptionsValidation2014(size=256)
244
+ print("constructed dataset.")
245
+ print(f"length of {d2.__class__.__name__}: {len(d2)}")
246
+
247
+ ex2 = d2[0]
248
+ # ex3 = d3[0]
249
+ # print(ex1["image"].shape)
250
+ print(ex2["image"].shape)
251
+ # print(ex3["image"].shape)
252
+ # print(ex1["segmentation"].shape)
253
+ print(ex2["caption"].__class__.__name__)
One-2-3-45-master 2/ldm/data/dummy.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import string
4
+ from torch.utils.data import Dataset, Subset
5
+
6
+ class DummyData(Dataset):
7
+ def __init__(self, length, size):
8
+ self.length = length
9
+ self.size = size
10
+
11
+ def __len__(self):
12
+ return self.length
13
+
14
+ def __getitem__(self, i):
15
+ x = np.random.randn(*self.size)
16
+ letters = string.ascii_lowercase
17
+ y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
18
+ return {"jpg": x, "txt": y}
19
+
20
+
21
+ class DummyDataWithEmbeddings(Dataset):
22
+ def __init__(self, length, size, emb_size):
23
+ self.length = length
24
+ self.size = size
25
+ self.emb_size = emb_size
26
+
27
+ def __len__(self):
28
+ return self.length
29
+
30
+ def __getitem__(self, i):
31
+ x = np.random.randn(*self.size)
32
+ y = np.random.randn(*self.emb_size).astype(np.float32)
33
+ return {"jpg": x, "txt": y}
34
+
One-2-3-45-master 2/ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+ example["caption"] = example["human_label"] # dummy caption
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
One-2-3-45-master 2/ldm/data/inpainting/__init__.py ADDED
File without changes
One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+ import numpy as np
3
+
4
+ settings = {
5
+ "256narrow": {
6
+ "p_irr": 1,
7
+ "min_n_irr": 4,
8
+ "max_n_irr": 50,
9
+ "max_l_irr": 40,
10
+ "max_w_irr": 10,
11
+ "min_n_box": None,
12
+ "max_n_box": None,
13
+ "min_s_box": None,
14
+ "max_s_box": None,
15
+ "marg": None,
16
+ },
17
+ "256train": {
18
+ "p_irr": 0.5,
19
+ "min_n_irr": 1,
20
+ "max_n_irr": 5,
21
+ "max_l_irr": 200,
22
+ "max_w_irr": 100,
23
+ "min_n_box": 1,
24
+ "max_n_box": 4,
25
+ "min_s_box": 30,
26
+ "max_s_box": 150,
27
+ "marg": 10,
28
+ },
29
+ "512train": { # TODO: experimental
30
+ "p_irr": 0.5,
31
+ "min_n_irr": 1,
32
+ "max_n_irr": 5,
33
+ "max_l_irr": 450,
34
+ "max_w_irr": 250,
35
+ "min_n_box": 1,
36
+ "max_n_box": 4,
37
+ "min_s_box": 30,
38
+ "max_s_box": 300,
39
+ "marg": 10,
40
+ },
41
+ "512train-large": { # TODO: experimental
42
+ "p_irr": 0.5,
43
+ "min_n_irr": 1,
44
+ "max_n_irr": 5,
45
+ "max_l_irr": 450,
46
+ "max_w_irr": 400,
47
+ "min_n_box": 1,
48
+ "max_n_box": 4,
49
+ "min_s_box": 75,
50
+ "max_s_box": 450,
51
+ "marg": 10,
52
+ },
53
+ }
54
+
55
+
56
+ def gen_segment_mask(mask, start, end, brush_width):
57
+ mask = mask > 0
58
+ mask = (255 * mask).astype(np.uint8)
59
+ mask = Image.fromarray(mask)
60
+ draw = ImageDraw.Draw(mask)
61
+ draw.line([start, end], fill=255, width=brush_width, joint="curve")
62
+ mask = np.array(mask) / 255
63
+ return mask
64
+
65
+
66
+ def gen_box_mask(mask, masked):
67
+ x_0, y_0, w, h = masked
68
+ mask[y_0:y_0 + h, x_0:x_0 + w] = 1
69
+ return mask
70
+
71
+
72
+ def gen_round_mask(mask, masked, radius):
73
+ x_0, y_0, w, h = masked
74
+ xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
75
+
76
+ mask = mask > 0
77
+ mask = (255 * mask).astype(np.uint8)
78
+ mask = Image.fromarray(mask)
79
+ draw = ImageDraw.Draw(mask)
80
+ draw.rounded_rectangle(xy, radius=radius, fill=255)
81
+ mask = np.array(mask) / 255
82
+ return mask
83
+
84
+
85
+ def gen_large_mask(prng, img_h, img_w,
86
+ marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
87
+ min_n_box, max_n_box, min_s_box, max_s_box):
88
+ """
89
+ img_h: int, an image height
90
+ img_w: int, an image width
91
+ marg: int, a margin for a box starting coordinate
92
+ p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
93
+
94
+ min_n_irr: int, min number of segments
95
+ max_n_irr: int, max number of segments
96
+ max_l_irr: max length of a segment in polygonal chain
97
+ max_w_irr: max width of a segment in polygonal chain
98
+
99
+ min_n_box: int, min bound for the number of box primitives
100
+ max_n_box: int, max bound for the number of box primitives
101
+ min_s_box: int, min length of a box side
102
+ max_s_box: int, max length of a box side
103
+ """
104
+
105
+ mask = np.zeros((img_h, img_w))
106
+ uniform = prng.randint
107
+
108
+ if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
109
+ n = uniform(min_n_irr, max_n_irr) # sample number of segments
110
+
111
+ for _ in range(n):
112
+ y = uniform(0, img_h) # sample a starting point
113
+ x = uniform(0, img_w)
114
+
115
+ a = uniform(0, 360) # sample angle
116
+ l = uniform(10, max_l_irr) # sample segment length
117
+ w = uniform(5, max_w_irr) # sample a segment width
118
+
119
+ # draw segment starting from (x,y) to (x_,y_) using brush of width w
120
+ x_ = x + l * np.sin(a)
121
+ y_ = y + l * np.cos(a)
122
+
123
+ mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
124
+ x, y = x_, y_
125
+ else: # generate Box masks
126
+ n = uniform(min_n_box, max_n_box) # sample number of rectangles
127
+
128
+ for _ in range(n):
129
+ h = uniform(min_s_box, max_s_box) # sample box shape
130
+ w = uniform(min_s_box, max_s_box)
131
+
132
+ x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
133
+ y_0 = uniform(marg, img_h - marg - h)
134
+
135
+ if np.random.uniform(0, 1) < 0.5:
136
+ mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
137
+ else:
138
+ r = uniform(0, 60) # sample radius
139
+ mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
140
+ return mask
141
+
142
+
143
+ make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
144
+ make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
145
+ make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
146
+ make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
147
+
148
+
149
+ MASK_MODES = {
150
+ "256train": make_lama_mask,
151
+ "256narrow": make_narrow_lama_mask,
152
+ "512train": make_512_lama_mask,
153
+ "512train-large": make_512_lama_mask_large
154
+ }
155
+
156
+ if __name__ == "__main__":
157
+ import sys
158
+
159
+ out = sys.argv[1]
160
+
161
+ prng = np.random.RandomState(1)
162
+ kwargs = settings["256train"]
163
+ mask = gen_large_mask(prng, 256, 256, **kwargs)
164
+ mask = (255 * mask).astype(np.uint8)
165
+ mask = Image.fromarray(mask)
166
+ mask.save(out)
One-2-3-45-master 2/ldm/data/laion.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import kornia
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+ import torchvision
7
+ from PIL import Image
8
+ import glob
9
+ import random
10
+ import numpy as np
11
+ import pytorch_lightning as pl
12
+ from tqdm import tqdm
13
+ from omegaconf import OmegaConf
14
+ from einops import rearrange
15
+ import torch
16
+ from webdataset.handlers import warn_and_continue
17
+
18
+
19
+ from ldm.util import instantiate_from_config
20
+ from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
21
+ from ldm.data.base import PRNGMixin
22
+
23
+
24
+ class DataWithWings(torch.utils.data.IterableDataset):
25
+ def __init__(self, min_size, transform=None, target_transform=None):
26
+ self.min_size = min_size
27
+ self.transform = transform if transform is not None else nn.Identity()
28
+ self.target_transform = target_transform if target_transform is not None else nn.Identity()
29
+ self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
30
+ self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
31
+ self.pwatermark_threshold = 0.8
32
+ self.punsafe_threshold = 0.5
33
+ self.aesthetic_threshold = 5.
34
+ self.total_samples = 0
35
+ self.samples = 0
36
+ location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
37
+
38
+ self.inner_dataset = wds.DataPipeline(
39
+ wds.ResampledShards(location),
40
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
41
+ wds.shuffle(1000, handler=wds.warn_and_continue),
42
+ wds.decode('pilrgb', handler=wds.warn_and_continue),
43
+ wds.map(self._add_tags, handler=wds.ignore_and_continue),
44
+ wds.select(self._filter_predicate),
45
+ wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
46
+ wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
47
+ )
48
+
49
+ @staticmethod
50
+ def _compute_hash(url, text):
51
+ if url is None:
52
+ url = ''
53
+ if text is None:
54
+ text = ''
55
+ total = (url + text).encode('utf-8')
56
+ return mmh3.hash64(total)[0]
57
+
58
+ def _add_tags(self, x):
59
+ hsh = self._compute_hash(x['json']['url'], x['txt'])
60
+ pwatermark, punsafe = self.kv[hsh]
61
+ aesthetic = self.kv_aesthetic[hsh][0]
62
+ return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
63
+
64
+ def _punsafe_to_class(self, punsafe):
65
+ return torch.tensor(punsafe >= self.punsafe_threshold).long()
66
+
67
+ def _filter_predicate(self, x):
68
+ try:
69
+ return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
70
+ except:
71
+ return False
72
+
73
+ def __iter__(self):
74
+ return iter(self.inner_dataset)
75
+
76
+
77
+ def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
78
+ """Take a list of samples (as dictionary) and create a batch, preserving the keys.
79
+ If `tensors` is True, `ndarray` objects are combined into
80
+ tensor batches.
81
+ :param dict samples: list of samples
82
+ :param bool tensors: whether to turn lists of ndarrays into a single ndarray
83
+ :returns: single sample consisting of a batch
84
+ :rtype: dict
85
+ """
86
+ keys = set.intersection(*[set(sample.keys()) for sample in samples])
87
+ batched = {key: [] for key in keys}
88
+
89
+ for s in samples:
90
+ [batched[key].append(s[key]) for key in batched]
91
+
92
+ result = {}
93
+ for key in batched:
94
+ if isinstance(batched[key][0], (int, float)):
95
+ if combine_scalars:
96
+ result[key] = np.array(list(batched[key]))
97
+ elif isinstance(batched[key][0], torch.Tensor):
98
+ if combine_tensors:
99
+ result[key] = torch.stack(list(batched[key]))
100
+ elif isinstance(batched[key][0], np.ndarray):
101
+ if combine_tensors:
102
+ result[key] = np.array(list(batched[key]))
103
+ else:
104
+ result[key] = list(batched[key])
105
+ return result
106
+
107
+
108
+ class WebDataModuleFromConfig(pl.LightningDataModule):
109
+ def __init__(self, tar_base, batch_size, train=None, validation=None,
110
+ test=None, num_workers=4, multinode=True, min_size=None,
111
+ max_pwatermark=1.0,
112
+ **kwargs):
113
+ super().__init__(self)
114
+ print(f'Setting tar base to {tar_base}')
115
+ self.tar_base = tar_base
116
+ self.batch_size = batch_size
117
+ self.num_workers = num_workers
118
+ self.train = train
119
+ self.validation = validation
120
+ self.test = test
121
+ self.multinode = multinode
122
+ self.min_size = min_size # filter out very small images
123
+ self.max_pwatermark = max_pwatermark # filter out watermarked images
124
+
125
+ def make_loader(self, dataset_config, train=True):
126
+ if 'image_transforms' in dataset_config:
127
+ image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
128
+ else:
129
+ image_transforms = []
130
+
131
+ image_transforms.extend([torchvision.transforms.ToTensor(),
132
+ torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
133
+ image_transforms = torchvision.transforms.Compose(image_transforms)
134
+
135
+ if 'transforms' in dataset_config:
136
+ transforms_config = OmegaConf.to_container(dataset_config.transforms)
137
+ else:
138
+ transforms_config = dict()
139
+
140
+ transform_dict = {dkey: load_partial_from_config(transforms_config[dkey])
141
+ if transforms_config[dkey] != 'identity' else identity
142
+ for dkey in transforms_config}
143
+ img_key = dataset_config.get('image_key', 'jpeg')
144
+ transform_dict.update({img_key: image_transforms})
145
+
146
+ if 'postprocess' in dataset_config:
147
+ postprocess = instantiate_from_config(dataset_config['postprocess'])
148
+ else:
149
+ postprocess = None
150
+
151
+ shuffle = dataset_config.get('shuffle', 0)
152
+ shardshuffle = shuffle > 0
153
+
154
+ nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
155
+
156
+ if self.tar_base == "__improvedaesthetic__":
157
+ print("## Warning, loading the same improved aesthetic dataset "
158
+ "for all splits and ignoring shards parameter.")
159
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
160
+ else:
161
+ tars = os.path.join(self.tar_base, dataset_config.shards)
162
+
163
+ dset = wds.WebDataset(
164
+ tars,
165
+ nodesplitter=nodesplitter,
166
+ shardshuffle=shardshuffle,
167
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
168
+ print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
169
+
170
+ dset = (dset
171
+ .select(self.filter_keys)
172
+ .decode('pil', handler=wds.warn_and_continue)
173
+ .select(self.filter_size)
174
+ .map_dict(**transform_dict, handler=wds.warn_and_continue)
175
+ )
176
+ if postprocess is not None:
177
+ dset = dset.map(postprocess)
178
+ dset = (dset
179
+ .batched(self.batch_size, partial=False,
180
+ collation_fn=dict_collation_fn)
181
+ )
182
+
183
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
184
+ num_workers=self.num_workers)
185
+
186
+ return loader
187
+
188
+ def filter_size(self, x):
189
+ try:
190
+ valid = True
191
+ if self.min_size is not None and self.min_size > 1:
192
+ try:
193
+ valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
194
+ except Exception:
195
+ valid = False
196
+ if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
197
+ try:
198
+ valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
199
+ except Exception:
200
+ valid = False
201
+ return valid
202
+ except Exception:
203
+ return False
204
+
205
+ def filter_keys(self, x):
206
+ try:
207
+ return ("jpg" in x) and ("txt" in x)
208
+ except Exception:
209
+ return False
210
+
211
+ def train_dataloader(self):
212
+ return self.make_loader(self.train)
213
+
214
+ def val_dataloader(self):
215
+ return self.make_loader(self.validation, train=False)
216
+
217
+ def test_dataloader(self):
218
+ return self.make_loader(self.test, train=False)
219
+
220
+
221
+ from ldm.modules.image_degradation import degradation_fn_bsr_light
222
+ import cv2
223
+
224
+ class AddLR(object):
225
+ def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
226
+ self.factor = factor
227
+ self.output_size = output_size
228
+ self.image_key = image_key
229
+ self.initial_size = initial_size
230
+
231
+ def pt2np(self, x):
232
+ x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
233
+ return x
234
+
235
+ def np2pt(self, x):
236
+ x = torch.from_numpy(x)/127.5-1.0
237
+ return x
238
+
239
+ def __call__(self, sample):
240
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
241
+ x = self.pt2np(sample[self.image_key])
242
+ if self.initial_size is not None:
243
+ x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
244
+ x = degradation_fn_bsr_light(x, sf=self.factor)['image']
245
+ x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
246
+ x = self.np2pt(x)
247
+ sample['lr'] = x
248
+ return sample
249
+
250
+ class AddBW(object):
251
+ def __init__(self, image_key="jpg"):
252
+ self.image_key = image_key
253
+
254
+ def pt2np(self, x):
255
+ x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
256
+ return x
257
+
258
+ def np2pt(self, x):
259
+ x = torch.from_numpy(x)/127.5-1.0
260
+ return x
261
+
262
+ def __call__(self, sample):
263
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
264
+ x = sample[self.image_key]
265
+ w = torch.rand(3, device=x.device)
266
+ w /= w.sum()
267
+ out = torch.einsum('hwc,c->hw', x, w)
268
+
269
+ # Keep as 3ch so we can pass to encoder, also we might want to add hints
270
+ sample['lr'] = out.unsqueeze(-1).tile(1,1,3)
271
+ return sample
272
+
273
+ class AddMask(PRNGMixin):
274
+ def __init__(self, mode="512train", p_drop=0.):
275
+ super().__init__()
276
+ assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
277
+ self.make_mask = MASK_MODES[mode]
278
+ self.p_drop = p_drop
279
+
280
+ def __call__(self, sample):
281
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
282
+ x = sample['jpg']
283
+ mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
284
+ if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
285
+ mask = np.ones_like(mask)
286
+ mask[mask < 0.5] = 0
287
+ mask[mask > 0.5] = 1
288
+ mask = torch.from_numpy(mask[..., None])
289
+ sample['mask'] = mask
290
+ sample['masked_image'] = x * (mask < 0.5)
291
+ return sample
292
+
293
+
294
+ class AddEdge(PRNGMixin):
295
+ def __init__(self, mode="512train", mask_edges=True):
296
+ super().__init__()
297
+ assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
298
+ self.make_mask = MASK_MODES[mode]
299
+ self.n_down_choices = [0]
300
+ self.sigma_choices = [1, 2]
301
+ self.mask_edges = mask_edges
302
+
303
+ @torch.no_grad()
304
+ def __call__(self, sample):
305
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
306
+ x = sample['jpg']
307
+
308
+ mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
309
+ mask[mask < 0.5] = 0
310
+ mask[mask > 0.5] = 1
311
+ mask = torch.from_numpy(mask[..., None])
312
+ sample['mask'] = mask
313
+
314
+ n_down_idx = self.prng.choice(len(self.n_down_choices))
315
+ sigma_idx = self.prng.choice(len(self.sigma_choices))
316
+
317
+ n_choices = len(self.n_down_choices)*len(self.sigma_choices)
318
+ raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx),
319
+ (len(self.n_down_choices), len(self.sigma_choices)))
320
+ normalized_idx = raveled_idx/max(1, n_choices-1)
321
+
322
+ n_down = self.n_down_choices[n_down_idx]
323
+ sigma = self.sigma_choices[sigma_idx]
324
+
325
+ kernel_size = 4*sigma+1
326
+ kernel_size = (kernel_size, kernel_size)
327
+ sigma = (sigma, sigma)
328
+ canny = kornia.filters.Canny(
329
+ low_threshold=0.1,
330
+ high_threshold=0.2,
331
+ kernel_size=kernel_size,
332
+ sigma=sigma,
333
+ hysteresis=True,
334
+ )
335
+ y = (x+1.0)/2.0 # in 01
336
+ y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
337
+
338
+ # down
339
+ for i_down in range(n_down):
340
+ size = min(y.shape[-2], y.shape[-1])//2
341
+ y = kornia.geometry.transform.resize(y, size, antialias=True)
342
+
343
+ # edge
344
+ _, y = canny(y)
345
+
346
+ if n_down > 0:
347
+ size = x.shape[0], x.shape[1]
348
+ y = kornia.geometry.transform.resize(y, size, interpolation="nearest")
349
+
350
+ y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
351
+ y = y*2.0-1.0
352
+
353
+ if self.mask_edges:
354
+ sample['masked_image'] = y * (mask < 0.5)
355
+ else:
356
+ sample['masked_image'] = y
357
+ sample['mask'] = torch.zeros_like(sample['mask'])
358
+
359
+ # concat normalized idx
360
+ sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx
361
+
362
+ return sample
363
+
364
+
365
+ def example00():
366
+ url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
367
+ dataset = wds.WebDataset(url)
368
+ example = next(iter(dataset))
369
+ for k in example:
370
+ print(k, type(example[k]))
371
+
372
+ print(example["__key__"])
373
+ for k in ["json", "txt"]:
374
+ print(example[k].decode())
375
+
376
+ image = Image.open(io.BytesIO(example["jpg"]))
377
+ outdir = "tmp"
378
+ os.makedirs(outdir, exist_ok=True)
379
+ image.save(os.path.join(outdir, example["__key__"] + ".png"))
380
+
381
+
382
+ def load_example(example):
383
+ return {
384
+ "key": example["__key__"],
385
+ "image": Image.open(io.BytesIO(example["jpg"])),
386
+ "text": example["txt"].decode(),
387
+ }
388
+
389
+
390
+ for i, example in tqdm(enumerate(dataset)):
391
+ ex = load_example(example)
392
+ print(ex["image"].size, ex["text"])
393
+ if i >= 100:
394
+ break
395
+
396
+
397
+ def example01():
398
+ # the first laion shards contain ~10k examples each
399
+ url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"
400
+
401
+ batch_size = 3
402
+ shuffle_buffer = 10000
403
+ dset = wds.WebDataset(
404
+ url,
405
+ nodesplitter=wds.shardlists.split_by_node,
406
+ shardshuffle=True,
407
+ )
408
+ dset = (dset
409
+ .shuffle(shuffle_buffer, initial=shuffle_buffer)
410
+ .decode('pil', handler=warn_and_continue)
411
+ .batched(batch_size, partial=False,
412
+ collation_fn=dict_collation_fn)
413
+ )
414
+
415
+ num_workers = 2
416
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers)
417
+
418
+ batch_sizes = list()
419
+ keys_per_epoch = list()
420
+ for epoch in range(5):
421
+ keys = list()
422
+ for batch in tqdm(loader):
423
+ batch_sizes.append(len(batch["__key__"]))
424
+ keys.append(batch["__key__"])
425
+
426
+ for bs in batch_sizes:
427
+ assert bs==batch_size
428
+ print(f"{len(batch_sizes)} batches of size {batch_size}.")
429
+ batch_sizes = list()
430
+
431
+ keys_per_epoch.append(keys)
432
+ for i_batch in [0, 1, -1]:
433
+ print(f"Batch {i_batch} of epoch {epoch}:")
434
+ print(keys[i_batch])
435
+ print("next epoch.")
436
+
437
+
438
+ def example02():
439
+ from omegaconf import OmegaConf
440
+ from torch.utils.data.distributed import DistributedSampler
441
+ from torch.utils.data import IterableDataset
442
+ from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
443
+ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
444
+
445
+ #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
446
+ #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
447
+ config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
448
+ datamod = WebDataModuleFromConfig(**config["data"]["params"])
449
+ dataloader = datamod.train_dataloader()
450
+
451
+ for batch in dataloader:
452
+ print(batch.keys())
453
+ print(batch["jpg"].shape)
454
+ break
455
+
456
+
457
+ def example03():
458
+ # improved aesthetics
459
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
460
+ dataset = wds.WebDataset(tars)
461
+
462
+ def filter_keys(x):
463
+ try:
464
+ return ("jpg" in x) and ("txt" in x)
465
+ except Exception:
466
+ return False
467
+
468
+ def filter_size(x):
469
+ try:
470
+ return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
471
+ except Exception:
472
+ return False
473
+
474
+ def filter_watermark(x):
475
+ try:
476
+ return x['json']['pwatermark'] < 0.5
477
+ except Exception:
478
+ return False
479
+
480
+ dataset = (dataset
481
+ .select(filter_keys)
482
+ .decode('pil', handler=wds.warn_and_continue))
483
+ n_save = 20
484
+ n_total = 0
485
+ n_large = 0
486
+ n_large_nowm = 0
487
+ for i, example in enumerate(dataset):
488
+ n_total += 1
489
+ if filter_size(example):
490
+ n_large += 1
491
+ if filter_watermark(example):
492
+ n_large_nowm += 1
493
+ if n_large_nowm < n_save+1:
494
+ image = example["jpg"]
495
+ image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))
496
+
497
+ if i%500 == 0:
498
+ print(i)
499
+ print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
500
+ if n_large > 0:
501
+ print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%")
502
+
503
+
504
+
505
+ def example04():
506
+ # improved aesthetics
507
+ for i_shard in range(60208)[::-1]:
508
+ print(i_shard)
509
+ tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
510
+ dataset = wds.WebDataset(tars)
511
+
512
+ def filter_keys(x):
513
+ try:
514
+ return ("jpg" in x) and ("txt" in x)
515
+ except Exception:
516
+ return False
517
+
518
+ def filter_size(x):
519
+ try:
520
+ return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
521
+ except Exception:
522
+ return False
523
+
524
+ dataset = (dataset
525
+ .select(filter_keys)
526
+ .decode('pil', handler=wds.warn_and_continue))
527
+ try:
528
+ example = next(iter(dataset))
529
+ except Exception:
530
+ print(f"Error @ {i_shard}")
531
+
532
+
533
+ if __name__ == "__main__":
534
+ #example01()
535
+ #example02()
536
+ example03()
537
+ #example04()
One-2-3-45-master 2/ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
One-2-3-45-master 2/ldm/data/nerf_like.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import torch
6
+ import imageio
7
+ import math
8
+ import cv2
9
+ from torchvision import transforms
10
+
11
+ def cartesian_to_spherical(xyz):
12
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
13
+ xy = xyz[:,0]**2 + xyz[:,1]**2
14
+ z = np.sqrt(xy + xyz[:,2]**2)
15
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
16
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
17
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
18
+ return np.array([theta, azimuth, z])
19
+
20
+
21
+ def get_T(T_target, T_cond):
22
+ theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
23
+ theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
24
+
25
+ d_theta = theta_target - theta_cond
26
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
27
+ d_z = z_target - z_cond
28
+
29
+ d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
30
+ return d_T
31
+
32
+ def get_spherical(T_target, T_cond):
33
+ theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
34
+ theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
35
+
36
+ d_theta = theta_target - theta_cond
37
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
38
+ d_z = z_target - z_cond
39
+
40
+ d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()])
41
+ return d_T
42
+
43
+ class RTMV(Dataset):
44
+ def __init__(self, root_dir='datasets/RTMV/google_scanned',\
45
+ first_K=64, resolution=256, load_target=False):
46
+ self.root_dir = root_dir
47
+ self.scene_list = sorted(next(os.walk(root_dir))[1])
48
+ self.resolution = resolution
49
+ self.first_K = first_K
50
+ self.load_target = load_target
51
+
52
+ def __len__(self):
53
+ return len(self.scene_list)
54
+
55
+ def __getitem__(self, idx):
56
+ scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
57
+ with open(os.path.join(scene_dir, 'transforms.json'), "r") as f:
58
+ meta = json.load(f)
59
+ imgs = []
60
+ poses = []
61
+ for i_img in range(self.first_K):
62
+ meta_img = meta['frames'][i_img]
63
+
64
+ if i_img == 0 or self.load_target:
65
+ img_path = os.path.join(scene_dir, meta_img['file_path'])
66
+ img = imageio.imread(img_path)
67
+ img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
68
+ imgs.append(img)
69
+
70
+ c2w = meta_img['transform_matrix']
71
+ poses.append(c2w)
72
+
73
+ imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
74
+ imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
75
+ imgs = imgs * 2 - 1. # convert to stable diffusion range
76
+ poses = torch.tensor(np.array(poses).astype(np.float32))
77
+ return imgs, poses
78
+
79
+ def blend_rgba(self, img):
80
+ img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
81
+ return img
82
+
83
+
84
+ class GSO(Dataset):
85
+ def __init__(self, root_dir='datasets/GoogleScannedObjects',\
86
+ split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'):
87
+ self.root_dir = root_dir
88
+ with open(os.path.join(root_dir, '%s.json' % split), "r") as f:
89
+ self.scene_list = json.load(f)
90
+ self.resolution = resolution
91
+ self.first_K = first_K
92
+ self.load_target = load_target
93
+ self.name = name
94
+
95
+ def __len__(self):
96
+ return len(self.scene_list)
97
+
98
+ def __getitem__(self, idx):
99
+ scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
100
+ with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f:
101
+ meta = json.load(f)
102
+ imgs = []
103
+ poses = []
104
+ for i_img in range(self.first_K):
105
+ meta_img = meta['frames'][i_img]
106
+
107
+ if i_img == 0 or self.load_target:
108
+ img_path = os.path.join(scene_dir, meta_img['file_path'])
109
+ img = imageio.imread(img_path)
110
+ img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
111
+ imgs.append(img)
112
+
113
+ c2w = meta_img['transform_matrix']
114
+ poses.append(c2w)
115
+
116
+ imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
117
+ mask = imgs[:, :, :, -1]
118
+ imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
119
+ imgs = imgs * 2 - 1. # convert to stable diffusion range
120
+ poses = torch.tensor(np.array(poses).astype(np.float32))
121
+ return imgs, poses
122
+
123
+ def blend_rgba(self, img):
124
+ img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
125
+ return img
126
+
127
+ class WILD(Dataset):
128
+ def __init__(self, root_dir='data/nerf_wild',\
129
+ first_K=33, resolution=256, load_target=False):
130
+ self.root_dir = root_dir
131
+ self.scene_list = sorted(next(os.walk(root_dir))[1])
132
+ self.resolution = resolution
133
+ self.first_K = first_K
134
+ self.load_target = load_target
135
+
136
+ def __len__(self):
137
+ return len(self.scene_list)
138
+
139
+ def __getitem__(self, idx):
140
+ scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
141
+ with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f:
142
+ meta = json.load(f)
143
+ imgs = []
144
+ poses = []
145
+ for i_img in range(self.first_K):
146
+ meta_img = meta['frames'][i_img]
147
+
148
+ if i_img == 0 or self.load_target:
149
+ img_path = os.path.join(scene_dir, meta_img['file_path'])
150
+ img = imageio.imread(img_path + '.png')
151
+ img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
152
+ imgs.append(img)
153
+
154
+ c2w = meta_img['transform_matrix']
155
+ poses.append(c2w)
156
+
157
+ imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
158
+ imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
159
+ imgs = imgs * 2 - 1. # convert to stable diffusion range
160
+ poses = torch.tensor(np.array(poses).astype(np.float32))
161
+ return imgs, poses
162
+
163
+ def blend_rgba(self, img):
164
+ img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
165
+ return img
One-2-3-45-master 2/ldm/data/simple.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import webdataset as wds
3
+ import numpy as np
4
+ from omegaconf import DictConfig, ListConfig
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from pathlib import Path
8
+ import json
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+ import torchvision
12
+ from einops import rearrange
13
+ from ldm.util import instantiate_from_config
14
+ from datasets import load_dataset
15
+ import pytorch_lightning as pl
16
+ import copy
17
+ import csv
18
+ import cv2
19
+ import random
20
+ import matplotlib.pyplot as plt
21
+ from torch.utils.data import DataLoader
22
+ import json
23
+ import os, sys
24
+ import webdataset as wds
25
+ import math
26
+ from torch.utils.data.distributed import DistributedSampler
27
+
28
+ # Some hacky things to make experimentation easier
29
+ def make_transform_multi_folder_data(paths, caption_files=None, **kwargs):
30
+ ds = make_multi_folder_data(paths, caption_files, **kwargs)
31
+ return TransformDataset(ds)
32
+
33
+ def make_nfp_data(base_path):
34
+ dirs = list(Path(base_path).glob("*/"))
35
+ print(f"Found {len(dirs)} folders")
36
+ print(dirs)
37
+ tforms = [transforms.Resize(512), transforms.CenterCrop(512)]
38
+ datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs]
39
+ return torch.utils.data.ConcatDataset(datasets)
40
+
41
+
42
+ class VideoDataset(Dataset):
43
+ def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2):
44
+ self.root_dir = Path(root_dir)
45
+ self.caption_file = caption_file
46
+ self.n = n
47
+ ext = "mp4"
48
+ self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
49
+ self.offset = offset
50
+
51
+ if isinstance(image_transforms, ListConfig):
52
+ image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
53
+ image_transforms.extend([transforms.ToTensor(),
54
+ transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
55
+ image_transforms = transforms.Compose(image_transforms)
56
+ self.tform = image_transforms
57
+ with open(self.caption_file) as f:
58
+ reader = csv.reader(f)
59
+ rows = [row for row in reader]
60
+ self.captions = dict(rows)
61
+
62
+ def __len__(self):
63
+ return len(self.paths)
64
+
65
+ def __getitem__(self, index):
66
+ for i in range(10):
67
+ try:
68
+ return self._load_sample(index)
69
+ except Exception:
70
+ # Not really good enough but...
71
+ print("uh oh")
72
+
73
+ def _load_sample(self, index):
74
+ n = self.n
75
+ filename = self.paths[index]
76
+ min_frame = 2*self.offset + 2
77
+ vid = cv2.VideoCapture(str(filename))
78
+ max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
79
+ curr_frame_n = random.randint(min_frame, max_frames)
80
+ vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n)
81
+ _, curr_frame = vid.read()
82
+
83
+ prev_frames = []
84
+ for i in range(n):
85
+ prev_frame_n = curr_frame_n - (i+1)*self.offset
86
+ vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n)
87
+ _, prev_frame = vid.read()
88
+ prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1]))
89
+ prev_frames.append(prev_frame)
90
+
91
+ vid.release()
92
+ caption = self.captions[filename.name]
93
+ data = {
94
+ "image": self.tform(Image.fromarray(curr_frame[...,::-1])),
95
+ "prev": torch.cat(prev_frames, dim=-1),
96
+ "txt": caption
97
+ }
98
+ return data
99
+
100
+ # end hacky things
101
+
102
+
103
+ def make_tranforms(image_transforms):
104
+ # if isinstance(image_transforms, ListConfig):
105
+ # image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
106
+ image_transforms = []
107
+ image_transforms.extend([transforms.ToTensor(),
108
+ transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
109
+ image_transforms = transforms.Compose(image_transforms)
110
+ return image_transforms
111
+
112
+
113
+ def make_multi_folder_data(paths, caption_files=None, **kwargs):
114
+ """Make a concat dataset from multiple folders
115
+ Don't suport captions yet
116
+
117
+ If paths is a list, that's ok, if it's a Dict interpret it as:
118
+ k=folder v=n_times to repeat that
119
+ """
120
+ list_of_paths = []
121
+ if isinstance(paths, (Dict, DictConfig)):
122
+ assert caption_files is None, \
123
+ "Caption files not yet supported for repeats"
124
+ for folder_path, repeats in paths.items():
125
+ list_of_paths.extend([folder_path]*repeats)
126
+ paths = list_of_paths
127
+
128
+ if caption_files is not None:
129
+ datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
130
+ else:
131
+ datasets = [FolderData(p, **kwargs) for p in paths]
132
+ return torch.utils.data.ConcatDataset(datasets)
133
+
134
+
135
+
136
+ class NfpDataset(Dataset):
137
+ def __init__(self,
138
+ root_dir,
139
+ image_transforms=[],
140
+ ext="jpg",
141
+ default_caption="",
142
+ ) -> None:
143
+ """assume sequential frames and a deterministic transform"""
144
+
145
+ self.root_dir = Path(root_dir)
146
+ self.default_caption = default_caption
147
+
148
+ self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
149
+ self.tform = make_tranforms(image_transforms)
150
+
151
+ def __len__(self):
152
+ return len(self.paths) - 1
153
+
154
+
155
+ def __getitem__(self, index):
156
+ prev = self.paths[index]
157
+ curr = self.paths[index+1]
158
+ data = {}
159
+ data["image"] = self._load_im(curr)
160
+ data["prev"] = self._load_im(prev)
161
+ data["txt"] = self.default_caption
162
+ return data
163
+
164
+ def _load_im(self, filename):
165
+ im = Image.open(filename).convert("RGB")
166
+ return self.tform(im)
167
+
168
+ class ObjaverseDataModuleFromConfig(pl.LightningDataModule):
169
+ def __init__(self, root_dir, batch_size, total_view, train=None, validation=None,
170
+ test=None, num_workers=4, **kwargs):
171
+ super().__init__(self)
172
+ self.root_dir = root_dir
173
+ self.batch_size = batch_size
174
+ self.num_workers = num_workers
175
+ self.total_view = total_view
176
+
177
+ if train is not None:
178
+ dataset_config = train
179
+ if validation is not None:
180
+ dataset_config = validation
181
+
182
+ if 'image_transforms' in dataset_config:
183
+ image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)]
184
+ else:
185
+ image_transforms = []
186
+ image_transforms.extend([transforms.ToTensor(),
187
+ transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
188
+ self.image_transforms = torchvision.transforms.Compose(image_transforms)
189
+
190
+
191
+ def train_dataloader(self):
192
+ dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \
193
+ image_transforms=self.image_transforms)
194
+ sampler = DistributedSampler(dataset)
195
+ return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
196
+
197
+ def val_dataloader(self):
198
+ dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \
199
+ image_transforms=self.image_transforms)
200
+ sampler = DistributedSampler(dataset)
201
+ return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
202
+
203
+ def test_dataloader(self):
204
+ return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\
205
+ batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
206
+
207
+
208
+ class ObjaverseData(Dataset):
209
+ def __init__(self,
210
+ root_dir='.objaverse/hf-objaverse-v1/views',
211
+ image_transforms=[],
212
+ ext="png",
213
+ default_trans=torch.zeros(3),
214
+ postprocess=None,
215
+ return_paths=False,
216
+ total_view=4,
217
+ validation=False
218
+ ) -> None:
219
+ """Create a dataset from a folder of images.
220
+ If you pass in a root directory it will be searched for images
221
+ ending in ext (ext can be a list)
222
+ """
223
+ self.root_dir = Path(root_dir)
224
+ self.default_trans = default_trans
225
+ self.return_paths = return_paths
226
+ if isinstance(postprocess, DictConfig):
227
+ postprocess = instantiate_from_config(postprocess)
228
+ self.postprocess = postprocess
229
+ self.total_view = total_view
230
+
231
+ if not isinstance(ext, (tuple, list, ListConfig)):
232
+ ext = [ext]
233
+
234
+ with open(os.path.join(root_dir, 'valid_paths.json')) as f:
235
+ self.paths = json.load(f)
236
+
237
+ total_objects = len(self.paths)
238
+ if validation:
239
+ self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
240
+ else:
241
+ self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
242
+ print('============= length of dataset %d =============' % len(self.paths))
243
+ self.tform = image_transforms
244
+
245
+ def __len__(self):
246
+ return len(self.paths)
247
+
248
+ def cartesian_to_spherical(self, xyz):
249
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
250
+ xy = xyz[:,0]**2 + xyz[:,1]**2
251
+ z = np.sqrt(xy + xyz[:,2]**2)
252
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
253
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
254
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
255
+ return np.array([theta, azimuth, z])
256
+
257
+ def get_T(self, target_RT, cond_RT):
258
+ R, T = target_RT[:3, :3], target_RT[:, -1]
259
+ T_target = -R.T @ T
260
+
261
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
262
+ T_cond = -R.T @ T
263
+
264
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
265
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
266
+
267
+ d_theta = theta_target - theta_cond
268
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
269
+ d_z = z_target - z_cond
270
+
271
+ d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
272
+ return d_T
273
+
274
+ def load_im(self, path, color):
275
+ '''
276
+ replace background pixel with random color in rendering
277
+ '''
278
+ try:
279
+ img = plt.imread(path)
280
+ except:
281
+ print(path)
282
+ sys.exit()
283
+ img[img[:, :, -1] == 0.] = color
284
+ img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
285
+ return img
286
+
287
+ def __getitem__(self, index):
288
+
289
+ data = {}
290
+ if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice
291
+ total_view = 8
292
+ else:
293
+ total_view = 4
294
+ index_target, index_cond = random.sample(range(total_view), 2) # without replacement
295
+ filename = os.path.join(self.root_dir, self.paths[index])
296
+
297
+ # print(self.paths[index])
298
+
299
+ if self.return_paths:
300
+ data["path"] = str(filename)
301
+
302
+ color = [1., 1., 1., 1.]
303
+
304
+ try:
305
+ target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
306
+ cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
307
+ target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
308
+ cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
309
+ except:
310
+ # very hacky solution, sorry about this
311
+ filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid
312
+ target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
313
+ cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
314
+ target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
315
+ cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
316
+ target_im = torch.zeros_like(target_im)
317
+ cond_im = torch.zeros_like(cond_im)
318
+
319
+ data["image_target"] = target_im
320
+ data["image_cond"] = cond_im
321
+ data["T"] = self.get_T(target_RT, cond_RT)
322
+
323
+ if self.postprocess is not None:
324
+ data = self.postprocess(data)
325
+
326
+ return data
327
+
328
+ def process_im(self, im):
329
+ im = im.convert("RGB")
330
+ return self.tform(im)
331
+
332
+ class FolderData(Dataset):
333
+ def __init__(self,
334
+ root_dir,
335
+ caption_file=None,
336
+ image_transforms=[],
337
+ ext="jpg",
338
+ default_caption="",
339
+ postprocess=None,
340
+ return_paths=False,
341
+ ) -> None:
342
+ """Create a dataset from a folder of images.
343
+ If you pass in a root directory it will be searched for images
344
+ ending in ext (ext can be a list)
345
+ """
346
+ self.root_dir = Path(root_dir)
347
+ self.default_caption = default_caption
348
+ self.return_paths = return_paths
349
+ if isinstance(postprocess, DictConfig):
350
+ postprocess = instantiate_from_config(postprocess)
351
+ self.postprocess = postprocess
352
+ if caption_file is not None:
353
+ with open(caption_file, "rt") as f:
354
+ ext = Path(caption_file).suffix.lower()
355
+ if ext == ".json":
356
+ captions = json.load(f)
357
+ elif ext == ".jsonl":
358
+ lines = f.readlines()
359
+ lines = [json.loads(x) for x in lines]
360
+ captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
361
+ else:
362
+ raise ValueError(f"Unrecognised format: {ext}")
363
+ self.captions = captions
364
+ else:
365
+ self.captions = None
366
+
367
+ if not isinstance(ext, (tuple, list, ListConfig)):
368
+ ext = [ext]
369
+
370
+ # Only used if there is no caption file
371
+ self.paths = []
372
+ for e in ext:
373
+ self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}"))))
374
+ self.tform = make_tranforms(image_transforms)
375
+
376
+ def __len__(self):
377
+ if self.captions is not None:
378
+ return len(self.captions.keys())
379
+ else:
380
+ return len(self.paths)
381
+
382
+ def __getitem__(self, index):
383
+ data = {}
384
+ if self.captions is not None:
385
+ chosen = list(self.captions.keys())[index]
386
+ caption = self.captions.get(chosen, None)
387
+ if caption is None:
388
+ caption = self.default_caption
389
+ filename = self.root_dir/chosen
390
+ else:
391
+ filename = self.paths[index]
392
+
393
+ if self.return_paths:
394
+ data["path"] = str(filename)
395
+
396
+ im = Image.open(filename).convert("RGB")
397
+ im = self.process_im(im)
398
+ data["image"] = im
399
+
400
+ if self.captions is not None:
401
+ data["txt"] = caption
402
+ else:
403
+ data["txt"] = self.default_caption
404
+
405
+ if self.postprocess is not None:
406
+ data = self.postprocess(data)
407
+
408
+ return data
409
+
410
+ def process_im(self, im):
411
+ im = im.convert("RGB")
412
+ return self.tform(im)
413
+ import random
414
+
415
+ class TransformDataset():
416
+ def __init__(self, ds, extra_label="sksbspic"):
417
+ self.ds = ds
418
+ self.extra_label = extra_label
419
+ self.transforms = {
420
+ "align": transforms.Resize(768),
421
+ "centerzoom": transforms.CenterCrop(768),
422
+ "randzoom": transforms.RandomCrop(768),
423
+ }
424
+
425
+
426
+ def __getitem__(self, index):
427
+ data = self.ds[index]
428
+
429
+ im = data['image']
430
+ im = im.permute(2,0,1)
431
+ # In case data is smaller than expected
432
+ im = transforms.Resize(1024)(im)
433
+
434
+ tform_name = random.choice(list(self.transforms.keys()))
435
+ im = self.transforms[tform_name](im)
436
+
437
+ im = im.permute(1,2,0)
438
+
439
+ data['image'] = im
440
+ data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}"
441
+
442
+ return data
443
+
444
+ def __len__(self):
445
+ return len(self.ds)
446
+
447
+ def hf_dataset(
448
+ name,
449
+ image_transforms=[],
450
+ image_column="image",
451
+ text_column="text",
452
+ split='train',
453
+ image_key='image',
454
+ caption_key='txt',
455
+ ):
456
+ """Make huggingface dataset with appropriate list of transforms applied
457
+ """
458
+ ds = load_dataset(name, split=split)
459
+ tform = make_tranforms(image_transforms)
460
+
461
+ assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
462
+ assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
463
+
464
+ def pre_process(examples):
465
+ processed = {}
466
+ processed[image_key] = [tform(im) for im in examples[image_column]]
467
+ processed[caption_key] = examples[text_column]
468
+ return processed
469
+
470
+ ds.set_transform(pre_process)
471
+ return ds
472
+
473
+ class TextOnly(Dataset):
474
+ def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
475
+ """Returns only captions with dummy images"""
476
+ self.output_size = output_size
477
+ self.image_key = image_key
478
+ self.caption_key = caption_key
479
+ if isinstance(captions, Path):
480
+ self.captions = self._load_caption_file(captions)
481
+ else:
482
+ self.captions = captions
483
+
484
+ if n_gpus > 1:
485
+ # hack to make sure that all the captions appear on each gpu
486
+ repeated = [n_gpus*[x] for x in self.captions]
487
+ self.captions = []
488
+ [self.captions.extend(x) for x in repeated]
489
+
490
+ def __len__(self):
491
+ return len(self.captions)
492
+
493
+ def __getitem__(self, index):
494
+ dummy_im = torch.zeros(3, self.output_size, self.output_size)
495
+ dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
496
+ return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
497
+
498
+ def _load_caption_file(self, filename):
499
+ with open(filename, 'rt') as f:
500
+ captions = f.readlines()
501
+ return [x.strip('\n') for x in captions]
502
+
503
+
504
+
505
+ import random
506
+ import json
507
+ class IdRetreivalDataset(FolderData):
508
+ def __init__(self, ret_file, *args, **kwargs):
509
+ super().__init__(*args, **kwargs)
510
+ with open(ret_file, "rt") as f:
511
+ self.ret = json.load(f)
512
+
513
+ def __getitem__(self, index):
514
+ data = super().__getitem__(index)
515
+ key = self.paths[index].name
516
+ matches = self.ret[key]
517
+ if len(matches) > 0:
518
+ retreived = random.choice(matches)
519
+ else:
520
+ retreived = key
521
+ filename = self.root_dir/retreived
522
+ im = Image.open(filename).convert("RGB")
523
+ im = self.process_im(im)
524
+ # data["match"] = im
525
+ data["match"] = torch.cat((data["image"], im), dim=-1)
526
+ return data
One-2-3-45-master 2/ldm/extras.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from omegaconf import OmegaConf
3
+ import torch
4
+ from ldm.util import instantiate_from_config
5
+ import logging
6
+ from contextlib import contextmanager
7
+
8
+ from contextlib import contextmanager
9
+ import logging
10
+
11
+ @contextmanager
12
+ def all_logging_disabled(highest_level=logging.CRITICAL):
13
+ """
14
+ A context manager that will prevent any logging messages
15
+ triggered during the body from being processed.
16
+
17
+ :param highest_level: the maximum logging level in use.
18
+ This would only need to be changed if a custom level greater than CRITICAL
19
+ is defined.
20
+
21
+ https://gist.github.com/simon-weber/7853144
22
+ """
23
+ # two kind-of hacks here:
24
+ # * can't get the highest logging level in effect => delegate to the user
25
+ # * can't get the current module-level override => use an undocumented
26
+ # (but non-private!) interface
27
+
28
+ previous_level = logging.root.manager.disable
29
+
30
+ logging.disable(highest_level)
31
+
32
+ try:
33
+ yield
34
+ finally:
35
+ logging.disable(previous_level)
36
+
37
+ def load_training_dir(train_dir, device, epoch="last"):
38
+ """Load a checkpoint and config from training directory"""
39
+ train_dir = Path(train_dir)
40
+ ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41
+ assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42
+ config = list(train_dir.rglob(f"*-project.yaml"))
43
+ assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44
+ if len(config) > 1:
45
+ print(f"found {len(config)} matching config files")
46
+ config = sorted(config)[-1]
47
+ print(f"selecting {config}")
48
+ else:
49
+ config = config[0]
50
+
51
+
52
+ config = OmegaConf.load(config)
53
+ return load_model_from_config(config, ckpt[0], device)
54
+
55
+ def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56
+ """Loads a model from config and a ckpt
57
+ if config is a path will use omegaconf to load
58
+ """
59
+ if isinstance(config, (str, Path)):
60
+ config = OmegaConf.load(config)
61
+
62
+ with all_logging_disabled():
63
+ print(f"Loading model from {ckpt}")
64
+ pl_sd = torch.load(ckpt, map_location="cpu")
65
+ global_step = pl_sd["global_step"]
66
+ sd = pl_sd["state_dict"]
67
+ model = instantiate_from_config(config.model)
68
+ m, u = model.load_state_dict(sd, strict=False)
69
+ if len(m) > 0 and verbose:
70
+ print("missing keys:")
71
+ print(m)
72
+ if len(u) > 0 and verbose:
73
+ print("unexpected keys:")
74
+ model.to(device)
75
+ model.eval()
76
+ model.cond_stage_model.device = device
77
+ return model
One-2-3-45-master 2/ldm/guidance.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ from scipy import interpolate
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from IPython.display import clear_output
7
+ import abc
8
+
9
+
10
+ class GuideModel(torch.nn.Module, abc.ABC):
11
+ def __init__(self) -> None:
12
+ super().__init__()
13
+
14
+ @abc.abstractmethod
15
+ def preprocess(self, x_img):
16
+ pass
17
+
18
+ @abc.abstractmethod
19
+ def compute_loss(self, inp):
20
+ pass
21
+
22
+
23
+ class Guider(torch.nn.Module):
24
+ def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
25
+ """Apply classifier guidance
26
+
27
+ Specify a guidance scale as either a scalar
28
+ Or a schedule as a list of tuples t = 0->1 and scale, e.g.
29
+ [(0, 10), (0.5, 20), (1, 50)]
30
+ """
31
+ super().__init__()
32
+ self.sampler = sampler
33
+ self.index = 0
34
+ self.show = verbose
35
+ self.guide_model = guide_model
36
+ self.history = []
37
+
38
+ if isinstance(scale, (Tuple, List)):
39
+ times = np.array([x[0] for x in scale])
40
+ values = np.array([x[1] for x in scale])
41
+ self.scale_schedule = {"times": times, "values": values}
42
+ else:
43
+ self.scale_schedule = float(scale)
44
+
45
+ self.ddim_timesteps = sampler.ddim_timesteps
46
+ self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
47
+
48
+
49
+ def get_scales(self):
50
+ if isinstance(self.scale_schedule, float):
51
+ return len(self.ddim_timesteps)*[self.scale_schedule]
52
+
53
+ interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
54
+ fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
55
+ return interpolater(fractional_steps)
56
+
57
+ def modify_score(self, model, e_t, x, t, c):
58
+
59
+ # TODO look up index by t
60
+ scale = self.get_scales()[self.index]
61
+
62
+ if (scale == 0):
63
+ return e_t
64
+
65
+ sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
66
+ with torch.enable_grad():
67
+ x_in = x.detach().requires_grad_(True)
68
+ pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
69
+ x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
70
+
71
+ inp = self.guide_model.preprocess(x_img)
72
+ loss = self.guide_model.compute_loss(inp)
73
+ grads = torch.autograd.grad(loss.sum(), x_in)[0]
74
+ correction = grads * scale
75
+
76
+ if self.show:
77
+ clear_output(wait=True)
78
+ print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
79
+ self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
80
+ plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
81
+ plt.axis('off')
82
+ plt.show()
83
+ plt.imshow(correction[0][0].detach().cpu())
84
+ plt.axis('off')
85
+ plt.show()
86
+
87
+
88
+ e_t_mod = e_t - sqrt_1ma*correction
89
+ if self.show:
90
+ fig, axs = plt.subplots(1, 3)
91
+ axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
92
+ axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
93
+ axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
94
+ plt.show()
95
+ self.index += 1
96
+ return e_t_mod
One-2-3-45-master 2/ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
One-2-3-45-master 2/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ return posterior
329
+
330
+ def decode(self, z):
331
+ z = self.post_quant_conv(z)
332
+ dec = self.decoder(z)
333
+ return dec
334
+
335
+ def forward(self, input, sample_posterior=True):
336
+ posterior = self.encode(input)
337
+ if sample_posterior:
338
+ z = posterior.sample()
339
+ else:
340
+ z = posterior.mode()
341
+ dec = self.decode(z)
342
+ return dec, posterior
343
+
344
+ def get_input(self, batch, k):
345
+ x = batch[k]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
+ return x
350
+
351
+ def training_step(self, batch, batch_idx, optimizer_idx):
352
+ inputs = self.get_input(batch, self.image_key)
353
+ reconstructions, posterior = self(inputs)
354
+
355
+ if optimizer_idx == 0:
356
+ # train encoder+decoder+logvar
357
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
+ last_layer=self.get_last_layer(), split="train")
359
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
+ return aeloss
362
+
363
+ if optimizer_idx == 1:
364
+ # train the discriminator
365
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
+ last_layer=self.get_last_layer(), split="train")
367
+
368
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
+ return discloss
371
+
372
+ def validation_step(self, batch, batch_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
+ last_layer=self.get_last_layer(), split="val")
377
+
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
+ self.log_dict(log_dict_ae)
383
+ self.log_dict(log_dict_disc)
384
+ return self.log_dict
385
+
386
+ def configure_optimizers(self):
387
+ lr = self.learning_rate
388
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
+ list(self.decoder.parameters())+
390
+ list(self.quant_conv.parameters())+
391
+ list(self.post_quant_conv.parameters()),
392
+ lr=lr, betas=(0.5, 0.9))
393
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
+ lr=lr, betas=(0.5, 0.9))
395
+ return [opt_ae, opt_disc], []
396
+
397
+ def get_last_layer(self):
398
+ return self.decoder.conv_out.weight
399
+
400
+ @torch.no_grad()
401
+ def log_images(self, batch, only_inputs=False, **kwargs):
402
+ log = dict()
403
+ x = self.get_input(batch, self.image_key)
404
+ x = x.to(self.device)
405
+ if not only_inputs:
406
+ xrec, posterior = self(x)
407
+ if x.shape[1] > 3:
408
+ # colorize with random projection
409
+ assert xrec.shape[1] > 3
410
+ x = self.to_rgb(x)
411
+ xrec = self.to_rgb(xrec)
412
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
+ log["reconstructions"] = xrec
414
+ log["inputs"] = x
415
+ return log
416
+
417
+ def to_rgb(self, x):
418
+ assert self.image_key == "segmentation"
419
+ if not hasattr(self, "colorize"):
420
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
+ x = F.conv2d(x, weight=self.colorize)
422
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
+ return x
424
+
425
+
426
+ class IdentityFirstStage(torch.nn.Module):
427
+ def __init__(self, *args, vq_interface=False, **kwargs):
428
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
+ super().__init__()
430
+
431
+ def encode(self, x, *args, **kwargs):
432
+ return x
433
+
434
+ def decode(self, x, *args, **kwargs):
435
+ return x
436
+
437
+ def quantize(self, x, *args, **kwargs):
438
+ if self.vq_interface:
439
+ return x, None, [None, None, None]
440
+ return x
441
+
442
+ def forward(self, x, *args, **kwargs):
443
+ return x
One-2-3-45-master 2/ldm/models/diffusion/__init__.py ADDED
File without changes
One-2-3-45-master 2/ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
One-2-3-45-master 2/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+ from einops import rearrange
8
+
9
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
+ from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
+
12
+
13
+ class DDIMSampler(object):
14
+ def __init__(self, model, schedule="linear", **kwargs):
15
+ super().__init__()
16
+ self.model = model
17
+ self.ddpm_num_timesteps = model.num_timesteps
18
+ self.schedule = schedule
19
+ self.device = model.device
20
+
21
+ def to(self, device):
22
+ """Same as to in torch module
23
+ Don't really underestand why this isn't a module in the first place"""
24
+ for k, v in self.__dict__.items():
25
+ if isinstance(v, torch.Tensor):
26
+ new_v = getattr(self, k).to(device)
27
+ setattr(self, k, new_v)
28
+
29
+
30
+ def register_buffer(self, name, attr, device=None):
31
+ if type(attr) == torch.Tensor:
32
+ attr = attr.to(device)
33
+ # if attr.device != torch.device("cuda"):
34
+ # attr = attr.to(torch.device("cuda"))
35
+ setattr(self, name, attr)
36
+
37
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
38
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
39
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
40
+ alphas_cumprod = self.model.alphas_cumprod
41
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer('betas', to_torch(self.model.betas), self.device)
45
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod), self.device)
46
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev), self.device)
47
+
48
+ # calculations for diffusion q(x_t | x_{t-1}) and others
49
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())), self.device)
50
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())), self.device)
51
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())), self.device)
52
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())), self.device)
53
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)), self.device)
54
+
55
+ # ddim sampling parameters
56
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
57
+ ddim_timesteps=self.ddim_timesteps,
58
+ eta=ddim_eta,verbose=verbose)
59
+ self.register_buffer('ddim_sigmas', ddim_sigmas, self.device)
60
+ self.register_buffer('ddim_alphas', ddim_alphas, self.device)
61
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev, self.device)
62
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas), self.device)
63
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
64
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
65
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
66
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps, self.device)
67
+
68
+ @torch.no_grad()
69
+ def sample(self,
70
+ S,
71
+ batch_size,
72
+ shape,
73
+ conditioning=None,
74
+ callback=None,
75
+ normals_sequence=None,
76
+ img_callback=None,
77
+ quantize_x0=False,
78
+ eta=0.,
79
+ mask=None,
80
+ x0=None,
81
+ temperature=1.,
82
+ noise_dropout=0.,
83
+ score_corrector=None,
84
+ corrector_kwargs=None,
85
+ verbose=True,
86
+ x_T=None,
87
+ log_every_t=100,
88
+ unconditional_guidance_scale=1.,
89
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
90
+ dynamic_threshold=None,
91
+ **kwargs
92
+ ):
93
+ if conditioning is not None:
94
+ if isinstance(conditioning, dict):
95
+ ctmp = conditioning[list(conditioning.keys())[0]]
96
+ while isinstance(ctmp, list): ctmp = ctmp[0]
97
+ cbs = ctmp.shape[0]
98
+ if cbs != batch_size:
99
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100
+
101
+ else:
102
+ if conditioning.shape[0] != batch_size:
103
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
104
+
105
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
106
+ # sampling
107
+ C, H, W = shape
108
+ size = (batch_size, C, H, W)
109
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
110
+
111
+ samples, intermediates = self.ddim_sampling(conditioning, size,
112
+ callback=callback,
113
+ img_callback=img_callback,
114
+ quantize_denoised=quantize_x0,
115
+ mask=mask, x0=x0,
116
+ ddim_use_original_steps=False,
117
+ noise_dropout=noise_dropout,
118
+ temperature=temperature,
119
+ score_corrector=score_corrector,
120
+ corrector_kwargs=corrector_kwargs,
121
+ x_T=x_T,
122
+ log_every_t=log_every_t,
123
+ unconditional_guidance_scale=unconditional_guidance_scale,
124
+ unconditional_conditioning=unconditional_conditioning,
125
+ dynamic_threshold=dynamic_threshold,
126
+ )
127
+ return samples, intermediates
128
+
129
+ @torch.no_grad()
130
+ def ddim_sampling(self, cond, shape,
131
+ x_T=None, ddim_use_original_steps=False,
132
+ callback=None, timesteps=None, quantize_denoised=False,
133
+ mask=None, x0=None, img_callback=None, log_every_t=100,
134
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
135
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
136
+ t_start=-1):
137
+ device = self.model.betas.device
138
+ b = shape[0]
139
+ if x_T is None:
140
+ img = torch.randn(shape, device=device)
141
+ else:
142
+ img = x_T
143
+
144
+ if timesteps is None:
145
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
146
+ elif timesteps is not None and not ddim_use_original_steps:
147
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
148
+ timesteps = self.ddim_timesteps[:subset_end]
149
+
150
+ timesteps = timesteps[:t_start]
151
+
152
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
153
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
154
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
155
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
156
+
157
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
158
+
159
+ for i, step in enumerate(iterator):
160
+ index = total_steps - i - 1
161
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
162
+
163
+ if mask is not None:
164
+ assert x0 is not None
165
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
166
+ img = img_orig * mask + (1. - mask) * img
167
+
168
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
169
+ quantize_denoised=quantize_denoised, temperature=temperature,
170
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
171
+ corrector_kwargs=corrector_kwargs,
172
+ unconditional_guidance_scale=unconditional_guidance_scale,
173
+ unconditional_conditioning=unconditional_conditioning,
174
+ dynamic_threshold=dynamic_threshold)
175
+ img, pred_x0 = outs
176
+ if callback:
177
+ img = callback(i, img, pred_x0)
178
+ if img_callback: img_callback(pred_x0, i)
179
+
180
+ if index % log_every_t == 0 or index == total_steps - 1:
181
+ intermediates['x_inter'].append(img)
182
+ intermediates['pred_x0'].append(pred_x0)
183
+
184
+ return img, intermediates
185
+
186
+ @torch.no_grad()
187
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
188
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
189
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
190
+ dynamic_threshold=None):
191
+ b, *_, device = *x.shape, x.device
192
+
193
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
194
+ e_t = self.model.apply_model(x, t, c)
195
+ else:
196
+ x_in = torch.cat([x] * 2)
197
+ t_in = torch.cat([t] * 2)
198
+ if isinstance(c, dict):
199
+ assert isinstance(unconditional_conditioning, dict)
200
+ c_in = dict()
201
+ for k in c:
202
+ if isinstance(c[k], list):
203
+ c_in[k] = [torch.cat([
204
+ unconditional_conditioning[k][i],
205
+ c[k][i]]) for i in range(len(c[k]))]
206
+ else:
207
+ c_in[k] = torch.cat([
208
+ unconditional_conditioning[k],
209
+ c[k]])
210
+ else:
211
+ c_in = torch.cat([unconditional_conditioning, c])
212
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
213
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
214
+
215
+ if score_corrector is not None:
216
+ assert self.model.parameterization == "eps"
217
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
218
+
219
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
220
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
221
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
222
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
223
+ # select parameters corresponding to the currently considered timestep
224
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
225
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
226
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
227
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
228
+
229
+ # current prediction for x_0
230
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
231
+ if quantize_denoised:
232
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
233
+
234
+ if dynamic_threshold is not None:
235
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
236
+
237
+ # direction pointing to x_t
238
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
239
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
240
+ if noise_dropout > 0.:
241
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
242
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
243
+ return x_prev, pred_x0
244
+
245
+ @torch.no_grad()
246
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
247
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None):
248
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
249
+
250
+ assert t_enc <= num_reference_steps
251
+ num_steps = t_enc
252
+
253
+ if use_original_steps:
254
+ alphas_next = self.alphas_cumprod[:num_steps]
255
+ alphas = self.alphas_cumprod_prev[:num_steps]
256
+ else:
257
+ alphas_next = self.ddim_alphas[:num_steps]
258
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
259
+
260
+ x_next = x0
261
+ intermediates = []
262
+ inter_steps = []
263
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
264
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
265
+ if unconditional_guidance_scale == 1.:
266
+ noise_pred = self.model.apply_model(x_next, t, c)
267
+ else:
268
+ assert unconditional_conditioning is not None
269
+ e_t_uncond, noise_pred = torch.chunk(
270
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
271
+ torch.cat((unconditional_conditioning, c))), 2)
272
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
273
+
274
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
275
+ weighted_noise_pred = alphas_next[i].sqrt() * (
276
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
277
+ x_next = xt_weighted + weighted_noise_pred
278
+ if return_intermediates and i % (
279
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
280
+ intermediates.append(x_next)
281
+ inter_steps.append(i)
282
+ elif return_intermediates and i >= num_steps - 2:
283
+ intermediates.append(x_next)
284
+ inter_steps.append(i)
285
+
286
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
287
+ if return_intermediates:
288
+ out.update({'intermediates': intermediates})
289
+ return x_next, out
290
+
291
+ @torch.no_grad()
292
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
293
+ # fast, but does not allow for exact reconstruction
294
+ # t serves as an index to gather the correct alphas
295
+ if use_original_steps:
296
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
297
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
298
+ else:
299
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
300
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
301
+
302
+ if noise is None:
303
+ noise = torch.randn_like(x0)
304
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
305
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
306
+
307
+ @torch.no_grad()
308
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
309
+ use_original_steps=False):
310
+
311
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
312
+ timesteps = timesteps[:t_start]
313
+
314
+ time_range = np.flip(timesteps)
315
+ total_steps = timesteps.shape[0]
316
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
317
+
318
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
319
+ x_dec = x_latent
320
+ for i, step in enumerate(iterator):
321
+ index = total_steps - i - 1
322
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
323
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
324
+ unconditional_guidance_scale=unconditional_guidance_scale,
325
+ unconditional_conditioning=unconditional_conditioning)
326
+ return x_dec