Jahnavibh commited on
Commit
5d13141
·
1 Parent(s): 445e126

Delete One-2-3-45-master 2

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. One-2-3-45-master 2/.DS_Store +0 -0
  2. One-2-3-45-master 2/.gitattributes +0 -35
  3. One-2-3-45-master 2/.gitignore +0 -11
  4. One-2-3-45-master 2/LICENSE +0 -201
  5. One-2-3-45-master 2/README.md +0 -221
  6. One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml +0 -117
  7. One-2-3-45-master 2/download_ckpt.py +0 -30
  8. One-2-3-45-master 2/elevation_estimate/.gitignore +0 -3
  9. One-2-3-45-master 2/elevation_estimate/__init__.py +0 -0
  10. One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py +0 -10
  11. One-2-3-45-master 2/elevation_estimate/loftr/__init__.py +0 -2
  12. One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py +0 -11
  13. One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py +0 -199
  14. One-2-3-45-master 2/elevation_estimate/loftr/loftr.py +0 -81
  15. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py +0 -2
  16. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py +0 -59
  17. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py +0 -81
  18. One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py +0 -101
  19. One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py +0 -261
  20. One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py +0 -50
  21. One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py +0 -74
  22. One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py +0 -54
  23. One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py +0 -42
  24. One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py +0 -151
  25. One-2-3-45-master 2/elevation_estimate/pyproject.toml +0 -7
  26. One-2-3-45-master 2/elevation_estimate/utils/__init__.py +0 -0
  27. One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py +0 -205
  28. One-2-3-45-master 2/elevation_estimate/utils/plotting.py +0 -154
  29. One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py +0 -318
  30. One-2-3-45-master 2/elevation_estimate/utils/utils3d.py +0 -62
  31. One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep +0 -0
  32. One-2-3-45-master 2/example.ipynb +0 -0
  33. One-2-3-45-master 2/ldm/data/__init__.py +0 -0
  34. One-2-3-45-master 2/ldm/data/base.py +0 -40
  35. One-2-3-45-master 2/ldm/data/coco.py +0 -253
  36. One-2-3-45-master 2/ldm/data/dummy.py +0 -34
  37. One-2-3-45-master 2/ldm/data/imagenet.py +0 -394
  38. One-2-3-45-master 2/ldm/data/inpainting/__init__.py +0 -0
  39. One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py +0 -166
  40. One-2-3-45-master 2/ldm/data/laion.py +0 -537
  41. One-2-3-45-master 2/ldm/data/lsun.py +0 -92
  42. One-2-3-45-master 2/ldm/data/nerf_like.py +0 -165
  43. One-2-3-45-master 2/ldm/data/simple.py +0 -526
  44. One-2-3-45-master 2/ldm/extras.py +0 -77
  45. One-2-3-45-master 2/ldm/guidance.py +0 -96
  46. One-2-3-45-master 2/ldm/lr_scheduler.py +0 -98
  47. One-2-3-45-master 2/ldm/models/autoencoder.py +0 -443
  48. One-2-3-45-master 2/ldm/models/diffusion/__init__.py +0 -0
  49. One-2-3-45-master 2/ldm/models/diffusion/classifier.py +0 -267
  50. One-2-3-45-master 2/ldm/models/diffusion/ddim.py +0 -326
One-2-3-45-master 2/.DS_Store DELETED
Binary file (6.15 kB)
 
