Yanrui95 commited on
Commit
fc13e66
·
verified ·
1 Parent(s): 3f350cf

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
37
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Python template
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ #
11
+ .gradio
12
+ .github
13
+ demo_output
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ .idea/
166
+
167
+ /logs
168
+ /gin-config
169
+ *.json
170
+ /eval/*csv
171
+ *__pycache__
172
+ scripts/
173
+ eval/
174
+ *.DS_Store
175
+ benchmark/datasets
LICENSE ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
2
+
3
+ License Terms of the inference code of NormalCrafter:
4
+ --------------------------------------------------------------------
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
7
+
8
+ - You agree to use the NormalCrafter only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
9
+
10
+ - The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
11
+
12
+ For avoidance of doubts, “Software” means the NormalCrafter model inference code and weights made available under this license excluding any pre-trained data and other AI components.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
15
+
16
+
17
+ Other dependencies and licenses:
18
+
19
+ Open Source Software Licensed under the MIT License:
20
+ --------------------------------------------------------------------
21
+ 1. Stability AI - Code
22
+ Copyright (c) 2023 Stability AI
23
+
24
+ Terms of the MIT License:
25
+ --------------------------------------------------------------------
26
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
27
+
28
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
29
+
30
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
31
+
32
+ **You may find the code license of Stability AI at the following links: https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE
README.md CHANGED
@@ -1,14 +1,65 @@
1
  ---
2
  title: NormalCrafter
3
- emoji: 📉
4
- colorFrom: green
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.23.1
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: NormalCrafter
12
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
  ---
2
  title: NormalCrafter
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.23.2
 
6
  ---
7
+ ## ___***NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors***___
8
+
9
+ _**[Yanrui Bin<sup>1</sup>](https://scholar.google.com/citations?user=_9fN3mEAAAAJ&hl=zh-CN),[Wenbo Hu<sup>2*](https://wbhu.github.io),
10
+ [Haoyuan Wang<sup>3](https://www.whyy.site/),
11
+ [Xinya Chen<sup>3](https://xinyachen21.github.io/),
12
+ [Bing Wang<sup>2 &dagger;</sup>](https://bingcs.github.io/)**_
13
+ <br><br>
14
+ <sup>1</sup>Spatial Intelligence Group, The Hong Kong Polytechnic University
15
+ <sup>2</sup>Tencent AI Lab
16
+ <sup>3</sup>City University of Hong Kong
17
+ <sup>4</sup>Huazhong University of Science and Technology
18
+ </div>
19
+
20
+ ## 🔆 Notice
21
+ We recommend that everyone use English to communicate on issues, as this helps developers from around the world discuss, share experiences, and answer questions together.
22
+
23
+ For business licensing and other related inquiries, don't hesitate to contact `binyanrui@gmail.com`.
24
+
25
+ ## 🔆 Introduction
26
+ 🤗 If you find NormalCrafter useful, **please help ⭐ this repo**, which is important to Open-Source projects. Thanks!
27
+
28
+ 🔥 NormalCrafter can generate temporally consistent normal sequences
29
+ with fine-grained details from open-world videos with arbitrary lengths.
30
+
31
+ - `[24-04-01]` 🔥🔥🔥 **NormalCrafter** is released now, have fun!
32
+ ## 🚀 Quick Start
33
+
34
+ ### 🤖 Gradio Demo
35
+ - Online demo: [NormalCrafter](https://huggingface.co/spaces/Yanrui95/NormalCrafter)
36
+ - Local demo:
37
+ ```bash
38
+ gradio app.py
39
+ ```
40
+
41
+ ### 🛠️ Installation
42
+ 1. Clone this repo:
43
+ ```bash
44
+ git clone git@github.com:Binyr/NormalCrafter.git
45
+ ```
46
+ 2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+
52
+
53
+ ### 🤗 Model Zoo
54
+ [NormalCrafter](https://huggingface.co/Yanrui95/NormalCrafter) is available in the Hugging Face Model Hub.
55
+
56
+ ### 🏃‍♂️ Inference
57
+ #### 1. High-resolution inference, requires a GPU with ~20GB memory for 1024x576 resolution:
58
+ ```bash
59
+ python run.py --video-path examples/example_01.mp4
60
+ ```
61
 
62
+ #### 2. Low-resolution inference requires a GPU with ~6GB memory for 512x256 resolution:
63
+ ```bash
64
+ python run.py --video-path examples/example_01.mp4 --max-res 512
65
+ ```
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+
4
+ import numpy as np
5
+ import spaces
6
+ import gradio as gr
7
+ import torch
8
+ from diffusers.training_utils import set_seed
9
+ from diffusers import AutoencoderKLTemporalDecoder
10
+
11
+ from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
12
+ from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
13
+
14
+ import uuid
15
+ import random
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from normalcrafter.utils import read_video_frames, vis_sequence_normal, save_video
19
+
20
+ examples = [
21
+ ["examples/example_01.mp4", 1024, -1, -1],
22
+ ["examples/example_02.mp4", 1024, -1, -1],
23
+ ["examples/example_03.mp4", 1024, -1, -1],
24
+ ["examples/example_04.mp4", 1024, -1, -1],
25
+ ["examples/example_05.mp4", 1024, -1, -1],
26
+ ["examples/example_06.mp4", 1024, -1, -1],
27
+ ]
28
+
29
+ pretrained_model_name_or_path = "Yanrui95/NormalCrafter"
30
+ weight_dtype = torch.float16
31
+ unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
32
+ pretrained_model_name_or_path,
33
+ subfolder="unet",
34
+ low_cpu_mem_usage=True,
35
+ )
36
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
37
+ pretrained_model_name_or_path, subfolder="vae")
38
+
39
+ vae.to(dtype=weight_dtype)
40
+ unet.to(dtype=weight_dtype)
41
+
42
+ pipe = NormalCrafterPipeline.from_pretrained(
43
+ "stabilityai/stable-video-diffusion-img2vid-xt",
44
+ unet=unet,
45
+ vae=vae,
46
+ torch_dtype=weight_dtype,
47
+ variant="fp16",
48
+ )
49
+ pipe.to("cuda")
50
+
51
+
52
+ @spaces.GPU(duration=120)
53
+ def infer_depth(
54
+ video: str,
55
+ max_res: int = 1024,
56
+ process_length: int = -1,
57
+ target_fps: int = -1,
58
+ #
59
+ save_folder: str = "./demo_output",
60
+ window_size: int = 14,
61
+ time_step_size: int = 10,
62
+ decode_chunk_size: int = 7,
63
+ seed: int = 42,
64
+ save_npz: bool = False,
65
+ ):
66
+ set_seed(seed)
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+
69
+ frames, target_fps = read_video_frames(video, process_length, target_fps, max_res)
70
+
71
+ # inference the depth map using the DepthCrafter pipeline
72
+ with torch.inference_mode():
73
+ res = pipe(
74
+ frames,
75
+ decode_chunk_size=decode_chunk_size,
76
+ time_step_size=time_step_size,
77
+ window_size=window_size,
78
+ ).frames[0]
79
+
80
+ # visualize the depth map and save the results
81
+ vis = vis_sequence_normal(res)
82
+ # save the depth map and visualization with the target FPS
83
+ save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
84
+ print(f"==> saving results to {save_path}")
85
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
86
+ if save_npz:
87
+ np.savez_compressed(save_path + ".npz", normal=res)
88
+ save_video(vis, save_path + "_vis.mp4", fps=target_fps)
89
+ save_video(frames, save_path + "_input.mp4", fps=target_fps)
90
+
91
+ # clear the cache for the next video
92
+ gc.collect()
93
+ torch.cuda.empty_cache()
94
+
95
+ return [
96
+ save_path + "_input.mp4",
97
+ save_path + "_vis.mp4",
98
+
99
+ ]
100
+
101
+
102
+ def construct_demo():
103
+ with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
104
+ gr.Markdown(
105
+ """
106
+ <div align='center'> <h1> NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors </span> </h1> \
107
+ <a style='font-size:18px;color: #000000'>If you find NormalCrafter useful, please help ⭐ the </a>\
108
+ <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Binyr/NormalCrafter'>[Github Repo]</a>\
109
+ <a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\
110
+ <a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02095'> [ArXiv] </a>\
111
+ <a style='font-size:18px;color: #000000' href='https://normalcrafter.github.io/'> [Project Page] </a> </div>
112
+ """
113
+ )
114
+
115
+ with gr.Row(equal_height=True):
116
+ with gr.Column(scale=1):
117
+ input_video = gr.Video(label="Input Video")
118
+
119
+ # with gr.Tab(label="Output"):
120
+ with gr.Column(scale=2):
121
+ with gr.Row(equal_height=True):
122
+ output_video_1 = gr.Video(
123
+ label="Preprocessed video",
124
+ interactive=False,
125
+ autoplay=True,
126
+ loop=True,
127
+ show_share_button=True,
128
+ scale=5,
129
+ )
130
+ output_video_2 = gr.Video(
131
+ label="Generated Depth Video",
132
+ interactive=False,
133
+ autoplay=True,
134
+ loop=True,
135
+ show_share_button=True,
136
+ scale=5,
137
+ )
138
+
139
+ with gr.Row(equal_height=True):
140
+ with gr.Column(scale=1):
141
+ with gr.Row(equal_height=False):
142
+ with gr.Accordion("Advanced Settings", open=False):
143
+ max_res = gr.Slider(
144
+ label="max resolution",
145
+ minimum=512,
146
+ maximum=1024,
147
+ value=1024,
148
+ step=64,
149
+ )
150
+ process_length = gr.Slider(
151
+ label="process length",
152
+ minimum=-1,
153
+ maximum=280,
154
+ value=60,
155
+ step=1,
156
+ )
157
+ process_target_fps = gr.Slider(
158
+ label="target FPS",
159
+ minimum=-1,
160
+ maximum=30,
161
+ value=15,
162
+ step=1,
163
+ )
164
+ generate_btn = gr.Button("Generate")
165
+ with gr.Column(scale=2):
166
+ pass
167
+
168
+ gr.Examples(
169
+ examples=examples,
170
+ inputs=[
171
+ input_video,
172
+ max_res,
173
+ process_length,
174
+ process_target_fps,
175
+ ],
176
+ outputs=[output_video_1, output_video_2],
177
+ fn=infer_depth,
178
+ cache_examples="lazy",
179
+ )
180
+ # gr.Markdown(
181
+ # """
182
+ # <span style='font-size:18px;color: #E7CCCC'>Note:
183
+ # For time quota consideration, we set the default parameters to be more efficient here,
184
+ # with a trade-off of shorter video length and slightly lower quality.
185
+ # You may adjust the parameters according to our
186
+ # <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Tencent/DepthCrafter'>[Github Repo]</a>
187
+ # for better results if you have enough time quota.
188
+ # </span>
189
+ # """
190
+ # )
191
+
192
+ generate_btn.click(
193
+ fn=infer_depth,
194
+ inputs=[
195
+ input_video,
196
+ max_res,
197
+ process_length,
198
+ process_target_fps,
199
+ ],
200
+ outputs=[output_video_1, output_video_2],
201
+ )
202
+
203
+ return depthcrafter_iface
204
+
205
+
206
+ if __name__ == "__main__":
207
+ demo = construct_demo()
208
+ demo.queue()
209
+ # demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
210
+ demo.launch(share=True)
examples/example_01.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3eb7fefd157bd9b403cf0b524c7c4f3cb6d9f82b9d6a48eba2146412fc9e64a2
3
+ size 5727137
examples/example_02.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea3c4e4c8cd9682d92c25170d8df333fead210118802fbe22198dde478dc5489
3
+ size 3150525
examples/example_03.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d332877a98bb41ff86a639139a03e383e91880bca722bba7e2518878fca54f6
3
+ size 3013435
examples/example_04.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2aa4962216adce71b1c47f395be435b23105df35f3892646e237b935ac1c74f
3
+ size 3591374
examples/example_05.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8d2319060f9a1d3cfcb9de317e4a5b138657fd741c530ed3983f6565c2eda44
3
+ size 3553683
examples/example_06.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3a2619b029129f34884c761cc278b6842620bfed96d4bb52c8aa07bc1d82a8b
3
+ size 5596872
normalcrafter/__init__.py ADDED
File without changes
normalcrafter/normal_crafter_ppl.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ import math
10
+
11
+ from diffusers.utils import BaseOutput, logging
12
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
13
+ from diffusers import DiffusionPipeline
14
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline
15
+ from PIL import Image
16
+ import cv2
17
+
18
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+ class NormalCrafterPipeline(StableVideoDiffusionPipeline):
21
+
22
+ def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None):
23
+ dtype = next(self.image_encoder.parameters()).dtype
24
+
25
+ if not isinstance(image, torch.Tensor):
26
+ image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1)
27
+ image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w)
28
+
29
+ # We normalize the image before resizing to match with the original implementation.
30
+ # Then we unnormalize it after resizing.
31
+ pixel_values = image
32
+ B, C, H, W = pixel_values.shape
33
+ patches = [pixel_values]
34
+ # patches = []
35
+ for i in range(1, scale):
36
+ num_patches_HW_this_level = i + 1
37
+ patch_H = H // num_patches_HW_this_level + 1
38
+ patch_W = W // num_patches_HW_this_level + 1
39
+ for j in range(num_patches_HW_this_level):
40
+ for k in range(num_patches_HW_this_level):
41
+ patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W])
42
+
43
+ def encode_image(image):
44
+ image = image * 2.0 - 1.0
45
+ if image_size is not None:
46
+ image = _resize_with_antialiasing(image, image_size)
47
+ else:
48
+ image = _resize_with_antialiasing(image, (224, 224))
49
+ image = (image + 1.0) / 2.0
50
+
51
+ # Normalize the image with for CLIP input
52
+ image = self.feature_extractor(
53
+ images=image,
54
+ do_normalize=True,
55
+ do_center_crop=False,
56
+ do_resize=False,
57
+ do_rescale=False,
58
+ return_tensors="pt",
59
+ ).pixel_values
60
+
61
+ image = image.to(device=device, dtype=dtype)
62
+ image_embeddings = self.image_encoder(image).image_embeds
63
+ if len(image_embeddings.shape) < 3:
64
+ image_embeddings = image_embeddings.unsqueeze(1)
65
+ return image_embeddings
66
+
67
+ image_embeddings = []
68
+ for patch in patches:
69
+ image_embeddings.append(encode_image(patch))
70
+ image_embeddings = torch.cat(image_embeddings, dim=1)
71
+
72
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
73
+ # import pdb
74
+ # pdb.set_trace()
75
+ bs_embed, seq_len, _ = image_embeddings.shape
76
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
77
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
78
+
79
+ if do_classifier_free_guidance:
80
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
81
+
82
+ # For classifier free guidance, we need to do two forward passes.
83
+ # Here we concatenate the unconditional and text embeddings into a single batch
84
+ # to avoid doing two forward passes
85
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
86
+
87
+ return image_embeddings
88
+
89
+ def ecnode_video_vae(self, images, chunk_size: int = 14):
90
+ if isinstance(images, list):
91
+ width, height = images[0].size
92
+ else:
93
+ height, width = images[0].shape[:2]
94
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
95
+ if needs_upcasting:
96
+ self.vae.to(dtype=torch.float32)
97
+
98
+ device = self._execution_device
99
+ images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w)
100
+ images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w)
101
+ images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w)
102
+
103
+ video_latents = []
104
+ # chunk_size = 14
105
+ for i in range(0, images.shape[0], chunk_size):
106
+ video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode())
107
+ image_latents = torch.cat(video_latents)
108
+
109
+ # cast back to fp16 if needed
110
+ if needs_upcasting:
111
+ self.vae.to(dtype=torch.float16)
112
+
113
+ return image_latents
114
+
115
+ def pad_image(self, images, scale=64):
116
+ def get_pad(newW, W):
117
+ pad_W = (newW - W) // 2
118
+ if W % 2 == 1:
119
+ pad_Ws = [pad_W, pad_W + 1]
120
+ else:
121
+ pad_Ws = [pad_W, pad_W]
122
+ return pad_Ws
123
+
124
+ if type(images[0]) is np.ndarray:
125
+ H, W = images[0].shape[:2]
126
+ else:
127
+ W, H = images[0].size
128
+
129
+ if W % scale == 0 and H % scale == 0:
130
+ return images, None
131
+ newW = int(np.ceil(W / scale) * scale)
132
+ newH = int(np.ceil(H / scale) * scale)
133
+
134
+ pad_Ws = get_pad(newW, W)
135
+ pad_Hs = get_pad(newH, H)
136
+
137
+ new_images = []
138
+ for image in images:
139
+ if type(image) is np.ndarray:
140
+ image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
141
+ new_images.append(image)
142
+ else:
143
+ image = np.array(image)
144
+ image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255))
145
+ new_images.append(Image.fromarray(image))
146
+ return new_images, pad_Hs+pad_Ws
147
+
148
+ def unpad_image(self, v, pad_HWs):
149
+ t, b, l, r = pad_HWs
150
+ if t > 0 or b > 0:
151
+ v = v[:, :, t:-b]
152
+ if l > 0 or r > 0:
153
+ v = v[:, :, :, l:-r]
154
+ return v
155
+
156
+ @torch.no_grad()
157
+ def __call__(
158
+ self,
159
+ images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
160
+ decode_chunk_size: Optional[int] = None,
161
+ time_step_size: Optional[int] = 1,
162
+ window_size: Optional[int] = 1,
163
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
164
+ return_dict: bool = True
165
+ ):
166
+ images, pad_HWs = self.pad_image(images)
167
+
168
+ # 0. Default height and width to unet
169
+ width, height = images[0].size
170
+ num_frames = len(images)
171
+
172
+ # 1. Check inputs. Raise error if not correct
173
+ self.check_inputs(images, height, width)
174
+
175
+ # 2. Define call parameters
176
+ batch_size = 1
177
+ device = self._execution_device
178
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
179
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
180
+ # corresponds to doing no classifier free guidance.
181
+ self._guidance_scale = 1.0
182
+ num_videos_per_prompt = 1
183
+ do_classifier_free_guidance = False
184
+ num_inference_steps = 1
185
+ fps = 7
186
+ motion_bucket_id = 127
187
+ noise_aug_strength = 0.
188
+ num_videos_per_prompt = 1
189
+ output_type = "np"
190
+ data_keys = ["normal"]
191
+ use_linear_merge = True
192
+ determineTrain = True
193
+ encode_image_scale = 1
194
+ encode_image_WH = None
195
+
196
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7
197
+
198
+ # 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024)
199
+ image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH)
200
+ # 4. Encode input image using VAE
201
+ image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype)
202
+
203
+ # image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width]
204
+ image_latents = image_latents.unsqueeze(0)
205
+
206
+ # 5. Get Added Time IDs
207
+ added_time_ids = self._get_add_time_ids(
208
+ fps,
209
+ motion_bucket_id,
210
+ noise_aug_strength,
211
+ image_embeddings.dtype,
212
+ batch_size,
213
+ num_videos_per_prompt,
214
+ do_classifier_free_guidance,
215
+ )
216
+ added_time_ids = added_time_ids.to(device)
217
+
218
+ # get Start and End frame idx for each window
219
+ def get_ses(num_frames):
220
+ ses = []
221
+ for i in range(0, num_frames, time_step_size):
222
+ ses.append([i, i+window_size])
223
+ num_to_remain = 0
224
+ for se in ses:
225
+ if se[1] > num_frames:
226
+ continue
227
+ num_to_remain += 1
228
+ ses = ses[:num_to_remain]
229
+
230
+ if ses[-1][-1] < num_frames:
231
+ ses.append([num_frames - window_size, num_frames])
232
+ return ses
233
+ ses = get_ses(num_frames)
234
+
235
+ pred = None
236
+ for i, se in enumerate(ses):
237
+ window_num_frames = window_size
238
+ window_image_embeddings = image_embeddings[se[0]:se[1]]
239
+ window_image_latents = image_latents[:, se[0]:se[1]]
240
+ window_added_time_ids = added_time_ids
241
+ # import pdb
242
+ # pdb.set_trace()
243
+ if i == 0 or time_step_size == window_size:
244
+ to_replace_latents = None
245
+ else:
246
+ last_se = ses[i-1]
247
+ num_to_replace_latents = last_se[1] - se[0]
248
+ to_replace_latents = pred[:, -num_to_replace_latents:]
249
+
250
+ latents = self.generate(
251
+ num_inference_steps,
252
+ device,
253
+ batch_size,
254
+ num_videos_per_prompt,
255
+ window_num_frames,
256
+ height,
257
+ width,
258
+ window_image_embeddings,
259
+ generator,
260
+ determineTrain,
261
+ to_replace_latents,
262
+ do_classifier_free_guidance,
263
+ window_image_latents,
264
+ window_added_time_ids
265
+ )
266
+
267
+ # merge last_latents and current latents in overlap window
268
+ if to_replace_latents is not None and use_linear_merge:
269
+ num_img_condition = to_replace_latents.shape[1]
270
+ weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device)
271
+ weight = weight[None, :, None, None, None]
272
+ latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight)
273
+
274
+ if pred is None:
275
+ pred = latents
276
+ else:
277
+ pred = torch.cat([pred[:, :se[0]], latents], dim=1)
278
+
279
+ if not output_type == "latent":
280
+ # cast back to fp16 if needed
281
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
282
+ if needs_upcasting:
283
+ self.vae.to(dtype=torch.float16)
284
+ # latents has shape (1, num_frames, 12, h, w)
285
+
286
+ def decode_latents(latents, num_frames, decode_chunk_size):
287
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1)
288
+ frames = self.video_processor.postprocess_video(video=frames, output_type="np")
289
+ frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1)
290
+ return frames
291
+
292
+ frames = decode_latents(pred, num_frames, decode_chunk_size)
293
+ if pad_HWs is not None:
294
+ frames = self.unpad_image(frames, pad_HWs)
295
+ else:
296
+ frames = pred
297
+
298
+ self.maybe_free_model_hooks()
299
+
300
+ if not return_dict:
301
+ return frames
302
+
303
+ return StableVideoDiffusionPipelineOutput(frames=frames)
304
+
305
+
306
+ def generate(
307
+ self,
308
+ num_inference_steps,
309
+ device,
310
+ batch_size,
311
+ num_videos_per_prompt,
312
+ num_frames,
313
+ height,
314
+ width,
315
+ image_embeddings,
316
+ generator,
317
+ determineTrain,
318
+ to_replace_latents,
319
+ do_classifier_free_guidance,
320
+ image_latents,
321
+ added_time_ids,
322
+ latents=None,
323
+ ):
324
+ # 6. Prepare timesteps
325
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
326
+ timesteps = self.scheduler.timesteps
327
+
328
+ # 7. Prepare latent variables
329
+ num_channels_latents = self.unet.config.in_channels
330
+ latents = self.prepare_latents(
331
+ batch_size * num_videos_per_prompt,
332
+ num_frames,
333
+ num_channels_latents,
334
+ height,
335
+ width,
336
+ image_embeddings.dtype,
337
+ device,
338
+ generator,
339
+ latents,
340
+ )
341
+ if determineTrain:
342
+ latents[...] = 0.
343
+
344
+ # 8. Denoising loop
345
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
346
+ self._num_timesteps = len(timesteps)
347
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
348
+ for i, t in enumerate(timesteps):
349
+ # replace part of latents with conditons. ToDo: t embedding should also replace
350
+ if to_replace_latents is not None:
351
+ num_img_condition = to_replace_latents.shape[1]
352
+ if not determineTrain:
353
+ _noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype)
354
+ noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0))
355
+ latents[:, :num_img_condition] = noisy_to_replace_latents
356
+ else:
357
+ latents[:, :num_img_condition] = to_replace_latents
358
+
359
+
360
+ # expand the latents if we are doing classifier free guidance
361
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
362
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
363
+ timestep = t
364
+ # Concatenate image_latents over channels dimention
365
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
366
+ # predict the noise residual
367
+ noise_pred = self.unet(
368
+ latent_model_input,
369
+ timestep,
370
+ encoder_hidden_states=image_embeddings,
371
+ added_time_ids=added_time_ids,
372
+ return_dict=False,
373
+ )[0]
374
+
375
+ # perform guidance
376
+ if do_classifier_free_guidance:
377
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
378
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
379
+
380
+ # compute the previous noisy sample x_t -> x_t-1
381
+ scheduler_output = self.scheduler.step(noise_pred, t, latents)
382
+ latents = scheduler_output.prev_sample
383
+
384
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
385
+ progress_bar.update()
386
+
387
+ return latents
388
+ # resizing utils
389
+ # TODO: clean up later
390
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
391
+ h, w = input.shape[-2:]
392
+ factors = (h / size[0], w / size[1])
393
+
394
+ # First, we have to determine sigma
395
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
396
+ sigmas = (
397
+ max((factors[0] - 1.0) / 2.0, 0.001),
398
+ max((factors[1] - 1.0) / 2.0, 0.001),
399
+ )
400
+
401
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
402
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
403
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
404
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
405
+
406
+ # Make sure it is odd
407
+ if (ks[0] % 2) == 0:
408
+ ks = ks[0] + 1, ks[1]
409
+
410
+ if (ks[1] % 2) == 0:
411
+ ks = ks[0], ks[1] + 1
412
+
413
+ input = _gaussian_blur2d(input, ks, sigmas)
414
+
415
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
416
+ return output
417
+
418
+
419
+ def _compute_padding(kernel_size):
420
+ """Compute padding tuple."""
421
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
422
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
423
+ if len(kernel_size) < 2:
424
+ raise AssertionError(kernel_size)
425
+ computed = [k - 1 for k in kernel_size]
426
+
427
+ # for even kernels we need to do asymmetric padding :(
428
+ out_padding = 2 * len(kernel_size) * [0]
429
+
430
+ for i in range(len(kernel_size)):
431
+ computed_tmp = computed[-(i + 1)]
432
+
433
+ pad_front = computed_tmp // 2
434
+ pad_rear = computed_tmp - pad_front
435
+
436
+ out_padding[2 * i + 0] = pad_front
437
+ out_padding[2 * i + 1] = pad_rear
438
+
439
+ return out_padding
440
+
441
+
442
+ def _filter2d(input, kernel):
443
+ # prepare kernel
444
+ b, c, h, w = input.shape
445
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
446
+
447
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
448
+
449
+ height, width = tmp_kernel.shape[-2:]
450
+
451
+ padding_shape: list[int] = _compute_padding([height, width])
452
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
453
+
454
+ # kernel and input tensor reshape to align element-wise or batch-wise params
455
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
456
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
457
+
458
+ # convolve the tensor with the kernel.
459
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
460
+
461
+ out = output.view(b, c, h, w)
462
+ return out
463
+
464
+
465
+ def _gaussian(window_size: int, sigma):
466
+ if isinstance(sigma, float):
467
+ sigma = torch.tensor([[sigma]])
468
+
469
+ batch_size = sigma.shape[0]
470
+
471
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
472
+
473
+ if window_size % 2 == 0:
474
+ x = x + 0.5
475
+
476
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
477
+
478
+ return gauss / gauss.sum(-1, keepdim=True)
479
+
480
+
481
+ def _gaussian_blur2d(input, kernel_size, sigma):
482
+ if isinstance(sigma, tuple):
483
+ sigma = torch.tensor([sigma], dtype=input.dtype)
484
+ else:
485
+ sigma = sigma.to(dtype=input.dtype)
486
+
487
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
488
+ bs = sigma.shape[0]
489
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
490
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
491
+ out_x = _filter2d(input, kernel_x[..., None, :])
492
+ out = _filter2d(out_x, kernel_y[..., None])
493
+
494
+ return out
normalcrafter/unet.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import UNetSpatioTemporalConditionModel
2
+ from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
3
+ from diffusers.utils import is_torch_version
4
+ import torch
5
+ from typing import Any, Dict, Optional, Tuple, Union
6
+
7
+ def create_custom_forward(module, return_dict=None):
8
+ def custom_forward(*inputs):
9
+ if return_dict is not None:
10
+ return module(*inputs, return_dict=return_dict)
11
+ else:
12
+ return module(*inputs)
13
+
14
+ return custom_forward
15
+ CKPT_KWARGS = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
16
+
17
+
18
+ class DiffusersUNetSpatioTemporalConditionModelNormalCrafter(UNetSpatioTemporalConditionModel):
19
+
20
+ @staticmethod
21
+ def forward_crossattn_down_block_dino(
22
+ module,
23
+ hidden_states: torch.Tensor,
24
+ temb: Optional[torch.Tensor] = None,
25
+ encoder_hidden_states: Optional[torch.Tensor] = None,
26
+ image_only_indicator: Optional[torch.Tensor] = None,
27
+ dino_down_block_res_samples = None,
28
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
29
+ output_states = ()
30
+ self = module
31
+ blocks = list(zip(self.resnets, self.attentions))
32
+ for resnet, attn in blocks:
33
+ if self.training and self.gradient_checkpointing: # TODO
34
+ hidden_states = torch.utils.checkpoint.checkpoint(
35
+ create_custom_forward(resnet),
36
+ hidden_states,
37
+ temb,
38
+ image_only_indicator,
39
+ **CKPT_KWARGS,
40
+ )
41
+
42
+ hidden_states = torch.utils.checkpoint.checkpoint(
43
+ create_custom_forward(attn),
44
+ hidden_states,
45
+ encoder_hidden_states,
46
+ image_only_indicator,
47
+ False,
48
+ **CKPT_KWARGS,
49
+ )[0]
50
+ else:
51
+ hidden_states = resnet(
52
+ hidden_states,
53
+ temb,
54
+ image_only_indicator=image_only_indicator,
55
+ )
56
+ hidden_states = attn(
57
+ hidden_states,
58
+ encoder_hidden_states=encoder_hidden_states,
59
+ image_only_indicator=image_only_indicator,
60
+ return_dict=False,
61
+ )[0]
62
+
63
+ if dino_down_block_res_samples is not None:
64
+ hidden_states += dino_down_block_res_samples.pop(0)
65
+
66
+ output_states = output_states + (hidden_states,)
67
+
68
+ if self.downsamplers is not None:
69
+ for downsampler in self.downsamplers:
70
+ hidden_states = downsampler(hidden_states)
71
+ if dino_down_block_res_samples is not None:
72
+ hidden_states += dino_down_block_res_samples.pop(0)
73
+
74
+ output_states = output_states + (hidden_states,)
75
+
76
+ return hidden_states, output_states
77
+ @staticmethod
78
+ def forward_down_block_dino(
79
+ module,
80
+ hidden_states: torch.Tensor,
81
+ temb: Optional[torch.Tensor] = None,
82
+ image_only_indicator: Optional[torch.Tensor] = None,
83
+ dino_down_block_res_samples = None,
84
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
85
+ self = module
86
+ output_states = ()
87
+ for resnet in self.resnets:
88
+ if self.training and self.gradient_checkpointing:
89
+ if is_torch_version(">=", "1.11.0"):
90
+ hidden_states = torch.utils.checkpoint.checkpoint(
91
+ create_custom_forward(resnet),
92
+ hidden_states,
93
+ temb,
94
+ image_only_indicator,
95
+ use_reentrant=False,
96
+ )
97
+ else:
98
+ hidden_states = torch.utils.checkpoint.checkpoint(
99
+ create_custom_forward(resnet),
100
+ hidden_states,
101
+ temb,
102
+ image_only_indicator,
103
+ )
104
+ else:
105
+ hidden_states = resnet(
106
+ hidden_states,
107
+ temb,
108
+ image_only_indicator=image_only_indicator,
109
+ )
110
+ if dino_down_block_res_samples is not None:
111
+ hidden_states += dino_down_block_res_samples.pop(0)
112
+ output_states = output_states + (hidden_states,)
113
+
114
+ if self.downsamplers is not None:
115
+ for downsampler in self.downsamplers:
116
+ hidden_states = downsampler(hidden_states)
117
+ if dino_down_block_res_samples is not None:
118
+ hidden_states += dino_down_block_res_samples.pop(0)
119
+ output_states = output_states + (hidden_states,)
120
+
121
+ return hidden_states, output_states
122
+
123
+
124
+ def forward(
125
+ self,
126
+ sample: torch.FloatTensor,
127
+ timestep: Union[torch.Tensor, float, int],
128
+ encoder_hidden_states: torch.Tensor,
129
+ added_time_ids: torch.Tensor,
130
+ return_dict: bool = True,
131
+ image_controlnet_down_block_res_samples = None,
132
+ image_controlnet_mid_block_res_sample = None,
133
+ dino_down_block_res_samples = None,
134
+
135
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
136
+ r"""
137
+ The [`UNetSpatioTemporalConditionModel`] forward method.
138
+
139
+ Args:
140
+ sample (`torch.FloatTensor`):
141
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
142
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
143
+ encoder_hidden_states (`torch.FloatTensor`):
144
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
145
+ added_time_ids: (`torch.FloatTensor`):
146
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
147
+ embeddings and added to the time embeddings.
148
+ return_dict (`bool`, *optional*, defaults to `True`):
149
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
150
+ tuple.
151
+ Returns:
152
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
153
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
154
+ a `tuple` is returned where the first element is the sample tensor.
155
+ """
156
+ if not hasattr(self, "custom_gradient_checkpointing"):
157
+ self.custom_gradient_checkpointing = False
158
+
159
+ # 1. time
160
+ timesteps = timestep
161
+ if not torch.is_tensor(timesteps):
162
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
163
+ # This would be a good case for the `match` statement (Python 3.10+)
164
+ is_mps = sample.device.type == "mps"
165
+ if isinstance(timestep, float):
166
+ dtype = torch.float32 if is_mps else torch.float64
167
+ else:
168
+ dtype = torch.int32 if is_mps else torch.int64
169
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
170
+ elif len(timesteps.shape) == 0:
171
+ timesteps = timesteps[None].to(sample.device)
172
+
173
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
174
+ batch_size, num_frames = sample.shape[:2]
175
+ if len(timesteps.shape) == 1:
176
+ timesteps = timesteps.expand(batch_size)
177
+ else:
178
+ timesteps = timesteps.reshape(batch_size * num_frames)
179
+ t_emb = self.time_proj(timesteps) # (B, C)
180
+
181
+ # `Timesteps` does not contain any weights and will always return f32 tensors
182
+ # but time_embedding might actually be running in fp16. so we need to cast here.
183
+ # there might be better ways to encapsulate this.
184
+ t_emb = t_emb.to(dtype=sample.dtype)
185
+
186
+ emb = self.time_embedding(t_emb) # (B, C)
187
+
188
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
189
+ time_embeds = time_embeds.reshape((batch_size, -1))
190
+ time_embeds = time_embeds.to(emb.dtype)
191
+ aug_emb = self.add_embedding(time_embeds)
192
+ if emb.shape[0] == 1:
193
+ emb = emb + aug_emb
194
+ # Repeat the embeddings num_video_frames times
195
+ # emb: [batch, channels] -> [batch * frames, channels]
196
+ emb = emb.repeat_interleave(num_frames, dim=0)
197
+ else:
198
+ aug_emb = aug_emb.repeat_interleave(num_frames, dim=0)
199
+ emb = emb + aug_emb
200
+
201
+ # Flatten the batch and frames dimensions
202
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
203
+ sample = sample.flatten(0, 1)
204
+
205
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
206
+ # here, our encoder_hidden_states is [batch * frames, 1, channels]
207
+
208
+ if not sample.shape[0] == encoder_hidden_states.shape[0]:
209
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
210
+ # 2. pre-process
211
+ sample = self.conv_in(sample)
212
+
213
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
214
+
215
+ if dino_down_block_res_samples is not None:
216
+ dino_down_block_res_samples = [x for x in dino_down_block_res_samples]
217
+ sample += dino_down_block_res_samples.pop(0)
218
+
219
+ down_block_res_samples = (sample,)
220
+ for downsample_block in self.down_blocks:
221
+ if dino_down_block_res_samples is None:
222
+ if self.custom_gradient_checkpointing:
223
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
224
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
225
+ create_custom_forward(downsample_block),
226
+ sample,
227
+ emb,
228
+ encoder_hidden_states,
229
+ image_only_indicator,
230
+ **CKPT_KWARGS,
231
+ )
232
+ else:
233
+ sample, res_samples = torch.utils.checkpoint.checkpoint(
234
+ create_custom_forward(downsample_block),
235
+ sample,
236
+ emb,
237
+ image_only_indicator,
238
+ **CKPT_KWARGS,
239
+ )
240
+ else:
241
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
242
+ sample, res_samples = downsample_block(
243
+ hidden_states=sample,
244
+ temb=emb,
245
+ encoder_hidden_states=encoder_hidden_states,
246
+ image_only_indicator=image_only_indicator,
247
+ )
248
+ else:
249
+ sample, res_samples = downsample_block(
250
+ hidden_states=sample,
251
+ temb=emb,
252
+ image_only_indicator=image_only_indicator,
253
+ )
254
+ else:
255
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
256
+ sample, res_samples = self.forward_crossattn_down_block_dino(
257
+ downsample_block,
258
+ sample,
259
+ emb,
260
+ encoder_hidden_states,
261
+ image_only_indicator,
262
+ dino_down_block_res_samples,
263
+ )
264
+ else:
265
+ sample, res_samples = self.forward_down_block_dino(
266
+ downsample_block,
267
+ sample,
268
+ emb,
269
+ image_only_indicator,
270
+ dino_down_block_res_samples,
271
+ )
272
+ down_block_res_samples += res_samples
273
+
274
+ if image_controlnet_down_block_res_samples is not None:
275
+ new_down_block_res_samples = ()
276
+
277
+ for down_block_res_sample, image_controlnet_down_block_res_sample in zip(
278
+ down_block_res_samples, image_controlnet_down_block_res_samples
279
+ ):
280
+ down_block_res_sample = (down_block_res_sample + image_controlnet_down_block_res_sample) / 2
281
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
282
+
283
+ down_block_res_samples = new_down_block_res_samples
284
+
285
+ # 4. mid
286
+ if self.custom_gradient_checkpointing:
287
+ sample = torch.utils.checkpoint.checkpoint(
288
+ create_custom_forward(self.mid_block),
289
+ sample,
290
+ emb,
291
+ encoder_hidden_states,
292
+ image_only_indicator,
293
+ **CKPT_KWARGS,
294
+ )
295
+ else:
296
+ sample = self.mid_block(
297
+ hidden_states=sample,
298
+ temb=emb,
299
+ encoder_hidden_states=encoder_hidden_states,
300
+ image_only_indicator=image_only_indicator,
301
+ )
302
+
303
+ if image_controlnet_mid_block_res_sample is not None:
304
+ sample = (sample + image_controlnet_mid_block_res_sample) / 2
305
+
306
+ # 5. up
307
+ mid_up_block_out_samples = [sample, ]
308
+ down_block_out_sampels = []
309
+ for i, upsample_block in enumerate(self.up_blocks):
310
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
311
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
312
+ down_block_out_sampels.append(res_samples[-1])
313
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
314
+ if self.custom_gradient_checkpointing:
315
+ sample = torch.utils.checkpoint.checkpoint(
316
+ create_custom_forward(upsample_block),
317
+ sample,
318
+ res_samples,
319
+ emb,
320
+ encoder_hidden_states,
321
+ image_only_indicator,
322
+ **CKPT_KWARGS
323
+ )
324
+ else:
325
+ sample = upsample_block(
326
+ hidden_states=sample,
327
+ temb=emb,
328
+ res_hidden_states_tuple=res_samples,
329
+ encoder_hidden_states=encoder_hidden_states,
330
+ image_only_indicator=image_only_indicator,
331
+ )
332
+ else:
333
+ if self.custom_gradient_checkpointing:
334
+ sample = torch.utils.checkpoint.checkpoint(
335
+ create_custom_forward(upsample_block),
336
+ sample,
337
+ res_samples,
338
+ emb,
339
+ image_only_indicator,
340
+ **CKPT_KWARGS
341
+ )
342
+ else:
343
+ sample = upsample_block(
344
+ hidden_states=sample,
345
+ temb=emb,
346
+ res_hidden_states_tuple=res_samples,
347
+ image_only_indicator=image_only_indicator,
348
+ )
349
+ mid_up_block_out_samples.append(sample)
350
+ # 6. post-process
351
+ sample = self.conv_norm_out(sample)
352
+ sample = self.conv_act(sample)
353
+ if self.custom_gradient_checkpointing:
354
+ sample = torch.utils.checkpoint.checkpoint(
355
+ create_custom_forward(self.conv_out),
356
+ sample,
357
+ **CKPT_KWARGS
358
+ )
359
+ else:
360
+ sample = self.conv_out(sample)
361
+
362
+ # 7. Reshape back to original shape
363
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
364
+
365
+ if not return_dict:
366
+ return (sample, down_block_out_sampels[::-1], mid_up_block_out_samples)
367
+
368
+ return UNetSpatioTemporalConditionOutput(sample=sample)
normalcrafter/utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, List
2
+ import tempfile
3
+ import numpy as np
4
+ import PIL.Image
5
+ import matplotlib.cm as cm
6
+ import mediapy
7
+ import torch
8
+ from decord import VideoReader, cpu
9
+
10
+
11
+ def read_video_frames(video_path, process_length, target_fps, max_res):
12
+ print("==> processing video: ", video_path)
13
+ vid = VideoReader(video_path, ctx=cpu(0))
14
+ print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
15
+ original_height, original_width = vid.get_batch([0]).shape[1:3]
16
+
17
+ if max(original_height, original_width) > max_res:
18
+ scale = max_res / max(original_height, original_width)
19
+ height = round(original_height * scale)
20
+ width = round(original_width * scale)
21
+ else:
22
+ height = original_height
23
+ width = original_width
24
+
25
+ vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
26
+
27
+ fps = vid.get_avg_fps() if target_fps == -1 else target_fps
28
+ stride = round(vid.get_avg_fps() / fps)
29
+ stride = max(stride, 1)
30
+ frames_idx = list(range(0, len(vid), stride))
31
+ print(
32
+ f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
33
+ )
34
+ if process_length != -1 and process_length < len(frames_idx):
35
+ frames_idx = frames_idx[:process_length]
36
+ print(
37
+ f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
38
+ )
39
+ frames = vid.get_batch(frames_idx).asnumpy().astype(np.uint8)
40
+ frames = [PIL.Image.fromarray(x) for x in frames]
41
+
42
+ return frames, fps
43
+
44
+ def save_video(
45
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
46
+ output_video_path: str = None,
47
+ fps: int = 10,
48
+ crf: int = 18,
49
+ ) -> str:
50
+ if output_video_path is None:
51
+ output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
52
+
53
+ if isinstance(video_frames[0], np.ndarray):
54
+ video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
55
+
56
+ elif isinstance(video_frames[0], PIL.Image.Image):
57
+ video_frames = [np.array(frame) for frame in video_frames]
58
+ mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
59
+ return output_video_path
60
+
61
+ def vis_sequence_normal(normals: np.ndarray):
62
+ normals = normals.clip(-1., 1.)
63
+ normals = normals * 0.5 + 0.5
64
+ return normals
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ diffusers==0.29.1
3
+ numpy==1.26.4
4
+ matplotlib==3.8.4
5
+ transformers==4.41.2
6
+ accelerate==0.30.1
7
+ xformers==0.0.20
8
+ mediapy==1.2.0
9
+ fire==0.6.0
10
+ decord==0.6.0
11
+ OpenEXR==3.2.4
run.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+
6
+ from diffusers.training_utils import set_seed
7
+ from diffusers import AutoencoderKLTemporalDecoder
8
+ from fire import Fire
9
+
10
+ from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
11
+ from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
12
+ from normalcrafter.utils import vis_sequence_normal, save_video, read_video_frames
13
+
14
+
15
+ class DepthCrafterDemo:
16
+ def __init__(
17
+ self,
18
+ unet_path: str,
19
+ pre_train_path: str,
20
+ cpu_offload: str = "model",
21
+ ):
22
+ unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
23
+ unet_path,
24
+ subfolder="unet",
25
+ low_cpu_mem_usage=True,
26
+ )
27
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
28
+ unet_path, subfolder="vae"
29
+ )
30
+ weight_dtype = torch.float16
31
+ vae.to(dtype=weight_dtype)
32
+ unet.to(dtype=weight_dtype)
33
+ # load weights of other components from the provided checkpoint
34
+ self.pipe = NormalCrafterPipeline.from_pretrained(
35
+ pre_train_path,
36
+ unet=unet,
37
+ vae=vae,
38
+ torch_dtype=weight_dtype,
39
+ variant="fp16",
40
+ )
41
+
42
+ # for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
43
+ if cpu_offload is not None:
44
+ if cpu_offload == "sequential":
45
+ # This will slow, but save more memory
46
+ self.pipe.enable_sequential_cpu_offload()
47
+ elif cpu_offload == "model":
48
+ self.pipe.enable_model_cpu_offload()
49
+ else:
50
+ raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
51
+ else:
52
+ self.pipe.to("cuda")
53
+ # enable attention slicing and xformers memory efficient attention
54
+ try:
55
+ self.pipe.enable_xformers_memory_efficient_attention()
56
+ except Exception as e:
57
+ print(e)
58
+ print("Xformers is not enabled")
59
+ # self.pipe.enable_attention_slicing()
60
+
61
+ def infer(
62
+ self,
63
+ video: str,
64
+ save_folder: str = "./demo_output",
65
+ window_size: int = 14,
66
+ time_step_size: int = 10,
67
+ process_length: int = 195,
68
+ decode_chunk_size: int = 7,
69
+ max_res: int = 1024,
70
+ dataset: str = "open",
71
+ target_fps: int = 15,
72
+ seed: int = 42,
73
+ save_npz: bool = False,
74
+ ):
75
+ set_seed(seed)
76
+
77
+ frames, target_fps = read_video_frames(
78
+ video,
79
+ process_length,
80
+ target_fps,
81
+ max_res,
82
+ )
83
+ # inference the depth map using the DepthCrafter pipeline
84
+ with torch.inference_mode():
85
+ res = self.pipe(
86
+ frames,
87
+ decode_chunk_size=decode_chunk_size,
88
+ time_step_size=time_step_size,
89
+ window_size=window_size,
90
+ ).frames[0]
91
+ # visualize the depth map and save the results
92
+ vis = vis_sequence_normal(res)
93
+ # save the depth map and visualization with the target FPS
94
+ save_path = os.path.join(
95
+ save_folder, os.path.splitext(os.path.basename(video))[0]
96
+ )
97
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
98
+ save_video(vis, save_path + "_vis.mp4", fps=target_fps)
99
+ save_video(frames, save_path + "_input.mp4", fps=target_fps)
100
+ if save_npz:
101
+ np.savez_compressed(save_path + ".npz", depth=res)
102
+
103
+ return [
104
+ save_path + "_input.mp4",
105
+ save_path + "_vis.mp4",
106
+ ]
107
+
108
+ def run(
109
+ self,
110
+ input_video,
111
+ num_denoising_steps,
112
+ guidance_scale,
113
+ max_res=1024,
114
+ process_length=195,
115
+ ):
116
+ res_path = self.infer(
117
+ input_video,
118
+ num_denoising_steps,
119
+ guidance_scale,
120
+ max_res=max_res,
121
+ process_length=process_length,
122
+ )
123
+ # clear the cache for the next video
124
+ gc.collect()
125
+ torch.cuda.empty_cache()
126
+ return res_path[:2]
127
+
128
+
129
+ def main(
130
+ video_path: str,
131
+ save_folder: str = "./demo_output",
132
+ unet_path: str = "Yanrui95/NormalCrafter",
133
+ pre_train_path: str = "stabilityai/stable-video-diffusion-img2vid-xt",
134
+ process_length: int = -1,
135
+ cpu_offload: str = "model",
136
+ target_fps: int = -1,
137
+ seed: int = 42,
138
+ window_size: int = 14,
139
+ time_step_size: int = 10,
140
+ max_res: int = 1024,
141
+ dataset: str = "open",
142
+ save_npz: bool = False
143
+ ):
144
+ depthcrafter_demo = DepthCrafterDemo(
145
+ unet_path=unet_path,
146
+ pre_train_path=pre_train_path,
147
+ cpu_offload=cpu_offload,
148
+ )
149
+ # process the videos, the video paths are separated by comma
150
+ video_paths = video_path.split(",")
151
+ for video in video_paths:
152
+ depthcrafter_demo.infer(
153
+ video,
154
+ save_folder=save_folder,
155
+ window_size=window_size,
156
+ process_length=process_length,
157
+ time_step_size=time_step_size,
158
+ max_res=max_res,
159
+ dataset=dataset,
160
+ target_fps=target_fps,
161
+ seed=seed,
162
+ save_npz=save_npz,
163
+ )
164
+ # clear the cache for the next video
165
+ gc.collect()
166
+ torch.cuda.empty_cache()
167
+
168
+
169
+ if __name__ == "__main__":
170
+ # running configs
171
+ # the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
172
+ # the most important arguments for trade-off between quality and speed are
173
+ # `num_inference_steps`, `guidance_scale`, and `max_res`
174
+ Fire(main)