Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- .gitignore +175 -0
- LICENSE +32 -0
- README.md +60 -9
- app.py +210 -0
- examples/example_01.mp4 +3 -0
- examples/example_02.mp4 +3 -0
- examples/example_03.mp4 +3 -0
- examples/example_04.mp4 +3 -0
- examples/example_05.mp4 +3 -0
- examples/example_06.mp4 +3 -0
- normalcrafter/__init__.py +0 -0
- normalcrafter/normal_crafter_ppl.py +494 -0
- normalcrafter/unet.py +368 -0
- normalcrafter/utils.py +64 -0
- requirements.txt +11 -0
- run.py +174 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### Python template
|
2 |
+
# Byte-compiled / optimized / DLL files
|
3 |
+
__pycache__/
|
4 |
+
*.py[cod]
|
5 |
+
*$py.class
|
6 |
+
|
7 |
+
# C extensions
|
8 |
+
*.so
|
9 |
+
|
10 |
+
#
|
11 |
+
.gradio
|
12 |
+
.github
|
13 |
+
demo_output
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
.idea/
|
166 |
+
|
167 |
+
/logs
|
168 |
+
/gin-config
|
169 |
+
*.json
|
170 |
+
/eval/*csv
|
171 |
+
*__pycache__
|
172 |
+
scripts/
|
173 |
+
eval/
|
174 |
+
*.DS_Store
|
175 |
+
benchmark/datasets
|
LICENSE
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software in this distribution may have been modified by THL A29 Limited ("Tencent Modifications").
|
2 |
+
|
3 |
+
License Terms of the inference code of NormalCrafter:
|
4 |
+
--------------------------------------------------------------------
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
7 |
+
|
8 |
+
- You agree to use the NormalCrafter only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances.
|
9 |
+
|
10 |
+
- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
11 |
+
|
12 |
+
For avoidance of doubts, “Software” means the NormalCrafter model inference code and weights made available under this license excluding any pre-trained data and other AI components.
|
13 |
+
|
14 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
15 |
+
|
16 |
+
|
17 |
+
Other dependencies and licenses:
|
18 |
+
|
19 |
+
Open Source Software Licensed under the MIT License:
|
20 |
+
--------------------------------------------------------------------
|
21 |
+
1. Stability AI - Code
|
22 |
+
Copyright (c) 2023 Stability AI
|
23 |
+
|
24 |
+
Terms of the MIT License:
|
25 |
+
--------------------------------------------------------------------
|
26 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
27 |
+
|
28 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
29 |
+
|
30 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
31 |
+
|
32 |
+
**You may find the code license of Stability AI at the following links: https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE
|
README.md
CHANGED
@@ -1,14 +1,65 @@
|
|
1 |
---
|
2 |
title: NormalCrafter
|
3 |
-
emoji: 📉
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: pink
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.23.1
|
8 |
app_file: app.py
|
9 |
-
|
10 |
-
|
11 |
-
short_description: NormalCrafter
|
12 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: NormalCrafter
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 5.23.2
|
|
|
6 |
---
|
7 |
+
## ___***NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors***___
|
8 |
+
|
9 |
+
_**[Yanrui Bin<sup>1</sup>](https://scholar.google.com/citations?user=_9fN3mEAAAAJ&hl=zh-CN),[Wenbo Hu<sup>2*](https://wbhu.github.io),
|
10 |
+
[Haoyuan Wang<sup>3](https://www.whyy.site/),
|
11 |
+
[Xinya Chen<sup>3](https://xinyachen21.github.io/),
|
12 |
+
[Bing Wang<sup>2 †</sup>](https://bingcs.github.io/)**_
|
13 |
+
<br><br>
|
14 |
+
<sup>1</sup>Spatial Intelligence Group, The Hong Kong Polytechnic University
|
15 |
+
<sup>2</sup>Tencent AI Lab
|
16 |
+
<sup>3</sup>City University of Hong Kong
|
17 |
+
<sup>4</sup>Huazhong University of Science and Technology
|
18 |
+
</div>
|
19 |
+
|
20 |
+
## 🔆 Notice
|
21 |
+
We recommend that everyone use English to communicate on issues, as this helps developers from around the world discuss, share experiences, and answer questions together.
|
22 |
+
|
23 |
+
For business licensing and other related inquiries, don't hesitate to contact `binyanrui@gmail.com`.
|
24 |
+
|
25 |
+
## 🔆 Introduction
|
26 |
+
🤗 If you find NormalCrafter useful, **please help ⭐ this repo**, which is important to Open-Source projects. Thanks!
|
27 |
+
|
28 |
+
🔥 NormalCrafter can generate temporally consistent normal sequences
|
29 |
+
with fine-grained details from open-world videos with arbitrary lengths.
|
30 |
+
|
31 |
+
- `[24-04-01]` 🔥🔥🔥 **NormalCrafter** is released now, have fun!
|
32 |
+
## 🚀 Quick Start
|
33 |
+
|
34 |
+
### 🤖 Gradio Demo
|
35 |
+
- Online demo: [NormalCrafter](https://huggingface.co/spaces/Yanrui95/NormalCrafter)
|
36 |
+
- Local demo:
|
37 |
+
```bash
|
38 |
+
gradio app.py
|
39 |
+
```
|
40 |
+
|
41 |
+
### 🛠️ Installation
|
42 |
+
1. Clone this repo:
|
43 |
+
```bash
|
44 |
+
git clone git@github.com:Binyr/NormalCrafter.git
|
45 |
+
```
|
46 |
+
2. Install dependencies (please refer to [requirements.txt](requirements.txt)):
|
47 |
+
```bash
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
### 🤗 Model Zoo
|
54 |
+
[NormalCrafter](https://huggingface.co/Yanrui95/NormalCrafter) is available in the Hugging Face Model Hub.
|
55 |
+
|
56 |
+
### 🏃♂️ Inference
|
57 |
+
#### 1. High-resolution inference, requires a GPU with ~20GB memory for 1024x576 resolution:
|
58 |
+
```bash
|
59 |
+
python run.py --video-path examples/example_01.mp4
|
60 |
+
```
|
61 |
|
62 |
+
#### 2. Low-resolution inference requires a GPU with ~6GB memory for 512x256 resolution:
|
63 |
+
```bash
|
64 |
+
python run.py --video-path examples/example_01.mp4 --max-res 512
|
65 |
+
```
|
app.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import spaces
|
6 |
+
import gradio as gr
|
7 |
+
import torch
|
8 |
+
from diffusers.training_utils import set_seed
|
9 |
+
from diffusers import AutoencoderKLTemporalDecoder
|
10 |
+
|
11 |
+
from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
|
12 |
+
from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
|
13 |
+
|
14 |
+
import uuid
|
15 |
+
import random
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
from normalcrafter.utils import read_video_frames, vis_sequence_normal, save_video
|
19 |
+
|
20 |
+
examples = [
|
21 |
+
["examples/example_01.mp4", 1024, -1, -1],
|
22 |
+
["examples/example_02.mp4", 1024, -1, -1],
|
23 |
+
["examples/example_03.mp4", 1024, -1, -1],
|
24 |
+
["examples/example_04.mp4", 1024, -1, -1],
|
25 |
+
["examples/example_05.mp4", 1024, -1, -1],
|
26 |
+
["examples/example_06.mp4", 1024, -1, -1],
|
27 |
+
]
|
28 |
+
|
29 |
+
pretrained_model_name_or_path = "Yanrui95/NormalCrafter"
|
30 |
+
weight_dtype = torch.float16
|
31 |
+
unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
|
32 |
+
pretrained_model_name_or_path,
|
33 |
+
subfolder="unet",
|
34 |
+
low_cpu_mem_usage=True,
|
35 |
+
)
|
36 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
37 |
+
pretrained_model_name_or_path, subfolder="vae")
|
38 |
+
|
39 |
+
vae.to(dtype=weight_dtype)
|
40 |
+
unet.to(dtype=weight_dtype)
|
41 |
+
|
42 |
+
pipe = NormalCrafterPipeline.from_pretrained(
|
43 |
+
"stabilityai/stable-video-diffusion-img2vid-xt",
|
44 |
+
unet=unet,
|
45 |
+
vae=vae,
|
46 |
+
torch_dtype=weight_dtype,
|
47 |
+
variant="fp16",
|
48 |
+
)
|
49 |
+
pipe.to("cuda")
|
50 |
+
|
51 |
+
|
52 |
+
@spaces.GPU(duration=120)
|
53 |
+
def infer_depth(
|
54 |
+
video: str,
|
55 |
+
max_res: int = 1024,
|
56 |
+
process_length: int = -1,
|
57 |
+
target_fps: int = -1,
|
58 |
+
#
|
59 |
+
save_folder: str = "./demo_output",
|
60 |
+
window_size: int = 14,
|
61 |
+
time_step_size: int = 10,
|
62 |
+
decode_chunk_size: int = 7,
|
63 |
+
seed: int = 42,
|
64 |
+
save_npz: bool = False,
|
65 |
+
):
|
66 |
+
set_seed(seed)
|
67 |
+
pipe.enable_xformers_memory_efficient_attention()
|
68 |
+
|
69 |
+
frames, target_fps = read_video_frames(video, process_length, target_fps, max_res)
|
70 |
+
|
71 |
+
# inference the depth map using the DepthCrafter pipeline
|
72 |
+
with torch.inference_mode():
|
73 |
+
res = pipe(
|
74 |
+
frames,
|
75 |
+
decode_chunk_size=decode_chunk_size,
|
76 |
+
time_step_size=time_step_size,
|
77 |
+
window_size=window_size,
|
78 |
+
).frames[0]
|
79 |
+
|
80 |
+
# visualize the depth map and save the results
|
81 |
+
vis = vis_sequence_normal(res)
|
82 |
+
# save the depth map and visualization with the target FPS
|
83 |
+
save_path = os.path.join(save_folder, os.path.splitext(os.path.basename(video))[0])
|
84 |
+
print(f"==> saving results to {save_path}")
|
85 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
86 |
+
if save_npz:
|
87 |
+
np.savez_compressed(save_path + ".npz", normal=res)
|
88 |
+
save_video(vis, save_path + "_vis.mp4", fps=target_fps)
|
89 |
+
save_video(frames, save_path + "_input.mp4", fps=target_fps)
|
90 |
+
|
91 |
+
# clear the cache for the next video
|
92 |
+
gc.collect()
|
93 |
+
torch.cuda.empty_cache()
|
94 |
+
|
95 |
+
return [
|
96 |
+
save_path + "_input.mp4",
|
97 |
+
save_path + "_vis.mp4",
|
98 |
+
|
99 |
+
]
|
100 |
+
|
101 |
+
|
102 |
+
def construct_demo():
|
103 |
+
with gr.Blocks(analytics_enabled=False) as depthcrafter_iface:
|
104 |
+
gr.Markdown(
|
105 |
+
"""
|
106 |
+
<div align='center'> <h1> NormalCrafter: Learning Temporally Consistent Video Normal from Video Diffusion Priors </span> </h1> \
|
107 |
+
<a style='font-size:18px;color: #000000'>If you find NormalCrafter useful, please help ⭐ the </a>\
|
108 |
+
<a style='font-size:18px;color: #FF5DB0' href='https://github.com/Binyr/NormalCrafter'>[Github Repo]</a>\
|
109 |
+
<a style='font-size:18px;color: #000000'>, which is important to Open-Source projects. Thanks!</a>\
|
110 |
+
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2409.02095'> [ArXiv] </a>\
|
111 |
+
<a style='font-size:18px;color: #000000' href='https://normalcrafter.github.io/'> [Project Page] </a> </div>
|
112 |
+
"""
|
113 |
+
)
|
114 |
+
|
115 |
+
with gr.Row(equal_height=True):
|
116 |
+
with gr.Column(scale=1):
|
117 |
+
input_video = gr.Video(label="Input Video")
|
118 |
+
|
119 |
+
# with gr.Tab(label="Output"):
|
120 |
+
with gr.Column(scale=2):
|
121 |
+
with gr.Row(equal_height=True):
|
122 |
+
output_video_1 = gr.Video(
|
123 |
+
label="Preprocessed video",
|
124 |
+
interactive=False,
|
125 |
+
autoplay=True,
|
126 |
+
loop=True,
|
127 |
+
show_share_button=True,
|
128 |
+
scale=5,
|
129 |
+
)
|
130 |
+
output_video_2 = gr.Video(
|
131 |
+
label="Generated Depth Video",
|
132 |
+
interactive=False,
|
133 |
+
autoplay=True,
|
134 |
+
loop=True,
|
135 |
+
show_share_button=True,
|
136 |
+
scale=5,
|
137 |
+
)
|
138 |
+
|
139 |
+
with gr.Row(equal_height=True):
|
140 |
+
with gr.Column(scale=1):
|
141 |
+
with gr.Row(equal_height=False):
|
142 |
+
with gr.Accordion("Advanced Settings", open=False):
|
143 |
+
max_res = gr.Slider(
|
144 |
+
label="max resolution",
|
145 |
+
minimum=512,
|
146 |
+
maximum=1024,
|
147 |
+
value=1024,
|
148 |
+
step=64,
|
149 |
+
)
|
150 |
+
process_length = gr.Slider(
|
151 |
+
label="process length",
|
152 |
+
minimum=-1,
|
153 |
+
maximum=280,
|
154 |
+
value=60,
|
155 |
+
step=1,
|
156 |
+
)
|
157 |
+
process_target_fps = gr.Slider(
|
158 |
+
label="target FPS",
|
159 |
+
minimum=-1,
|
160 |
+
maximum=30,
|
161 |
+
value=15,
|
162 |
+
step=1,
|
163 |
+
)
|
164 |
+
generate_btn = gr.Button("Generate")
|
165 |
+
with gr.Column(scale=2):
|
166 |
+
pass
|
167 |
+
|
168 |
+
gr.Examples(
|
169 |
+
examples=examples,
|
170 |
+
inputs=[
|
171 |
+
input_video,
|
172 |
+
max_res,
|
173 |
+
process_length,
|
174 |
+
process_target_fps,
|
175 |
+
],
|
176 |
+
outputs=[output_video_1, output_video_2],
|
177 |
+
fn=infer_depth,
|
178 |
+
cache_examples="lazy",
|
179 |
+
)
|
180 |
+
# gr.Markdown(
|
181 |
+
# """
|
182 |
+
# <span style='font-size:18px;color: #E7CCCC'>Note:
|
183 |
+
# For time quota consideration, we set the default parameters to be more efficient here,
|
184 |
+
# with a trade-off of shorter video length and slightly lower quality.
|
185 |
+
# You may adjust the parameters according to our
|
186 |
+
# <a style='font-size:18px;color: #FF5DB0' href='https://github.com/Tencent/DepthCrafter'>[Github Repo]</a>
|
187 |
+
# for better results if you have enough time quota.
|
188 |
+
# </span>
|
189 |
+
# """
|
190 |
+
# )
|
191 |
+
|
192 |
+
generate_btn.click(
|
193 |
+
fn=infer_depth,
|
194 |
+
inputs=[
|
195 |
+
input_video,
|
196 |
+
max_res,
|
197 |
+
process_length,
|
198 |
+
process_target_fps,
|
199 |
+
],
|
200 |
+
outputs=[output_video_1, output_video_2],
|
201 |
+
)
|
202 |
+
|
203 |
+
return depthcrafter_iface
|
204 |
+
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
demo = construct_demo()
|
208 |
+
demo.queue()
|
209 |
+
# demo.launch(server_name="0.0.0.0", server_port=12345, debug=True, share=False)
|
210 |
+
demo.launch(share=True)
|
examples/example_01.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3eb7fefd157bd9b403cf0b524c7c4f3cb6d9f82b9d6a48eba2146412fc9e64a2
|
3 |
+
size 5727137
|
examples/example_02.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea3c4e4c8cd9682d92c25170d8df333fead210118802fbe22198dde478dc5489
|
3 |
+
size 3150525
|
examples/example_03.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d332877a98bb41ff86a639139a03e383e91880bca722bba7e2518878fca54f6
|
3 |
+
size 3013435
|
examples/example_04.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2aa4962216adce71b1c47f395be435b23105df35f3892646e237b935ac1c74f
|
3 |
+
size 3591374
|
examples/example_05.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8d2319060f9a1d3cfcb9de317e4a5b138657fd741c530ed3983f6565c2eda44
|
3 |
+
size 3553683
|
examples/example_06.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3a2619b029129f34884c761cc278b6842620bfed96d4bb52c8aa07bc1d82a8b
|
3 |
+
size 5596872
|
normalcrafter/__init__.py
ADDED
File without changes
|
normalcrafter/normal_crafter_ppl.py
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange
|
9 |
+
import math
|
10 |
+
|
11 |
+
from diffusers.utils import BaseOutput, logging
|
12 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
13 |
+
from diffusers import DiffusionPipeline
|
14 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput, StableVideoDiffusionPipeline
|
15 |
+
from PIL import Image
|
16 |
+
import cv2
|
17 |
+
|
18 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
19 |
+
|
20 |
+
class NormalCrafterPipeline(StableVideoDiffusionPipeline):
|
21 |
+
|
22 |
+
def _encode_image(self, image, device, num_videos_per_prompt, do_classifier_free_guidance, scale=1, image_size=None):
|
23 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
24 |
+
|
25 |
+
if not isinstance(image, torch.Tensor):
|
26 |
+
image = self.video_processor.pil_to_numpy(image) # (0, 255) -> (0, 1)
|
27 |
+
image = self.video_processor.numpy_to_pt(image) # (n, h, w, c) -> (n, c, h, w)
|
28 |
+
|
29 |
+
# We normalize the image before resizing to match with the original implementation.
|
30 |
+
# Then we unnormalize it after resizing.
|
31 |
+
pixel_values = image
|
32 |
+
B, C, H, W = pixel_values.shape
|
33 |
+
patches = [pixel_values]
|
34 |
+
# patches = []
|
35 |
+
for i in range(1, scale):
|
36 |
+
num_patches_HW_this_level = i + 1
|
37 |
+
patch_H = H // num_patches_HW_this_level + 1
|
38 |
+
patch_W = W // num_patches_HW_this_level + 1
|
39 |
+
for j in range(num_patches_HW_this_level):
|
40 |
+
for k in range(num_patches_HW_this_level):
|
41 |
+
patches.append(pixel_values[:, :, j*patch_H:(j+1)*patch_H, k*patch_W:(k+1)*patch_W])
|
42 |
+
|
43 |
+
def encode_image(image):
|
44 |
+
image = image * 2.0 - 1.0
|
45 |
+
if image_size is not None:
|
46 |
+
image = _resize_with_antialiasing(image, image_size)
|
47 |
+
else:
|
48 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
49 |
+
image = (image + 1.0) / 2.0
|
50 |
+
|
51 |
+
# Normalize the image with for CLIP input
|
52 |
+
image = self.feature_extractor(
|
53 |
+
images=image,
|
54 |
+
do_normalize=True,
|
55 |
+
do_center_crop=False,
|
56 |
+
do_resize=False,
|
57 |
+
do_rescale=False,
|
58 |
+
return_tensors="pt",
|
59 |
+
).pixel_values
|
60 |
+
|
61 |
+
image = image.to(device=device, dtype=dtype)
|
62 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
63 |
+
if len(image_embeddings.shape) < 3:
|
64 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
65 |
+
return image_embeddings
|
66 |
+
|
67 |
+
image_embeddings = []
|
68 |
+
for patch in patches:
|
69 |
+
image_embeddings.append(encode_image(patch))
|
70 |
+
image_embeddings = torch.cat(image_embeddings, dim=1)
|
71 |
+
|
72 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
73 |
+
# import pdb
|
74 |
+
# pdb.set_trace()
|
75 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
76 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
77 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
78 |
+
|
79 |
+
if do_classifier_free_guidance:
|
80 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
81 |
+
|
82 |
+
# For classifier free guidance, we need to do two forward passes.
|
83 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
84 |
+
# to avoid doing two forward passes
|
85 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
86 |
+
|
87 |
+
return image_embeddings
|
88 |
+
|
89 |
+
def ecnode_video_vae(self, images, chunk_size: int = 14):
|
90 |
+
if isinstance(images, list):
|
91 |
+
width, height = images[0].size
|
92 |
+
else:
|
93 |
+
height, width = images[0].shape[:2]
|
94 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
95 |
+
if needs_upcasting:
|
96 |
+
self.vae.to(dtype=torch.float32)
|
97 |
+
|
98 |
+
device = self._execution_device
|
99 |
+
images = self.video_processor.preprocess_video(images, height=height, width=width).to(device, self.vae.dtype) # torch type in range(-1, 1) with (1,3,h,w)
|
100 |
+
images = images.squeeze(0) # from (1, c, t, h, w) -> (c, t, h, w)
|
101 |
+
images = images.permute(1,0,2,3) # c, t, h, w -> (t, c, h, w)
|
102 |
+
|
103 |
+
video_latents = []
|
104 |
+
# chunk_size = 14
|
105 |
+
for i in range(0, images.shape[0], chunk_size):
|
106 |
+
video_latents.append(self.vae.encode(images[i : i + chunk_size]).latent_dist.mode())
|
107 |
+
image_latents = torch.cat(video_latents)
|
108 |
+
|
109 |
+
# cast back to fp16 if needed
|
110 |
+
if needs_upcasting:
|
111 |
+
self.vae.to(dtype=torch.float16)
|
112 |
+
|
113 |
+
return image_latents
|
114 |
+
|
115 |
+
def pad_image(self, images, scale=64):
|
116 |
+
def get_pad(newW, W):
|
117 |
+
pad_W = (newW - W) // 2
|
118 |
+
if W % 2 == 1:
|
119 |
+
pad_Ws = [pad_W, pad_W + 1]
|
120 |
+
else:
|
121 |
+
pad_Ws = [pad_W, pad_W]
|
122 |
+
return pad_Ws
|
123 |
+
|
124 |
+
if type(images[0]) is np.ndarray:
|
125 |
+
H, W = images[0].shape[:2]
|
126 |
+
else:
|
127 |
+
W, H = images[0].size
|
128 |
+
|
129 |
+
if W % scale == 0 and H % scale == 0:
|
130 |
+
return images, None
|
131 |
+
newW = int(np.ceil(W / scale) * scale)
|
132 |
+
newH = int(np.ceil(H / scale) * scale)
|
133 |
+
|
134 |
+
pad_Ws = get_pad(newW, W)
|
135 |
+
pad_Hs = get_pad(newH, H)
|
136 |
+
|
137 |
+
new_images = []
|
138 |
+
for image in images:
|
139 |
+
if type(image) is np.ndarray:
|
140 |
+
image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
|
141 |
+
new_images.append(image)
|
142 |
+
else:
|
143 |
+
image = np.array(image)
|
144 |
+
image = cv2.copyMakeBorder(image, *pad_Hs, *pad_Ws, cv2.BORDER_CONSTANT, value=(255,255,255))
|
145 |
+
new_images.append(Image.fromarray(image))
|
146 |
+
return new_images, pad_Hs+pad_Ws
|
147 |
+
|
148 |
+
def unpad_image(self, v, pad_HWs):
|
149 |
+
t, b, l, r = pad_HWs
|
150 |
+
if t > 0 or b > 0:
|
151 |
+
v = v[:, :, t:-b]
|
152 |
+
if l > 0 or r > 0:
|
153 |
+
v = v[:, :, :, l:-r]
|
154 |
+
return v
|
155 |
+
|
156 |
+
@torch.no_grad()
|
157 |
+
def __call__(
|
158 |
+
self,
|
159 |
+
images: Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor],
|
160 |
+
decode_chunk_size: Optional[int] = None,
|
161 |
+
time_step_size: Optional[int] = 1,
|
162 |
+
window_size: Optional[int] = 1,
|
163 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
164 |
+
return_dict: bool = True
|
165 |
+
):
|
166 |
+
images, pad_HWs = self.pad_image(images)
|
167 |
+
|
168 |
+
# 0. Default height and width to unet
|
169 |
+
width, height = images[0].size
|
170 |
+
num_frames = len(images)
|
171 |
+
|
172 |
+
# 1. Check inputs. Raise error if not correct
|
173 |
+
self.check_inputs(images, height, width)
|
174 |
+
|
175 |
+
# 2. Define call parameters
|
176 |
+
batch_size = 1
|
177 |
+
device = self._execution_device
|
178 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
179 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
180 |
+
# corresponds to doing no classifier free guidance.
|
181 |
+
self._guidance_scale = 1.0
|
182 |
+
num_videos_per_prompt = 1
|
183 |
+
do_classifier_free_guidance = False
|
184 |
+
num_inference_steps = 1
|
185 |
+
fps = 7
|
186 |
+
motion_bucket_id = 127
|
187 |
+
noise_aug_strength = 0.
|
188 |
+
num_videos_per_prompt = 1
|
189 |
+
output_type = "np"
|
190 |
+
data_keys = ["normal"]
|
191 |
+
use_linear_merge = True
|
192 |
+
determineTrain = True
|
193 |
+
encode_image_scale = 1
|
194 |
+
encode_image_WH = None
|
195 |
+
|
196 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 7
|
197 |
+
|
198 |
+
# 3. Encode input image using using clip. (num_image * num_videos_per_prompt, 1, 1024)
|
199 |
+
image_embeddings = self._encode_image(images, device, num_videos_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, scale=encode_image_scale, image_size=encode_image_WH)
|
200 |
+
# 4. Encode input image using VAE
|
201 |
+
image_latents = self.ecnode_video_vae(images, chunk_size=decode_chunk_size).to(image_embeddings.dtype)
|
202 |
+
|
203 |
+
# image_latents [num_frames, channels, height, width] ->[1, num_frames, channels, height, width]
|
204 |
+
image_latents = image_latents.unsqueeze(0)
|
205 |
+
|
206 |
+
# 5. Get Added Time IDs
|
207 |
+
added_time_ids = self._get_add_time_ids(
|
208 |
+
fps,
|
209 |
+
motion_bucket_id,
|
210 |
+
noise_aug_strength,
|
211 |
+
image_embeddings.dtype,
|
212 |
+
batch_size,
|
213 |
+
num_videos_per_prompt,
|
214 |
+
do_classifier_free_guidance,
|
215 |
+
)
|
216 |
+
added_time_ids = added_time_ids.to(device)
|
217 |
+
|
218 |
+
# get Start and End frame idx for each window
|
219 |
+
def get_ses(num_frames):
|
220 |
+
ses = []
|
221 |
+
for i in range(0, num_frames, time_step_size):
|
222 |
+
ses.append([i, i+window_size])
|
223 |
+
num_to_remain = 0
|
224 |
+
for se in ses:
|
225 |
+
if se[1] > num_frames:
|
226 |
+
continue
|
227 |
+
num_to_remain += 1
|
228 |
+
ses = ses[:num_to_remain]
|
229 |
+
|
230 |
+
if ses[-1][-1] < num_frames:
|
231 |
+
ses.append([num_frames - window_size, num_frames])
|
232 |
+
return ses
|
233 |
+
ses = get_ses(num_frames)
|
234 |
+
|
235 |
+
pred = None
|
236 |
+
for i, se in enumerate(ses):
|
237 |
+
window_num_frames = window_size
|
238 |
+
window_image_embeddings = image_embeddings[se[0]:se[1]]
|
239 |
+
window_image_latents = image_latents[:, se[0]:se[1]]
|
240 |
+
window_added_time_ids = added_time_ids
|
241 |
+
# import pdb
|
242 |
+
# pdb.set_trace()
|
243 |
+
if i == 0 or time_step_size == window_size:
|
244 |
+
to_replace_latents = None
|
245 |
+
else:
|
246 |
+
last_se = ses[i-1]
|
247 |
+
num_to_replace_latents = last_se[1] - se[0]
|
248 |
+
to_replace_latents = pred[:, -num_to_replace_latents:]
|
249 |
+
|
250 |
+
latents = self.generate(
|
251 |
+
num_inference_steps,
|
252 |
+
device,
|
253 |
+
batch_size,
|
254 |
+
num_videos_per_prompt,
|
255 |
+
window_num_frames,
|
256 |
+
height,
|
257 |
+
width,
|
258 |
+
window_image_embeddings,
|
259 |
+
generator,
|
260 |
+
determineTrain,
|
261 |
+
to_replace_latents,
|
262 |
+
do_classifier_free_guidance,
|
263 |
+
window_image_latents,
|
264 |
+
window_added_time_ids
|
265 |
+
)
|
266 |
+
|
267 |
+
# merge last_latents and current latents in overlap window
|
268 |
+
if to_replace_latents is not None and use_linear_merge:
|
269 |
+
num_img_condition = to_replace_latents.shape[1]
|
270 |
+
weight = torch.linspace(1., 0., num_img_condition+2)[1:-1].to(device)
|
271 |
+
weight = weight[None, :, None, None, None]
|
272 |
+
latents[:, :num_img_condition] = to_replace_latents * weight + latents[:, :num_img_condition] * (1 - weight)
|
273 |
+
|
274 |
+
if pred is None:
|
275 |
+
pred = latents
|
276 |
+
else:
|
277 |
+
pred = torch.cat([pred[:, :se[0]], latents], dim=1)
|
278 |
+
|
279 |
+
if not output_type == "latent":
|
280 |
+
# cast back to fp16 if needed
|
281 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
282 |
+
if needs_upcasting:
|
283 |
+
self.vae.to(dtype=torch.float16)
|
284 |
+
# latents has shape (1, num_frames, 12, h, w)
|
285 |
+
|
286 |
+
def decode_latents(latents, num_frames, decode_chunk_size):
|
287 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size) # in range(-1, 1)
|
288 |
+
frames = self.video_processor.postprocess_video(video=frames, output_type="np")
|
289 |
+
frames = frames * 2 - 1 # from range(0, 1) -> range(-1, 1)
|
290 |
+
return frames
|
291 |
+
|
292 |
+
frames = decode_latents(pred, num_frames, decode_chunk_size)
|
293 |
+
if pad_HWs is not None:
|
294 |
+
frames = self.unpad_image(frames, pad_HWs)
|
295 |
+
else:
|
296 |
+
frames = pred
|
297 |
+
|
298 |
+
self.maybe_free_model_hooks()
|
299 |
+
|
300 |
+
if not return_dict:
|
301 |
+
return frames
|
302 |
+
|
303 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
304 |
+
|
305 |
+
|
306 |
+
def generate(
|
307 |
+
self,
|
308 |
+
num_inference_steps,
|
309 |
+
device,
|
310 |
+
batch_size,
|
311 |
+
num_videos_per_prompt,
|
312 |
+
num_frames,
|
313 |
+
height,
|
314 |
+
width,
|
315 |
+
image_embeddings,
|
316 |
+
generator,
|
317 |
+
determineTrain,
|
318 |
+
to_replace_latents,
|
319 |
+
do_classifier_free_guidance,
|
320 |
+
image_latents,
|
321 |
+
added_time_ids,
|
322 |
+
latents=None,
|
323 |
+
):
|
324 |
+
# 6. Prepare timesteps
|
325 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
326 |
+
timesteps = self.scheduler.timesteps
|
327 |
+
|
328 |
+
# 7. Prepare latent variables
|
329 |
+
num_channels_latents = self.unet.config.in_channels
|
330 |
+
latents = self.prepare_latents(
|
331 |
+
batch_size * num_videos_per_prompt,
|
332 |
+
num_frames,
|
333 |
+
num_channels_latents,
|
334 |
+
height,
|
335 |
+
width,
|
336 |
+
image_embeddings.dtype,
|
337 |
+
device,
|
338 |
+
generator,
|
339 |
+
latents,
|
340 |
+
)
|
341 |
+
if determineTrain:
|
342 |
+
latents[...] = 0.
|
343 |
+
|
344 |
+
# 8. Denoising loop
|
345 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
346 |
+
self._num_timesteps = len(timesteps)
|
347 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
348 |
+
for i, t in enumerate(timesteps):
|
349 |
+
# replace part of latents with conditons. ToDo: t embedding should also replace
|
350 |
+
if to_replace_latents is not None:
|
351 |
+
num_img_condition = to_replace_latents.shape[1]
|
352 |
+
if not determineTrain:
|
353 |
+
_noise = randn_tensor(to_replace_latents.shape, generator=generator, device=device, dtype=image_embeddings.dtype)
|
354 |
+
noisy_to_replace_latents = self.scheduler.add_noise(to_replace_latents, _noise, t.unsqueeze(0))
|
355 |
+
latents[:, :num_img_condition] = noisy_to_replace_latents
|
356 |
+
else:
|
357 |
+
latents[:, :num_img_condition] = to_replace_latents
|
358 |
+
|
359 |
+
|
360 |
+
# expand the latents if we are doing classifier free guidance
|
361 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
362 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
363 |
+
timestep = t
|
364 |
+
# Concatenate image_latents over channels dimention
|
365 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
366 |
+
# predict the noise residual
|
367 |
+
noise_pred = self.unet(
|
368 |
+
latent_model_input,
|
369 |
+
timestep,
|
370 |
+
encoder_hidden_states=image_embeddings,
|
371 |
+
added_time_ids=added_time_ids,
|
372 |
+
return_dict=False,
|
373 |
+
)[0]
|
374 |
+
|
375 |
+
# perform guidance
|
376 |
+
if do_classifier_free_guidance:
|
377 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
378 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
379 |
+
|
380 |
+
# compute the previous noisy sample x_t -> x_t-1
|
381 |
+
scheduler_output = self.scheduler.step(noise_pred, t, latents)
|
382 |
+
latents = scheduler_output.prev_sample
|
383 |
+
|
384 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
385 |
+
progress_bar.update()
|
386 |
+
|
387 |
+
return latents
|
388 |
+
# resizing utils
|
389 |
+
# TODO: clean up later
|
390 |
+
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
|
391 |
+
h, w = input.shape[-2:]
|
392 |
+
factors = (h / size[0], w / size[1])
|
393 |
+
|
394 |
+
# First, we have to determine sigma
|
395 |
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
396 |
+
sigmas = (
|
397 |
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
398 |
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
399 |
+
)
|
400 |
+
|
401 |
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
402 |
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
403 |
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
404 |
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
405 |
+
|
406 |
+
# Make sure it is odd
|
407 |
+
if (ks[0] % 2) == 0:
|
408 |
+
ks = ks[0] + 1, ks[1]
|
409 |
+
|
410 |
+
if (ks[1] % 2) == 0:
|
411 |
+
ks = ks[0], ks[1] + 1
|
412 |
+
|
413 |
+
input = _gaussian_blur2d(input, ks, sigmas)
|
414 |
+
|
415 |
+
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
|
416 |
+
return output
|
417 |
+
|
418 |
+
|
419 |
+
def _compute_padding(kernel_size):
|
420 |
+
"""Compute padding tuple."""
|
421 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
422 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
423 |
+
if len(kernel_size) < 2:
|
424 |
+
raise AssertionError(kernel_size)
|
425 |
+
computed = [k - 1 for k in kernel_size]
|
426 |
+
|
427 |
+
# for even kernels we need to do asymmetric padding :(
|
428 |
+
out_padding = 2 * len(kernel_size) * [0]
|
429 |
+
|
430 |
+
for i in range(len(kernel_size)):
|
431 |
+
computed_tmp = computed[-(i + 1)]
|
432 |
+
|
433 |
+
pad_front = computed_tmp // 2
|
434 |
+
pad_rear = computed_tmp - pad_front
|
435 |
+
|
436 |
+
out_padding[2 * i + 0] = pad_front
|
437 |
+
out_padding[2 * i + 1] = pad_rear
|
438 |
+
|
439 |
+
return out_padding
|
440 |
+
|
441 |
+
|
442 |
+
def _filter2d(input, kernel):
|
443 |
+
# prepare kernel
|
444 |
+
b, c, h, w = input.shape
|
445 |
+
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
|
446 |
+
|
447 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
448 |
+
|
449 |
+
height, width = tmp_kernel.shape[-2:]
|
450 |
+
|
451 |
+
padding_shape: list[int] = _compute_padding([height, width])
|
452 |
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
453 |
+
|
454 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
455 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
456 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
457 |
+
|
458 |
+
# convolve the tensor with the kernel.
|
459 |
+
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
460 |
+
|
461 |
+
out = output.view(b, c, h, w)
|
462 |
+
return out
|
463 |
+
|
464 |
+
|
465 |
+
def _gaussian(window_size: int, sigma):
|
466 |
+
if isinstance(sigma, float):
|
467 |
+
sigma = torch.tensor([[sigma]])
|
468 |
+
|
469 |
+
batch_size = sigma.shape[0]
|
470 |
+
|
471 |
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
472 |
+
|
473 |
+
if window_size % 2 == 0:
|
474 |
+
x = x + 0.5
|
475 |
+
|
476 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
477 |
+
|
478 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
479 |
+
|
480 |
+
|
481 |
+
def _gaussian_blur2d(input, kernel_size, sigma):
|
482 |
+
if isinstance(sigma, tuple):
|
483 |
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
484 |
+
else:
|
485 |
+
sigma = sigma.to(dtype=input.dtype)
|
486 |
+
|
487 |
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
488 |
+
bs = sigma.shape[0]
|
489 |
+
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
|
490 |
+
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
|
491 |
+
out_x = _filter2d(input, kernel_x[..., None, :])
|
492 |
+
out = _filter2d(out_x, kernel_y[..., None])
|
493 |
+
|
494 |
+
return out
|
normalcrafter/unet.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import UNetSpatioTemporalConditionModel
|
2 |
+
from diffusers.models.unets.unet_spatio_temporal_condition import UNetSpatioTemporalConditionOutput
|
3 |
+
from diffusers.utils import is_torch_version
|
4 |
+
import torch
|
5 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
6 |
+
|
7 |
+
def create_custom_forward(module, return_dict=None):
|
8 |
+
def custom_forward(*inputs):
|
9 |
+
if return_dict is not None:
|
10 |
+
return module(*inputs, return_dict=return_dict)
|
11 |
+
else:
|
12 |
+
return module(*inputs)
|
13 |
+
|
14 |
+
return custom_forward
|
15 |
+
CKPT_KWARGS = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
16 |
+
|
17 |
+
|
18 |
+
class DiffusersUNetSpatioTemporalConditionModelNormalCrafter(UNetSpatioTemporalConditionModel):
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def forward_crossattn_down_block_dino(
|
22 |
+
module,
|
23 |
+
hidden_states: torch.Tensor,
|
24 |
+
temb: Optional[torch.Tensor] = None,
|
25 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
26 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
27 |
+
dino_down_block_res_samples = None,
|
28 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
29 |
+
output_states = ()
|
30 |
+
self = module
|
31 |
+
blocks = list(zip(self.resnets, self.attentions))
|
32 |
+
for resnet, attn in blocks:
|
33 |
+
if self.training and self.gradient_checkpointing: # TODO
|
34 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
35 |
+
create_custom_forward(resnet),
|
36 |
+
hidden_states,
|
37 |
+
temb,
|
38 |
+
image_only_indicator,
|
39 |
+
**CKPT_KWARGS,
|
40 |
+
)
|
41 |
+
|
42 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
43 |
+
create_custom_forward(attn),
|
44 |
+
hidden_states,
|
45 |
+
encoder_hidden_states,
|
46 |
+
image_only_indicator,
|
47 |
+
False,
|
48 |
+
**CKPT_KWARGS,
|
49 |
+
)[0]
|
50 |
+
else:
|
51 |
+
hidden_states = resnet(
|
52 |
+
hidden_states,
|
53 |
+
temb,
|
54 |
+
image_only_indicator=image_only_indicator,
|
55 |
+
)
|
56 |
+
hidden_states = attn(
|
57 |
+
hidden_states,
|
58 |
+
encoder_hidden_states=encoder_hidden_states,
|
59 |
+
image_only_indicator=image_only_indicator,
|
60 |
+
return_dict=False,
|
61 |
+
)[0]
|
62 |
+
|
63 |
+
if dino_down_block_res_samples is not None:
|
64 |
+
hidden_states += dino_down_block_res_samples.pop(0)
|
65 |
+
|
66 |
+
output_states = output_states + (hidden_states,)
|
67 |
+
|
68 |
+
if self.downsamplers is not None:
|
69 |
+
for downsampler in self.downsamplers:
|
70 |
+
hidden_states = downsampler(hidden_states)
|
71 |
+
if dino_down_block_res_samples is not None:
|
72 |
+
hidden_states += dino_down_block_res_samples.pop(0)
|
73 |
+
|
74 |
+
output_states = output_states + (hidden_states,)
|
75 |
+
|
76 |
+
return hidden_states, output_states
|
77 |
+
@staticmethod
|
78 |
+
def forward_down_block_dino(
|
79 |
+
module,
|
80 |
+
hidden_states: torch.Tensor,
|
81 |
+
temb: Optional[torch.Tensor] = None,
|
82 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
83 |
+
dino_down_block_res_samples = None,
|
84 |
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
85 |
+
self = module
|
86 |
+
output_states = ()
|
87 |
+
for resnet in self.resnets:
|
88 |
+
if self.training and self.gradient_checkpointing:
|
89 |
+
if is_torch_version(">=", "1.11.0"):
|
90 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
91 |
+
create_custom_forward(resnet),
|
92 |
+
hidden_states,
|
93 |
+
temb,
|
94 |
+
image_only_indicator,
|
95 |
+
use_reentrant=False,
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
99 |
+
create_custom_forward(resnet),
|
100 |
+
hidden_states,
|
101 |
+
temb,
|
102 |
+
image_only_indicator,
|
103 |
+
)
|
104 |
+
else:
|
105 |
+
hidden_states = resnet(
|
106 |
+
hidden_states,
|
107 |
+
temb,
|
108 |
+
image_only_indicator=image_only_indicator,
|
109 |
+
)
|
110 |
+
if dino_down_block_res_samples is not None:
|
111 |
+
hidden_states += dino_down_block_res_samples.pop(0)
|
112 |
+
output_states = output_states + (hidden_states,)
|
113 |
+
|
114 |
+
if self.downsamplers is not None:
|
115 |
+
for downsampler in self.downsamplers:
|
116 |
+
hidden_states = downsampler(hidden_states)
|
117 |
+
if dino_down_block_res_samples is not None:
|
118 |
+
hidden_states += dino_down_block_res_samples.pop(0)
|
119 |
+
output_states = output_states + (hidden_states,)
|
120 |
+
|
121 |
+
return hidden_states, output_states
|
122 |
+
|
123 |
+
|
124 |
+
def forward(
|
125 |
+
self,
|
126 |
+
sample: torch.FloatTensor,
|
127 |
+
timestep: Union[torch.Tensor, float, int],
|
128 |
+
encoder_hidden_states: torch.Tensor,
|
129 |
+
added_time_ids: torch.Tensor,
|
130 |
+
return_dict: bool = True,
|
131 |
+
image_controlnet_down_block_res_samples = None,
|
132 |
+
image_controlnet_mid_block_res_sample = None,
|
133 |
+
dino_down_block_res_samples = None,
|
134 |
+
|
135 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
136 |
+
r"""
|
137 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
sample (`torch.FloatTensor`):
|
141 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
142 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
143 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
144 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
145 |
+
added_time_ids: (`torch.FloatTensor`):
|
146 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
147 |
+
embeddings and added to the time embeddings.
|
148 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
149 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
|
150 |
+
tuple.
|
151 |
+
Returns:
|
152 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
153 |
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
|
154 |
+
a `tuple` is returned where the first element is the sample tensor.
|
155 |
+
"""
|
156 |
+
if not hasattr(self, "custom_gradient_checkpointing"):
|
157 |
+
self.custom_gradient_checkpointing = False
|
158 |
+
|
159 |
+
# 1. time
|
160 |
+
timesteps = timestep
|
161 |
+
if not torch.is_tensor(timesteps):
|
162 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
163 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
164 |
+
is_mps = sample.device.type == "mps"
|
165 |
+
if isinstance(timestep, float):
|
166 |
+
dtype = torch.float32 if is_mps else torch.float64
|
167 |
+
else:
|
168 |
+
dtype = torch.int32 if is_mps else torch.int64
|
169 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
170 |
+
elif len(timesteps.shape) == 0:
|
171 |
+
timesteps = timesteps[None].to(sample.device)
|
172 |
+
|
173 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
174 |
+
batch_size, num_frames = sample.shape[:2]
|
175 |
+
if len(timesteps.shape) == 1:
|
176 |
+
timesteps = timesteps.expand(batch_size)
|
177 |
+
else:
|
178 |
+
timesteps = timesteps.reshape(batch_size * num_frames)
|
179 |
+
t_emb = self.time_proj(timesteps) # (B, C)
|
180 |
+
|
181 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
182 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
183 |
+
# there might be better ways to encapsulate this.
|
184 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
185 |
+
|
186 |
+
emb = self.time_embedding(t_emb) # (B, C)
|
187 |
+
|
188 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
189 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
190 |
+
time_embeds = time_embeds.to(emb.dtype)
|
191 |
+
aug_emb = self.add_embedding(time_embeds)
|
192 |
+
if emb.shape[0] == 1:
|
193 |
+
emb = emb + aug_emb
|
194 |
+
# Repeat the embeddings num_video_frames times
|
195 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
196 |
+
emb = emb.repeat_interleave(num_frames, dim=0)
|
197 |
+
else:
|
198 |
+
aug_emb = aug_emb.repeat_interleave(num_frames, dim=0)
|
199 |
+
emb = emb + aug_emb
|
200 |
+
|
201 |
+
# Flatten the batch and frames dimensions
|
202 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
203 |
+
sample = sample.flatten(0, 1)
|
204 |
+
|
205 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
206 |
+
# here, our encoder_hidden_states is [batch * frames, 1, channels]
|
207 |
+
|
208 |
+
if not sample.shape[0] == encoder_hidden_states.shape[0]:
|
209 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
210 |
+
# 2. pre-process
|
211 |
+
sample = self.conv_in(sample)
|
212 |
+
|
213 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
214 |
+
|
215 |
+
if dino_down_block_res_samples is not None:
|
216 |
+
dino_down_block_res_samples = [x for x in dino_down_block_res_samples]
|
217 |
+
sample += dino_down_block_res_samples.pop(0)
|
218 |
+
|
219 |
+
down_block_res_samples = (sample,)
|
220 |
+
for downsample_block in self.down_blocks:
|
221 |
+
if dino_down_block_res_samples is None:
|
222 |
+
if self.custom_gradient_checkpointing:
|
223 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
224 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
225 |
+
create_custom_forward(downsample_block),
|
226 |
+
sample,
|
227 |
+
emb,
|
228 |
+
encoder_hidden_states,
|
229 |
+
image_only_indicator,
|
230 |
+
**CKPT_KWARGS,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
sample, res_samples = torch.utils.checkpoint.checkpoint(
|
234 |
+
create_custom_forward(downsample_block),
|
235 |
+
sample,
|
236 |
+
emb,
|
237 |
+
image_only_indicator,
|
238 |
+
**CKPT_KWARGS,
|
239 |
+
)
|
240 |
+
else:
|
241 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
242 |
+
sample, res_samples = downsample_block(
|
243 |
+
hidden_states=sample,
|
244 |
+
temb=emb,
|
245 |
+
encoder_hidden_states=encoder_hidden_states,
|
246 |
+
image_only_indicator=image_only_indicator,
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
sample, res_samples = downsample_block(
|
250 |
+
hidden_states=sample,
|
251 |
+
temb=emb,
|
252 |
+
image_only_indicator=image_only_indicator,
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
256 |
+
sample, res_samples = self.forward_crossattn_down_block_dino(
|
257 |
+
downsample_block,
|
258 |
+
sample,
|
259 |
+
emb,
|
260 |
+
encoder_hidden_states,
|
261 |
+
image_only_indicator,
|
262 |
+
dino_down_block_res_samples,
|
263 |
+
)
|
264 |
+
else:
|
265 |
+
sample, res_samples = self.forward_down_block_dino(
|
266 |
+
downsample_block,
|
267 |
+
sample,
|
268 |
+
emb,
|
269 |
+
image_only_indicator,
|
270 |
+
dino_down_block_res_samples,
|
271 |
+
)
|
272 |
+
down_block_res_samples += res_samples
|
273 |
+
|
274 |
+
if image_controlnet_down_block_res_samples is not None:
|
275 |
+
new_down_block_res_samples = ()
|
276 |
+
|
277 |
+
for down_block_res_sample, image_controlnet_down_block_res_sample in zip(
|
278 |
+
down_block_res_samples, image_controlnet_down_block_res_samples
|
279 |
+
):
|
280 |
+
down_block_res_sample = (down_block_res_sample + image_controlnet_down_block_res_sample) / 2
|
281 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
282 |
+
|
283 |
+
down_block_res_samples = new_down_block_res_samples
|
284 |
+
|
285 |
+
# 4. mid
|
286 |
+
if self.custom_gradient_checkpointing:
|
287 |
+
sample = torch.utils.checkpoint.checkpoint(
|
288 |
+
create_custom_forward(self.mid_block),
|
289 |
+
sample,
|
290 |
+
emb,
|
291 |
+
encoder_hidden_states,
|
292 |
+
image_only_indicator,
|
293 |
+
**CKPT_KWARGS,
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
sample = self.mid_block(
|
297 |
+
hidden_states=sample,
|
298 |
+
temb=emb,
|
299 |
+
encoder_hidden_states=encoder_hidden_states,
|
300 |
+
image_only_indicator=image_only_indicator,
|
301 |
+
)
|
302 |
+
|
303 |
+
if image_controlnet_mid_block_res_sample is not None:
|
304 |
+
sample = (sample + image_controlnet_mid_block_res_sample) / 2
|
305 |
+
|
306 |
+
# 5. up
|
307 |
+
mid_up_block_out_samples = [sample, ]
|
308 |
+
down_block_out_sampels = []
|
309 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
310 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
311 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
312 |
+
down_block_out_sampels.append(res_samples[-1])
|
313 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
314 |
+
if self.custom_gradient_checkpointing:
|
315 |
+
sample = torch.utils.checkpoint.checkpoint(
|
316 |
+
create_custom_forward(upsample_block),
|
317 |
+
sample,
|
318 |
+
res_samples,
|
319 |
+
emb,
|
320 |
+
encoder_hidden_states,
|
321 |
+
image_only_indicator,
|
322 |
+
**CKPT_KWARGS
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
sample = upsample_block(
|
326 |
+
hidden_states=sample,
|
327 |
+
temb=emb,
|
328 |
+
res_hidden_states_tuple=res_samples,
|
329 |
+
encoder_hidden_states=encoder_hidden_states,
|
330 |
+
image_only_indicator=image_only_indicator,
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
if self.custom_gradient_checkpointing:
|
334 |
+
sample = torch.utils.checkpoint.checkpoint(
|
335 |
+
create_custom_forward(upsample_block),
|
336 |
+
sample,
|
337 |
+
res_samples,
|
338 |
+
emb,
|
339 |
+
image_only_indicator,
|
340 |
+
**CKPT_KWARGS
|
341 |
+
)
|
342 |
+
else:
|
343 |
+
sample = upsample_block(
|
344 |
+
hidden_states=sample,
|
345 |
+
temb=emb,
|
346 |
+
res_hidden_states_tuple=res_samples,
|
347 |
+
image_only_indicator=image_only_indicator,
|
348 |
+
)
|
349 |
+
mid_up_block_out_samples.append(sample)
|
350 |
+
# 6. post-process
|
351 |
+
sample = self.conv_norm_out(sample)
|
352 |
+
sample = self.conv_act(sample)
|
353 |
+
if self.custom_gradient_checkpointing:
|
354 |
+
sample = torch.utils.checkpoint.checkpoint(
|
355 |
+
create_custom_forward(self.conv_out),
|
356 |
+
sample,
|
357 |
+
**CKPT_KWARGS
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
sample = self.conv_out(sample)
|
361 |
+
|
362 |
+
# 7. Reshape back to original shape
|
363 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
364 |
+
|
365 |
+
if not return_dict:
|
366 |
+
return (sample, down_block_out_sampels[::-1], mid_up_block_out_samples)
|
367 |
+
|
368 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
normalcrafter/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union, List
|
2 |
+
import tempfile
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import matplotlib.cm as cm
|
6 |
+
import mediapy
|
7 |
+
import torch
|
8 |
+
from decord import VideoReader, cpu
|
9 |
+
|
10 |
+
|
11 |
+
def read_video_frames(video_path, process_length, target_fps, max_res):
|
12 |
+
print("==> processing video: ", video_path)
|
13 |
+
vid = VideoReader(video_path, ctx=cpu(0))
|
14 |
+
print("==> original video shape: ", (len(vid), *vid.get_batch([0]).shape[1:]))
|
15 |
+
original_height, original_width = vid.get_batch([0]).shape[1:3]
|
16 |
+
|
17 |
+
if max(original_height, original_width) > max_res:
|
18 |
+
scale = max_res / max(original_height, original_width)
|
19 |
+
height = round(original_height * scale)
|
20 |
+
width = round(original_width * scale)
|
21 |
+
else:
|
22 |
+
height = original_height
|
23 |
+
width = original_width
|
24 |
+
|
25 |
+
vid = VideoReader(video_path, ctx=cpu(0), width=width, height=height)
|
26 |
+
|
27 |
+
fps = vid.get_avg_fps() if target_fps == -1 else target_fps
|
28 |
+
stride = round(vid.get_avg_fps() / fps)
|
29 |
+
stride = max(stride, 1)
|
30 |
+
frames_idx = list(range(0, len(vid), stride))
|
31 |
+
print(
|
32 |
+
f"==> downsampled shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}, with stride: {stride}"
|
33 |
+
)
|
34 |
+
if process_length != -1 and process_length < len(frames_idx):
|
35 |
+
frames_idx = frames_idx[:process_length]
|
36 |
+
print(
|
37 |
+
f"==> final processing shape: {len(frames_idx), *vid.get_batch([0]).shape[1:]}"
|
38 |
+
)
|
39 |
+
frames = vid.get_batch(frames_idx).asnumpy().astype(np.uint8)
|
40 |
+
frames = [PIL.Image.fromarray(x) for x in frames]
|
41 |
+
|
42 |
+
return frames, fps
|
43 |
+
|
44 |
+
def save_video(
|
45 |
+
video_frames: Union[List[np.ndarray], List[PIL.Image.Image]],
|
46 |
+
output_video_path: str = None,
|
47 |
+
fps: int = 10,
|
48 |
+
crf: int = 18,
|
49 |
+
) -> str:
|
50 |
+
if output_video_path is None:
|
51 |
+
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
52 |
+
|
53 |
+
if isinstance(video_frames[0], np.ndarray):
|
54 |
+
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
|
55 |
+
|
56 |
+
elif isinstance(video_frames[0], PIL.Image.Image):
|
57 |
+
video_frames = [np.array(frame) for frame in video_frames]
|
58 |
+
mediapy.write_video(output_video_path, video_frames, fps=fps, crf=crf)
|
59 |
+
return output_video_path
|
60 |
+
|
61 |
+
def vis_sequence_normal(normals: np.ndarray):
|
62 |
+
normals = normals.clip(-1., 1.)
|
63 |
+
normals = normals * 0.5 + 0.5
|
64 |
+
return normals
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
diffusers==0.29.1
|
3 |
+
numpy==1.26.4
|
4 |
+
matplotlib==3.8.4
|
5 |
+
transformers==4.41.2
|
6 |
+
accelerate==0.30.1
|
7 |
+
xformers==0.0.20
|
8 |
+
mediapy==1.2.0
|
9 |
+
fire==0.6.0
|
10 |
+
decord==0.6.0
|
11 |
+
OpenEXR==3.2.4
|
run.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from diffusers.training_utils import set_seed
|
7 |
+
from diffusers import AutoencoderKLTemporalDecoder
|
8 |
+
from fire import Fire
|
9 |
+
|
10 |
+
from normalcrafter.normal_crafter_ppl import NormalCrafterPipeline
|
11 |
+
from normalcrafter.unet import DiffusersUNetSpatioTemporalConditionModelNormalCrafter
|
12 |
+
from normalcrafter.utils import vis_sequence_normal, save_video, read_video_frames
|
13 |
+
|
14 |
+
|
15 |
+
class DepthCrafterDemo:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
unet_path: str,
|
19 |
+
pre_train_path: str,
|
20 |
+
cpu_offload: str = "model",
|
21 |
+
):
|
22 |
+
unet = DiffusersUNetSpatioTemporalConditionModelNormalCrafter.from_pretrained(
|
23 |
+
unet_path,
|
24 |
+
subfolder="unet",
|
25 |
+
low_cpu_mem_usage=True,
|
26 |
+
)
|
27 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
28 |
+
unet_path, subfolder="vae"
|
29 |
+
)
|
30 |
+
weight_dtype = torch.float16
|
31 |
+
vae.to(dtype=weight_dtype)
|
32 |
+
unet.to(dtype=weight_dtype)
|
33 |
+
# load weights of other components from the provided checkpoint
|
34 |
+
self.pipe = NormalCrafterPipeline.from_pretrained(
|
35 |
+
pre_train_path,
|
36 |
+
unet=unet,
|
37 |
+
vae=vae,
|
38 |
+
torch_dtype=weight_dtype,
|
39 |
+
variant="fp16",
|
40 |
+
)
|
41 |
+
|
42 |
+
# for saving memory, we can offload the model to CPU, or even run the model sequentially to save more memory
|
43 |
+
if cpu_offload is not None:
|
44 |
+
if cpu_offload == "sequential":
|
45 |
+
# This will slow, but save more memory
|
46 |
+
self.pipe.enable_sequential_cpu_offload()
|
47 |
+
elif cpu_offload == "model":
|
48 |
+
self.pipe.enable_model_cpu_offload()
|
49 |
+
else:
|
50 |
+
raise ValueError(f"Unknown cpu offload option: {cpu_offload}")
|
51 |
+
else:
|
52 |
+
self.pipe.to("cuda")
|
53 |
+
# enable attention slicing and xformers memory efficient attention
|
54 |
+
try:
|
55 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
56 |
+
except Exception as e:
|
57 |
+
print(e)
|
58 |
+
print("Xformers is not enabled")
|
59 |
+
# self.pipe.enable_attention_slicing()
|
60 |
+
|
61 |
+
def infer(
|
62 |
+
self,
|
63 |
+
video: str,
|
64 |
+
save_folder: str = "./demo_output",
|
65 |
+
window_size: int = 14,
|
66 |
+
time_step_size: int = 10,
|
67 |
+
process_length: int = 195,
|
68 |
+
decode_chunk_size: int = 7,
|
69 |
+
max_res: int = 1024,
|
70 |
+
dataset: str = "open",
|
71 |
+
target_fps: int = 15,
|
72 |
+
seed: int = 42,
|
73 |
+
save_npz: bool = False,
|
74 |
+
):
|
75 |
+
set_seed(seed)
|
76 |
+
|
77 |
+
frames, target_fps = read_video_frames(
|
78 |
+
video,
|
79 |
+
process_length,
|
80 |
+
target_fps,
|
81 |
+
max_res,
|
82 |
+
)
|
83 |
+
# inference the depth map using the DepthCrafter pipeline
|
84 |
+
with torch.inference_mode():
|
85 |
+
res = self.pipe(
|
86 |
+
frames,
|
87 |
+
decode_chunk_size=decode_chunk_size,
|
88 |
+
time_step_size=time_step_size,
|
89 |
+
window_size=window_size,
|
90 |
+
).frames[0]
|
91 |
+
# visualize the depth map and save the results
|
92 |
+
vis = vis_sequence_normal(res)
|
93 |
+
# save the depth map and visualization with the target FPS
|
94 |
+
save_path = os.path.join(
|
95 |
+
save_folder, os.path.splitext(os.path.basename(video))[0]
|
96 |
+
)
|
97 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
98 |
+
save_video(vis, save_path + "_vis.mp4", fps=target_fps)
|
99 |
+
save_video(frames, save_path + "_input.mp4", fps=target_fps)
|
100 |
+
if save_npz:
|
101 |
+
np.savez_compressed(save_path + ".npz", depth=res)
|
102 |
+
|
103 |
+
return [
|
104 |
+
save_path + "_input.mp4",
|
105 |
+
save_path + "_vis.mp4",
|
106 |
+
]
|
107 |
+
|
108 |
+
def run(
|
109 |
+
self,
|
110 |
+
input_video,
|
111 |
+
num_denoising_steps,
|
112 |
+
guidance_scale,
|
113 |
+
max_res=1024,
|
114 |
+
process_length=195,
|
115 |
+
):
|
116 |
+
res_path = self.infer(
|
117 |
+
input_video,
|
118 |
+
num_denoising_steps,
|
119 |
+
guidance_scale,
|
120 |
+
max_res=max_res,
|
121 |
+
process_length=process_length,
|
122 |
+
)
|
123 |
+
# clear the cache for the next video
|
124 |
+
gc.collect()
|
125 |
+
torch.cuda.empty_cache()
|
126 |
+
return res_path[:2]
|
127 |
+
|
128 |
+
|
129 |
+
def main(
|
130 |
+
video_path: str,
|
131 |
+
save_folder: str = "./demo_output",
|
132 |
+
unet_path: str = "Yanrui95/NormalCrafter",
|
133 |
+
pre_train_path: str = "stabilityai/stable-video-diffusion-img2vid-xt",
|
134 |
+
process_length: int = -1,
|
135 |
+
cpu_offload: str = "model",
|
136 |
+
target_fps: int = -1,
|
137 |
+
seed: int = 42,
|
138 |
+
window_size: int = 14,
|
139 |
+
time_step_size: int = 10,
|
140 |
+
max_res: int = 1024,
|
141 |
+
dataset: str = "open",
|
142 |
+
save_npz: bool = False
|
143 |
+
):
|
144 |
+
depthcrafter_demo = DepthCrafterDemo(
|
145 |
+
unet_path=unet_path,
|
146 |
+
pre_train_path=pre_train_path,
|
147 |
+
cpu_offload=cpu_offload,
|
148 |
+
)
|
149 |
+
# process the videos, the video paths are separated by comma
|
150 |
+
video_paths = video_path.split(",")
|
151 |
+
for video in video_paths:
|
152 |
+
depthcrafter_demo.infer(
|
153 |
+
video,
|
154 |
+
save_folder=save_folder,
|
155 |
+
window_size=window_size,
|
156 |
+
process_length=process_length,
|
157 |
+
time_step_size=time_step_size,
|
158 |
+
max_res=max_res,
|
159 |
+
dataset=dataset,
|
160 |
+
target_fps=target_fps,
|
161 |
+
seed=seed,
|
162 |
+
save_npz=save_npz,
|
163 |
+
)
|
164 |
+
# clear the cache for the next video
|
165 |
+
gc.collect()
|
166 |
+
torch.cuda.empty_cache()
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
# running configs
|
171 |
+
# the most important arguments for memory saving are `cpu_offload`, `enable_xformers`, `max_res`, and `window_size`
|
172 |
+
# the most important arguments for trade-off between quality and speed are
|
173 |
+
# `num_inference_steps`, `guidance_scale`, and `max_res`
|
174 |
+
Fire(main)
|