One-2-3-45-master 2/.gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/.gitignore DELETED
@@ -1,11 +0,0 @@
1
- __pycache__/
2
- exp/
3
- src/
4
- *.DS_Store
5
- *.ipynb
6
- *.egg-info/
7
- *.ckpt
8
- *.pth
9
-
10
- !example.ipynb
11
- !reconstruction/exp
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [yyyy] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/README.md DELETED
@@ -1,221 +0,0 @@
1
- <p align="center" width="100%">
2
- <img src="https://github.com/Dustinpro/Dustinpro/assets/23076389/0fbdb69a-0fb4-4b42-b9da-e0b28532bdfd" width="80%" height="80%">
3
- </p>
4
-
5
-
6
- <p align="center">
7
- [<a href="https://arxiv.org/pdf/2306.16928.pdf"><strong>Paper</strong></a>]
8
- [<a href="http://one-2-3-45.com"><strong>Project</strong></a>]
9
- [<a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45"><strong>Demo</strong></a>]
10
- [<a href="#citation"><strong>BibTeX</strong></a>]
11
- </p>
12
-
13
- <p align="center">
14
- <a href="https://huggingface.co/spaces/One-2-3-45/One-2-3-45">
15
- <img alt="Hugging Face Spaces" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Space_of_the_Week_%F0%9F%94%A5-blue">
16
- </a>
17
- </p>
18
-
19
- One-2-3-45 rethinks how to leverage 2D diffusion models for 3D AIGC and introduces a novel forward-only paradigm that avoids the time-consuming optimization.
20
-
21
- https://github.com/One-2-3-45/One-2-3-45/assets/16759292/a81d6e32-8d29-43a5-b044-b5112b9f9664
22
-
23
-
24
-
25
- https://github.com/One-2-3-45/One-2-3-45/assets/16759292/5ecd45ef-8fd3-4643-af4c-fac3050a0428
26
-
27
-
28
- ## News
29
- **[09/21/2023]**
30
- One-2-3-45 is accepted by NeurIPS 2023. See you in New Orleans!
31
-
32
- **[09/11/2023]**
33
- Training code released.
34
-
35
- **[08/18/2023]**
36
- Inference code released.
37
-
38
- **[07/24/2023]**
39
- Our demo reached the HuggingFace top 4 trending and was featured in 🤗 Spaces of the Week 🔥! Special thanks to HuggingFace 🤗 for sponsoring this demo!!
40
-
41
- **[07/11/2023]**
42
- [Online interactive demo](https://huggingface.co/spaces/One-2-3-45/One-2-3-45) released! Explore it and create your own 3D models in just 45 seconds!
43
-
44
- **[06/29/2023]**
45
- Check out our [paper](https://arxiv.org/pdf/2306.16928.pdf). [[X](https://twitter.com/_akhaliq/status/1674617785119305728)]
46
-
47
- ## Installation
48
- Hardware requirement: an NVIDIA GPU with memory >=18GB (_e.g._, RTX 3090 or A10). Tested on Ubuntu.
49
-
50
- We offer two ways to setup the environment:
51
-
52
- ### Traditional Installation
53
- <details>
54
- <summary>Step 1: Install Debian packages. </summary>
55
-
56
- ```bash
57
- sudo apt update && sudo apt install git-lfs libsparsehash-dev build-essential
58
- ```
59
- </details>
60
-
61
- <details>
62
- <summary>Step 2: Create and activate a conda environment. </summary>
63
-
64
- ```bash
65
- conda create -n One2345 python=3.10
66
- conda activate One2345
67
- ```
68
- </details>
69
-
70
- <details>
71
- <summary>Step 3: Clone the repository to the local machine. </summary>
72
-
73
- ```bash
74
- # Make sure you have git-lfs installed.
75
- git lfs install
76
- git clone https://github.com/One-2-3-45/One-2-3-45
77
- cd One-2-3-45
78
- ```
79
- </details>
80
-
81
- <details>
82
- <summary>Step 4: Install project dependencies using pip. </summary>
83
-
84
- ```bash
85
- # Ensure that the installed CUDA version matches the torch's cuda version.
86
- # Example: CUDA 11.8 installation
87
- wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
88
- sudo sh cuda_11.8.0_520.61.05_linux.run
89
- export PATH="/usr/local/cuda-11.8/bin:$PATH"
90
- export LD_LIBRARY_PATH="/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH"
91
- # Install PyTorch 2.0
92
- pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
93
- # Install dependencies
94
- pip install -r requirements.txt
95
- # Install inplace_abn and torchsparse
96
- export TORCH_CUDA_ARCH_LIST="7.0;7.2;8.0;8.6+PTX" # CUDA architectures. Modify according to your hardware.
97
- export IABN_FORCE_CUDA=1
98
- pip install inplace_abn
99
- FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/torchsparse.git@v1.4.0
100
- ```
101
- </details>
102
-
103
- <details>
104
- <summary>Step 5: Download model checkpoints. </summary>
105
-
106
- ```bash
107
- python download_ckpt.py
108
- ```
109
- </details>
110
-
111
-
112
- ### Installation by Docker Images
113
- <details>
114
- <summary>Option 1: Pull and Play (environment and checkpoints). (~22.3G)</summary>
115
-
116
- ```bash
117
- # Pull the Docker image that contains the full repository.
118
- docker pull chaoxu98/one2345:demo_1.0
119
- # An interactive demo will be launched automatically upon running the container.
120
- # This will provide a public URL like XXXXXXX.gradio.live
121
- docker run --name One-2-3-45_demo --gpus all -it chaoxu98/one2345:demo_1.0
122
- ```
123
- </details>
124
-
125
- <details>
126
- <summary>Option 2: Environment Only. (~7.3G)</summary>
127
-
128
- ```bash
129
- # Pull the Docker image that installed all project dependencies.
130
- docker pull chaoxu98/one2345:1.0
131
- # Start a Docker container named One2345.
132
- docker run --name One-2-3-45 --gpus all -it chaoxu98/one2345:1.0
133
- # Get a bash shell in the container.
134
- docker exec -it One-2-3-45 /bin/bash
135
- # Clone the repository to the local machine.
136
- git clone https://github.com/One-2-3-45/One-2-3-45
137
- cd One-2-3-45
138
- # Download model checkpoints.
139
- python download_ckpt.py
140
- # Refer to getting started for inference.
141
- ```
142
- </details>
143
-
144
- ## Getting Started (Inference)
145
-
146
- First-time running will take longer time to compile the models.
147
-
148
- Expected time cost per image: 40s on an NVIDIA A6000.
149
- ```bash
150
- # 1. Script
151
- python run.py --img_path PATH_TO_INPUT_IMG --half_precision
152
-
153
- # 2. Interactive demo (Gradio) with a friendly web interface
154
- # An URL will be provided in the output
155
- # (Local: 127.0.0.1:7860; Public: XXXXXXX.gradio.live)
156
- cd demo/
157
- python app.py
158
-
159
- # 3. Jupyter Notebook
160
- example.ipynb
161
- ```
162
-
163
- ## Training Your Own Model
164
-
165
- ### Data Preparation
166
- We use Objaverse-LVIS dataset for training and render the selected shapes (with CC-BY license) into 2D images with Blender.
167
- #### Download the training images.
168
- Download all One2345.zip.part-* files (5 files in total) from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and then cat them into a single .zip file using the following command:
169
- ```bash
170
- cat One2345.zip.part-* > One2345.zip
171
- ```
172
-
173
- #### Unzip the training images zip file.
174
- Unzip the zip file into a folder specified by yourself (`YOUR_BASE_FOLDER`) with the following command:
175
-
176
- ```bash
177
- unzip One2345.zip -d YOUR_BASE_FOLDER
178
- ```
179
-
180
- #### Download meta files.
181
-
182
- Download `One2345_training_pose.json` and `lvis_split_cc_by.json` from <a href="https://huggingface.co/datasets/One-2-3-45/training_data/tree/main">here</a> and put them into the same folder as the training images (`YOUR_BASE_FOLDER`).
183
-
184
- Your file structure should look like this:
185
- ```
186
- # One2345 is your base folder used in the previous steps
187
-
188
- One2345
189
- ├── One2345_training_pose.json
190
- ├── lvis_split_cc_by.json
191
- └── zero12345_narrow
192
- ├── 000-000
193
- ├── 000-001
194
- ├── 000-002
195
- ...
196
- └── 000-159
197
-
198
- ```
199
-
200
- ### Training
201
- Specify the `trainpath`, `valpath`, and `testpath` in the config file `./reconstruction/confs/one2345_lod_train.conf` to be `YOUR_BASE_FOLDER` used in data preparation steps and run the following command:
202
- ```bash
203
- cd reconstruction
204
- python exp_runner_generic_blender_train.py --mode train --conf confs/one2345_lod_train.conf
205
- ```
206
- Experiment logs and checkpoints will be saved in `./reconstruction/exp/`.
207
-
208
- ## Citation
209
-
210
- If you find our code helpful, please cite our paper:
211
-
212
- ```
213
- @misc{liu2023one2345,
214
- title={One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization},
215
- author={Minghua Liu and Chao Xu and Haian Jin and Linghao Chen and Mukund Varma T and Zexiang Xu and Hao Su},
216
- year={2023},
217
- eprint={2306.16928},
218
- archivePrefix={arXiv},
219
- primaryClass={cs.CV}
220
- }
221
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/configs/sd-objaverse-finetune-c_concat-256.yaml DELETED
@@ -1,117 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-04
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "image_target"
11
- cond_stage_key: "image_cond"
12
- image_size: 32
13
- channels: 4
14
- cond_stage_trainable: false # Note: different from the one we trained before
15
- conditioning_key: hybrid
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
-
19
- scheduler_config: # 10000 warmup steps
20
- target: ldm.lr_scheduler.LambdaLinearScheduler
21
- params:
22
- warm_up_steps: [ 100 ]
23
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
24
- f_start: [ 1.e-6 ]
25
- f_max: [ 1. ]
26
- f_min: [ 1. ]
27
-
28
- unet_config:
29
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
30
- params:
31
- image_size: 32 # unused
32
- in_channels: 8
33
- out_channels: 4
34
- model_channels: 320
35
- attention_resolutions: [ 4, 2, 1 ]
36
- num_res_blocks: 2
37
- channel_mult: [ 1, 2, 4, 4 ]
38
- num_heads: 8
39
- use_spatial_transformer: True
40
- transformer_depth: 1
41
- context_dim: 768
42
- use_checkpoint: True
43
- legacy: False
44
-
45
- first_stage_config:
46
- target: ldm.models.autoencoder.AutoencoderKL
47
- params:
48
- embed_dim: 4
49
- monitor: val/rec_loss
50
- ddconfig:
51
- double_z: true
52
- z_channels: 4
53
- resolution: 256
54
- in_channels: 3
55
- out_ch: 3
56
- ch: 128
57
- ch_mult:
58
- - 1
59
- - 2
60
- - 4
61
- - 4
62
- num_res_blocks: 2
63
- attn_resolutions: []
64
- dropout: 0.0
65
- lossconfig:
66
- target: torch.nn.Identity
67
-
68
- cond_stage_config:
69
- target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder
70
-
71
-
72
- data:
73
- target: ldm.data.simple.ObjaverseDataModuleFromConfig
74
- params:
75
- root_dir: 'views_whole_sphere'
76
- batch_size: 192
77
- num_workers: 16
78
- total_view: 4
79
- train:
80
- validation: False
81
- image_transforms:
82
- size: 256
83
-
84
- validation:
85
- validation: True
86
- image_transforms:
87
- size: 256
88
-
89
-
90
- lightning:
91
- find_unused_parameters: false
92
- metrics_over_trainsteps_checkpoint: True
93
- modelcheckpoint:
94
- params:
95
- every_n_train_steps: 5000
96
- callbacks:
97
- image_logger:
98
- target: main.ImageLogger
99
- params:
100
- batch_frequency: 500
101
- max_images: 32
102
- increase_log_steps: False
103
- log_first_step: True
104
- log_images_kwargs:
105
- use_ema_scope: False
106
- inpaint: False
107
- plot_progressive_rows: False
108
- plot_diffusion_rows: False
109
- N: 32
110
- unconditional_guidance_scale: 3.0
111
- unconditional_guidance_label: [""]
112
-
113
- trainer:
114
- benchmark: True
115
- val_check_interval: 5000000 # really sorry
116
- num_sanity_val_steps: 0
117
- accumulate_grad_batches: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/download_ckpt.py DELETED
@@ -1,30 +0,0 @@
1
- import urllib.request
2
- from tqdm import tqdm
3
-
4
- def download_checkpoint(url, save_path):
5
- try:
6
- with urllib.request.urlopen(url) as response, open(save_path, 'wb') as file:
7
- file_size = int(response.info().get('Content-Length', -1))
8
- chunk_size = 8192
9
- num_chunks = file_size // chunk_size if file_size > chunk_size else 1
10
-
11
- with tqdm(total=file_size, unit='B', unit_scale=True, desc='Downloading', ncols=100) as pbar:
12
- for chunk in iter(lambda: response.read(chunk_size), b''):
13
- file.write(chunk)
14
- pbar.update(len(chunk))
15
-
16
- print(f"Checkpoint downloaded and saved to: {save_path}")
17
- except Exception as e:
18
- print(f"Error downloading checkpoint: {e}")
19
-
20
- if __name__ == "__main__":
21
- ckpts = {
22
- "sam_vit_h_4b8939.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/sam_vit_h_4b8939.pth",
23
- "zero123-xl.ckpt": "https://huggingface.co/One-2-3-45/code/resolve/main/zero123-xl.ckpt",
24
- "elevation_estimate/utils/weights/indoor_ds_new.ckpt" : "https://huggingface.co/One-2-3-45/code/resolve/main/one2345_elev_est/tools/weights/indoor_ds_new.ckpt",
25
- "reconstruction/exp/lod0/checkpoints/ckpt_215000.pth": "https://huggingface.co/One-2-3-45/code/resolve/main/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth"
26
- }
27
- for ckpt_name, ckpt_url in ckpts.items():
28
- print(f"Downloading checkpoint: {ckpt_name}")
29
- download_checkpoint(ckpt_url, ckpt_name)
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/.gitignore DELETED
@@ -1,3 +0,0 @@
1
- build/
2
- .idea/
3
- *.egg-info/
 
 
 
 
One-2-3-45-master 2/elevation_estimate/__init__.py DELETED
File without changes
One-2-3-45-master 2/elevation_estimate/estimate_wild_imgs.py DELETED
@@ -1,10 +0,0 @@
1
- import os.path as osp
2
- from .utils.elev_est_api import elev_est_api
3
-
4
- def estimate_elev(root_dir):
5
- img_dir = osp.join(root_dir, "stage2_8")
6
- img_paths = []
7
- for i in range(4):
8
- img_paths.append(f"{img_dir}/0_{i}.png")
9
- elev = elev_est_api(img_paths)
10
- return elev
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .loftr import LoFTR
2
- from .utils.cvpr_ds_config import default_cfg
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/backbone/__init__.py DELETED
@@ -1,11 +0,0 @@
1
- from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
2
-
3
-
4
- def build_backbone(config):
5
- if config['backbone_type'] == 'ResNetFPN':
6
- if config['resolution'] == (8, 2):
7
- return ResNetFPN_8_2(config['resnetfpn'])
8
- elif config['resolution'] == (16, 4):
9
- return ResNetFPN_16_4(config['resnetfpn'])
10
- else:
11
- raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/backbone/resnet_fpn.py DELETED
@@ -1,199 +0,0 @@
1
- import torch.nn as nn
2
- import torch.nn.functional as F
3
-
4
-
5
- def conv1x1(in_planes, out_planes, stride=1):
6
- """1x1 convolution without padding"""
7
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
8
-
9
-
10
- def conv3x3(in_planes, out_planes, stride=1):
11
- """3x3 convolution with padding"""
12
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
13
-
14
-
15
- class BasicBlock(nn.Module):
16
- def __init__(self, in_planes, planes, stride=1):
17
- super().__init__()
18
- self.conv1 = conv3x3(in_planes, planes, stride)
19
- self.conv2 = conv3x3(planes, planes)
20
- self.bn1 = nn.BatchNorm2d(planes)
21
- self.bn2 = nn.BatchNorm2d(planes)
22
- self.relu = nn.ReLU(inplace=True)
23
-
24
- if stride == 1:
25
- self.downsample = None
26
- else:
27
- self.downsample = nn.Sequential(
28
- conv1x1(in_planes, planes, stride=stride),
29
- nn.BatchNorm2d(planes)
30
- )
31
-
32
- def forward(self, x):
33
- y = x
34
- y = self.relu(self.bn1(self.conv1(y)))
35
- y = self.bn2(self.conv2(y))
36
-
37
- if self.downsample is not None:
38
- x = self.downsample(x)
39
-
40
- return self.relu(x+y)
41
-
42
-
43
- class ResNetFPN_8_2(nn.Module):
44
- """
45
- ResNet+FPN, output resolution are 1/8 and 1/2.
46
- Each block has 2 layers.
47
- """
48
-
49
- def __init__(self, config):
50
- super().__init__()
51
- # Config
52
- block = BasicBlock
53
- initial_dim = config['initial_dim']
54
- block_dims = config['block_dims']
55
-
56
- # Class Variable
57
- self.in_planes = initial_dim
58
-
59
- # Networks
60
- self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
61
- self.bn1 = nn.BatchNorm2d(initial_dim)
62
- self.relu = nn.ReLU(inplace=True)
63
-
64
- self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
65
- self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
66
- self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
67
-
68
- # 3. FPN upsample
69
- self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
70
- self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
71
- self.layer2_outconv2 = nn.Sequential(
72
- conv3x3(block_dims[2], block_dims[2]),
73
- nn.BatchNorm2d(block_dims[2]),
74
- nn.LeakyReLU(),
75
- conv3x3(block_dims[2], block_dims[1]),
76
- )
77
- self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
78
- self.layer1_outconv2 = nn.Sequential(
79
- conv3x3(block_dims[1], block_dims[1]),
80
- nn.BatchNorm2d(block_dims[1]),
81
- nn.LeakyReLU(),
82
- conv3x3(block_dims[1], block_dims[0]),
83
- )
84
-
85
- for m in self.modules():
86
- if isinstance(m, nn.Conv2d):
87
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
88
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
89
- nn.init.constant_(m.weight, 1)
90
- nn.init.constant_(m.bias, 0)
91
-
92
- def _make_layer(self, block, dim, stride=1):
93
- layer1 = block(self.in_planes, dim, stride=stride)
94
- layer2 = block(dim, dim, stride=1)
95
- layers = (layer1, layer2)
96
-
97
- self.in_planes = dim
98
- return nn.Sequential(*layers)
99
-
100
- def forward(self, x):
101
- # ResNet Backbone
102
- x0 = self.relu(self.bn1(self.conv1(x)))
103
- x1 = self.layer1(x0) # 1/2
104
- x2 = self.layer2(x1) # 1/4
105
- x3 = self.layer3(x2) # 1/8
106
-
107
- # FPN
108
- x3_out = self.layer3_outconv(x3)
109
-
110
- x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
111
- x2_out = self.layer2_outconv(x2)
112
- x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
113
-
114
- x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
115
- x1_out = self.layer1_outconv(x1)
116
- x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
117
-
118
- return [x3_out, x1_out]
119
-
120
-
121
- class ResNetFPN_16_4(nn.Module):
122
- """
123
- ResNet+FPN, output resolution are 1/16 and 1/4.
124
- Each block has 2 layers.
125
- """
126
-
127
- def __init__(self, config):
128
- super().__init__()
129
- # Config
130
- block = BasicBlock
131
- initial_dim = config['initial_dim']
132
- block_dims = config['block_dims']
133
-
134
- # Class Variable
135
- self.in_planes = initial_dim
136
-
137
- # Networks
138
- self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
139
- self.bn1 = nn.BatchNorm2d(initial_dim)
140
- self.relu = nn.ReLU(inplace=True)
141
-
142
- self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
143
- self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
144
- self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
145
- self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
146
-
147
- # 3. FPN upsample
148
- self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
149
- self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
150
- self.layer3_outconv2 = nn.Sequential(
151
- conv3x3(block_dims[3], block_dims[3]),
152
- nn.BatchNorm2d(block_dims[3]),
153
- nn.LeakyReLU(),
154
- conv3x3(block_dims[3], block_dims[2]),
155
- )
156
-
157
- self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
158
- self.layer2_outconv2 = nn.Sequential(
159
- conv3x3(block_dims[2], block_dims[2]),
160
- nn.BatchNorm2d(block_dims[2]),
161
- nn.LeakyReLU(),
162
- conv3x3(block_dims[2], block_dims[1]),
163
- )
164
-
165
- for m in self.modules():
166
- if isinstance(m, nn.Conv2d):
167
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
- elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
169
- nn.init.constant_(m.weight, 1)
170
- nn.init.constant_(m.bias, 0)
171
-
172
- def _make_layer(self, block, dim, stride=1):
173
- layer1 = block(self.in_planes, dim, stride=stride)
174
- layer2 = block(dim, dim, stride=1)
175
- layers = (layer1, layer2)
176
-
177
- self.in_planes = dim
178
- return nn.Sequential(*layers)
179
-
180
- def forward(self, x):
181
- # ResNet Backbone
182
- x0 = self.relu(self.bn1(self.conv1(x)))
183
- x1 = self.layer1(x0) # 1/2
184
- x2 = self.layer2(x1) # 1/4
185
- x3 = self.layer3(x2) # 1/8
186
- x4 = self.layer4(x3) # 1/16
187
-
188
- # FPN
189
- x4_out = self.layer4_outconv(x4)
190
-
191
- x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
192
- x3_out = self.layer3_outconv(x3)
193
- x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
194
-
195
- x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
196
- x2_out = self.layer2_outconv(x2)
197
- x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
198
-
199
- return [x4_out, x2_out]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/loftr.py DELETED
@@ -1,81 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from einops.einops import rearrange
4
-
5
- from .backbone import build_backbone
6
- from .utils.position_encoding import PositionEncodingSine
7
- from .loftr_module import LocalFeatureTransformer, FinePreprocess
8
- from .utils.coarse_matching import CoarseMatching
9
- from .utils.fine_matching import FineMatching
10
-
11
-
12
- class LoFTR(nn.Module):
13
- def __init__(self, config):
14
- super().__init__()
15
- # Misc
16
- self.config = config
17
-
18
- # Modules
19
- self.backbone = build_backbone(config)
20
- self.pos_encoding = PositionEncodingSine(
21
- config['coarse']['d_model'],
22
- temp_bug_fix=config['coarse']['temp_bug_fix'])
23
- self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
24
- self.coarse_matching = CoarseMatching(config['match_coarse'])
25
- self.fine_preprocess = FinePreprocess(config)
26
- self.loftr_fine = LocalFeatureTransformer(config["fine"])
27
- self.fine_matching = FineMatching()
28
-
29
- def forward(self, data):
30
- """
31
- Update:
32
- data (dict): {
33
- 'image0': (torch.Tensor): (N, 1, H, W)
34
- 'image1': (torch.Tensor): (N, 1, H, W)
35
- 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
36
- 'mask1'(optional) : (torch.Tensor): (N, H, W)
37
- }
38
- """
39
- # 1. Local Feature CNN
40
- data.update({
41
- 'bs': data['image0'].size(0),
42
- 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
43
- })
44
-
45
- if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
46
- feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
47
- (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
48
- else: # handle different input shapes
49
- (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
50
-
51
- data.update({
52
- 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
53
- 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
54
- })
55
-
56
- # 2. coarse-level loftr module
57
- # add featmap with positional encoding, then flatten it to sequence [N, HW, C]
58
- feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
59
- feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
60
-
61
- mask_c0 = mask_c1 = None # mask is useful in training
62
- if 'mask0' in data:
63
- mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
64
- feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
65
-
66
- # 3. match coarse-level
67
- self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
68
-
69
- # 4. fine-level refinement
70
- feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
71
- if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
72
- feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
73
-
74
- # 5. match fine-level
75
- self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
76
-
77
- def load_state_dict(self, state_dict, *args, **kwargs):
78
- for k in list(state_dict.keys()):
79
- if k.startswith('matcher.'):
80
- state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
81
- return super().load_state_dict(state_dict, *args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .transformer import LocalFeatureTransformer
2
- from .fine_preprocess import FinePreprocess
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/fine_preprocess.py DELETED
@@ -1,59 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from einops.einops import rearrange, repeat
5
-
6
-
7
- class FinePreprocess(nn.Module):
8
- def __init__(self, config):
9
- super().__init__()
10
-
11
- self.config = config
12
- self.cat_c_feat = config['fine_concat_coarse_feat']
13
- self.W = self.config['fine_window_size']
14
-
15
- d_model_c = self.config['coarse']['d_model']
16
- d_model_f = self.config['fine']['d_model']
17
- self.d_model_f = d_model_f
18
- if self.cat_c_feat:
19
- self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
20
- self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
21
-
22
- self._reset_parameters()
23
-
24
- def _reset_parameters(self):
25
- for p in self.parameters():
26
- if p.dim() > 1:
27
- nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
28
-
29
- def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
30
- W = self.W
31
- stride = data['hw0_f'][0] // data['hw0_c'][0]
32
-
33
- data.update({'W': W})
34
- if data['b_ids'].shape[0] == 0:
35
- feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
36
- feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
37
- return feat0, feat1
38
-
39
- # 1. unfold(crop) all local windows
40
- feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
41
- feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
42
- feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
43
- feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
44
-
45
- # 2. select only the predicted matches
46
- feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
47
- feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
48
-
49
- # option: use coarse-level loftr feature as context: concat and linear
50
- if self.cat_c_feat:
51
- feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
52
- feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
53
- feat_cf_win = self.merge_feat(torch.cat([
54
- torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
55
- repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
56
- ], -1))
57
- feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
58
-
59
- return feat_f0_unfold, feat_f1_unfold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/linear_attention.py DELETED
@@ -1,81 +0,0 @@
1
- """
2
- Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
3
- Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
4
- """
5
-
6
- import torch
7
- from torch.nn import Module, Dropout
8
-
9
-
10
- def elu_feature_map(x):
11
- return torch.nn.functional.elu(x) + 1
12
-
13
-
14
- class LinearAttention(Module):
15
- def __init__(self, eps=1e-6):
16
- super().__init__()
17
- self.feature_map = elu_feature_map
18
- self.eps = eps
19
-
20
- def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
21
- """ Multi-Head linear attention proposed in "Transformers are RNNs"
22
- Args:
23
- queries: [N, L, H, D]
24
- keys: [N, S, H, D]
25
- values: [N, S, H, D]
26
- q_mask: [N, L]
27
- kv_mask: [N, S]
28
- Returns:
29
- queried_values: (N, L, H, D)
30
- """
31
- Q = self.feature_map(queries)
32
- K = self.feature_map(keys)
33
-
34
- # set padded position to zero
35
- if q_mask is not None:
36
- Q = Q * q_mask[:, :, None, None]
37
- if kv_mask is not None:
38
- K = K * kv_mask[:, :, None, None]
39
- values = values * kv_mask[:, :, None, None]
40
-
41
- v_length = values.size(1)
42
- values = values / v_length # prevent fp16 overflow
43
- KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
44
- Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
45
- queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
46
-
47
- return queried_values.contiguous()
48
-
49
-
50
- class FullAttention(Module):
51
- def __init__(self, use_dropout=False, attention_dropout=0.1):
52
- super().__init__()
53
- self.use_dropout = use_dropout
54
- self.dropout = Dropout(attention_dropout)
55
-
56
- def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
57
- """ Multi-head scaled dot-product attention, a.k.a full attention.
58
- Args:
59
- queries: [N, L, H, D]
60
- keys: [N, S, H, D]
61
- values: [N, S, H, D]
62
- q_mask: [N, L]
63
- kv_mask: [N, S]
64
- Returns:
65
- queried_values: (N, L, H, D)
66
- """
67
-
68
- # Compute the unnormalized attention and apply the masks
69
- QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
70
- if kv_mask is not None:
71
- QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
72
-
73
- # Compute the attention and the weighted average
74
- softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
75
- A = torch.softmax(softmax_temp * QK, dim=2)
76
- if self.use_dropout:
77
- A = self.dropout(A)
78
-
79
- queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
80
-
81
- return queried_values.contiguous()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/loftr_module/transformer.py DELETED
@@ -1,101 +0,0 @@
1
- import copy
2
- import torch
3
- import torch.nn as nn
4
- from .linear_attention import LinearAttention, FullAttention
5
-
6
-
7
- class LoFTREncoderLayer(nn.Module):
8
- def __init__(self,
9
- d_model,
10
- nhead,
11
- attention='linear'):
12
- super(LoFTREncoderLayer, self).__init__()
13
-
14
- self.dim = d_model // nhead
15
- self.nhead = nhead
16
-
17
- # multi-head attention
18
- self.q_proj = nn.Linear(d_model, d_model, bias=False)
19
- self.k_proj = nn.Linear(d_model, d_model, bias=False)
20
- self.v_proj = nn.Linear(d_model, d_model, bias=False)
21
- self.attention = LinearAttention() if attention == 'linear' else FullAttention()
22
- self.merge = nn.Linear(d_model, d_model, bias=False)
23
-
24
- # feed-forward network
25
- self.mlp = nn.Sequential(
26
- nn.Linear(d_model*2, d_model*2, bias=False),
27
- nn.ReLU(True),
28
- nn.Linear(d_model*2, d_model, bias=False),
29
- )
30
-
31
- # norm and dropout
32
- self.norm1 = nn.LayerNorm(d_model)
33
- self.norm2 = nn.LayerNorm(d_model)
34
-
35
- def forward(self, x, source, x_mask=None, source_mask=None):
36
- """
37
- Args:
38
- x (torch.Tensor): [N, L, C]
39
- source (torch.Tensor): [N, S, C]
40
- x_mask (torch.Tensor): [N, L] (optional)
41
- source_mask (torch.Tensor): [N, S] (optional)
42
- """
43
- bs = x.size(0)
44
- query, key, value = x, source, source
45
-
46
- # multi-head attention
47
- query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
48
- key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
49
- value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
50
- message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
51
- message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
52
- message = self.norm1(message)
53
-
54
- # feed-forward network
55
- message = self.mlp(torch.cat([x, message], dim=2))
56
- message = self.norm2(message)
57
-
58
- return x + message
59
-
60
-
61
- class LocalFeatureTransformer(nn.Module):
62
- """A Local Feature Transformer (LoFTR) module."""
63
-
64
- def __init__(self, config):
65
- super(LocalFeatureTransformer, self).__init__()
66
-
67
- self.config = config
68
- self.d_model = config['d_model']
69
- self.nhead = config['nhead']
70
- self.layer_names = config['layer_names']
71
- encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
72
- self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
73
- self._reset_parameters()
74
-
75
- def _reset_parameters(self):
76
- for p in self.parameters():
77
- if p.dim() > 1:
78
- nn.init.xavier_uniform_(p)
79
-
80
- def forward(self, feat0, feat1, mask0=None, mask1=None):
81
- """
82
- Args:
83
- feat0 (torch.Tensor): [N, L, C]
84
- feat1 (torch.Tensor): [N, S, C]
85
- mask0 (torch.Tensor): [N, L] (optional)
86
- mask1 (torch.Tensor): [N, S] (optional)
87
- """
88
-
89
- assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
90
-
91
- for layer, name in zip(self.layers, self.layer_names):
92
- if name == 'self':
93
- feat0 = layer(feat0, feat0, mask0, mask0)
94
- feat1 = layer(feat1, feat1, mask1, mask1)
95
- elif name == 'cross':
96
- feat0 = layer(feat0, feat1, mask0, mask1)
97
- feat1 = layer(feat1, feat0, mask1, mask0)
98
- else:
99
- raise KeyError
100
-
101
- return feat0, feat1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/coarse_matching.py DELETED
@@ -1,261 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from einops.einops import rearrange
5
-
6
- INF = 1e9
7
-
8
- def mask_border(m, b: int, v):
9
- """ Mask borders with value
10
- Args:
11
- m (torch.Tensor): [N, H0, W0, H1, W1]
12
- b (int)
13
- v (m.dtype)
14
- """
15
- if b <= 0:
16
- return
17
-
18
- m[:, :b] = v
19
- m[:, :, :b] = v
20
- m[:, :, :, :b] = v
21
- m[:, :, :, :, :b] = v
22
- m[:, -b:] = v
23
- m[:, :, -b:] = v
24
- m[:, :, :, -b:] = v
25
- m[:, :, :, :, -b:] = v
26
-
27
-
28
- def mask_border_with_padding(m, bd, v, p_m0, p_m1):
29
- if bd <= 0:
30
- return
31
-
32
- m[:, :bd] = v
33
- m[:, :, :bd] = v
34
- m[:, :, :, :bd] = v
35
- m[:, :, :, :, :bd] = v
36
-
37
- h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
38
- h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
39
- for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
40
- m[b_idx, h0 - bd:] = v
41
- m[b_idx, :, w0 - bd:] = v
42
- m[b_idx, :, :, h1 - bd:] = v
43
- m[b_idx, :, :, :, w1 - bd:] = v
44
-
45
-
46
- def compute_max_candidates(p_m0, p_m1):
47
- """Compute the max candidates of all pairs within a batch
48
-
49
- Args:
50
- p_m0, p_m1 (torch.Tensor): padded masks
51
- """
52
- h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
53
- h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
54
- max_cand = torch.sum(
55
- torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
56
- return max_cand
57
-
58
-
59
- class CoarseMatching(nn.Module):
60
- def __init__(self, config):
61
- super().__init__()
62
- self.config = config
63
- # general config
64
- self.thr = config['thr']
65
- self.border_rm = config['border_rm']
66
- # -- # for trainig fine-level LoFTR
67
- self.train_coarse_percent = config['train_coarse_percent']
68
- self.train_pad_num_gt_min = config['train_pad_num_gt_min']
69
-
70
- # we provide 2 options for differentiable matching
71
- self.match_type = config['match_type']
72
- if self.match_type == 'dual_softmax':
73
- self.temperature = config['dsmax_temperature']
74
- elif self.match_type == 'sinkhorn':
75
- try:
76
- from .superglue import log_optimal_transport
77
- except ImportError:
78
- raise ImportError("download superglue.py first!")
79
- self.log_optimal_transport = log_optimal_transport
80
- self.bin_score = nn.Parameter(
81
- torch.tensor(config['skh_init_bin_score'], requires_grad=True))
82
- self.skh_iters = config['skh_iters']
83
- self.skh_prefilter = config['skh_prefilter']
84
- else:
85
- raise NotImplementedError()
86
-
87
- def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
88
- """
89
- Args:
90
- feat0 (torch.Tensor): [N, L, C]
91
- feat1 (torch.Tensor): [N, S, C]
92
- data (dict)
93
- mask_c0 (torch.Tensor): [N, L] (optional)
94
- mask_c1 (torch.Tensor): [N, S] (optional)
95
- Update:
96
- data (dict): {
97
- 'b_ids' (torch.Tensor): [M'],
98
- 'i_ids' (torch.Tensor): [M'],
99
- 'j_ids' (torch.Tensor): [M'],
100
- 'gt_mask' (torch.Tensor): [M'],
101
- 'mkpts0_c' (torch.Tensor): [M, 2],
102
- 'mkpts1_c' (torch.Tensor): [M, 2],
103
- 'mconf' (torch.Tensor): [M]}
104
- NOTE: M' != M during training.
105
- """
106
- N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
107
-
108
- # normalize
109
- feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
110
- [feat_c0, feat_c1])
111
-
112
- if self.match_type == 'dual_softmax':
113
- sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
114
- feat_c1) / self.temperature
115
- if mask_c0 is not None:
116
- sim_matrix.masked_fill_(
117
- ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
118
- -INF)
119
- conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
120
-
121
- elif self.match_type == 'sinkhorn':
122
- # sinkhorn, dustbin included
123
- sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
124
- if mask_c0 is not None:
125
- sim_matrix[:, :L, :S].masked_fill_(
126
- ~(mask_c0[..., None] * mask_c1[:, None]).bool(),
127
- -INF)
128
-
129
- # build uniform prior & use sinkhorn
130
- log_assign_matrix = self.log_optimal_transport(
131
- sim_matrix, self.bin_score, self.skh_iters)
132
- assign_matrix = log_assign_matrix.exp()
133
- conf_matrix = assign_matrix[:, :-1, :-1]
134
-
135
- # filter prediction with dustbin score (only in evaluation mode)
136
- if not self.training and self.skh_prefilter:
137
- filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
138
- filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
139
- conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
140
- conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
141
-
142
- if self.config['sparse_spvs']:
143
- data.update({'conf_matrix_with_bin': assign_matrix.clone()})
144
-
145
- data.update({'conf_matrix': conf_matrix})
146
-
147
- # predict coarse matches from conf_matrix
148
- data.update(**self.get_coarse_match(conf_matrix, data))
149
-
150
- @torch.no_grad()
151
- def get_coarse_match(self, conf_matrix, data):
152
- """
153
- Args:
154
- conf_matrix (torch.Tensor): [N, L, S]
155
- data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
156
- Returns:
157
- coarse_matches (dict): {
158
- 'b_ids' (torch.Tensor): [M'],
159
- 'i_ids' (torch.Tensor): [M'],
160
- 'j_ids' (torch.Tensor): [M'],
161
- 'gt_mask' (torch.Tensor): [M'],
162
- 'm_bids' (torch.Tensor): [M],
163
- 'mkpts0_c' (torch.Tensor): [M, 2],
164
- 'mkpts1_c' (torch.Tensor): [M, 2],
165
- 'mconf' (torch.Tensor): [M]}
166
- """
167
- axes_lengths = {
168
- 'h0c': data['hw0_c'][0],
169
- 'w0c': data['hw0_c'][1],
170
- 'h1c': data['hw1_c'][0],
171
- 'w1c': data['hw1_c'][1]
172
- }
173
- _device = conf_matrix.device
174
- # 1. confidence thresholding
175
- mask = conf_matrix > self.thr
176
- mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
177
- **axes_lengths)
178
- if 'mask0' not in data:
179
- mask_border(mask, self.border_rm, False)
180
- else:
181
- mask_border_with_padding(mask, self.border_rm, False,
182
- data['mask0'], data['mask1'])
183
- mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
184
- **axes_lengths)
185
-
186
- # 2. mutual nearest
187
- mask = mask \
188
- * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
189
- * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
190
-
191
- # 3. find all valid coarse matches
192
- # this only works when at most one `True` in each row
193
- mask_v, all_j_ids = mask.max(dim=2)
194
- b_ids, i_ids = torch.where(mask_v)
195
- j_ids = all_j_ids[b_ids, i_ids]
196
- mconf = conf_matrix[b_ids, i_ids, j_ids]
197
-
198
- # 4. Random sampling of training samples for fine-level LoFTR
199
- # (optional) pad samples with gt coarse-level matches
200
- if self.training:
201
- # NOTE:
202
- # The sampling is performed across all pairs in a batch without manually balancing
203
- # #samples for fine-level increases w.r.t. batch_size
204
- if 'mask0' not in data:
205
- num_candidates_max = mask.size(0) * max(
206
- mask.size(1), mask.size(2))
207
- else:
208
- num_candidates_max = compute_max_candidates(
209
- data['mask0'], data['mask1'])
210
- num_matches_train = int(num_candidates_max *
211
- self.train_coarse_percent)
212
- num_matches_pred = len(b_ids)
213
- assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
214
-
215
- # pred_indices is to select from prediction
216
- if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
217
- pred_indices = torch.arange(num_matches_pred, device=_device)
218
- else:
219
- pred_indices = torch.randint(
220
- num_matches_pred,
221
- (num_matches_train - self.train_pad_num_gt_min, ),
222
- device=_device)
223
-
224
- # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
225
- gt_pad_indices = torch.randint(
226
- len(data['spv_b_ids']),
227
- (max(num_matches_train - num_matches_pred,
228
- self.train_pad_num_gt_min), ),
229
- device=_device)
230
- mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
231
-
232
- b_ids, i_ids, j_ids, mconf = map(
233
- lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
234
- dim=0),
235
- *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
236
- [j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
237
-
238
- # These matches select patches that feed into fine-level network
239
- coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
240
-
241
- # 4. Update with matches in original image resolution
242
- scale = data['hw0_i'][0] / data['hw0_c'][0]
243
- scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
244
- scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
245
- mkpts0_c = torch.stack(
246
- [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
247
- dim=1) * scale0
248
- mkpts1_c = torch.stack(
249
- [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
250
- dim=1) * scale1
251
-
252
- # These matches is the current prediction (for visualization)
253
- coarse_matches.update({
254
- 'gt_mask': mconf == 0,
255
- 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
256
- 'mkpts0_c': mkpts0_c[mconf != 0],
257
- 'mkpts1_c': mkpts1_c[mconf != 0],
258
- 'mconf': mconf[mconf != 0]
259
- })
260
-
261
- return coarse_matches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/cvpr_ds_config.py DELETED
@@ -1,50 +0,0 @@
1
- from yacs.config import CfgNode as CN
2
-
3
-
4
- def lower_config(yacs_cfg):
5
- if not isinstance(yacs_cfg, CN):
6
- return yacs_cfg
7
- return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
8
-
9
-
10
- _CN = CN()
11
- _CN.BACKBONE_TYPE = 'ResNetFPN'
12
- _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
13
- _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
14
- _CN.FINE_CONCAT_COARSE_FEAT = True
15
-
16
- # 1. LoFTR-backbone (local feature CNN) config
17
- _CN.RESNETFPN = CN()
18
- _CN.RESNETFPN.INITIAL_DIM = 128
19
- _CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
20
-
21
- # 2. LoFTR-coarse module config
22
- _CN.COARSE = CN()
23
- _CN.COARSE.D_MODEL = 256
24
- _CN.COARSE.D_FFN = 256
25
- _CN.COARSE.NHEAD = 8
26
- _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
27
- _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
28
- _CN.COARSE.TEMP_BUG_FIX = False
29
-
30
- # 3. Coarse-Matching config
31
- _CN.MATCH_COARSE = CN()
32
- _CN.MATCH_COARSE.THR = 0.2
33
- _CN.MATCH_COARSE.BORDER_RM = 2
34
- _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
35
- _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
36
- _CN.MATCH_COARSE.SKH_ITERS = 3
37
- _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
38
- _CN.MATCH_COARSE.SKH_PREFILTER = True
39
- _CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
40
- _CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
41
-
42
- # 4. LoFTR-fine module config
43
- _CN.FINE = CN()
44
- _CN.FINE.D_MODEL = 128
45
- _CN.FINE.D_FFN = 128
46
- _CN.FINE.NHEAD = 8
47
- _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
48
- _CN.FINE.ATTENTION = 'linear'
49
-
50
- default_cfg = lower_config(_CN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/fine_matching.py DELETED
@@ -1,74 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
-
5
- from kornia.geometry.subpix import dsnt
6
- from kornia.utils.grid import create_meshgrid
7
-
8
-
9
- class FineMatching(nn.Module):
10
- """FineMatching with s2d paradigm"""
11
-
12
- def __init__(self):
13
- super().__init__()
14
-
15
- def forward(self, feat_f0, feat_f1, data):
16
- """
17
- Args:
18
- feat0 (torch.Tensor): [M, WW, C]
19
- feat1 (torch.Tensor): [M, WW, C]
20
- data (dict)
21
- Update:
22
- data (dict):{
23
- 'expec_f' (torch.Tensor): [M, 3],
24
- 'mkpts0_f' (torch.Tensor): [M, 2],
25
- 'mkpts1_f' (torch.Tensor): [M, 2]}
26
- """
27
- M, WW, C = feat_f0.shape
28
- W = int(math.sqrt(WW))
29
- scale = data['hw0_i'][0] / data['hw0_f'][0]
30
- self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
31
-
32
- # corner case: if no coarse matches found
33
- if M == 0:
34
- assert self.training == False, "M is always >0, when training, see coarse_matching.py"
35
- # logger.warning('No matches found in coarse-level.')
36
- data.update({
37
- 'expec_f': torch.empty(0, 3, device=feat_f0.device),
38
- 'mkpts0_f': data['mkpts0_c'],
39
- 'mkpts1_f': data['mkpts1_c'],
40
- })
41
- return
42
-
43
- feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
44
- sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
45
- softmax_temp = 1. / C**.5
46
- heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
47
-
48
- # compute coordinates from heatmap
49
- coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
50
- grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
51
-
52
- # compute std over <x, y>
53
- var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
54
- std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
55
-
56
- # for fine-level supervision
57
- data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
58
-
59
- # compute absolute kpt coords
60
- self.get_fine_match(coords_normalized, data)
61
-
62
- @torch.no_grad()
63
- def get_fine_match(self, coords_normed, data):
64
- W, WW, C, scale = self.W, self.WW, self.C, self.scale
65
-
66
- # mkpts0_f and mkpts1_f
67
- mkpts0_f = data['mkpts0_c']
68
- scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
69
- mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
70
-
71
- data.update({
72
- "mkpts0_f": mkpts0_f,
73
- "mkpts1_f": mkpts1_f
74
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/geometry.py DELETED
@@ -1,54 +0,0 @@
1
- import torch
2
-
3
-
4
- @torch.no_grad()
5
- def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
6
- """ Warp kpts0 from I0 to I1 with depth, K and Rt
7
- Also check covisibility and depth consistency.
8
- Depth is consistent if relative error < 0.2 (hard-coded).
9
-
10
- Args:
11
- kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
12
- depth0 (torch.Tensor): [N, H, W],
13
- depth1 (torch.Tensor): [N, H, W],
14
- T_0to1 (torch.Tensor): [N, 3, 4],
15
- K0 (torch.Tensor): [N, 3, 3],
16
- K1 (torch.Tensor): [N, 3, 3],
17
- Returns:
18
- calculable_mask (torch.Tensor): [N, L]
19
- warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
20
- """
21
- kpts0_long = kpts0.round().long()
22
-
23
- # Sample depth, get calculable_mask on depth != 0
24
- kpts0_depth = torch.stack(
25
- [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
26
- ) # (N, L)
27
- nonzero_mask = kpts0_depth != 0
28
-
29
- # Unproject
30
- kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
31
- kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
32
-
33
- # Rigid Transform
34
- w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
35
- w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
36
-
37
- # Project
38
- w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
39
- w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
40
-
41
- # Covisible Check
42
- h, w = depth1.shape[1:3]
43
- covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
44
- (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
45
- w_kpts0_long = w_kpts0.long()
46
- w_kpts0_long[~covisible_mask, :] = 0
47
-
48
- w_kpts0_depth = torch.stack(
49
- [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
50
- ) # (N, L)
51
- consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
52
- valid_mask = nonzero_mask * covisible_mask * consistent_mask
53
-
54
- return valid_mask, w_kpts0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/position_encoding.py DELETED
@@ -1,42 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class PositionEncodingSine(nn.Module):
7
- """
8
- This is a sinusoidal position encoding that generalized to 2-dimensional images
9
- """
10
-
11
- def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
12
- """
13
- Args:
14
- max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
15
- temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
16
- the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
17
- on the final performance. For now, we keep both impls for backward compatability.
18
- We will remove the buggy impl after re-training all variants of our released models.
19
- """
20
- super().__init__()
21
-
22
- pe = torch.zeros((d_model, *max_shape))
23
- y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
24
- x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
25
- if temp_bug_fix:
26
- div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
27
- else: # a buggy implementation (for backward compatability only)
28
- div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
29
- div_term = div_term[:, None, None] # [C//4, 1, 1]
30
- pe[0::4, :, :] = torch.sin(x_position * div_term)
31
- pe[1::4, :, :] = torch.cos(x_position * div_term)
32
- pe[2::4, :, :] = torch.sin(y_position * div_term)
33
- pe[3::4, :, :] = torch.cos(y_position * div_term)
34
-
35
- self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
36
-
37
- def forward(self, x):
38
- """
39
- Args:
40
- x: [N, C, H, W]
41
- """
42
- return x + self.pe[:, :, :x.size(2), :x.size(3)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/loftr/utils/supervision.py DELETED
@@ -1,151 +0,0 @@
1
- from math import log
2
- from loguru import logger
3
-
4
- import torch
5
- from einops import repeat
6
- from kornia.utils import create_meshgrid
7
-
8
- from .geometry import warp_kpts
9
-
10
- ############## ↓ Coarse-Level supervision ↓ ##############
11
-
12
-
13
- @torch.no_grad()
14
- def mask_pts_at_padded_regions(grid_pt, mask):
15
- """For megadepth dataset, zero-padding exists in images"""
16
- mask = repeat(mask, 'n h w -> n (h w) c', c=2)
17
- grid_pt[~mask.bool()] = 0
18
- return grid_pt
19
-
20
-
21
- @torch.no_grad()
22
- def spvs_coarse(data, config):
23
- """
24
- Update:
25
- data (dict): {
26
- "conf_matrix_gt": [N, hw0, hw1],
27
- 'spv_b_ids': [M]
28
- 'spv_i_ids': [M]
29
- 'spv_j_ids': [M]
30
- 'spv_w_pt0_i': [N, hw0, 2], in original image resolution
31
- 'spv_pt1_i': [N, hw1, 2], in original image resolution
32
- }
33
-
34
- NOTE:
35
- - for scannet dataset, there're 3 kinds of resolution {i, c, f}
36
- - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
37
- """
38
- # 1. misc
39
- device = data['image0'].device
40
- N, _, H0, W0 = data['image0'].shape
41
- _, _, H1, W1 = data['image1'].shape
42
- scale = config['LOFTR']['RESOLUTION'][0]
43
- scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
44
- scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale
45
- h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
46
-
47
- # 2. warp grids
48
- # create kpts in meshgrid and resize them to image resolution
49
- grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
50
- grid_pt0_i = scale0 * grid_pt0_c
51
- grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
52
- grid_pt1_i = scale1 * grid_pt1_c
53
-
54
- # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
55
- if 'mask0' in data:
56
- grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
57
- grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
58
-
59
- # warp kpts bi-directionally and resize them to coarse-level resolution
60
- # (no depth consistency check, since it leads to worse results experimentally)
61
- # (unhandled edge case: points with 0-depth will be warped to the left-up corner)
62
- _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
63
- _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
64
- w_pt0_c = w_pt0_i / scale1
65
- w_pt1_c = w_pt1_i / scale0
66
-
67
- # 3. check if mutual nearest neighbor
68
- w_pt0_c_round = w_pt0_c[:, :, :].round().long()
69
- nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
70
- w_pt1_c_round = w_pt1_c[:, :, :].round().long()
71
- nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
72
-
73
- # corner case: out of boundary
74
- def out_bound_mask(pt, w, h):
75
- return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
76
- nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
77
- nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
78
-
79
- loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
80
- correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
81
- correct_0to1[:, 0] = False # ignore the top-left corner
82
-
83
- # 4. construct a gt conf_matrix
84
- conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
85
- b_ids, i_ids = torch.where(correct_0to1 != 0)
86
- j_ids = nearest_index1[b_ids, i_ids]
87
-
88
- conf_matrix_gt[b_ids, i_ids, j_ids] = 1
89
- data.update({'conf_matrix_gt': conf_matrix_gt})
90
-
91
- # 5. save coarse matches(gt) for training fine level
92
- if len(b_ids) == 0:
93
- logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
94
- # this won't affect fine-level loss calculation
95
- b_ids = torch.tensor([0], device=device)
96
- i_ids = torch.tensor([0], device=device)
97
- j_ids = torch.tensor([0], device=device)
98
-
99
- data.update({
100
- 'spv_b_ids': b_ids,
101
- 'spv_i_ids': i_ids,
102
- 'spv_j_ids': j_ids
103
- })
104
-
105
- # 6. save intermediate results (for fast fine-level computation)
106
- data.update({
107
- 'spv_w_pt0_i': w_pt0_i,
108
- 'spv_pt1_i': grid_pt1_i
109
- })
110
-
111
-
112
- def compute_supervision_coarse(data, config):
113
- assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
114
- data_source = data['dataset_name'][0]
115
- if data_source.lower() in ['scannet', 'megadepth']:
116
- spvs_coarse(data, config)
117
- else:
118
- raise ValueError(f'Unknown data source: {data_source}')
119
-
120
-
121
- ############## ↓ Fine-Level supervision ↓ ##############
122
-
123
- @torch.no_grad()
124
- def spvs_fine(data, config):
125
- """
126
- Update:
127
- data (dict):{
128
- "expec_f_gt": [M, 2]}
129
- """
130
- # 1. misc
131
- # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
132
- w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
133
- scale = config['LOFTR']['RESOLUTION'][1]
134
- radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
135
-
136
- # 2. get coarse prediction
137
- b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
138
-
139
- # 3. compute gt
140
- scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
141
- # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
142
- expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
143
- data.update({"expec_f_gt": expec_f_gt})
144
-
145
-
146
- def compute_supervision_fine(data, config):
147
- data_source = data['dataset_name'][0]
148
- if data_source.lower() in ['scannet', 'megadepth']:
149
- spvs_fine(data, config)
150
- else:
151
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/pyproject.toml DELETED
@@ -1,7 +0,0 @@
1
- [project]
2
- name = "elevation_estimate"
3
- version = "0.1"
4
-
5
- [tool.setuptools.packages.find]
6
- exclude = ["configs", "tests"] # empty by default
7
- namespaces = false # true by default
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/utils/__init__.py DELETED
File without changes
One-2-3-45-master 2/elevation_estimate/utils/elev_est_api.py DELETED
@@ -1,205 +0,0 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import os.path as osp
5
- import imageio
6
- from copy import deepcopy
7
-
8
- import loguru
9
- import torch
10
- import matplotlib.cm as cm
11
- import matplotlib.pyplot as plt
12
-
13
- from ..loftr import LoFTR, default_cfg
14
- from . import plt_utils
15
- from .plotting import make_matching_figure
16
- from .utils3d import rect_to_img, canonical_to_camera, calc_pose
17
-
18
-
19
- class ElevEstHelper:
20
- _feature_matcher = None
21
-
22
- @classmethod
23
- def get_feature_matcher(cls):
24
- if cls._feature_matcher is None:
25
- loguru.logger.info("Loading feature matcher...")
26
- _default_cfg = deepcopy(default_cfg)
27
- _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt
28
- matcher = LoFTR(config=_default_cfg)
29
- current_dir = os.path.dirname(os.path.abspath(__file__))
30
- ckpt_path = os.path.join(current_dir, "weights/indoor_ds_new.ckpt")
31
- if not osp.exists(ckpt_path):
32
- loguru.logger.info("Downloading feature matcher...")
33
- os.makedirs("weights", exist_ok=True)
34
- import gdown
35
- gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS",
36
- path=ckpt_path)
37
- matcher.load_state_dict(torch.load(ckpt_path)['state_dict'])
38
- matcher = matcher.eval().cuda()
39
- cls._feature_matcher = matcher
40
- return cls._feature_matcher
41
-
42
-
43
- def mask_out_bkgd(img_path, dbg=False):
44
- img = imageio.imread_v2(img_path)
45
- if img.shape[-1] == 4:
46
- fg_mask = img[:, :, :3]
47
- else:
48
- loguru.logger.info("Image has no alpha channel, using thresholding to mask out background")
49
- fg_mask = ~(img > 245).all(axis=-1)
50
- if dbg:
51
- plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0]))
52
- plt.show()
53
- return fg_mask
54
-
55
-
56
- def get_feature_matching(img_paths, dbg=False):
57
- assert len(img_paths) == 4
58
- matcher = ElevEstHelper.get_feature_matcher()
59
- feature_matching = {}
60
- masks = []
61
- for i in range(4):
62
- mask = mask_out_bkgd(img_paths[i], dbg=dbg)
63
- masks.append(mask)
64
- for i in range(0, 4):
65
- for j in range(i + 1, 4):
66
- img0_pth = img_paths[i]
67
- img1_pth = img_paths[j]
68
- mask0 = masks[i]
69
- mask1 = masks[j]
70
- img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE)
71
- img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE)
72
- original_shape = img0_raw.shape
73
- img0_raw_resized = cv2.resize(img0_raw, (480, 480))
74
- img1_raw_resized = cv2.resize(img1_raw, (480, 480))
75
-
76
- img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255.
77
- img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255.
78
- batch = {'image0': img0, 'image1': img1}
79
-
80
- # Inference with LoFTR and get prediction
81
- with torch.no_grad():
82
- matcher(batch)
83
- mkpts0 = batch['mkpts0_f'].cpu().numpy()
84
- mkpts1 = batch['mkpts1_f'].cpu().numpy()
85
- mconf = batch['mconf'].cpu().numpy()
86
- mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480
87
- mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480
88
- mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480
89
- mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480
90
- keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)]
91
- keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)]
92
- keep = np.logical_and(keep0, keep1)
93
- mkpts0 = mkpts0[keep]
94
- mkpts1 = mkpts1[keep]
95
- mconf = mconf[keep]
96
- if dbg:
97
- # Draw visualization
98
- color = cm.jet(mconf)
99
- text = [
100
- 'LoFTR',
101
- 'Matches: {}'.format(len(mkpts0)),
102
- ]
103
- fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text)
104
- fig.show()
105
- feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1)
106
-
107
- return feature_matching
108
-
109
-
110
- def gen_pose_hypothesis(center_elevation):
111
- elevations = np.radians(
112
- [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120
113
- azimuths = np.radians([30, 30, 30, 20, 40])
114
- input_poses = calc_pose(elevations, azimuths, len(azimuths))
115
- input_poses = input_poses[1:]
116
- input_poses[..., 1] *= -1
117
- input_poses[..., 2] *= -1
118
- return input_poses
119
-
120
-
121
- def ba_error_general(K, matches, poses):
122
- projmat0 = K @ poses[0].inverse()[:3, :4]
123
- projmat1 = K @ poses[1].inverse()[:3, :4]
124
- match_01 = matches[0]
125
- pts0 = match_01[:, :2]
126
- pts1 = match_01[:, 2:4]
127
- Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(),
128
- pts0.cpu().numpy().T, pts1.cpu().numpy().T)
129
- Xref = Xref[:3] / Xref[3:]
130
- Xref = Xref.T
131
- Xref = torch.from_numpy(Xref).cuda().float()
132
- reproj_error = 0
133
- for match, cp in zip(matches[1:], poses[2:]):
134
- dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1))
135
- if dist.numel() > 0:
136
- # print("dist.shape", dist.shape)
137
- m0to2_index = dist.argmin(1)
138
- keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1
139
- if keep.sum() > 0:
140
- xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse()))
141
- reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1)
142
- conf02 = match[m0to2_index][keep][:, -1]
143
- reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum())
144
-
145
- return reproj_error
146
-
147
-
148
- def find_optim_elev(elevs, nimgs, matches, K, dbg=False):
149
- errs = []
150
- for elev in elevs:
151
- err = 0
152
- cam_poses = gen_pose_hypothesis(elev)
153
- for start in range(nimgs - 1):
154
- batch_matches, batch_poses = [], []
155
- for i in range(start, nimgs + start):
156
- ci = i % nimgs
157
- batch_poses.append(cam_poses[ci])
158
- for j in range(nimgs - 1):
159
- key = f"{start}_{(start + j + 1) % nimgs}"
160
- match = matches[key]
161
- batch_matches.append(match)
162
- err += ba_error_general(K, batch_matches, batch_poses)
163
- errs.append(err)
164
- errs = torch.tensor(errs)
165
- if dbg:
166
- plt.plot(elevs, errs)
167
- plt.show()
168
- optim_elev = elevs[torch.argmin(errs)].item()
169
- return optim_elev
170
-
171
-
172
- def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False):
173
- flag = True
174
- matches = {}
175
- for i in range(4):
176
- for j in range(i + 1, 4):
177
- match_ij = feature_matching[f"{i}_{j}"]
178
- if len(match_ij) == 0:
179
- flag = False
180
- match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1)
181
- matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda()
182
- matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda()
183
- if not flag:
184
- loguru.logger.info("0 matches, could not estimate elevation")
185
- return None
186
- interval = 10
187
- elevs = np.arange(min_elev, max_elev, interval)
188
- optim_elev1 = find_optim_elev(elevs, 4, matches, K)
189
-
190
- elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1)
191
- optim_elev2 = find_optim_elev(elevs, 4, matches, K)
192
-
193
- return optim_elev2
194
-
195
-
196
- def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False):
197
- feature_matching = get_feature_matching(img_paths, dbg=dbg)
198
- if K is None:
199
- loguru.logger.warning("K is not provided, using default K")
200
- K = np.array([[280.0, 0, 128.0],
201
- [0, 280.0, 128.0],
202
- [0, 0, 1]])
203
- K = torch.from_numpy(K).cuda().float()
204
- elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg)
205
- return elev
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/utils/plotting.py DELETED
@@ -1,154 +0,0 @@
1
- import bisect
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import matplotlib
5
-
6
-
7
- def _compute_conf_thresh(data):
8
- dataset_name = data['dataset_name'][0].lower()
9
- if dataset_name == 'scannet':
10
- thr = 5e-4
11
- elif dataset_name == 'megadepth':
12
- thr = 1e-4
13
- else:
14
- raise ValueError(f'Unknown dataset: {dataset_name}')
15
- return thr
16
-
17
-
18
- # --- VISUALIZATION --- #
19
-
20
- def make_matching_figure(
21
- img0, img1, mkpts0, mkpts1, color,
22
- kpts0=None, kpts1=None, text=[], dpi=75, path=None):
23
- # draw image pair
24
- assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
25
- fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
26
- axes[0].imshow(img0, cmap='gray')
27
- axes[1].imshow(img1, cmap='gray')
28
- for i in range(2): # clear all frames
29
- axes[i].get_yaxis().set_ticks([])
30
- axes[i].get_xaxis().set_ticks([])
31
- for spine in axes[i].spines.values():
32
- spine.set_visible(False)
33
- plt.tight_layout(pad=1)
34
-
35
- if kpts0 is not None:
36
- assert kpts1 is not None
37
- axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
38
- axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
39
-
40
- # draw matches
41
- if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
42
- fig.canvas.draw()
43
- transFigure = fig.transFigure.inverted()
44
- fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
45
- fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
46
- fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
47
- (fkpts0[i, 1], fkpts1[i, 1]),
48
- transform=fig.transFigure, c=color[i], linewidth=1)
49
- for i in range(len(mkpts0))]
50
-
51
- axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
52
- axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
53
-
54
- # put txts
55
- txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
56
- fig.text(
57
- 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
58
- fontsize=15, va='top', ha='left', color=txt_color)
59
-
60
- # save or return figure
61
- if path:
62
- plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
63
- plt.close()
64
- else:
65
- return fig
66
-
67
-
68
- def _make_evaluation_figure(data, b_id, alpha='dynamic'):
69
- b_mask = data['m_bids'] == b_id
70
- conf_thr = _compute_conf_thresh(data)
71
-
72
- img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
73
- img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
74
- kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
75
- kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
76
-
77
- # for megadepth, we visualize matches on the resized image
78
- if 'scale0' in data:
79
- kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
80
- kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
81
-
82
- epi_errs = data['epi_errs'][b_mask].cpu().numpy()
83
- correct_mask = epi_errs < conf_thr
84
- precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
85
- n_correct = np.sum(correct_mask)
86
- n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
87
- recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
88
- # recall might be larger than 1, since the calculation of conf_matrix_gt
89
- # uses groundtruth depths and camera poses, but epipolar distance is used here.
90
-
91
- # matching info
92
- if alpha == 'dynamic':
93
- alpha = dynamic_alpha(len(correct_mask))
94
- color = error_colormap(epi_errs, conf_thr, alpha=alpha)
95
-
96
- text = [
97
- f'#Matches {len(kpts0)}',
98
- f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
99
- f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
100
- ]
101
-
102
- # make the figure
103
- figure = make_matching_figure(img0, img1, kpts0, kpts1,
104
- color, text=text)
105
- return figure
106
-
107
- def _make_confidence_figure(data, b_id):
108
- # TODO: Implement confidence figure
109
- raise NotImplementedError()
110
-
111
-
112
- def make_matching_figures(data, config, mode='evaluation'):
113
- """ Make matching figures for a batch.
114
-
115
- Args:
116
- data (Dict): a batch updated by PL_LoFTR.
117
- config (Dict): matcher config
118
- Returns:
119
- figures (Dict[str, List[plt.figure]]
120
- """
121
- assert mode in ['evaluation', 'confidence'] # 'confidence'
122
- figures = {mode: []}
123
- for b_id in range(data['image0'].size(0)):
124
- if mode == 'evaluation':
125
- fig = _make_evaluation_figure(
126
- data, b_id,
127
- alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
128
- elif mode == 'confidence':
129
- fig = _make_confidence_figure(data, b_id)
130
- else:
131
- raise ValueError(f'Unknown plot mode: {mode}')
132
- figures[mode].append(fig)
133
- return figures
134
-
135
-
136
- def dynamic_alpha(n_matches,
137
- milestones=[0, 300, 1000, 2000],
138
- alphas=[1.0, 0.8, 0.4, 0.2]):
139
- if n_matches == 0:
140
- return 1.0
141
- ranges = list(zip(alphas, alphas[1:] + [None]))
142
- loc = bisect.bisect_right(milestones, n_matches) - 1
143
- _range = ranges[loc]
144
- if _range[1] is None:
145
- return _range[0]
146
- return _range[1] + (milestones[loc + 1] - n_matches) / (
147
- milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
148
-
149
-
150
- def error_colormap(err, thr, alpha=1.0):
151
- assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
152
- x = 1 - np.clip(err / (thr * 2), 0, 1)
153
- return np.clip(
154
- np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/utils/plt_utils.py DELETED
@@ -1,318 +0,0 @@
1
- import os.path as osp
2
- import os
3
- import matplotlib.pyplot as plt
4
- import torch
5
- import cv2
6
- import math
7
-
8
- import numpy as np
9
- import tqdm
10
- from cv2 import findContours
11
- from dl_ext.primitive import safe_zip
12
- from dl_ext.timer import EvalTime
13
-
14
-
15
- def plot_confidence(confidence):
16
- n = len(confidence)
17
- plt.plot(np.arange(n), confidence)
18
- plt.show()
19
-
20
-
21
- def image_grid(
22
- images,
23
- rows=None,
24
- cols=None,
25
- fill: bool = True,
26
- show_axes: bool = False,
27
- rgb=None,
28
- show=True,
29
- label=None,
30
- **kwargs
31
- ):
32
- """
33
- A util function for plotting a grid of images.
34
- Args:
35
- images: (N, H, W, 4) array of RGBA images
36
- rows: number of rows in the grid
37
- cols: number of columns in the grid
38
- fill: boolean indicating if the space between images should be filled
39
- show_axes: boolean indicating if the axes of the plots should be visible
40
- rgb: boolean, If True, only RGB channels are plotted.
41
- If False, only the alpha channel is plotted.
42
- Returns:
43
- None
44
- """
45
- evaltime = EvalTime(disable=True)
46
- evaltime('')
47
- if isinstance(images, torch.Tensor):
48
- images = images.detach().cpu()
49
- if len(images[0].shape) == 2:
50
- rgb = False
51
- if images[0].shape[-1] == 2:
52
- # flow
53
- images = [flow_to_image(im) for im in images]
54
- if (rows is None) != (cols is None):
55
- raise ValueError("Specify either both rows and cols or neither.")
56
-
57
- if rows is None:
58
- rows = int(len(images) ** 0.5)
59
- cols = math.ceil(len(images) / rows)
60
-
61
- gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
62
- if len(images) < 50:
63
- figsize = (10, 10)
64
- else:
65
- figsize = (15, 15)
66
- evaltime('0.5')
67
- plt.figure(figsize=figsize)
68
- # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize)
69
- if label:
70
- # fig.suptitle(label, fontsize=30)
71
- plt.suptitle(label, fontsize=30)
72
- # bleed = 0
73
- # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
74
- evaltime('subplots')
75
-
76
- # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))):
77
- for i in range(len(images)):
78
- # evaltime(f'{i} begin')
79
- plt.subplot(rows, cols, i + 1)
80
- if rgb:
81
- # only render RGB channels
82
- plt.imshow(images[i][..., :3], **kwargs)
83
- # ax.imshow(im[..., :3], **kwargs)
84
- else:
85
- # only render Alpha channel
86
- plt.imshow(images[i], **kwargs)
87
- # ax.imshow(im, **kwargs)
88
- if not show_axes:
89
- plt.axis('off')
90
- # ax.set_axis_off()
91
- # ax.set_title(f'{i}')
92
- plt.title(f'{i}')
93
- # evaltime(f'{i} end')
94
- evaltime('2')
95
- if show:
96
- plt.show()
97
- # return fig
98
-
99
-
100
- def depth_grid(
101
- depths,
102
- rows=None,
103
- cols=None,
104
- fill: bool = True,
105
- show_axes: bool = False,
106
- ):
107
- """
108
- A util function for plotting a grid of images.
109
- Args:
110
- images: (N, H, W, 4) array of RGBA images
111
- rows: number of rows in the grid
112
- cols: number of columns in the grid
113
- fill: boolean indicating if the space between images should be filled
114
- show_axes: boolean indicating if the axes of the plots should be visible
115
- rgb: boolean, If True, only RGB channels are plotted.
116
- If False, only the alpha channel is plotted.
117
- Returns:
118
- None
119
- """
120
- if (rows is None) != (cols is None):
121
- raise ValueError("Specify either both rows and cols or neither.")
122
-
123
- if rows is None:
124
- rows = len(depths)
125
- cols = 1
126
-
127
- gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
128
- fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
129
- bleed = 0
130
- fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
131
-
132
- for ax, im in zip(axarr.ravel(), depths):
133
- ax.imshow(im)
134
- if not show_axes:
135
- ax.set_axis_off()
136
- plt.show()
137
-
138
-
139
- def hover_masks_on_imgs(images, masks):
140
- masks = np.array(masks)
141
- new_imgs = []
142
- tids = list(range(1, masks.max() + 1))
143
- colors = colormap(rgb=True, lighten=True)
144
- for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)):
145
- for tid in tids:
146
- im = vis_mask(
147
- im,
148
- (mask == tid).astype(np.uint8),
149
- color=colors[tid],
150
- alpha=0.5,
151
- border_alpha=0.5,
152
- border_color=[255, 255, 255],
153
- border_thick=3)
154
- new_imgs.append(im)
155
- return new_imgs
156
-
157
-
158
- def vis_mask(img,
159
- mask,
160
- color=[255, 255, 255],
161
- alpha=0.4,
162
- show_border=True,
163
- border_alpha=0.5,
164
- border_thick=1,
165
- border_color=None):
166
- """Visualizes a single binary mask."""
167
- if isinstance(mask, torch.Tensor):
168
- from anypose.utils.pn_utils import to_array
169
- mask = to_array(mask > 0).astype(np.uint8)
170
- img = img.astype(np.float32)
171
- idx = np.nonzero(mask)
172
-
173
- img[idx[0], idx[1], :] *= 1.0 - alpha
174
- img[idx[0], idx[1], :] += [alpha * x for x in color]
175
-
176
- if show_border:
177
- contours, _ = findContours(
178
- mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
179
- # contours = [c for c in contours if c.shape[0] > 10]
180
- if border_color is None:
181
- border_color = color
182
- if not isinstance(border_color, list):
183
- border_color = border_color.tolist()
184
- if border_alpha < 1:
185
- with_border = img.copy()
186
- cv2.drawContours(with_border, contours, -1, border_color,
187
- border_thick, cv2.LINE_AA)
188
- img = (1 - border_alpha) * img + border_alpha * with_border
189
- else:
190
- cv2.drawContours(img, contours, -1, border_color, border_thick,
191
- cv2.LINE_AA)
192
-
193
- return img.astype(np.uint8)
194
-
195
-
196
- def colormap(rgb=False, lighten=True):
197
- """Copied from Detectron codebase."""
198
- color_list = np.array(
199
- [
200
- 0.000, 0.447, 0.741,
201
- 0.850, 0.325, 0.098,
202
- 0.929, 0.694, 0.125,
203
- 0.494, 0.184, 0.556,
204
- 0.466, 0.674, 0.188,
205
- 0.301, 0.745, 0.933,
206
- 0.635, 0.078, 0.184,
207
- 0.300, 0.300, 0.300,
208
- 0.600, 0.600, 0.600,
209
- 1.000, 0.000, 0.000,
210
- 1.000, 0.500, 0.000,
211
- 0.749, 0.749, 0.000,
212
- 0.000, 1.000, 0.000,
213
- 0.000, 0.000, 1.000,
214
- 0.667, 0.000, 1.000,
215
- 0.333, 0.333, 0.000,
216
- 0.333, 0.667, 0.000,
217
- 0.333, 1.000, 0.000,
218
- 0.667, 0.333, 0.000,
219
- 0.667, 0.667, 0.000,
220
- 0.667, 1.000, 0.000,
221
- 1.000, 0.333, 0.000,
222
- 1.000, 0.667, 0.000,
223
- 1.000, 1.000, 0.000,
224
- 0.000, 0.333, 0.500,
225
- 0.000, 0.667, 0.500,
226
- 0.000, 1.000, 0.500,
227
- 0.333, 0.000, 0.500,
228
- 0.333, 0.333, 0.500,
229
- 0.333, 0.667, 0.500,
230
- 0.333, 1.000, 0.500,
231
- 0.667, 0.000, 0.500,
232
- 0.667, 0.333, 0.500,
233
- 0.667, 0.667, 0.500,
234
- 0.667, 1.000, 0.500,
235
- 1.000, 0.000, 0.500,
236
- 1.000, 0.333, 0.500,
237
- 1.000, 0.667, 0.500,
238
- 1.000, 1.000, 0.500,
239
- 0.000, 0.333, 1.000,
240
- 0.000, 0.667, 1.000,
241
- 0.000, 1.000, 1.000,
242
- 0.333, 0.000, 1.000,
243
- 0.333, 0.333, 1.000,
244
- 0.333, 0.667, 1.000,
245
- 0.333, 1.000, 1.000,
246
- 0.667, 0.000, 1.000,
247
- 0.667, 0.333, 1.000,
248
- 0.667, 0.667, 1.000,
249
- 0.667, 1.000, 1.000,
250
- 1.000, 0.000, 1.000,
251
- 1.000, 0.333, 1.000,
252
- 1.000, 0.667, 1.000,
253
- 0.167, 0.000, 0.000,
254
- 0.333, 0.000, 0.000,
255
- 0.500, 0.000, 0.000,
256
- 0.667, 0.000, 0.000,
257
- 0.833, 0.000, 0.000,
258
- 1.000, 0.000, 0.000,
259
- 0.000, 0.167, 0.000,
260
- 0.000, 0.333, 0.000,
261
- 0.000, 0.500, 0.000,
262
- 0.000, 0.667, 0.000,
263
- 0.000, 0.833, 0.000,
264
- 0.000, 1.000, 0.000,
265
- 0.000, 0.000, 0.167,
266
- 0.000, 0.000, 0.333,
267
- 0.000, 0.000, 0.500,
268
- 0.000, 0.000, 0.667,
269
- 0.000, 0.000, 0.833,
270
- 0.000, 0.000, 1.000,
271
- 0.000, 0.000, 0.000,
272
- 0.143, 0.143, 0.143,
273
- 0.286, 0.286, 0.286,
274
- 0.429, 0.429, 0.429,
275
- 0.571, 0.571, 0.571,
276
- 0.714, 0.714, 0.714,
277
- 0.857, 0.857, 0.857,
278
- 1.000, 1.000, 1.000
279
- ]
280
- ).astype(np.float32)
281
- color_list = color_list.reshape((-1, 3))
282
- if not rgb:
283
- color_list = color_list[:, ::-1]
284
-
285
- if lighten:
286
- # Make all the colors a little lighter / whiter. This is copied
287
- # from the detectron visualization code (search for 'w_ratio').
288
- w_ratio = 0.4
289
- color_list = (color_list * (1 - w_ratio) + w_ratio)
290
- return color_list * 255
291
-
292
-
293
- def vis_layer_mask(masks, save_path=None):
294
- masks = torch.as_tensor(masks)
295
- tids = masks.unique().tolist()
296
- tids.remove(0)
297
- for tid in tqdm.tqdm(tids):
298
- show = save_path is None
299
- image_grid(masks == tid, label=f'{tid}', show=show)
300
- if save_path:
301
- os.makedirs(osp.dirname(save_path), exist_ok=True)
302
- plt.savefig(save_path % tid)
303
- plt.close('all')
304
-
305
-
306
- def show(x, **kwargs):
307
- if isinstance(x, torch.Tensor):
308
- x = x.detach().cpu()
309
- plt.imshow(x, **kwargs)
310
- plt.show()
311
-
312
-
313
- def vis_title(rgb, text, shift_y=30):
314
- tmp = rgb.copy()
315
- shift_x = rgb.shape[1] // 2
316
- cv2.putText(tmp, text,
317
- (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA)
318
- return tmp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/utils/utils3d.py DELETED
@@ -1,62 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
-
5
- def cart_to_hom(pts):
6
- """
7
- :param pts: (N, 3 or 2)
8
- :return pts_hom: (N, 4 or 3)
9
- """
10
- if isinstance(pts, np.ndarray):
11
- pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1)
12
- else:
13
- ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device)
14
- pts_hom = torch.cat((pts, ones), dim=-1)
15
- return pts_hom
16
-
17
-
18
- def hom_to_cart(pts):
19
- return pts[..., :-1] / pts[..., -1:]
20
-
21
-
22
- def canonical_to_camera(pts, pose):
23
- pts = cart_to_hom(pts)
24
- pts = pts @ pose.transpose(-1, -2)
25
- pts = hom_to_cart(pts)
26
- return pts
27
-
28
-
29
- def rect_to_img(K, pts_rect):
30
- from dl_ext.vision_ext.datasets.kitti.structures import Calibration
31
- pts_2d_hom = pts_rect @ K.t()
32
- pts_img = Calibration.hom_to_cart(pts_2d_hom)
33
- return pts_img
34
-
35
-
36
- def calc_pose(phis, thetas, size, radius=1.2):
37
- import torch
38
- def normalize(vectors):
39
- return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10)
40
-
41
- device = torch.device('cuda')
42
- thetas = torch.FloatTensor(thetas).to(device)
43
- phis = torch.FloatTensor(phis).to(device)
44
-
45
- centers = torch.stack([
46
- radius * torch.sin(thetas) * torch.sin(phis),
47
- -radius * torch.cos(thetas) * torch.sin(phis),
48
- radius * torch.cos(phis),
49
- ], dim=-1) # [B, 3]
50
-
51
- # lookat
52
- forward_vector = normalize(centers).squeeze(0)
53
- up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1)
54
- right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1))
55
- if right_vector.pow(2).sum() < 0.01:
56
- right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1)
57
- up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1))
58
-
59
- poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1)
60
- poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
61
- poses[:, :3, 3] = centers
62
- return poses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/elevation_estimate/utils/weights/.gitkeep DELETED
File without changes
One-2-3-45-master 2/example.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
One-2-3-45-master 2/ldm/data/__init__.py DELETED
File without changes
One-2-3-45-master 2/ldm/data/base.py DELETED
@@ -1,40 +0,0 @@
1
- import os
2
- import numpy as np
3
- from abc import abstractmethod
4
- from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
5
-
6
-
7
- class Txt2ImgIterableBaseDataset(IterableDataset):
8
- '''
9
- Define an interface to make the IterableDatasets for text2img data chainable
10
- '''
11
- def __init__(self, num_records=0, valid_ids=None, size=256):
12
- super().__init__()
13
- self.num_records = num_records
14
- self.valid_ids = valid_ids
15
- self.sample_ids = valid_ids
16
- self.size = size
17
-
18
- print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
19
-
20
- def __len__(self):
21
- return self.num_records
22
-
23
- @abstractmethod
24
- def __iter__(self):
25
- pass
26
-
27
-
28
- class PRNGMixin(object):
29
- """
30
- Adds a prng property which is a numpy RandomState which gets
31
- reinitialized whenever the pid changes to avoid synchronized sampling
32
- behavior when used in conjunction with multiprocessing.
33
- """
34
- @property
35
- def prng(self):
36
- currentpid = os.getpid()
37
- if getattr(self, "_initpid", None) != currentpid:
38
- self._initpid = currentpid
39
- self._prng = np.random.RandomState()
40
- return self._prng
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/coco.py DELETED
@@ -1,253 +0,0 @@
1
- import os
2
- import json
3
- import albumentations
4
- import numpy as np
5
- from PIL import Image
6
- from tqdm import tqdm
7
- from torch.utils.data import Dataset
8
- from abc import abstractmethod
9
-
10
-
11
- class CocoBase(Dataset):
12
- """needed for (image, caption, segmentation) pairs"""
13
- def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
14
- crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None):
15
- self.split = self.get_split()
16
- self.size = size
17
- if crop_size is None:
18
- self.crop_size = size
19
- else:
20
- self.crop_size = crop_size
21
-
22
- assert crop_type in [None, 'random', 'center']
23
- self.crop_type = crop_type
24
- self.use_segmenation = use_segmentation
25
- self.onehot = onehot_segmentation # return segmentation as rgb or one hot
26
- self.stuffthing = use_stuffthing # include thing in segmentation
27
- if self.onehot and not self.stuffthing:
28
- raise NotImplemented("One hot mode is only supported for the "
29
- "stuffthings version because labels are stored "
30
- "a bit different.")
31
-
32
- data_json = datajson
33
- with open(data_json) as json_file:
34
- self.json_data = json.load(json_file)
35
- self.img_id_to_captions = dict()
36
- self.img_id_to_filepath = dict()
37
- self.img_id_to_segmentation_filepath = dict()
38
-
39
- assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json",
40
- f"captions_val{self.year()}.json"]
41
- # TODO currently hardcoded paths, would be better to follow logic in
42
- # cocstuff pixelmaps
43
- if self.use_segmenation:
44
- if self.stuffthing:
45
- self.segmentation_prefix = (
46
- f"data/cocostuffthings/val{self.year()}" if
47
- data_json.endswith(f"captions_val{self.year()}.json") else
48
- f"data/cocostuffthings/train{self.year()}")
49
- else:
50
- self.segmentation_prefix = (
51
- f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if
52
- data_json.endswith(f"captions_val{self.year()}.json") else
53
- f"data/coco/annotations/stuff_train{self.year()}_pixelmaps")
54
-
55
- imagedirs = self.json_data["images"]
56
- self.labels = {"image_ids": list()}
57
- for imgdir in tqdm(imagedirs, desc="ImgToPath"):
58
- self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
59
- self.img_id_to_captions[imgdir["id"]] = list()
60
- pngfilename = imgdir["file_name"].replace("jpg", "png")
61
- if self.use_segmenation:
62
- self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
63
- self.segmentation_prefix, pngfilename)
64
- if given_files is not None:
65
- if pngfilename in given_files:
66
- self.labels["image_ids"].append(imgdir["id"])
67
- else:
68
- self.labels["image_ids"].append(imgdir["id"])
69
-
70
- capdirs = self.json_data["annotations"]
71
- for capdir in tqdm(capdirs, desc="ImgToCaptions"):
72
- # there are in average 5 captions per image
73
- #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
74
- self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"])
75
-
76
- self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
77
- if self.split=="validation":
78
- self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
79
- else:
80
- # default option for train is random crop
81
- if self.crop_type in [None, 'random']:
82
- self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
83
- else:
84
- self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
85
- self.preprocessor = albumentations.Compose(
86
- [self.rescaler, self.cropper],
87
- additional_targets={"segmentation": "image"})
88
- if force_no_crop:
89
- self.rescaler = albumentations.Resize(height=self.size, width=self.size)
90
- self.preprocessor = albumentations.Compose(
91
- [self.rescaler],
92
- additional_targets={"segmentation": "image"})
93
-
94
- @abstractmethod
95
- def year(self):
96
- raise NotImplementedError()
97
-
98
- def __len__(self):
99
- return len(self.labels["image_ids"])
100
-
101
- def preprocess_image(self, image_path, segmentation_path=None):
102
- image = Image.open(image_path)
103
- if not image.mode == "RGB":
104
- image = image.convert("RGB")
105
- image = np.array(image).astype(np.uint8)
106
- if segmentation_path:
107
- segmentation = Image.open(segmentation_path)
108
- if not self.onehot and not segmentation.mode == "RGB":
109
- segmentation = segmentation.convert("RGB")
110
- segmentation = np.array(segmentation).astype(np.uint8)
111
- if self.onehot:
112
- assert self.stuffthing
113
- # stored in caffe format: unlabeled==255. stuff and thing from
114
- # 0-181. to be compatible with the labels in
115
- # https://github.com/nightrome/cocostuff/blob/master/labels.txt
116
- # we shift stuffthing one to the right and put unlabeled in zero
117
- # as long as segmentation is uint8 shifting to right handles the
118
- # latter too
119
- assert segmentation.dtype == np.uint8
120
- segmentation = segmentation + 1
121
-
122
- processed = self.preprocessor(image=image, segmentation=segmentation)
123
-
124
- image, segmentation = processed["image"], processed["segmentation"]
125
- else:
126
- image = self.preprocessor(image=image,)['image']
127
-
128
- image = (image / 127.5 - 1.0).astype(np.float32)
129
- if segmentation_path:
130
- if self.onehot:
131
- assert segmentation.dtype == np.uint8
132
- # make it one hot
133
- n_labels = 183
134
- flatseg = np.ravel(segmentation)
135
- onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
136
- onehot[np.arange(flatseg.size), flatseg] = True
137
- onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
138
- segmentation = onehot
139
- else:
140
- segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
141
- return image, segmentation
142
- else:
143
- return image
144
-
145
- def __getitem__(self, i):
146
- img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
147
- if self.use_segmenation:
148
- seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
149
- image, segmentation = self.preprocess_image(img_path, seg_path)
150
- else:
151
- image = self.preprocess_image(img_path)
152
- captions = self.img_id_to_captions[self.labels["image_ids"][i]]
153
- # randomly draw one of all available captions per image
154
- caption = captions[np.random.randint(0, len(captions))]
155
- example = {"image": image,
156
- #"caption": [str(caption[0])],
157
- "caption": caption,
158
- "img_path": img_path,
159
- "filename_": img_path.split(os.sep)[-1]
160
- }
161
- if self.use_segmenation:
162
- example.update({"seg_path": seg_path, 'segmentation': segmentation})
163
- return example
164
-
165
-
166
- class CocoImagesAndCaptionsTrain2017(CocoBase):
167
- """returns a pair of (image, caption)"""
168
- def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,):
169
- super().__init__(size=size,
170
- dataroot="data/coco/train2017",
171
- datajson="data/coco/annotations/captions_train2017.json",
172
- onehot_segmentation=onehot_segmentation,
173
- use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
174
-
175
- def get_split(self):
176
- return "train"
177
-
178
- def year(self):
179
- return '2017'
180
-
181
-
182
- class CocoImagesAndCaptionsValidation2017(CocoBase):
183
- """returns a pair of (image, caption)"""
184
- def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
185
- given_files=None):
186
- super().__init__(size=size,
187
- dataroot="data/coco/val2017",
188
- datajson="data/coco/annotations/captions_val2017.json",
189
- onehot_segmentation=onehot_segmentation,
190
- use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
191
- given_files=given_files)
192
-
193
- def get_split(self):
194
- return "validation"
195
-
196
- def year(self):
197
- return '2017'
198
-
199
-
200
-
201
- class CocoImagesAndCaptionsTrain2014(CocoBase):
202
- """returns a pair of (image, caption)"""
203
- def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'):
204
- super().__init__(size=size,
205
- dataroot="data/coco/train2014",
206
- datajson="data/coco/annotations2014/annotations/captions_train2014.json",
207
- onehot_segmentation=onehot_segmentation,
208
- use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
209
- use_segmentation=False,
210
- crop_type=crop_type)
211
-
212
- def get_split(self):
213
- return "train"
214
-
215
- def year(self):
216
- return '2014'
217
-
218
- class CocoImagesAndCaptionsValidation2014(CocoBase):
219
- """returns a pair of (image, caption)"""
220
- def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
221
- given_files=None,crop_type='center',**kwargs):
222
- super().__init__(size=size,
223
- dataroot="data/coco/val2014",
224
- datajson="data/coco/annotations2014/annotations/captions_val2014.json",
225
- onehot_segmentation=onehot_segmentation,
226
- use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
227
- given_files=given_files,
228
- use_segmentation=False,
229
- crop_type=crop_type)
230
-
231
- def get_split(self):
232
- return "validation"
233
-
234
- def year(self):
235
- return '2014'
236
-
237
- if __name__ == '__main__':
238
- with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file:
239
- json_data = json.load(json_file)
240
- capdirs = json_data["annotations"]
241
- import pudb; pudb.set_trace()
242
- #d2 = CocoImagesAndCaptionsTrain2014(size=256)
243
- d2 = CocoImagesAndCaptionsValidation2014(size=256)
244
- print("constructed dataset.")
245
- print(f"length of {d2.__class__.__name__}: {len(d2)}")
246
-
247
- ex2 = d2[0]
248
- # ex3 = d3[0]
249
- # print(ex1["image"].shape)
250
- print(ex2["image"].shape)
251
- # print(ex3["image"].shape)
252
- # print(ex1["segmentation"].shape)
253
- print(ex2["caption"].__class__.__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/dummy.py DELETED
@@ -1,34 +0,0 @@
1
- import numpy as np
2
- import random
3
- import string
4
- from torch.utils.data import Dataset, Subset
5
-
6
- class DummyData(Dataset):
7
- def __init__(self, length, size):
8
- self.length = length
9
- self.size = size
10
-
11
- def __len__(self):
12
- return self.length
13
-
14
- def __getitem__(self, i):
15
- x = np.random.randn(*self.size)
16
- letters = string.ascii_lowercase
17
- y = ''.join(random.choice(string.ascii_lowercase) for i in range(10))
18
- return {"jpg": x, "txt": y}
19
-
20
-
21
- class DummyDataWithEmbeddings(Dataset):
22
- def __init__(self, length, size, emb_size):
23
- self.length = length
24
- self.size = size
25
- self.emb_size = emb_size
26
-
27
- def __len__(self):
28
- return self.length
29
-
30
- def __getitem__(self, i):
31
- x = np.random.randn(*self.size)
32
- y = np.random.randn(*self.emb_size).astype(np.float32)
33
- return {"jpg": x, "txt": y}
34
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/imagenet.py DELETED
@@ -1,394 +0,0 @@
1
- import os, yaml, pickle, shutil, tarfile, glob
2
- import cv2
3
- import albumentations
4
- import PIL
5
- import numpy as np
6
- import torchvision.transforms.functional as TF
7
- from omegaconf import OmegaConf
8
- from functools import partial
9
- from PIL import Image
10
- from tqdm import tqdm
11
- from torch.utils.data import Dataset, Subset
12
-
13
- import taming.data.utils as tdu
14
- from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
- from taming.data.imagenet import ImagePaths
16
-
17
- from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
-
19
-
20
- def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
- with open(path_to_yaml) as f:
22
- di2s = yaml.load(f)
23
- return dict((v,k) for k,v in di2s.items())
24
-
25
-
26
- class ImageNetBase(Dataset):
27
- def __init__(self, config=None):
28
- self.config = config or OmegaConf.create()
29
- if not type(self.config)==dict:
30
- self.config = OmegaConf.to_container(self.config)
31
- self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
- self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
- self._prepare()
34
- self._prepare_synset_to_human()
35
- self._prepare_idx_to_synset()
36
- self._prepare_human_to_integer_label()
37
- self._load()
38
-
39
- def __len__(self):
40
- return len(self.data)
41
-
42
- def __getitem__(self, i):
43
- return self.data[i]
44
-
45
- def _prepare(self):
46
- raise NotImplementedError()
47
-
48
- def _filter_relpaths(self, relpaths):
49
- ignore = set([
50
- "n06596364_9591.JPEG",
51
- ])
52
- relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
- if "sub_indices" in self.config:
54
- indices = str_to_indices(self.config["sub_indices"])
55
- synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
- self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
- files = []
58
- for rpath in relpaths:
59
- syn = rpath.split("/")[0]
60
- if syn in synsets:
61
- files.append(rpath)
62
- return files
63
- else:
64
- return relpaths
65
-
66
- def _prepare_synset_to_human(self):
67
- SIZE = 2655750
68
- URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
- self.human_dict = os.path.join(self.root, "synset_human.txt")
70
- if (not os.path.exists(self.human_dict) or
71
- not os.path.getsize(self.human_dict)==SIZE):
72
- download(URL, self.human_dict)
73
-
74
- def _prepare_idx_to_synset(self):
75
- URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
- self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
- if (not os.path.exists(self.idx2syn)):
78
- download(URL, self.idx2syn)
79
-
80
- def _prepare_human_to_integer_label(self):
81
- URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
- self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
- if (not os.path.exists(self.human2integer)):
84
- download(URL, self.human2integer)
85
- with open(self.human2integer, "r") as f:
86
- lines = f.read().splitlines()
87
- assert len(lines) == 1000
88
- self.human2integer_dict = dict()
89
- for line in lines:
90
- value, key = line.split(":")
91
- self.human2integer_dict[key] = int(value)
92
-
93
- def _load(self):
94
- with open(self.txt_filelist, "r") as f:
95
- self.relpaths = f.read().splitlines()
96
- l1 = len(self.relpaths)
97
- self.relpaths = self._filter_relpaths(self.relpaths)
98
- print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
-
100
- self.synsets = [p.split("/")[0] for p in self.relpaths]
101
- self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
-
103
- unique_synsets = np.unique(self.synsets)
104
- class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
- if not self.keep_orig_class_label:
106
- self.class_labels = [class_dict[s] for s in self.synsets]
107
- else:
108
- self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
-
110
- with open(self.human_dict, "r") as f:
111
- human_dict = f.read().splitlines()
112
- human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
-
114
- self.human_labels = [human_dict[s] for s in self.synsets]
115
-
116
- labels = {
117
- "relpath": np.array(self.relpaths),
118
- "synsets": np.array(self.synsets),
119
- "class_label": np.array(self.class_labels),
120
- "human_label": np.array(self.human_labels),
121
- }
122
-
123
- if self.process_images:
124
- self.size = retrieve(self.config, "size", default=256)
125
- self.data = ImagePaths(self.abspaths,
126
- labels=labels,
127
- size=self.size,
128
- random_crop=self.random_crop,
129
- )
130
- else:
131
- self.data = self.abspaths
132
-
133
-
134
- class ImageNetTrain(ImageNetBase):
135
- NAME = "ILSVRC2012_train"
136
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
- AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
- FILES = [
139
- "ILSVRC2012_img_train.tar",
140
- ]
141
- SIZES = [
142
- 147897477120,
143
- ]
144
-
145
- def __init__(self, process_images=True, data_root=None, **kwargs):
146
- self.process_images = process_images
147
- self.data_root = data_root
148
- super().__init__(**kwargs)
149
-
150
- def _prepare(self):
151
- if self.data_root:
152
- self.root = os.path.join(self.data_root, self.NAME)
153
- else:
154
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
-
157
- self.datadir = os.path.join(self.root, "data")
158
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
- self.expected_length = 1281167
160
- self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
- default=True)
162
- if not tdu.is_prepared(self.root):
163
- # prep
164
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
-
166
- datadir = self.datadir
167
- if not os.path.exists(datadir):
168
- path = os.path.join(self.root, self.FILES[0])
169
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
- import academictorrents as at
171
- atpath = at.get(self.AT_HASH, datastore=self.root)
172
- assert atpath == path
173
-
174
- print("Extracting {} to {}".format(path, datadir))
175
- os.makedirs(datadir, exist_ok=True)
176
- with tarfile.open(path, "r:") as tar:
177
- tar.extractall(path=datadir)
178
-
179
- print("Extracting sub-tars.")
180
- subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
- for subpath in tqdm(subpaths):
182
- subdir = subpath[:-len(".tar")]
183
- os.makedirs(subdir, exist_ok=True)
184
- with tarfile.open(subpath, "r:") as tar:
185
- tar.extractall(path=subdir)
186
-
187
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
- filelist = sorted(filelist)
190
- filelist = "\n".join(filelist)+"\n"
191
- with open(self.txt_filelist, "w") as f:
192
- f.write(filelist)
193
-
194
- tdu.mark_prepared(self.root)
195
-
196
-
197
- class ImageNetValidation(ImageNetBase):
198
- NAME = "ILSVRC2012_validation"
199
- URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
- AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
- VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
- FILES = [
203
- "ILSVRC2012_img_val.tar",
204
- "validation_synset.txt",
205
- ]
206
- SIZES = [
207
- 6744924160,
208
- 1950000,
209
- ]
210
-
211
- def __init__(self, process_images=True, data_root=None, **kwargs):
212
- self.data_root = data_root
213
- self.process_images = process_images
214
- super().__init__(**kwargs)
215
-
216
- def _prepare(self):
217
- if self.data_root:
218
- self.root = os.path.join(self.data_root, self.NAME)
219
- else:
220
- cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
- self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
- self.datadir = os.path.join(self.root, "data")
223
- self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
- self.expected_length = 50000
225
- self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
- default=False)
227
- if not tdu.is_prepared(self.root):
228
- # prep
229
- print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
-
231
- datadir = self.datadir
232
- if not os.path.exists(datadir):
233
- path = os.path.join(self.root, self.FILES[0])
234
- if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
- import academictorrents as at
236
- atpath = at.get(self.AT_HASH, datastore=self.root)
237
- assert atpath == path
238
-
239
- print("Extracting {} to {}".format(path, datadir))
240
- os.makedirs(datadir, exist_ok=True)
241
- with tarfile.open(path, "r:") as tar:
242
- tar.extractall(path=datadir)
243
-
244
- vspath = os.path.join(self.root, self.FILES[1])
245
- if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
- download(self.VS_URL, vspath)
247
-
248
- with open(vspath, "r") as f:
249
- synset_dict = f.read().splitlines()
250
- synset_dict = dict(line.split() for line in synset_dict)
251
-
252
- print("Reorganizing into synset folders")
253
- synsets = np.unique(list(synset_dict.values()))
254
- for s in synsets:
255
- os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
- for k, v in synset_dict.items():
257
- src = os.path.join(datadir, k)
258
- dst = os.path.join(datadir, v)
259
- shutil.move(src, dst)
260
-
261
- filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
- filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
- filelist = sorted(filelist)
264
- filelist = "\n".join(filelist)+"\n"
265
- with open(self.txt_filelist, "w") as f:
266
- f.write(filelist)
267
-
268
- tdu.mark_prepared(self.root)
269
-
270
-
271
-
272
- class ImageNetSR(Dataset):
273
- def __init__(self, size=None,
274
- degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
- random_crop=True):
276
- """
277
- Imagenet Superresolution Dataloader
278
- Performs following ops in order:
279
- 1. crops a crop of size s from image either as random or center crop
280
- 2. resizes crop to size with cv2.area_interpolation
281
- 3. degrades resized crop with degradation_fn
282
-
283
- :param size: resizing to size after cropping
284
- :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
- :param downscale_f: Low Resolution Downsample factor
286
- :param min_crop_f: determines crop size s,
287
- where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
- :param max_crop_f: ""
289
- :param data_root:
290
- :param random_crop:
291
- """
292
- self.base = self.get_base()
293
- assert size
294
- assert (size / downscale_f).is_integer()
295
- self.size = size
296
- self.LR_size = int(size / downscale_f)
297
- self.min_crop_f = min_crop_f
298
- self.max_crop_f = max_crop_f
299
- assert(max_crop_f <= 1.)
300
- self.center_crop = not random_crop
301
-
302
- self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
-
304
- self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
-
306
- if degradation == "bsrgan":
307
- self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
-
309
- elif degradation == "bsrgan_light":
310
- self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
-
312
- else:
313
- interpolation_fn = {
314
- "cv_nearest": cv2.INTER_NEAREST,
315
- "cv_bilinear": cv2.INTER_LINEAR,
316
- "cv_bicubic": cv2.INTER_CUBIC,
317
- "cv_area": cv2.INTER_AREA,
318
- "cv_lanczos": cv2.INTER_LANCZOS4,
319
- "pil_nearest": PIL.Image.NEAREST,
320
- "pil_bilinear": PIL.Image.BILINEAR,
321
- "pil_bicubic": PIL.Image.BICUBIC,
322
- "pil_box": PIL.Image.BOX,
323
- "pil_hamming": PIL.Image.HAMMING,
324
- "pil_lanczos": PIL.Image.LANCZOS,
325
- }[degradation]
326
-
327
- self.pil_interpolation = degradation.startswith("pil_")
328
-
329
- if self.pil_interpolation:
330
- self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
-
332
- else:
333
- self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
- interpolation=interpolation_fn)
335
-
336
- def __len__(self):
337
- return len(self.base)
338
-
339
- def __getitem__(self, i):
340
- example = self.base[i]
341
- image = Image.open(example["file_path_"])
342
-
343
- if not image.mode == "RGB":
344
- image = image.convert("RGB")
345
-
346
- image = np.array(image).astype(np.uint8)
347
-
348
- min_side_len = min(image.shape[:2])
349
- crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
- crop_side_len = int(crop_side_len)
351
-
352
- if self.center_crop:
353
- self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
-
355
- else:
356
- self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
-
358
- image = self.cropper(image=image)["image"]
359
- image = self.image_rescaler(image=image)["image"]
360
-
361
- if self.pil_interpolation:
362
- image_pil = PIL.Image.fromarray(image)
363
- LR_image = self.degradation_process(image_pil)
364
- LR_image = np.array(LR_image).astype(np.uint8)
365
-
366
- else:
367
- LR_image = self.degradation_process(image=image)["image"]
368
-
369
- example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
- example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
- example["caption"] = example["human_label"] # dummy caption
372
- return example
373
-
374
-
375
- class ImageNetSRTrain(ImageNetSR):
376
- def __init__(self, **kwargs):
377
- super().__init__(**kwargs)
378
-
379
- def get_base(self):
380
- with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
- indices = pickle.load(f)
382
- dset = ImageNetTrain(process_images=False,)
383
- return Subset(dset, indices)
384
-
385
-
386
- class ImageNetSRValidation(ImageNetSR):
387
- def __init__(self, **kwargs):
388
- super().__init__(**kwargs)
389
-
390
- def get_base(self):
391
- with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
- indices = pickle.load(f)
393
- dset = ImageNetValidation(process_images=False,)
394
- return Subset(dset, indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/inpainting/__init__.py DELETED
File without changes
One-2-3-45-master 2/ldm/data/inpainting/synthetic_mask.py DELETED
@@ -1,166 +0,0 @@
1
- from PIL import Image, ImageDraw
2
- import numpy as np
3
-
4
- settings = {
5
- "256narrow": {
6
- "p_irr": 1,
7
- "min_n_irr": 4,
8
- "max_n_irr": 50,
9
- "max_l_irr": 40,
10
- "max_w_irr": 10,
11
- "min_n_box": None,
12
- "max_n_box": None,
13
- "min_s_box": None,
14
- "max_s_box": None,
15
- "marg": None,
16
- },
17
- "256train": {
18
- "p_irr": 0.5,
19
- "min_n_irr": 1,
20
- "max_n_irr": 5,
21
- "max_l_irr": 200,
22
- "max_w_irr": 100,
23
- "min_n_box": 1,
24
- "max_n_box": 4,
25
- "min_s_box": 30,
26
- "max_s_box": 150,
27
- "marg": 10,
28
- },
29
- "512train": { # TODO: experimental
30
- "p_irr": 0.5,
31
- "min_n_irr": 1,
32
- "max_n_irr": 5,
33
- "max_l_irr": 450,
34
- "max_w_irr": 250,
35
- "min_n_box": 1,
36
- "max_n_box": 4,
37
- "min_s_box": 30,
38
- "max_s_box": 300,
39
- "marg": 10,
40
- },
41
- "512train-large": { # TODO: experimental
42
- "p_irr": 0.5,
43
- "min_n_irr": 1,
44
- "max_n_irr": 5,
45
- "max_l_irr": 450,
46
- "max_w_irr": 400,
47
- "min_n_box": 1,
48
- "max_n_box": 4,
49
- "min_s_box": 75,
50
- "max_s_box": 450,
51
- "marg": 10,
52
- },
53
- }
54
-
55
-
56
- def gen_segment_mask(mask, start, end, brush_width):
57
- mask = mask > 0
58
- mask = (255 * mask).astype(np.uint8)
59
- mask = Image.fromarray(mask)
60
- draw = ImageDraw.Draw(mask)
61
- draw.line([start, end], fill=255, width=brush_width, joint="curve")
62
- mask = np.array(mask) / 255
63
- return mask
64
-
65
-
66
- def gen_box_mask(mask, masked):
67
- x_0, y_0, w, h = masked
68
- mask[y_0:y_0 + h, x_0:x_0 + w] = 1
69
- return mask
70
-
71
-
72
- def gen_round_mask(mask, masked, radius):
73
- x_0, y_0, w, h = masked
74
- xy = [(x_0, y_0), (x_0 + w, y_0 + w)]
75
-
76
- mask = mask > 0
77
- mask = (255 * mask).astype(np.uint8)
78
- mask = Image.fromarray(mask)
79
- draw = ImageDraw.Draw(mask)
80
- draw.rounded_rectangle(xy, radius=radius, fill=255)
81
- mask = np.array(mask) / 255
82
- return mask
83
-
84
-
85
- def gen_large_mask(prng, img_h, img_w,
86
- marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr,
87
- min_n_box, max_n_box, min_s_box, max_s_box):
88
- """
89
- img_h: int, an image height
90
- img_w: int, an image width
91
- marg: int, a margin for a box starting coordinate
92
- p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask
93
-
94
- min_n_irr: int, min number of segments
95
- max_n_irr: int, max number of segments
96
- max_l_irr: max length of a segment in polygonal chain
97
- max_w_irr: max width of a segment in polygonal chain
98
-
99
- min_n_box: int, min bound for the number of box primitives
100
- max_n_box: int, max bound for the number of box primitives
101
- min_s_box: int, min length of a box side
102
- max_s_box: int, max length of a box side
103
- """
104
-
105
- mask = np.zeros((img_h, img_w))
106
- uniform = prng.randint
107
-
108
- if np.random.uniform(0, 1) < p_irr: # generate polygonal chain
109
- n = uniform(min_n_irr, max_n_irr) # sample number of segments
110
-
111
- for _ in range(n):
112
- y = uniform(0, img_h) # sample a starting point
113
- x = uniform(0, img_w)
114
-
115
- a = uniform(0, 360) # sample angle
116
- l = uniform(10, max_l_irr) # sample segment length
117
- w = uniform(5, max_w_irr) # sample a segment width
118
-
119
- # draw segment starting from (x,y) to (x_,y_) using brush of width w
120
- x_ = x + l * np.sin(a)
121
- y_ = y + l * np.cos(a)
122
-
123
- mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w)
124
- x, y = x_, y_
125
- else: # generate Box masks
126
- n = uniform(min_n_box, max_n_box) # sample number of rectangles
127
-
128
- for _ in range(n):
129
- h = uniform(min_s_box, max_s_box) # sample box shape
130
- w = uniform(min_s_box, max_s_box)
131
-
132
- x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box
133
- y_0 = uniform(marg, img_h - marg - h)
134
-
135
- if np.random.uniform(0, 1) < 0.5:
136
- mask = gen_box_mask(mask, masked=(x_0, y_0, w, h))
137
- else:
138
- r = uniform(0, 60) # sample radius
139
- mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r)
140
- return mask
141
-
142
-
143
- make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"])
144
- make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"])
145
- make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"])
146
- make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"])
147
-
148
-
149
- MASK_MODES = {
150
- "256train": make_lama_mask,
151
- "256narrow": make_narrow_lama_mask,
152
- "512train": make_512_lama_mask,
153
- "512train-large": make_512_lama_mask_large
154
- }
155
-
156
- if __name__ == "__main__":
157
- import sys
158
-
159
- out = sys.argv[1]
160
-
161
- prng = np.random.RandomState(1)
162
- kwargs = settings["256train"]
163
- mask = gen_large_mask(prng, 256, 256, **kwargs)
164
- mask = (255 * mask).astype(np.uint8)
165
- mask = Image.fromarray(mask)
166
- mask.save(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/laion.py DELETED
@@ -1,537 +0,0 @@
1
- import webdataset as wds
2
- import kornia
3
- from PIL import Image
4
- import io
5
- import os
6
- import torchvision
7
- from PIL import Image
8
- import glob
9
- import random
10
- import numpy as np
11
- import pytorch_lightning as pl
12
- from tqdm import tqdm
13
- from omegaconf import OmegaConf
14
- from einops import rearrange
15
- import torch
16
- from webdataset.handlers import warn_and_continue
17
-
18
-
19
- from ldm.util import instantiate_from_config
20
- from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES
21
- from ldm.data.base import PRNGMixin
22
-
23
-
24
- class DataWithWings(torch.utils.data.IterableDataset):
25
- def __init__(self, min_size, transform=None, target_transform=None):
26
- self.min_size = min_size
27
- self.transform = transform if transform is not None else nn.Identity()
28
- self.target_transform = target_transform if target_transform is not None else nn.Identity()
29
- self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee')
30
- self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e')
31
- self.pwatermark_threshold = 0.8
32
- self.punsafe_threshold = 0.5
33
- self.aesthetic_threshold = 5.
34
- self.total_samples = 0
35
- self.samples = 0
36
- location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -'
37
-
38
- self.inner_dataset = wds.DataPipeline(
39
- wds.ResampledShards(location),
40
- wds.tarfile_to_samples(handler=wds.warn_and_continue),
41
- wds.shuffle(1000, handler=wds.warn_and_continue),
42
- wds.decode('pilrgb', handler=wds.warn_and_continue),
43
- wds.map(self._add_tags, handler=wds.ignore_and_continue),
44
- wds.select(self._filter_predicate),
45
- wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue),
46
- wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue),
47
- )
48
-
49
- @staticmethod
50
- def _compute_hash(url, text):
51
- if url is None:
52
- url = ''
53
- if text is None:
54
- text = ''
55
- total = (url + text).encode('utf-8')
56
- return mmh3.hash64(total)[0]
57
-
58
- def _add_tags(self, x):
59
- hsh = self._compute_hash(x['json']['url'], x['txt'])
60
- pwatermark, punsafe = self.kv[hsh]
61
- aesthetic = self.kv_aesthetic[hsh][0]
62
- return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic}
63
-
64
- def _punsafe_to_class(self, punsafe):
65
- return torch.tensor(punsafe >= self.punsafe_threshold).long()
66
-
67
- def _filter_predicate(self, x):
68
- try:
69
- return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
70
- except:
71
- return False
72
-
73
- def __iter__(self):
74
- return iter(self.inner_dataset)
75
-
76
-
77
- def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
78
- """Take a list of samples (as dictionary) and create a batch, preserving the keys.
79
- If `tensors` is True, `ndarray` objects are combined into
80
- tensor batches.
81
- :param dict samples: list of samples
82
- :param bool tensors: whether to turn lists of ndarrays into a single ndarray
83
- :returns: single sample consisting of a batch
84
- :rtype: dict
85
- """
86
- keys = set.intersection(*[set(sample.keys()) for sample in samples])
87
- batched = {key: [] for key in keys}
88
-
89
- for s in samples:
90
- [batched[key].append(s[key]) for key in batched]
91
-
92
- result = {}
93
- for key in batched:
94
- if isinstance(batched[key][0], (int, float)):
95
- if combine_scalars:
96
- result[key] = np.array(list(batched[key]))
97
- elif isinstance(batched[key][0], torch.Tensor):
98
- if combine_tensors:
99
- result[key] = torch.stack(list(batched[key]))
100
- elif isinstance(batched[key][0], np.ndarray):
101
- if combine_tensors:
102
- result[key] = np.array(list(batched[key]))
103
- else:
104
- result[key] = list(batched[key])
105
- return result
106
-
107
-
108
- class WebDataModuleFromConfig(pl.LightningDataModule):
109
- def __init__(self, tar_base, batch_size, train=None, validation=None,
110
- test=None, num_workers=4, multinode=True, min_size=None,
111
- max_pwatermark=1.0,
112
- **kwargs):
113
- super().__init__(self)
114
- print(f'Setting tar base to {tar_base}')
115
- self.tar_base = tar_base
116
- self.batch_size = batch_size
117
- self.num_workers = num_workers
118
- self.train = train
119
- self.validation = validation
120
- self.test = test
121
- self.multinode = multinode
122
- self.min_size = min_size # filter out very small images
123
- self.max_pwatermark = max_pwatermark # filter out watermarked images
124
-
125
- def make_loader(self, dataset_config, train=True):
126
- if 'image_transforms' in dataset_config:
127
- image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
128
- else:
129
- image_transforms = []
130
-
131
- image_transforms.extend([torchvision.transforms.ToTensor(),
132
- torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
133
- image_transforms = torchvision.transforms.Compose(image_transforms)
134
-
135
- if 'transforms' in dataset_config:
136
- transforms_config = OmegaConf.to_container(dataset_config.transforms)
137
- else:
138
- transforms_config = dict()
139
-
140
- transform_dict = {dkey: load_partial_from_config(transforms_config[dkey])
141
- if transforms_config[dkey] != 'identity' else identity
142
- for dkey in transforms_config}
143
- img_key = dataset_config.get('image_key', 'jpeg')
144
- transform_dict.update({img_key: image_transforms})
145
-
146
- if 'postprocess' in dataset_config:
147
- postprocess = instantiate_from_config(dataset_config['postprocess'])
148
- else:
149
- postprocess = None
150
-
151
- shuffle = dataset_config.get('shuffle', 0)
152
- shardshuffle = shuffle > 0
153
-
154
- nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
155
-
156
- if self.tar_base == "__improvedaesthetic__":
157
- print("## Warning, loading the same improved aesthetic dataset "
158
- "for all splits and ignoring shards parameter.")
159
- tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
160
- else:
161
- tars = os.path.join(self.tar_base, dataset_config.shards)
162
-
163
- dset = wds.WebDataset(
164
- tars,
165
- nodesplitter=nodesplitter,
166
- shardshuffle=shardshuffle,
167
- handler=wds.warn_and_continue).repeat().shuffle(shuffle)
168
- print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
169
-
170
- dset = (dset
171
- .select(self.filter_keys)
172
- .decode('pil', handler=wds.warn_and_continue)
173
- .select(self.filter_size)
174
- .map_dict(**transform_dict, handler=wds.warn_and_continue)
175
- )
176
- if postprocess is not None:
177
- dset = dset.map(postprocess)
178
- dset = (dset
179
- .batched(self.batch_size, partial=False,
180
- collation_fn=dict_collation_fn)
181
- )
182
-
183
- loader = wds.WebLoader(dset, batch_size=None, shuffle=False,
184
- num_workers=self.num_workers)
185
-
186
- return loader
187
-
188
- def filter_size(self, x):
189
- try:
190
- valid = True
191
- if self.min_size is not None and self.min_size > 1:
192
- try:
193
- valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size
194
- except Exception:
195
- valid = False
196
- if self.max_pwatermark is not None and self.max_pwatermark < 1.0:
197
- try:
198
- valid = valid and x['json']['pwatermark'] <= self.max_pwatermark
199
- except Exception:
200
- valid = False
201
- return valid
202
- except Exception:
203
- return False
204
-
205
- def filter_keys(self, x):
206
- try:
207
- return ("jpg" in x) and ("txt" in x)
208
- except Exception:
209
- return False
210
-
211
- def train_dataloader(self):
212
- return self.make_loader(self.train)
213
-
214
- def val_dataloader(self):
215
- return self.make_loader(self.validation, train=False)
216
-
217
- def test_dataloader(self):
218
- return self.make_loader(self.test, train=False)
219
-
220
-
221
- from ldm.modules.image_degradation import degradation_fn_bsr_light
222
- import cv2
223
-
224
- class AddLR(object):
225
- def __init__(self, factor, output_size, initial_size=None, image_key="jpg"):
226
- self.factor = factor
227
- self.output_size = output_size
228
- self.image_key = image_key
229
- self.initial_size = initial_size
230
-
231
- def pt2np(self, x):
232
- x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
233
- return x
234
-
235
- def np2pt(self, x):
236
- x = torch.from_numpy(x)/127.5-1.0
237
- return x
238
-
239
- def __call__(self, sample):
240
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
241
- x = self.pt2np(sample[self.image_key])
242
- if self.initial_size is not None:
243
- x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2)
244
- x = degradation_fn_bsr_light(x, sf=self.factor)['image']
245
- x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2)
246
- x = self.np2pt(x)
247
- sample['lr'] = x
248
- return sample
249
-
250
- class AddBW(object):
251
- def __init__(self, image_key="jpg"):
252
- self.image_key = image_key
253
-
254
- def pt2np(self, x):
255
- x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy()
256
- return x
257
-
258
- def np2pt(self, x):
259
- x = torch.from_numpy(x)/127.5-1.0
260
- return x
261
-
262
- def __call__(self, sample):
263
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
264
- x = sample[self.image_key]
265
- w = torch.rand(3, device=x.device)
266
- w /= w.sum()
267
- out = torch.einsum('hwc,c->hw', x, w)
268
-
269
- # Keep as 3ch so we can pass to encoder, also we might want to add hints
270
- sample['lr'] = out.unsqueeze(-1).tile(1,1,3)
271
- return sample
272
-
273
- class AddMask(PRNGMixin):
274
- def __init__(self, mode="512train", p_drop=0.):
275
- super().__init__()
276
- assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
277
- self.make_mask = MASK_MODES[mode]
278
- self.p_drop = p_drop
279
-
280
- def __call__(self, sample):
281
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
282
- x = sample['jpg']
283
- mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
284
- if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]):
285
- mask = np.ones_like(mask)
286
- mask[mask < 0.5] = 0
287
- mask[mask > 0.5] = 1
288
- mask = torch.from_numpy(mask[..., None])
289
- sample['mask'] = mask
290
- sample['masked_image'] = x * (mask < 0.5)
291
- return sample
292
-
293
-
294
- class AddEdge(PRNGMixin):
295
- def __init__(self, mode="512train", mask_edges=True):
296
- super().__init__()
297
- assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"'
298
- self.make_mask = MASK_MODES[mode]
299
- self.n_down_choices = [0]
300
- self.sigma_choices = [1, 2]
301
- self.mask_edges = mask_edges
302
-
303
- @torch.no_grad()
304
- def __call__(self, sample):
305
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
306
- x = sample['jpg']
307
-
308
- mask = self.make_mask(self.prng, x.shape[0], x.shape[1])
309
- mask[mask < 0.5] = 0
310
- mask[mask > 0.5] = 1
311
- mask = torch.from_numpy(mask[..., None])
312
- sample['mask'] = mask
313
-
314
- n_down_idx = self.prng.choice(len(self.n_down_choices))
315
- sigma_idx = self.prng.choice(len(self.sigma_choices))
316
-
317
- n_choices = len(self.n_down_choices)*len(self.sigma_choices)
318
- raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx),
319
- (len(self.n_down_choices), len(self.sigma_choices)))
320
- normalized_idx = raveled_idx/max(1, n_choices-1)
321
-
322
- n_down = self.n_down_choices[n_down_idx]
323
- sigma = self.sigma_choices[sigma_idx]
324
-
325
- kernel_size = 4*sigma+1
326
- kernel_size = (kernel_size, kernel_size)
327
- sigma = (sigma, sigma)
328
- canny = kornia.filters.Canny(
329
- low_threshold=0.1,
330
- high_threshold=0.2,
331
- kernel_size=kernel_size,
332
- sigma=sigma,
333
- hysteresis=True,
334
- )
335
- y = (x+1.0)/2.0 # in 01
336
- y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous()
337
-
338
- # down
339
- for i_down in range(n_down):
340
- size = min(y.shape[-2], y.shape[-1])//2
341
- y = kornia.geometry.transform.resize(y, size, antialias=True)
342
-
343
- # edge
344
- _, y = canny(y)
345
-
346
- if n_down > 0:
347
- size = x.shape[0], x.shape[1]
348
- y = kornia.geometry.transform.resize(y, size, interpolation="nearest")
349
-
350
- y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous()
351
- y = y*2.0-1.0
352
-
353
- if self.mask_edges:
354
- sample['masked_image'] = y * (mask < 0.5)
355
- else:
356
- sample['masked_image'] = y
357
- sample['mask'] = torch.zeros_like(sample['mask'])
358
-
359
- # concat normalized idx
360
- sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx
361
-
362
- return sample
363
-
364
-
365
- def example00():
366
- url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -"
367
- dataset = wds.WebDataset(url)
368
- example = next(iter(dataset))
369
- for k in example:
370
- print(k, type(example[k]))
371
-
372
- print(example["__key__"])
373
- for k in ["json", "txt"]:
374
- print(example[k].decode())
375
-
376
- image = Image.open(io.BytesIO(example["jpg"]))
377
- outdir = "tmp"
378
- os.makedirs(outdir, exist_ok=True)
379
- image.save(os.path.join(outdir, example["__key__"] + ".png"))
380
-
381
-
382
- def load_example(example):
383
- return {
384
- "key": example["__key__"],
385
- "image": Image.open(io.BytesIO(example["jpg"])),
386
- "text": example["txt"].decode(),
387
- }
388
-
389
-
390
- for i, example in tqdm(enumerate(dataset)):
391
- ex = load_example(example)
392
- print(ex["image"].size, ex["text"])
393
- if i >= 100:
394
- break
395
-
396
-
397
- def example01():
398
- # the first laion shards contain ~10k examples each
399
- url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -"
400
-
401
- batch_size = 3
402
- shuffle_buffer = 10000
403
- dset = wds.WebDataset(
404
- url,
405
- nodesplitter=wds.shardlists.split_by_node,
406
- shardshuffle=True,
407
- )
408
- dset = (dset
409
- .shuffle(shuffle_buffer, initial=shuffle_buffer)
410
- .decode('pil', handler=warn_and_continue)
411
- .batched(batch_size, partial=False,
412
- collation_fn=dict_collation_fn)
413
- )
414
-
415
- num_workers = 2
416
- loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers)
417
-
418
- batch_sizes = list()
419
- keys_per_epoch = list()
420
- for epoch in range(5):
421
- keys = list()
422
- for batch in tqdm(loader):
423
- batch_sizes.append(len(batch["__key__"]))
424
- keys.append(batch["__key__"])
425
-
426
- for bs in batch_sizes:
427
- assert bs==batch_size
428
- print(f"{len(batch_sizes)} batches of size {batch_size}.")
429
- batch_sizes = list()
430
-
431
- keys_per_epoch.append(keys)
432
- for i_batch in [0, 1, -1]:
433
- print(f"Batch {i_batch} of epoch {epoch}:")
434
- print(keys[i_batch])
435
- print("next epoch.")
436
-
437
-
438
- def example02():
439
- from omegaconf import OmegaConf
440
- from torch.utils.data.distributed import DistributedSampler
441
- from torch.utils.data import IterableDataset
442
- from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
443
- from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
444
-
445
- #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml")
446
- #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml")
447
- config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml")
448
- datamod = WebDataModuleFromConfig(**config["data"]["params"])
449
- dataloader = datamod.train_dataloader()
450
-
451
- for batch in dataloader:
452
- print(batch.keys())
453
- print(batch["jpg"].shape)
454
- break
455
-
456
-
457
- def example03():
458
- # improved aesthetics
459
- tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -"
460
- dataset = wds.WebDataset(tars)
461
-
462
- def filter_keys(x):
463
- try:
464
- return ("jpg" in x) and ("txt" in x)
465
- except Exception:
466
- return False
467
-
468
- def filter_size(x):
469
- try:
470
- return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
471
- except Exception:
472
- return False
473
-
474
- def filter_watermark(x):
475
- try:
476
- return x['json']['pwatermark'] < 0.5
477
- except Exception:
478
- return False
479
-
480
- dataset = (dataset
481
- .select(filter_keys)
482
- .decode('pil', handler=wds.warn_and_continue))
483
- n_save = 20
484
- n_total = 0
485
- n_large = 0
486
- n_large_nowm = 0
487
- for i, example in enumerate(dataset):
488
- n_total += 1
489
- if filter_size(example):
490
- n_large += 1
491
- if filter_watermark(example):
492
- n_large_nowm += 1
493
- if n_large_nowm < n_save+1:
494
- image = example["jpg"]
495
- image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png"))
496
-
497
- if i%500 == 0:
498
- print(i)
499
- print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%")
500
- if n_large > 0:
501
- print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%")
502
-
503
-
504
-
505
- def example04():
506
- # improved aesthetics
507
- for i_shard in range(60208)[::-1]:
508
- print(i_shard)
509
- tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard)
510
- dataset = wds.WebDataset(tars)
511
-
512
- def filter_keys(x):
513
- try:
514
- return ("jpg" in x) and ("txt" in x)
515
- except Exception:
516
- return False
517
-
518
- def filter_size(x):
519
- try:
520
- return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512
521
- except Exception:
522
- return False
523
-
524
- dataset = (dataset
525
- .select(filter_keys)
526
- .decode('pil', handler=wds.warn_and_continue))
527
- try:
528
- example = next(iter(dataset))
529
- except Exception:
530
- print(f"Error @ {i_shard}")
531
-
532
-
533
- if __name__ == "__main__":
534
- #example01()
535
- #example02()
536
- example03()
537
- #example04()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/lsun.py DELETED
@@ -1,92 +0,0 @@
1
- import os
2
- import numpy as np
3
- import PIL
4
- from PIL import Image
5
- from torch.utils.data import Dataset
6
- from torchvision import transforms
7
-
8
-
9
- class LSUNBase(Dataset):
10
- def __init__(self,
11
- txt_file,
12
- data_root,
13
- size=None,
14
- interpolation="bicubic",
15
- flip_p=0.5
16
- ):
17
- self.data_paths = txt_file
18
- self.data_root = data_root
19
- with open(self.data_paths, "r") as f:
20
- self.image_paths = f.read().splitlines()
21
- self._length = len(self.image_paths)
22
- self.labels = {
23
- "relative_file_path_": [l for l in self.image_paths],
24
- "file_path_": [os.path.join(self.data_root, l)
25
- for l in self.image_paths],
26
- }
27
-
28
- self.size = size
29
- self.interpolation = {"linear": PIL.Image.LINEAR,
30
- "bilinear": PIL.Image.BILINEAR,
31
- "bicubic": PIL.Image.BICUBIC,
32
- "lanczos": PIL.Image.LANCZOS,
33
- }[interpolation]
34
- self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
-
36
- def __len__(self):
37
- return self._length
38
-
39
- def __getitem__(self, i):
40
- example = dict((k, self.labels[k][i]) for k in self.labels)
41
- image = Image.open(example["file_path_"])
42
- if not image.mode == "RGB":
43
- image = image.convert("RGB")
44
-
45
- # default to score-sde preprocessing
46
- img = np.array(image).astype(np.uint8)
47
- crop = min(img.shape[0], img.shape[1])
48
- h, w, = img.shape[0], img.shape[1]
49
- img = img[(h - crop) // 2:(h + crop) // 2,
50
- (w - crop) // 2:(w + crop) // 2]
51
-
52
- image = Image.fromarray(img)
53
- if self.size is not None:
54
- image = image.resize((self.size, self.size), resample=self.interpolation)
55
-
56
- image = self.flip(image)
57
- image = np.array(image).astype(np.uint8)
58
- example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
- return example
60
-
61
-
62
- class LSUNChurchesTrain(LSUNBase):
63
- def __init__(self, **kwargs):
64
- super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
-
66
-
67
- class LSUNChurchesValidation(LSUNBase):
68
- def __init__(self, flip_p=0., **kwargs):
69
- super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
- flip_p=flip_p, **kwargs)
71
-
72
-
73
- class LSUNBedroomsTrain(LSUNBase):
74
- def __init__(self, **kwargs):
75
- super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
-
77
-
78
- class LSUNBedroomsValidation(LSUNBase):
79
- def __init__(self, flip_p=0.0, **kwargs):
80
- super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
- flip_p=flip_p, **kwargs)
82
-
83
-
84
- class LSUNCatsTrain(LSUNBase):
85
- def __init__(self, **kwargs):
86
- super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
-
88
-
89
- class LSUNCatsValidation(LSUNBase):
90
- def __init__(self, flip_p=0., **kwargs):
91
- super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
- flip_p=flip_p, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/nerf_like.py DELETED
@@ -1,165 +0,0 @@
1
- from torch.utils.data import Dataset
2
- import os
3
- import json
4
- import numpy as np
5
- import torch
6
- import imageio
7
- import math
8
- import cv2
9
- from torchvision import transforms
10
-
11
- def cartesian_to_spherical(xyz):
12
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
13
- xy = xyz[:,0]**2 + xyz[:,1]**2
14
- z = np.sqrt(xy + xyz[:,2]**2)
15
- theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
16
- #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
17
- azimuth = np.arctan2(xyz[:,1], xyz[:,0])
18
- return np.array([theta, azimuth, z])
19
-
20
-
21
- def get_T(T_target, T_cond):
22
- theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
23
- theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
24
-
25
- d_theta = theta_target - theta_cond
26
- d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
27
- d_z = z_target - z_cond
28
-
29
- d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
30
- return d_T
31
-
32
- def get_spherical(T_target, T_cond):
33
- theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :])
34
- theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :])
35
-
36
- d_theta = theta_target - theta_cond
37
- d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
38
- d_z = z_target - z_cond
39
-
40
- d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()])
41
- return d_T
42
-
43
- class RTMV(Dataset):
44
- def __init__(self, root_dir='datasets/RTMV/google_scanned',\
45
- first_K=64, resolution=256, load_target=False):
46
- self.root_dir = root_dir
47
- self.scene_list = sorted(next(os.walk(root_dir))[1])
48
- self.resolution = resolution
49
- self.first_K = first_K
50
- self.load_target = load_target
51
-
52
- def __len__(self):
53
- return len(self.scene_list)
54
-
55
- def __getitem__(self, idx):
56
- scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
57
- with open(os.path.join(scene_dir, 'transforms.json'), "r") as f:
58
- meta = json.load(f)
59
- imgs = []
60
- poses = []
61
- for i_img in range(self.first_K):
62
- meta_img = meta['frames'][i_img]
63
-
64
- if i_img == 0 or self.load_target:
65
- img_path = os.path.join(scene_dir, meta_img['file_path'])
66
- img = imageio.imread(img_path)
67
- img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
68
- imgs.append(img)
69
-
70
- c2w = meta_img['transform_matrix']
71
- poses.append(c2w)
72
-
73
- imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
74
- imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
75
- imgs = imgs * 2 - 1. # convert to stable diffusion range
76
- poses = torch.tensor(np.array(poses).astype(np.float32))
77
- return imgs, poses
78
-
79
- def blend_rgba(self, img):
80
- img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
81
- return img
82
-
83
-
84
- class GSO(Dataset):
85
- def __init__(self, root_dir='datasets/GoogleScannedObjects',\
86
- split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'):
87
- self.root_dir = root_dir
88
- with open(os.path.join(root_dir, '%s.json' % split), "r") as f:
89
- self.scene_list = json.load(f)
90
- self.resolution = resolution
91
- self.first_K = first_K
92
- self.load_target = load_target
93
- self.name = name
94
-
95
- def __len__(self):
96
- return len(self.scene_list)
97
-
98
- def __getitem__(self, idx):
99
- scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
100
- with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f:
101
- meta = json.load(f)
102
- imgs = []
103
- poses = []
104
- for i_img in range(self.first_K):
105
- meta_img = meta['frames'][i_img]
106
-
107
- if i_img == 0 or self.load_target:
108
- img_path = os.path.join(scene_dir, meta_img['file_path'])
109
- img = imageio.imread(img_path)
110
- img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
111
- imgs.append(img)
112
-
113
- c2w = meta_img['transform_matrix']
114
- poses.append(c2w)
115
-
116
- imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
117
- mask = imgs[:, :, :, -1]
118
- imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
119
- imgs = imgs * 2 - 1. # convert to stable diffusion range
120
- poses = torch.tensor(np.array(poses).astype(np.float32))
121
- return imgs, poses
122
-
123
- def blend_rgba(self, img):
124
- img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
125
- return img
126
-
127
- class WILD(Dataset):
128
- def __init__(self, root_dir='data/nerf_wild',\
129
- first_K=33, resolution=256, load_target=False):
130
- self.root_dir = root_dir
131
- self.scene_list = sorted(next(os.walk(root_dir))[1])
132
- self.resolution = resolution
133
- self.first_K = first_K
134
- self.load_target = load_target
135
-
136
- def __len__(self):
137
- return len(self.scene_list)
138
-
139
- def __getitem__(self, idx):
140
- scene_dir = os.path.join(self.root_dir, self.scene_list[idx])
141
- with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f:
142
- meta = json.load(f)
143
- imgs = []
144
- poses = []
145
- for i_img in range(self.first_K):
146
- meta_img = meta['frames'][i_img]
147
-
148
- if i_img == 0 or self.load_target:
149
- img_path = os.path.join(scene_dir, meta_img['file_path'])
150
- img = imageio.imread(img_path + '.png')
151
- img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR)
152
- imgs.append(img)
153
-
154
- c2w = meta_img['transform_matrix']
155
- poses.append(c2w)
156
-
157
- imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs
158
- imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2)
159
- imgs = imgs * 2 - 1. # convert to stable diffusion range
160
- poses = torch.tensor(np.array(poses).astype(np.float32))
161
- return imgs, poses
162
-
163
- def blend_rgba(self, img):
164
- img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB
165
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/data/simple.py DELETED
@@ -1,526 +0,0 @@
1
- from typing import Dict
2
- import webdataset as wds
3
- import numpy as np
4
- from omegaconf import DictConfig, ListConfig
5
- import torch
6
- from torch.utils.data import Dataset
7
- from pathlib import Path
8
- import json
9
- from PIL import Image
10
- from torchvision import transforms
11
- import torchvision
12
- from einops import rearrange
13
- from ldm.util import instantiate_from_config
14
- from datasets import load_dataset
15
- import pytorch_lightning as pl
16
- import copy
17
- import csv
18
- import cv2
19
- import random
20
- import matplotlib.pyplot as plt
21
- from torch.utils.data import DataLoader
22
- import json
23
- import os, sys
24
- import webdataset as wds
25
- import math
26
- from torch.utils.data.distributed import DistributedSampler
27
-
28
- # Some hacky things to make experimentation easier
29
- def make_transform_multi_folder_data(paths, caption_files=None, **kwargs):
30
- ds = make_multi_folder_data(paths, caption_files, **kwargs)
31
- return TransformDataset(ds)
32
-
33
- def make_nfp_data(base_path):
34
- dirs = list(Path(base_path).glob("*/"))
35
- print(f"Found {len(dirs)} folders")
36
- print(dirs)
37
- tforms = [transforms.Resize(512), transforms.CenterCrop(512)]
38
- datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs]
39
- return torch.utils.data.ConcatDataset(datasets)
40
-
41
-
42
- class VideoDataset(Dataset):
43
- def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2):
44
- self.root_dir = Path(root_dir)
45
- self.caption_file = caption_file
46
- self.n = n
47
- ext = "mp4"
48
- self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
49
- self.offset = offset
50
-
51
- if isinstance(image_transforms, ListConfig):
52
- image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
53
- image_transforms.extend([transforms.ToTensor(),
54
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
55
- image_transforms = transforms.Compose(image_transforms)
56
- self.tform = image_transforms
57
- with open(self.caption_file) as f:
58
- reader = csv.reader(f)
59
- rows = [row for row in reader]
60
- self.captions = dict(rows)
61
-
62
- def __len__(self):
63
- return len(self.paths)
64
-
65
- def __getitem__(self, index):
66
- for i in range(10):
67
- try:
68
- return self._load_sample(index)
69
- except Exception:
70
- # Not really good enough but...
71
- print("uh oh")
72
-
73
- def _load_sample(self, index):
74
- n = self.n
75
- filename = self.paths[index]
76
- min_frame = 2*self.offset + 2
77
- vid = cv2.VideoCapture(str(filename))
78
- max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
79
- curr_frame_n = random.randint(min_frame, max_frames)
80
- vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n)
81
- _, curr_frame = vid.read()
82
-
83
- prev_frames = []
84
- for i in range(n):
85
- prev_frame_n = curr_frame_n - (i+1)*self.offset
86
- vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n)
87
- _, prev_frame = vid.read()
88
- prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1]))
89
- prev_frames.append(prev_frame)
90
-
91
- vid.release()
92
- caption = self.captions[filename.name]
93
- data = {
94
- "image": self.tform(Image.fromarray(curr_frame[...,::-1])),
95
- "prev": torch.cat(prev_frames, dim=-1),
96
- "txt": caption
97
- }
98
- return data
99
-
100
- # end hacky things
101
-
102
-
103
- def make_tranforms(image_transforms):
104
- # if isinstance(image_transforms, ListConfig):
105
- # image_transforms = [instantiate_from_config(tt) for tt in image_transforms]
106
- image_transforms = []
107
- image_transforms.extend([transforms.ToTensor(),
108
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
109
- image_transforms = transforms.Compose(image_transforms)
110
- return image_transforms
111
-
112
-
113
- def make_multi_folder_data(paths, caption_files=None, **kwargs):
114
- """Make a concat dataset from multiple folders
115
- Don't suport captions yet
116
-
117
- If paths is a list, that's ok, if it's a Dict interpret it as:
118
- k=folder v=n_times to repeat that
119
- """
120
- list_of_paths = []
121
- if isinstance(paths, (Dict, DictConfig)):
122
- assert caption_files is None, \
123
- "Caption files not yet supported for repeats"
124
- for folder_path, repeats in paths.items():
125
- list_of_paths.extend([folder_path]*repeats)
126
- paths = list_of_paths
127
-
128
- if caption_files is not None:
129
- datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)]
130
- else:
131
- datasets = [FolderData(p, **kwargs) for p in paths]
132
- return torch.utils.data.ConcatDataset(datasets)
133
-
134
-
135
-
136
- class NfpDataset(Dataset):
137
- def __init__(self,
138
- root_dir,
139
- image_transforms=[],
140
- ext="jpg",
141
- default_caption="",
142
- ) -> None:
143
- """assume sequential frames and a deterministic transform"""
144
-
145
- self.root_dir = Path(root_dir)
146
- self.default_caption = default_caption
147
-
148
- self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}")))
149
- self.tform = make_tranforms(image_transforms)
150
-
151
- def __len__(self):
152
- return len(self.paths) - 1
153
-
154
-
155
- def __getitem__(self, index):
156
- prev = self.paths[index]
157
- curr = self.paths[index+1]
158
- data = {}
159
- data["image"] = self._load_im(curr)
160
- data["prev"] = self._load_im(prev)
161
- data["txt"] = self.default_caption
162
- return data
163
-
164
- def _load_im(self, filename):
165
- im = Image.open(filename).convert("RGB")
166
- return self.tform(im)
167
-
168
- class ObjaverseDataModuleFromConfig(pl.LightningDataModule):
169
- def __init__(self, root_dir, batch_size, total_view, train=None, validation=None,
170
- test=None, num_workers=4, **kwargs):
171
- super().__init__(self)
172
- self.root_dir = root_dir
173
- self.batch_size = batch_size
174
- self.num_workers = num_workers
175
- self.total_view = total_view
176
-
177
- if train is not None:
178
- dataset_config = train
179
- if validation is not None:
180
- dataset_config = validation
181
-
182
- if 'image_transforms' in dataset_config:
183
- image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)]
184
- else:
185
- image_transforms = []
186
- image_transforms.extend([transforms.ToTensor(),
187
- transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))])
188
- self.image_transforms = torchvision.transforms.Compose(image_transforms)
189
-
190
-
191
- def train_dataloader(self):
192
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \
193
- image_transforms=self.image_transforms)
194
- sampler = DistributedSampler(dataset)
195
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
196
-
197
- def val_dataloader(self):
198
- dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \
199
- image_transforms=self.image_transforms)
200
- sampler = DistributedSampler(dataset)
201
- return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
202
-
203
- def test_dataloader(self):
204
- return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\
205
- batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
206
-
207
-
208
- class ObjaverseData(Dataset):
209
- def __init__(self,
210
- root_dir='.objaverse/hf-objaverse-v1/views',
211
- image_transforms=[],
212
- ext="png",
213
- default_trans=torch.zeros(3),
214
- postprocess=None,
215
- return_paths=False,
216
- total_view=4,
217
- validation=False
218
- ) -> None:
219
- """Create a dataset from a folder of images.
220
- If you pass in a root directory it will be searched for images
221
- ending in ext (ext can be a list)
222
- """
223
- self.root_dir = Path(root_dir)
224
- self.default_trans = default_trans
225
- self.return_paths = return_paths
226
- if isinstance(postprocess, DictConfig):
227
- postprocess = instantiate_from_config(postprocess)
228
- self.postprocess = postprocess
229
- self.total_view = total_view
230
-
231
- if not isinstance(ext, (tuple, list, ListConfig)):
232
- ext = [ext]
233
-
234
- with open(os.path.join(root_dir, 'valid_paths.json')) as f:
235
- self.paths = json.load(f)
236
-
237
- total_objects = len(self.paths)
238
- if validation:
239
- self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation
240
- else:
241
- self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
242
- print('============= length of dataset %d =============' % len(self.paths))
243
- self.tform = image_transforms
244
-
245
- def __len__(self):
246
- return len(self.paths)
247
-
248
- def cartesian_to_spherical(self, xyz):
249
- ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
250
- xy = xyz[:,0]**2 + xyz[:,1]**2
251
- z = np.sqrt(xy + xyz[:,2]**2)
252
- theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
253
- #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
254
- azimuth = np.arctan2(xyz[:,1], xyz[:,0])
255
- return np.array([theta, azimuth, z])
256
-
257
- def get_T(self, target_RT, cond_RT):
258
- R, T = target_RT[:3, :3], target_RT[:, -1]
259
- T_target = -R.T @ T
260
-
261
- R, T = cond_RT[:3, :3], cond_RT[:, -1]
262
- T_cond = -R.T @ T
263
-
264
- theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
265
- theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
266
-
267
- d_theta = theta_target - theta_cond
268
- d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
269
- d_z = z_target - z_cond
270
-
271
- d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
272
- return d_T
273
-
274
- def load_im(self, path, color):
275
- '''
276
- replace background pixel with random color in rendering
277
- '''
278
- try:
279
- img = plt.imread(path)
280
- except:
281
- print(path)
282
- sys.exit()
283
- img[img[:, :, -1] == 0.] = color
284
- img = Image.fromarray(np.uint8(img[:, :, :3] * 255.))
285
- return img
286
-
287
- def __getitem__(self, index):
288
-
289
- data = {}
290
- if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice
291
- total_view = 8
292
- else:
293
- total_view = 4
294
- index_target, index_cond = random.sample(range(total_view), 2) # without replacement
295
- filename = os.path.join(self.root_dir, self.paths[index])
296
-
297
- # print(self.paths[index])
298
-
299
- if self.return_paths:
300
- data["path"] = str(filename)
301
-
302
- color = [1., 1., 1., 1.]
303
-
304
- try:
305
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
306
- cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
307
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
308
- cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
309
- except:
310
- # very hacky solution, sorry about this
311
- filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid
312
- target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color))
313
- cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color))
314
- target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target))
315
- cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond))
316
- target_im = torch.zeros_like(target_im)
317
- cond_im = torch.zeros_like(cond_im)
318
-
319
- data["image_target"] = target_im
320
- data["image_cond"] = cond_im
321
- data["T"] = self.get_T(target_RT, cond_RT)
322
-
323
- if self.postprocess is not None:
324
- data = self.postprocess(data)
325
-
326
- return data
327
-
328
- def process_im(self, im):
329
- im = im.convert("RGB")
330
- return self.tform(im)
331
-
332
- class FolderData(Dataset):
333
- def __init__(self,
334
- root_dir,
335
- caption_file=None,
336
- image_transforms=[],
337
- ext="jpg",
338
- default_caption="",
339
- postprocess=None,
340
- return_paths=False,
341
- ) -> None:
342
- """Create a dataset from a folder of images.
343
- If you pass in a root directory it will be searched for images
344
- ending in ext (ext can be a list)
345
- """
346
- self.root_dir = Path(root_dir)
347
- self.default_caption = default_caption
348
- self.return_paths = return_paths
349
- if isinstance(postprocess, DictConfig):
350
- postprocess = instantiate_from_config(postprocess)
351
- self.postprocess = postprocess
352
- if caption_file is not None:
353
- with open(caption_file, "rt") as f:
354
- ext = Path(caption_file).suffix.lower()
355
- if ext == ".json":
356
- captions = json.load(f)
357
- elif ext == ".jsonl":
358
- lines = f.readlines()
359
- lines = [json.loads(x) for x in lines]
360
- captions = {x["file_name"]: x["text"].strip("\n") for x in lines}
361
- else:
362
- raise ValueError(f"Unrecognised format: {ext}")
363
- self.captions = captions
364
- else:
365
- self.captions = None
366
-
367
- if not isinstance(ext, (tuple, list, ListConfig)):
368
- ext = [ext]
369
-
370
- # Only used if there is no caption file
371
- self.paths = []
372
- for e in ext:
373
- self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}"))))
374
- self.tform = make_tranforms(image_transforms)
375
-
376
- def __len__(self):
377
- if self.captions is not None:
378
- return len(self.captions.keys())
379
- else:
380
- return len(self.paths)
381
-
382
- def __getitem__(self, index):
383
- data = {}
384
- if self.captions is not None:
385
- chosen = list(self.captions.keys())[index]
386
- caption = self.captions.get(chosen, None)
387
- if caption is None:
388
- caption = self.default_caption
389
- filename = self.root_dir/chosen
390
- else:
391
- filename = self.paths[index]
392
-
393
- if self.return_paths:
394
- data["path"] = str(filename)
395
-
396
- im = Image.open(filename).convert("RGB")
397
- im = self.process_im(im)
398
- data["image"] = im
399
-
400
- if self.captions is not None:
401
- data["txt"] = caption
402
- else:
403
- data["txt"] = self.default_caption
404
-
405
- if self.postprocess is not None:
406
- data = self.postprocess(data)
407
-
408
- return data
409
-
410
- def process_im(self, im):
411
- im = im.convert("RGB")
412
- return self.tform(im)
413
- import random
414
-
415
- class TransformDataset():
416
- def __init__(self, ds, extra_label="sksbspic"):
417
- self.ds = ds
418
- self.extra_label = extra_label
419
- self.transforms = {
420
- "align": transforms.Resize(768),
421
- "centerzoom": transforms.CenterCrop(768),
422
- "randzoom": transforms.RandomCrop(768),
423
- }
424
-
425
-
426
- def __getitem__(self, index):
427
- data = self.ds[index]
428
-
429
- im = data['image']
430
- im = im.permute(2,0,1)
431
- # In case data is smaller than expected
432
- im = transforms.Resize(1024)(im)
433
-
434
- tform_name = random.choice(list(self.transforms.keys()))
435
- im = self.transforms[tform_name](im)
436
-
437
- im = im.permute(1,2,0)
438
-
439
- data['image'] = im
440
- data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}"
441
-
442
- return data
443
-
444
- def __len__(self):
445
- return len(self.ds)
446
-
447
- def hf_dataset(
448
- name,
449
- image_transforms=[],
450
- image_column="image",
451
- text_column="text",
452
- split='train',
453
- image_key='image',
454
- caption_key='txt',
455
- ):
456
- """Make huggingface dataset with appropriate list of transforms applied
457
- """
458
- ds = load_dataset(name, split=split)
459
- tform = make_tranforms(image_transforms)
460
-
461
- assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}"
462
- assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}"
463
-
464
- def pre_process(examples):
465
- processed = {}
466
- processed[image_key] = [tform(im) for im in examples[image_column]]
467
- processed[caption_key] = examples[text_column]
468
- return processed
469
-
470
- ds.set_transform(pre_process)
471
- return ds
472
-
473
- class TextOnly(Dataset):
474
- def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1):
475
- """Returns only captions with dummy images"""
476
- self.output_size = output_size
477
- self.image_key = image_key
478
- self.caption_key = caption_key
479
- if isinstance(captions, Path):
480
- self.captions = self._load_caption_file(captions)
481
- else:
482
- self.captions = captions
483
-
484
- if n_gpus > 1:
485
- # hack to make sure that all the captions appear on each gpu
486
- repeated = [n_gpus*[x] for x in self.captions]
487
- self.captions = []
488
- [self.captions.extend(x) for x in repeated]
489
-
490
- def __len__(self):
491
- return len(self.captions)
492
-
493
- def __getitem__(self, index):
494
- dummy_im = torch.zeros(3, self.output_size, self.output_size)
495
- dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c')
496
- return {self.image_key: dummy_im, self.caption_key: self.captions[index]}
497
-
498
- def _load_caption_file(self, filename):
499
- with open(filename, 'rt') as f:
500
- captions = f.readlines()
501
- return [x.strip('\n') for x in captions]
502
-
503
-
504
-
505
- import random
506
- import json
507
- class IdRetreivalDataset(FolderData):
508
- def __init__(self, ret_file, *args, **kwargs):
509
- super().__init__(*args, **kwargs)
510
- with open(ret_file, "rt") as f:
511
- self.ret = json.load(f)
512
-
513
- def __getitem__(self, index):
514
- data = super().__getitem__(index)
515
- key = self.paths[index].name
516
- matches = self.ret[key]
517
- if len(matches) > 0:
518
- retreived = random.choice(matches)
519
- else:
520
- retreived = key
521
- filename = self.root_dir/retreived
522
- im = Image.open(filename).convert("RGB")
523
- im = self.process_im(im)
524
- # data["match"] = im
525
- data["match"] = torch.cat((data["image"], im), dim=-1)
526
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/extras.py DELETED
@@ -1,77 +0,0 @@
1
- from pathlib import Path
2
- from omegaconf import OmegaConf
3
- import torch
4
- from ldm.util import instantiate_from_config
5
- import logging
6
- from contextlib import contextmanager
7
-
8
- from contextlib import contextmanager
9
- import logging
10
-
11
- @contextmanager
12
- def all_logging_disabled(highest_level=logging.CRITICAL):
13
- """
14
- A context manager that will prevent any logging messages
15
- triggered during the body from being processed.
16
-
17
- :param highest_level: the maximum logging level in use.
18
- This would only need to be changed if a custom level greater than CRITICAL
19
- is defined.
20
-
21
- https://gist.github.com/simon-weber/7853144
22
- """
23
- # two kind-of hacks here:
24
- # * can't get the highest logging level in effect => delegate to the user
25
- # * can't get the current module-level override => use an undocumented
26
- # (but non-private!) interface
27
-
28
- previous_level = logging.root.manager.disable
29
-
30
- logging.disable(highest_level)
31
-
32
- try:
33
- yield
34
- finally:
35
- logging.disable(previous_level)
36
-
37
- def load_training_dir(train_dir, device, epoch="last"):
38
- """Load a checkpoint and config from training directory"""
39
- train_dir = Path(train_dir)
40
- ckpt = list(train_dir.rglob(f"*{epoch}.ckpt"))
41
- assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files"
42
- config = list(train_dir.rglob(f"*-project.yaml"))
43
- assert len(ckpt) > 0, f"didn't find any config in {train_dir}"
44
- if len(config) > 1:
45
- print(f"found {len(config)} matching config files")
46
- config = sorted(config)[-1]
47
- print(f"selecting {config}")
48
- else:
49
- config = config[0]
50
-
51
-
52
- config = OmegaConf.load(config)
53
- return load_model_from_config(config, ckpt[0], device)
54
-
55
- def load_model_from_config(config, ckpt, device="cpu", verbose=False):
56
- """Loads a model from config and a ckpt
57
- if config is a path will use omegaconf to load
58
- """
59
- if isinstance(config, (str, Path)):
60
- config = OmegaConf.load(config)
61
-
62
- with all_logging_disabled():
63
- print(f"Loading model from {ckpt}")
64
- pl_sd = torch.load(ckpt, map_location="cpu")
65
- global_step = pl_sd["global_step"]
66
- sd = pl_sd["state_dict"]
67
- model = instantiate_from_config(config.model)
68
- m, u = model.load_state_dict(sd, strict=False)
69
- if len(m) > 0 and verbose:
70
- print("missing keys:")
71
- print(m)
72
- if len(u) > 0 and verbose:
73
- print("unexpected keys:")
74
- model.to(device)
75
- model.eval()
76
- model.cond_stage_model.device = device
77
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/guidance.py DELETED
@@ -1,96 +0,0 @@
1
- from typing import List, Tuple
2
- from scipy import interpolate
3
- import numpy as np
4
- import torch
5
- import matplotlib.pyplot as plt
6
- from IPython.display import clear_output
7
- import abc
8
-
9
-
10
- class GuideModel(torch.nn.Module, abc.ABC):
11
- def __init__(self) -> None:
12
- super().__init__()
13
-
14
- @abc.abstractmethod
15
- def preprocess(self, x_img):
16
- pass
17
-
18
- @abc.abstractmethod
19
- def compute_loss(self, inp):
20
- pass
21
-
22
-
23
- class Guider(torch.nn.Module):
24
- def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
25
- """Apply classifier guidance
26
-
27
- Specify a guidance scale as either a scalar
28
- Or a schedule as a list of tuples t = 0->1 and scale, e.g.
29
- [(0, 10), (0.5, 20), (1, 50)]
30
- """
31
- super().__init__()
32
- self.sampler = sampler
33
- self.index = 0
34
- self.show = verbose
35
- self.guide_model = guide_model
36
- self.history = []
37
-
38
- if isinstance(scale, (Tuple, List)):
39
- times = np.array([x[0] for x in scale])
40
- values = np.array([x[1] for x in scale])
41
- self.scale_schedule = {"times": times, "values": values}
42
- else:
43
- self.scale_schedule = float(scale)
44
-
45
- self.ddim_timesteps = sampler.ddim_timesteps
46
- self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
47
-
48
-
49
- def get_scales(self):
50
- if isinstance(self.scale_schedule, float):
51
- return len(self.ddim_timesteps)*[self.scale_schedule]
52
-
53
- interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
54
- fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
55
- return interpolater(fractional_steps)
56
-
57
- def modify_score(self, model, e_t, x, t, c):
58
-
59
- # TODO look up index by t
60
- scale = self.get_scales()[self.index]
61
-
62
- if (scale == 0):
63
- return e_t
64
-
65
- sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
66
- with torch.enable_grad():
67
- x_in = x.detach().requires_grad_(True)
68
- pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
69
- x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
70
-
71
- inp = self.guide_model.preprocess(x_img)
72
- loss = self.guide_model.compute_loss(inp)
73
- grads = torch.autograd.grad(loss.sum(), x_in)[0]
74
- correction = grads * scale
75
-
76
- if self.show:
77
- clear_output(wait=True)
78
- print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
79
- self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
80
- plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
81
- plt.axis('off')
82
- plt.show()
83
- plt.imshow(correction[0][0].detach().cpu())
84
- plt.axis('off')
85
- plt.show()
86
-
87
-
88
- e_t_mod = e_t - sqrt_1ma*correction
89
- if self.show:
90
- fig, axs = plt.subplots(1, 3)
91
- axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
92
- axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
93
- axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
94
- plt.show()
95
- self.index += 1
96
- return e_t_mod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/lr_scheduler.py DELETED
@@ -1,98 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- class LambdaWarmUpCosineScheduler:
5
- """
6
- note: use with a base_lr of 1.0
7
- """
8
- def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
- self.lr_warm_up_steps = warm_up_steps
10
- self.lr_start = lr_start
11
- self.lr_min = lr_min
12
- self.lr_max = lr_max
13
- self.lr_max_decay_steps = max_decay_steps
14
- self.last_lr = 0.
15
- self.verbosity_interval = verbosity_interval
16
-
17
- def schedule(self, n, **kwargs):
18
- if self.verbosity_interval > 0:
19
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
- if n < self.lr_warm_up_steps:
21
- lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
- self.last_lr = lr
23
- return lr
24
- else:
25
- t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
- t = min(t, 1.0)
27
- lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
- 1 + np.cos(t * np.pi))
29
- self.last_lr = lr
30
- return lr
31
-
32
- def __call__(self, n, **kwargs):
33
- return self.schedule(n,**kwargs)
34
-
35
-
36
- class LambdaWarmUpCosineScheduler2:
37
- """
38
- supports repeated iterations, configurable via lists
39
- note: use with a base_lr of 1.0.
40
- """
41
- def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
- assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
- self.lr_warm_up_steps = warm_up_steps
44
- self.f_start = f_start
45
- self.f_min = f_min
46
- self.f_max = f_max
47
- self.cycle_lengths = cycle_lengths
48
- self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
- self.last_f = 0.
50
- self.verbosity_interval = verbosity_interval
51
-
52
- def find_in_interval(self, n):
53
- interval = 0
54
- for cl in self.cum_cycles[1:]:
55
- if n <= cl:
56
- return interval
57
- interval += 1
58
-
59
- def schedule(self, n, **kwargs):
60
- cycle = self.find_in_interval(n)
61
- n = n - self.cum_cycles[cycle]
62
- if self.verbosity_interval > 0:
63
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
- f"current cycle {cycle}")
65
- if n < self.lr_warm_up_steps[cycle]:
66
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
- self.last_f = f
68
- return f
69
- else:
70
- t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
- t = min(t, 1.0)
72
- f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
- 1 + np.cos(t * np.pi))
74
- self.last_f = f
75
- return f
76
-
77
- def __call__(self, n, **kwargs):
78
- return self.schedule(n, **kwargs)
79
-
80
-
81
- class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
-
83
- def schedule(self, n, **kwargs):
84
- cycle = self.find_in_interval(n)
85
- n = n - self.cum_cycles[cycle]
86
- if self.verbosity_interval > 0:
87
- if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
- f"current cycle {cycle}")
89
-
90
- if n < self.lr_warm_up_steps[cycle]:
91
- f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
- self.last_f = f
93
- return f
94
- else:
95
- f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
- self.last_f = f
97
- return f
98
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/models/autoencoder.py DELETED
@@ -1,443 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import torch.nn.functional as F
4
- from contextlib import contextmanager
5
-
6
- from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
-
8
- from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
-
11
- from ldm.util import instantiate_from_config
12
-
13
-
14
- class VQModel(pl.LightningModule):
15
- def __init__(self,
16
- ddconfig,
17
- lossconfig,
18
- n_embed,
19
- embed_dim,
20
- ckpt_path=None,
21
- ignore_keys=[],
22
- image_key="image",
23
- colorize_nlabels=None,
24
- monitor=None,
25
- batch_resize_range=None,
26
- scheduler_config=None,
27
- lr_g_factor=1.0,
28
- remap=None,
29
- sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
- use_ema=False
31
- ):
32
- super().__init__()
33
- self.embed_dim = embed_dim
34
- self.n_embed = n_embed
35
- self.image_key = image_key
36
- self.encoder = Encoder(**ddconfig)
37
- self.decoder = Decoder(**ddconfig)
38
- self.loss = instantiate_from_config(lossconfig)
39
- self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
- remap=remap,
41
- sane_index_shape=sane_index_shape)
42
- self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
- if colorize_nlabels is not None:
45
- assert type(colorize_nlabels)==int
46
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
- if monitor is not None:
48
- self.monitor = monitor
49
- self.batch_resize_range = batch_resize_range
50
- if self.batch_resize_range is not None:
51
- print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
-
53
- self.use_ema = use_ema
54
- if self.use_ema:
55
- self.model_ema = LitEma(self)
56
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
-
58
- if ckpt_path is not None:
59
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
- self.scheduler_config = scheduler_config
61
- self.lr_g_factor = lr_g_factor
62
-
63
- @contextmanager
64
- def ema_scope(self, context=None):
65
- if self.use_ema:
66
- self.model_ema.store(self.parameters())
67
- self.model_ema.copy_to(self)
68
- if context is not None:
69
- print(f"{context}: Switched to EMA weights")
70
- try:
71
- yield None
72
- finally:
73
- if self.use_ema:
74
- self.model_ema.restore(self.parameters())
75
- if context is not None:
76
- print(f"{context}: Restored training weights")
77
-
78
- def init_from_ckpt(self, path, ignore_keys=list()):
79
- sd = torch.load(path, map_location="cpu")["state_dict"]
80
- keys = list(sd.keys())
81
- for k in keys:
82
- for ik in ignore_keys:
83
- if k.startswith(ik):
84
- print("Deleting key {} from state_dict.".format(k))
85
- del sd[k]
86
- missing, unexpected = self.load_state_dict(sd, strict=False)
87
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
- if len(missing) > 0:
89
- print(f"Missing Keys: {missing}")
90
- print(f"Unexpected Keys: {unexpected}")
91
-
92
- def on_train_batch_end(self, *args, **kwargs):
93
- if self.use_ema:
94
- self.model_ema(self)
95
-
96
- def encode(self, x):
97
- h = self.encoder(x)
98
- h = self.quant_conv(h)
99
- quant, emb_loss, info = self.quantize(h)
100
- return quant, emb_loss, info
101
-
102
- def encode_to_prequant(self, x):
103
- h = self.encoder(x)
104
- h = self.quant_conv(h)
105
- return h
106
-
107
- def decode(self, quant):
108
- quant = self.post_quant_conv(quant)
109
- dec = self.decoder(quant)
110
- return dec
111
-
112
- def decode_code(self, code_b):
113
- quant_b = self.quantize.embed_code(code_b)
114
- dec = self.decode(quant_b)
115
- return dec
116
-
117
- def forward(self, input, return_pred_indices=False):
118
- quant, diff, (_,_,ind) = self.encode(input)
119
- dec = self.decode(quant)
120
- if return_pred_indices:
121
- return dec, diff, ind
122
- return dec, diff
123
-
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
- if self.batch_resize_range is not None:
130
- lower_size = self.batch_resize_range[0]
131
- upper_size = self.batch_resize_range[1]
132
- if self.global_step <= 4:
133
- # do the first few batches with max size to avoid later oom
134
- new_resize = upper_size
135
- else:
136
- new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
- if new_resize != x.shape[2]:
138
- x = F.interpolate(x, size=new_resize, mode="bicubic")
139
- x = x.detach()
140
- return x
141
-
142
- def training_step(self, batch, batch_idx, optimizer_idx):
143
- # https://github.com/pytorch/pytorch/issues/37142
144
- # try not to fool the heuristics
145
- x = self.get_input(batch, self.image_key)
146
- xrec, qloss, ind = self(x, return_pred_indices=True)
147
-
148
- if optimizer_idx == 0:
149
- # autoencode
150
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
- last_layer=self.get_last_layer(), split="train",
152
- predicted_indices=ind)
153
-
154
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
- return aeloss
156
-
157
- if optimizer_idx == 1:
158
- # discriminator
159
- discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
- last_layer=self.get_last_layer(), split="train")
161
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
- return discloss
163
-
164
- def validation_step(self, batch, batch_idx):
165
- log_dict = self._validation_step(batch, batch_idx)
166
- with self.ema_scope():
167
- log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
- return log_dict
169
-
170
- def _validation_step(self, batch, batch_idx, suffix=""):
171
- x = self.get_input(batch, self.image_key)
172
- xrec, qloss, ind = self(x, return_pred_indices=True)
173
- aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
- self.global_step,
175
- last_layer=self.get_last_layer(),
176
- split="val"+suffix,
177
- predicted_indices=ind
178
- )
179
-
180
- discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
- self.global_step,
182
- last_layer=self.get_last_layer(),
183
- split="val"+suffix,
184
- predicted_indices=ind
185
- )
186
- rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
- self.log(f"val{suffix}/rec_loss", rec_loss,
188
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
- self.log(f"val{suffix}/aeloss", aeloss,
190
- prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
- if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
- del log_dict_ae[f"val{suffix}/rec_loss"]
193
- self.log_dict(log_dict_ae)
194
- self.log_dict(log_dict_disc)
195
- return self.log_dict
196
-
197
- def configure_optimizers(self):
198
- lr_d = self.learning_rate
199
- lr_g = self.lr_g_factor*self.learning_rate
200
- print("lr_d", lr_d)
201
- print("lr_g", lr_g)
202
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
- list(self.decoder.parameters())+
204
- list(self.quantize.parameters())+
205
- list(self.quant_conv.parameters())+
206
- list(self.post_quant_conv.parameters()),
207
- lr=lr_g, betas=(0.5, 0.9))
208
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
- lr=lr_d, betas=(0.5, 0.9))
210
-
211
- if self.scheduler_config is not None:
212
- scheduler = instantiate_from_config(self.scheduler_config)
213
-
214
- print("Setting up LambdaLR scheduler...")
215
- scheduler = [
216
- {
217
- 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
- 'interval': 'step',
219
- 'frequency': 1
220
- },
221
- {
222
- 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
- 'interval': 'step',
224
- 'frequency': 1
225
- },
226
- ]
227
- return [opt_ae, opt_disc], scheduler
228
- return [opt_ae, opt_disc], []
229
-
230
- def get_last_layer(self):
231
- return self.decoder.conv_out.weight
232
-
233
- def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
- log = dict()
235
- x = self.get_input(batch, self.image_key)
236
- x = x.to(self.device)
237
- if only_inputs:
238
- log["inputs"] = x
239
- return log
240
- xrec, _ = self(x)
241
- if x.shape[1] > 3:
242
- # colorize with random projection
243
- assert xrec.shape[1] > 3
244
- x = self.to_rgb(x)
245
- xrec = self.to_rgb(xrec)
246
- log["inputs"] = x
247
- log["reconstructions"] = xrec
248
- if plot_ema:
249
- with self.ema_scope():
250
- xrec_ema, _ = self(x)
251
- if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
- log["reconstructions_ema"] = xrec_ema
253
- return log
254
-
255
- def to_rgb(self, x):
256
- assert self.image_key == "segmentation"
257
- if not hasattr(self, "colorize"):
258
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
- x = F.conv2d(x, weight=self.colorize)
260
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
- return x
262
-
263
-
264
- class VQModelInterface(VQModel):
265
- def __init__(self, embed_dim, *args, **kwargs):
266
- super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
- self.embed_dim = embed_dim
268
-
269
- def encode(self, x):
270
- h = self.encoder(x)
271
- h = self.quant_conv(h)
272
- return h
273
-
274
- def decode(self, h, force_not_quantize=False):
275
- # also go through quantization layer
276
- if not force_not_quantize:
277
- quant, emb_loss, info = self.quantize(h)
278
- else:
279
- quant = h
280
- quant = self.post_quant_conv(quant)
281
- dec = self.decoder(quant)
282
- return dec
283
-
284
-
285
- class AutoencoderKL(pl.LightningModule):
286
- def __init__(self,
287
- ddconfig,
288
- lossconfig,
289
- embed_dim,
290
- ckpt_path=None,
291
- ignore_keys=[],
292
- image_key="image",
293
- colorize_nlabels=None,
294
- monitor=None,
295
- ):
296
- super().__init__()
297
- self.image_key = image_key
298
- self.encoder = Encoder(**ddconfig)
299
- self.decoder = Decoder(**ddconfig)
300
- self.loss = instantiate_from_config(lossconfig)
301
- assert ddconfig["double_z"]
302
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
- self.embed_dim = embed_dim
305
- if colorize_nlabels is not None:
306
- assert type(colorize_nlabels)==int
307
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
- if monitor is not None:
309
- self.monitor = monitor
310
- if ckpt_path is not None:
311
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
-
313
- def init_from_ckpt(self, path, ignore_keys=list()):
314
- sd = torch.load(path, map_location="cpu")["state_dict"]
315
- keys = list(sd.keys())
316
- for k in keys:
317
- for ik in ignore_keys:
318
- if k.startswith(ik):
319
- print("Deleting key {} from state_dict.".format(k))
320
- del sd[k]
321
- self.load_state_dict(sd, strict=False)
322
- print(f"Restored from {path}")
323
-
324
- def encode(self, x):
325
- h = self.encoder(x)
326
- moments = self.quant_conv(h)
327
- posterior = DiagonalGaussianDistribution(moments)
328
- return posterior
329
-
330
- def decode(self, z):
331
- z = self.post_quant_conv(z)
332
- dec = self.decoder(z)
333
- return dec
334
-
335
- def forward(self, input, sample_posterior=True):
336
- posterior = self.encode(input)
337
- if sample_posterior:
338
- z = posterior.sample()
339
- else:
340
- z = posterior.mode()
341
- dec = self.decode(z)
342
- return dec, posterior
343
-
344
- def get_input(self, batch, k):
345
- x = batch[k]
346
- if len(x.shape) == 3:
347
- x = x[..., None]
348
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
- return x
350
-
351
- def training_step(self, batch, batch_idx, optimizer_idx):
352
- inputs = self.get_input(batch, self.image_key)
353
- reconstructions, posterior = self(inputs)
354
-
355
- if optimizer_idx == 0:
356
- # train encoder+decoder+logvar
357
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
- last_layer=self.get_last_layer(), split="train")
359
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
- return aeloss
362
-
363
- if optimizer_idx == 1:
364
- # train the discriminator
365
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
- last_layer=self.get_last_layer(), split="train")
367
-
368
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
- return discloss
371
-
372
- def validation_step(self, batch, batch_idx):
373
- inputs = self.get_input(batch, self.image_key)
374
- reconstructions, posterior = self(inputs)
375
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
- last_layer=self.get_last_layer(), split="val")
377
-
378
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
- last_layer=self.get_last_layer(), split="val")
380
-
381
- self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
- self.log_dict(log_dict_ae)
383
- self.log_dict(log_dict_disc)
384
- return self.log_dict
385
-
386
- def configure_optimizers(self):
387
- lr = self.learning_rate
388
- opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
- list(self.decoder.parameters())+
390
- list(self.quant_conv.parameters())+
391
- list(self.post_quant_conv.parameters()),
392
- lr=lr, betas=(0.5, 0.9))
393
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
- lr=lr, betas=(0.5, 0.9))
395
- return [opt_ae, opt_disc], []
396
-
397
- def get_last_layer(self):
398
- return self.decoder.conv_out.weight
399
-
400
- @torch.no_grad()
401
- def log_images(self, batch, only_inputs=False, **kwargs):
402
- log = dict()
403
- x = self.get_input(batch, self.image_key)
404
- x = x.to(self.device)
405
- if not only_inputs:
406
- xrec, posterior = self(x)
407
- if x.shape[1] > 3:
408
- # colorize with random projection
409
- assert xrec.shape[1] > 3
410
- x = self.to_rgb(x)
411
- xrec = self.to_rgb(xrec)
412
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
- log["reconstructions"] = xrec
414
- log["inputs"] = x
415
- return log
416
-
417
- def to_rgb(self, x):
418
- assert self.image_key == "segmentation"
419
- if not hasattr(self, "colorize"):
420
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
- x = F.conv2d(x, weight=self.colorize)
422
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
- return x
424
-
425
-
426
- class IdentityFirstStage(torch.nn.Module):
427
- def __init__(self, *args, vq_interface=False, **kwargs):
428
- self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
- super().__init__()
430
-
431
- def encode(self, x, *args, **kwargs):
432
- return x
433
-
434
- def decode(self, x, *args, **kwargs):
435
- return x
436
-
437
- def quantize(self, x, *args, **kwargs):
438
- if self.vq_interface:
439
- return x, None, [None, None, None]
440
- return x
441
-
442
- def forward(self, x, *args, **kwargs):
443
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/models/diffusion/__init__.py DELETED
File without changes
One-2-3-45-master 2/ldm/models/diffusion/classifier.py DELETED
@@ -1,267 +0,0 @@
1
- import os
2
- import torch
3
- import pytorch_lightning as pl
4
- from omegaconf import OmegaConf
5
- from torch.nn import functional as F
6
- from torch.optim import AdamW
7
- from torch.optim.lr_scheduler import LambdaLR
8
- from copy import deepcopy
9
- from einops import rearrange
10
- from glob import glob
11
- from natsort import natsorted
12
-
13
- from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
- from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
-
16
- __models__ = {
17
- 'class_label': EncoderUNetModel,
18
- 'segmentation': UNetModel
19
- }
20
-
21
-
22
- def disabled_train(self, mode=True):
23
- """Overwrite model.train with this function to make sure train/eval mode
24
- does not change anymore."""
25
- return self
26
-
27
-
28
- class NoisyLatentImageClassifier(pl.LightningModule):
29
-
30
- def __init__(self,
31
- diffusion_path,
32
- num_classes,
33
- ckpt_path=None,
34
- pool='attention',
35
- label_key=None,
36
- diffusion_ckpt_path=None,
37
- scheduler_config=None,
38
- weight_decay=1.e-2,
39
- log_steps=10,
40
- monitor='val/loss',
41
- *args,
42
- **kwargs):
43
- super().__init__(*args, **kwargs)
44
- self.num_classes = num_classes
45
- # get latest config of diffusion model
46
- diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
- self.diffusion_config = OmegaConf.load(diffusion_config).model
48
- self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
- self.load_diffusion()
50
-
51
- self.monitor = monitor
52
- self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
- self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
- self.log_steps = log_steps
55
-
56
- self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
- else self.diffusion_model.cond_stage_key
58
-
59
- assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
-
61
- if self.label_key not in __models__:
62
- raise NotImplementedError()
63
-
64
- self.load_classifier(ckpt_path, pool)
65
-
66
- self.scheduler_config = scheduler_config
67
- self.use_scheduler = self.scheduler_config is not None
68
- self.weight_decay = weight_decay
69
-
70
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
- sd = torch.load(path, map_location="cpu")
72
- if "state_dict" in list(sd.keys()):
73
- sd = sd["state_dict"]
74
- keys = list(sd.keys())
75
- for k in keys:
76
- for ik in ignore_keys:
77
- if k.startswith(ik):
78
- print("Deleting key {} from state_dict.".format(k))
79
- del sd[k]
80
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
- sd, strict=False)
82
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
- if len(missing) > 0:
84
- print(f"Missing Keys: {missing}")
85
- if len(unexpected) > 0:
86
- print(f"Unexpected Keys: {unexpected}")
87
-
88
- def load_diffusion(self):
89
- model = instantiate_from_config(self.diffusion_config)
90
- self.diffusion_model = model.eval()
91
- self.diffusion_model.train = disabled_train
92
- for param in self.diffusion_model.parameters():
93
- param.requires_grad = False
94
-
95
- def load_classifier(self, ckpt_path, pool):
96
- model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
- model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
- model_config.out_channels = self.num_classes
99
- if self.label_key == 'class_label':
100
- model_config.pool = pool
101
-
102
- self.model = __models__[self.label_key](**model_config)
103
- if ckpt_path is not None:
104
- print('#####################################################################')
105
- print(f'load from ckpt "{ckpt_path}"')
106
- print('#####################################################################')
107
- self.init_from_ckpt(ckpt_path)
108
-
109
- @torch.no_grad()
110
- def get_x_noisy(self, x, t, noise=None):
111
- noise = default(noise, lambda: torch.randn_like(x))
112
- continuous_sqrt_alpha_cumprod = None
113
- if self.diffusion_model.use_continuous_noise:
114
- continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
- # todo: make sure t+1 is correct here
116
-
117
- return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
- continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
-
120
- def forward(self, x_noisy, t, *args, **kwargs):
121
- return self.model(x_noisy, t)
122
-
123
- @torch.no_grad()
124
- def get_input(self, batch, k):
125
- x = batch[k]
126
- if len(x.shape) == 3:
127
- x = x[..., None]
128
- x = rearrange(x, 'b h w c -> b c h w')
129
- x = x.to(memory_format=torch.contiguous_format).float()
130
- return x
131
-
132
- @torch.no_grad()
133
- def get_conditioning(self, batch, k=None):
134
- if k is None:
135
- k = self.label_key
136
- assert k is not None, 'Needs to provide label key'
137
-
138
- targets = batch[k].to(self.device)
139
-
140
- if self.label_key == 'segmentation':
141
- targets = rearrange(targets, 'b h w c -> b c h w')
142
- for down in range(self.numd):
143
- h, w = targets.shape[-2:]
144
- targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
-
146
- # targets = rearrange(targets,'b c h w -> b h w c')
147
-
148
- return targets
149
-
150
- def compute_top_k(self, logits, labels, k, reduction="mean"):
151
- _, top_ks = torch.topk(logits, k, dim=1)
152
- if reduction == "mean":
153
- return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
- elif reduction == "none":
155
- return (top_ks == labels[:, None]).float().sum(dim=-1)
156
-
157
- def on_train_epoch_start(self):
158
- # save some memory
159
- self.diffusion_model.model.to('cpu')
160
-
161
- @torch.no_grad()
162
- def write_logs(self, loss, logits, targets):
163
- log_prefix = 'train' if self.training else 'val'
164
- log = {}
165
- log[f"{log_prefix}/loss"] = loss.mean()
166
- log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
- logits, targets, k=1, reduction="mean"
168
- )
169
- log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
- logits, targets, k=5, reduction="mean"
171
- )
172
-
173
- self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
- self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
- self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
- lr = self.optimizers().param_groups[0]['lr']
177
- self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
-
179
- def shared_step(self, batch, t=None):
180
- x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
- targets = self.get_conditioning(batch)
182
- if targets.dim() == 4:
183
- targets = targets.argmax(dim=1)
184
- if t is None:
185
- t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
- else:
187
- t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
- x_noisy = self.get_x_noisy(x, t)
189
- logits = self(x_noisy, t)
190
-
191
- loss = F.cross_entropy(logits, targets, reduction='none')
192
-
193
- self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
-
195
- loss = loss.mean()
196
- return loss, logits, x_noisy, targets
197
-
198
- def training_step(self, batch, batch_idx):
199
- loss, *_ = self.shared_step(batch)
200
- return loss
201
-
202
- def reset_noise_accs(self):
203
- self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
- range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
-
206
- def on_validation_start(self):
207
- self.reset_noise_accs()
208
-
209
- @torch.no_grad()
210
- def validation_step(self, batch, batch_idx):
211
- loss, *_ = self.shared_step(batch)
212
-
213
- for t in self.noisy_acc:
214
- _, logits, _, targets = self.shared_step(batch, t)
215
- self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
- self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
-
218
- return loss
219
-
220
- def configure_optimizers(self):
221
- optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
-
223
- if self.use_scheduler:
224
- scheduler = instantiate_from_config(self.scheduler_config)
225
-
226
- print("Setting up LambdaLR scheduler...")
227
- scheduler = [
228
- {
229
- 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
- 'interval': 'step',
231
- 'frequency': 1
232
- }]
233
- return [optimizer], scheduler
234
-
235
- return optimizer
236
-
237
- @torch.no_grad()
238
- def log_images(self, batch, N=8, *args, **kwargs):
239
- log = dict()
240
- x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
- log['inputs'] = x
242
-
243
- y = self.get_conditioning(batch)
244
-
245
- if self.label_key == 'class_label':
246
- y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
- log['labels'] = y
248
-
249
- if ismap(y):
250
- log['labels'] = self.diffusion_model.to_rgb(y)
251
-
252
- for step in range(self.log_steps):
253
- current_time = step * self.log_time_interval
254
-
255
- _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
-
257
- log[f'inputs@t{current_time}'] = x_noisy
258
-
259
- pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
- pred = rearrange(pred, 'b h w c -> b c h w')
261
-
262
- log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
-
264
- for key in log:
265
- log[key] = log[key][:N]
266
-
267
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
One-2-3-45-master 2/ldm/models/diffusion/ddim.py DELETED
@@ -1,326 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
- from functools import partial
7
- from einops import rearrange
8
-
9
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
10
- from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding
11
-
12
-
13
- class DDIMSampler(object):
14
- def __init__(self, model, schedule="linear", **kwargs):
15
- super().__init__()
16
- self.model = model
17
- self.ddpm_num_timesteps = model.num_timesteps
18
- self.schedule = schedule
19
- self.device = model.device
20
-
21
- def to(self, device):
22
- """Same as to in torch module
23
- Don't really underestand why this isn't a module in the first place"""
24
- for k, v in self.__dict__.items():
25
- if isinstance(v, torch.Tensor):
26
- new_v = getattr(self, k).to(device)
27
- setattr(self, k, new_v)
28
-
29
-
30
- def register_buffer(self, name, attr, device=None):
31
- if type(attr) == torch.Tensor:
32
- attr = attr.to(device)
33
- # if attr.device != torch.device("cuda"):
34
- # attr = attr.to(torch.device("cuda"))
35
- setattr(self, name, attr)
36
-
37
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
38
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
39
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
40
- alphas_cumprod = self.model.alphas_cumprod
41
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
42
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
-
44
- self.register_buffer('betas', to_torch(self.model.betas), self.device)
45
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod), self.device)
46
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev), self.device)
47
-
48
- # calculations for diffusion q(x_t | x_{t-1}) and others
49
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())), self.device)
50
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())), self.device)
51
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())), self.device)
52
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())), self.device)
53
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)), self.device)
54
-
55
- # ddim sampling parameters
56
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
57
- ddim_timesteps=self.ddim_timesteps,
58
- eta=ddim_eta,verbose=verbose)
59
- self.register_buffer('ddim_sigmas', ddim_sigmas, self.device)
60
- self.register_buffer('ddim_alphas', ddim_alphas, self.device)
61
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev, self.device)
62
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas), self.device)
63
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
64
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
65
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
66
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps, self.device)
67
-
68
- @torch.no_grad()
69
- def sample(self,
70
- S,
71
- batch_size,
72
- shape,
73
- conditioning=None,
74
- callback=None,
75
- normals_sequence=None,
76
- img_callback=None,
77
- quantize_x0=False,
78
- eta=0.,
79
- mask=None,
80
- x0=None,
81
- temperature=1.,
82
- noise_dropout=0.,
83
- score_corrector=None,
84
- corrector_kwargs=None,
85
- verbose=True,
86
- x_T=None,
87
- log_every_t=100,
88
- unconditional_guidance_scale=1.,
89
- unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
90
- dynamic_threshold=None,
91
- **kwargs
92
- ):
93
- if conditioning is not None:
94
- if isinstance(conditioning, dict):
95
- ctmp = conditioning[list(conditioning.keys())[0]]
96
- while isinstance(ctmp, list): ctmp = ctmp[0]
97
- cbs = ctmp.shape[0]
98
- if cbs != batch_size:
99
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
100
-
101
- else:
102
- if conditioning.shape[0] != batch_size:
103
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
104
-
105
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
106
- # sampling
107
- C, H, W = shape
108
- size = (batch_size, C, H, W)
109
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
110
-
111
- samples, intermediates = self.ddim_sampling(conditioning, size,
112
- callback=callback,
113
- img_callback=img_callback,
114
- quantize_denoised=quantize_x0,
115
- mask=mask, x0=x0,
116
- ddim_use_original_steps=False,
117
- noise_dropout=noise_dropout,
118
- temperature=temperature,
119
- score_corrector=score_corrector,
120
- corrector_kwargs=corrector_kwargs,
121
- x_T=x_T,
122
- log_every_t=log_every_t,
123
- unconditional_guidance_scale=unconditional_guidance_scale,
124
- unconditional_conditioning=unconditional_conditioning,
125
- dynamic_threshold=dynamic_threshold,
126
- )
127
- return samples, intermediates
128
-
129
- @torch.no_grad()
130
- def ddim_sampling(self, cond, shape,
131
- x_T=None, ddim_use_original_steps=False,
132
- callback=None, timesteps=None, quantize_denoised=False,
133
- mask=None, x0=None, img_callback=None, log_every_t=100,
134
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
135
- unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
136
- t_start=-1):
137
- device = self.model.betas.device
138
- b = shape[0]
139
- if x_T is None:
140
- img = torch.randn(shape, device=device)
141
- else:
142
- img = x_T
143
-
144
- if timesteps is None:
145
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
146
- elif timesteps is not None and not ddim_use_original_steps:
147
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
148
- timesteps = self.ddim_timesteps[:subset_end]
149
-
150
- timesteps = timesteps[:t_start]
151
-
152
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
153
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
154
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
155
- print(f"Running DDIM Sampling with {total_steps} timesteps")
156
-
157
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
158
-
159
- for i, step in enumerate(iterator):
160
- index = total_steps - i - 1
161
- ts = torch.full((b,), step, device=device, dtype=torch.long)
162
-
163
- if mask is not None:
164
- assert x0 is not None
165
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
166
- img = img_orig * mask + (1. - mask) * img
167
-
168
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
169
- quantize_denoised=quantize_denoised, temperature=temperature,
170
- noise_dropout=noise_dropout, score_corrector=score_corrector,
171
- corrector_kwargs=corrector_kwargs,
172
- unconditional_guidance_scale=unconditional_guidance_scale,
173
- unconditional_conditioning=unconditional_conditioning,
174
- dynamic_threshold=dynamic_threshold)
175
- img, pred_x0 = outs
176
- if callback:
177
- img = callback(i, img, pred_x0)
178
- if img_callback: img_callback(pred_x0, i)
179
-
180
- if index % log_every_t == 0 or index == total_steps - 1:
181
- intermediates['x_inter'].append(img)
182
- intermediates['pred_x0'].append(pred_x0)
183
-
184
- return img, intermediates
185
-
186
- @torch.no_grad()
187
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
188
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
189
- unconditional_guidance_scale=1., unconditional_conditioning=None,
190
- dynamic_threshold=None):
191
- b, *_, device = *x.shape, x.device
192
-
193
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
194
- e_t = self.model.apply_model(x, t, c)
195
- else:
196
- x_in = torch.cat([x] * 2)
197
- t_in = torch.cat([t] * 2)
198
- if isinstance(c, dict):
199
- assert isinstance(unconditional_conditioning, dict)
200
- c_in = dict()
201
- for k in c:
202
- if isinstance(c[k], list):
203
- c_in[k] = [torch.cat([
204
- unconditional_conditioning[k][i],
205
- c[k][i]]) for i in range(len(c[k]))]
206
- else:
207
- c_in[k] = torch.cat([
208
- unconditional_conditioning[k],
209
- c[k]])
210
- else:
211
- c_in = torch.cat([unconditional_conditioning, c])
212
- e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
213
- e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
214
-
215
- if score_corrector is not None:
216
- assert self.model.parameterization == "eps"
217
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
218
-
219
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
220
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
221
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
222
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
223
- # select parameters corresponding to the currently considered timestep
224
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
225
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
226
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
227
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
228
-
229
- # current prediction for x_0
230
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
231
- if quantize_denoised:
232
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
233
-
234
- if dynamic_threshold is not None:
235
- pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
236
-
237
- # direction pointing to x_t
238
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
239
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
240
- if noise_dropout > 0.:
241
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
242
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
243
- return x_prev, pred_x0
244
-
245
- @torch.no_grad()
246
- def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
247
- unconditional_guidance_scale=1.0, unconditional_conditioning=None):
248
- num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
249
-
250
- assert t_enc <= num_reference_steps
251
- num_steps = t_enc
252
-
253
- if use_original_steps:
254
- alphas_next = self.alphas_cumprod[:num_steps]
255
- alphas = self.alphas_cumprod_prev[:num_steps]
256
- else:
257
- alphas_next = self.ddim_alphas[:num_steps]
258
- alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
259
-
260
- x_next = x0
261
- intermediates = []
262
- inter_steps = []
263
- for i in tqdm(range(num_steps), desc='Encoding Image'):
264
- t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
265
- if unconditional_guidance_scale == 1.:
266
- noise_pred = self.model.apply_model(x_next, t, c)
267
- else:
268
- assert unconditional_conditioning is not None
269
- e_t_uncond, noise_pred = torch.chunk(
270
- self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
271
- torch.cat((unconditional_conditioning, c))), 2)
272
- noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
273
-
274
- xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
275
- weighted_noise_pred = alphas_next[i].sqrt() * (
276
- (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
277
- x_next = xt_weighted + weighted_noise_pred
278
- if return_intermediates and i % (
279
- num_steps // return_intermediates) == 0 and i < num_steps - 1:
280
- intermediates.append(x_next)
281
- inter_steps.append(i)
282
- elif return_intermediates and i >= num_steps - 2:
283
- intermediates.append(x_next)
284
- inter_steps.append(i)
285
-
286
- out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
287
- if return_intermediates:
288
- out.update({'intermediates': intermediates})
289
- return x_next, out
290
-
291
- @torch.no_grad()
292
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
293
- # fast, but does not allow for exact reconstruction
294
- # t serves as an index to gather the correct alphas
295
- if use_original_steps:
296
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
297
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
298
- else:
299
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
300
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
301
-
302
- if noise is None:
303
- noise = torch.randn_like(x0)
304
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
305
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
306
-
307
- @torch.no_grad()
308
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
309
- use_original_steps=False):
310
-
311
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
312
- timesteps = timesteps[:t_start]
313
-
314
- time_range = np.flip(timesteps)
315
- total_steps = timesteps.shape[0]
316
- print(f"Running DDIM Sampling with {total_steps} timesteps")
317
-
318
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
319
- x_dec = x_latent
320
- for i, step in enumerate(iterator):
321
- index = total_steps - i - 1
322
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
323
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
324
- unconditional_guidance_scale=unconditional_guidance_scale,
325
- unconditional_conditioning=unconditional_conditioning)
326
- return x_dec