Upload 25 files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.md +101 -14
- ball2envmap.py +152 -0
- environment.yml +8 -0
- example/bed.png +3 -0
- exposure2hdr.py +139 -0
- inpaint.py +363 -0
- models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/optimizer.bin +3 -0
- models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/pytorch_lora_weights.safetensors +3 -0
- models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/random_states_0.pkl +3 -0
- models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/scheduler.bin +3 -0
- relighting/argument.py +43 -0
- relighting/ball_processor.py +60 -0
- relighting/dataset.py +412 -0
- relighting/dist_utils.py +154 -0
- relighting/image_processor.py +141 -0
- relighting/inpainter.py +424 -0
- relighting/mask_utils.py +124 -0
- relighting/pipeline.py +344 -0
- relighting/pipeline_inpaintonly.py +613 -0
- relighting/pipeline_utils.py +185 -0
- relighting/pipeline_xl.py +482 -0
- relighting/tonemapper.py +33 -0
- relighting/utils.py +56 -0
- requirements.txt +25 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example/bed.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Pakkapon Phongthawee
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,101 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffusionLight: Light Probes for Free by Painting a Chrome Ball
|
2 |
+
|
3 |
+
### [Project Page](https://diffusionlight.github.io/) | [Paper](https://arxiv.org/abs/2312.09168) | [Colab](https://colab.research.google.com/drive/15pC4qb9mEtRYsW3utXkk-jnaeVxUy-0S?usp=sharing&sandboxMode=true) | [HuggingFace](https://huggingface.co/DiffusionLight/DiffusionLight)
|
4 |
+
|
5 |
+
[](https://colab.research.google.com/drive/15pC4qb9mEtRYsW3utXkk-jnaeVxUy-0S?usp=sharing&sandboxMode=true)
|
6 |
+
|
7 |
+

|
8 |
+
|
9 |
+
We present a simple yet effective technique to estimate lighting in a single input image. Current techniques rely heavily on HDR panorama datasets to train neural networks to regress an input with limited field-of-view to a full environment map. However, these approaches often struggle with real-world, uncontrolled settings due to the limited diversity and size of their datasets. To address this problem, we leverage diffusion models trained on billions of standard images to render a chrome ball into the input image. Despite its simplicity, this task remains challenging: the diffusion models often insert incorrect or inconsistent objects and cannot readily generate images in HDR format. Our research uncovers a surprising relationship between the appearance of chrome balls and the initial diffusion noise map, which we utilize to consistently generate high-quality chrome balls. We further fine-tune an LDR diffusion model (Stable Diffusion XL) with LoRA, enabling it to perform exposure bracketing for HDR light estimation. Our method produces convincing light estimates across diverse settings and demonstrates superior generalization to in-the-wild scenarios.
|
10 |
+
|
11 |
+
## Table of contents
|
12 |
+
-----
|
13 |
+
* [TL;DR](#Getting-started)
|
14 |
+
* [Installation](#Installation)
|
15 |
+
* [Prediction](#Prediction)
|
16 |
+
* [Evaluation](#Evaluation)
|
17 |
+
* [Citation](#Citation)
|
18 |
+
------
|
19 |
+
|
20 |
+
## Getting started
|
21 |
+
|
22 |
+
```shell
|
23 |
+
conda env create -f environment.yml
|
24 |
+
conda activate diffusionlight
|
25 |
+
pip install -r requirements.txt
|
26 |
+
python inpaint.py --dataset example --output_dir output
|
27 |
+
python ball2envmap.py --ball_dir output/square --envmap_dir output/envmap
|
28 |
+
python exposure2hdr.py --input_dir output/envmap --output_dir output/hdr
|
29 |
+
```
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
To setup the Python environment, you need to run the following commands in both Conda and pip:
|
34 |
+
|
35 |
+
```shell
|
36 |
+
conda env create -f environment.yml
|
37 |
+
conda activate diffusionlight
|
38 |
+
pip install -r requirements.txt
|
39 |
+
```
|
40 |
+
|
41 |
+
Note that Conda is optional. However, if you choose not to use Conda, you must manually install CUDA-toolkit and OpenEXR.
|
42 |
+
|
43 |
+
## Prediction
|
44 |
+
|
45 |
+
### 0. Preparing the image
|
46 |
+
|
47 |
+
Please resize the input image to 1024x1024. If the image is not square, we recommend padding it with a black border.
|
48 |
+
|
49 |
+
### 1. Inpainting the chrome ball
|
50 |
+
|
51 |
+
First, we predict the chrome ball in different exposure values (EV) using the following command:
|
52 |
+
|
53 |
+
```shell
|
54 |
+
python inpaint.py --dataset <input_directory> --output_dir <output_directory>
|
55 |
+
```
|
56 |
+
|
57 |
+
This command outputs three subdirectories: `control`, `raw`, and `square`
|
58 |
+
|
59 |
+
The contents of each directory are:
|
60 |
+
|
61 |
+
- `control`: Conditioned depth map
|
62 |
+
- `raw`: Inpainted image with a chrome ball in the center
|
63 |
+
- `square`: Square-cropped chrome ball (used for the next step)
|
64 |
+
|
65 |
+
|
66 |
+
### 2. Projecting a ball into an environment map
|
67 |
+
|
68 |
+
Next, we project the chrome ball from the previous step to the LDR environment map using the following command:
|
69 |
+
|
70 |
+
```shell
|
71 |
+
python ball2envmap.py --ball_dir <output_directory>/square --envmap_dir <output_directory>/envmap
|
72 |
+
```
|
73 |
+
|
74 |
+
### 3. Compose HDR image
|
75 |
+
|
76 |
+
Finally, we compose an HDR image from multiple LDR environment maps using our custom exposure bracketing:
|
77 |
+
|
78 |
+
```shell
|
79 |
+
python exposure2hdr.py --input_dir <output_directory>/envmap --output_dir <output_directory>/hdr
|
80 |
+
```
|
81 |
+
|
82 |
+
The predicted light estimation will be located at `<output_directory>/hdr` and can be used for downstream tasks such as object insertion. We will also use it to compare with other methods.
|
83 |
+
|
84 |
+
## Evaluation
|
85 |
+
We use the evaluation code from [StyleLight](https://style-light.github.io/) and [Editable Indoor LightEstimation](https://lvsn.github.io/EditableIndoorLight/). You can use their code to measure our score.
|
86 |
+
|
87 |
+
Additionally, we provide a *slightly* modified version of the evaluation code at [DiffusionLight-evaluation](https://github.com/DiffusionLight/DiffusionLight-evaluation) including the test input.
|
88 |
+
|
89 |
+
## Citation
|
90 |
+
|
91 |
+
```
|
92 |
+
@inproceedings{Phongthawee2023DiffusionLight,
|
93 |
+
author = {Phongthawee, Pakkapon and Chinchuthakun, Worameth and Sinsunthithet, Nontaphat and Raj, Amit and Jampani, Varun and Khungurn, Pramook and Suwajanakorn, Supasorn},
|
94 |
+
title = {DiffusionLight: Light Probes for Free by Painting a Chrome Ball},
|
95 |
+
booktitle = {ArXiv},
|
96 |
+
year = {2023},
|
97 |
+
}
|
98 |
+
```
|
99 |
+
|
100 |
+
## Visit us 🦉
|
101 |
+
[](https://vistec.ist/vision) [](https://vistec.ist/)
|
ball2envmap.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# convert the ball to environment map, lat, long format
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import skimage
|
6 |
+
import time
|
7 |
+
import torch
|
8 |
+
import argparse
|
9 |
+
from multiprocessing import Pool
|
10 |
+
from functools import partial
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
import os
|
13 |
+
|
14 |
+
try:
|
15 |
+
import ezexr
|
16 |
+
except:
|
17 |
+
pass
|
18 |
+
|
19 |
+
def create_argparser():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--ball_dir", type=str, required=True ,help='directory that contain the image')
|
22 |
+
parser.add_argument("--envmap_dir", type=str, required=True ,help='directory to output environment map') #dataset name or directory
|
23 |
+
parser.add_argument("--envmap_height", type=int, default=256, help="size of the environment map height in pixel (height)")
|
24 |
+
parser.add_argument("--scale", type=int, default=4, help="scale factor")
|
25 |
+
parser.add_argument("--threads", type=int, default=8, help="num thread for pararell processing")
|
26 |
+
return parser
|
27 |
+
|
28 |
+
def create_envmap_grid(size: int):
|
29 |
+
"""
|
30 |
+
BLENDER CONVENSION
|
31 |
+
Create the grid of environment map that contain the position in sperical coordinate
|
32 |
+
Top left is (0,0) and bottom right is (pi/2, 2pi)
|
33 |
+
"""
|
34 |
+
|
35 |
+
theta = torch.linspace(0, np.pi * 2, size * 2)
|
36 |
+
phi = torch.linspace(0, np.pi, size)
|
37 |
+
|
38 |
+
#use indexing 'xy' torch match vision's homework 3
|
39 |
+
theta, phi = torch.meshgrid(theta, phi ,indexing='xy')
|
40 |
+
|
41 |
+
|
42 |
+
theta_phi = torch.cat([theta[..., None], phi[..., None]], dim=-1)
|
43 |
+
theta_phi = theta_phi.numpy()
|
44 |
+
return theta_phi
|
45 |
+
|
46 |
+
def get_normal_vector(incoming_vector: np.ndarray, reflect_vector: np.ndarray):
|
47 |
+
"""
|
48 |
+
BLENDER CONVENSION
|
49 |
+
incoming_vector: the vector from the point to the camera
|
50 |
+
reflect_vector: the vector from the point to the light source
|
51 |
+
"""
|
52 |
+
#N = 2(R ⋅ I)R - I
|
53 |
+
N = (incoming_vector + reflect_vector) / np.linalg.norm(incoming_vector + reflect_vector, axis=-1, keepdims=True)
|
54 |
+
return N
|
55 |
+
|
56 |
+
def get_cartesian_from_spherical(theta: np.array, phi: np.array, r = 1.0):
|
57 |
+
"""
|
58 |
+
BLENDER CONVENSION
|
59 |
+
theta: vertical angle
|
60 |
+
phi: horizontal angle
|
61 |
+
r: radius
|
62 |
+
"""
|
63 |
+
x = r * np.sin(theta) * np.cos(phi)
|
64 |
+
y = r * np.sin(theta) * np.sin(phi)
|
65 |
+
z = r * np.cos(theta)
|
66 |
+
return np.concatenate([x[...,None],y[...,None],z[...,None]], axis=-1)
|
67 |
+
|
68 |
+
|
69 |
+
def process_image(args: argparse.Namespace, file_name: str):
|
70 |
+
I = np.array([1,0, 0])
|
71 |
+
|
72 |
+
# check if exist, skip!
|
73 |
+
envmap_output_path = os.path.join(args.envmap_dir, file_name)
|
74 |
+
if os.path.exists(envmap_output_path):
|
75 |
+
return None
|
76 |
+
|
77 |
+
# read ball image
|
78 |
+
ball_path = os.path.join(args.ball_dir, file_name)
|
79 |
+
if file_name.endswith(".exr"):
|
80 |
+
ball_image = ezexr.imread(ball_path)
|
81 |
+
else:
|
82 |
+
try:
|
83 |
+
ball_image = skimage.io.imread(ball_path)
|
84 |
+
ball_image = skimage.img_as_float(ball_image)
|
85 |
+
except:
|
86 |
+
return None
|
87 |
+
|
88 |
+
# compute normal map that create from reflect vector
|
89 |
+
env_grid = create_envmap_grid(args.envmap_height * args.scale)
|
90 |
+
reflect_vec = get_cartesian_from_spherical(env_grid[...,1], env_grid[...,0])
|
91 |
+
normal = get_normal_vector(I[None,None], reflect_vec)
|
92 |
+
|
93 |
+
# turn from normal map to position to lookup [Range: 0,1]
|
94 |
+
pos = (normal + 1.0) / 2
|
95 |
+
pos = 1.0 - pos
|
96 |
+
pos = pos[...,1:]
|
97 |
+
|
98 |
+
env_map = None
|
99 |
+
|
100 |
+
# using pytorch method for bilinear interpolation
|
101 |
+
with torch.no_grad():
|
102 |
+
# convert position to pytorch grid look up
|
103 |
+
grid = torch.from_numpy(pos)[None].float()
|
104 |
+
grid = grid * 2 - 1 # convert to range [-1,1]
|
105 |
+
|
106 |
+
# convert ball to support pytorch
|
107 |
+
ball_image = torch.from_numpy(ball_image[None]).float()
|
108 |
+
ball_image = ball_image.permute(0,3,1,2) # [1,3,H,W]
|
109 |
+
|
110 |
+
env_map = torch.nn.functional.grid_sample(ball_image, grid, mode='bilinear', padding_mode='border', align_corners=True)
|
111 |
+
env_map = env_map[0].permute(1,2,0).numpy()
|
112 |
+
|
113 |
+
env_map_default = skimage.transform.resize(env_map, (args.envmap_height, args.envmap_height*2), anti_aliasing=True)
|
114 |
+
if file_name.endswith(".exr"):
|
115 |
+
ezexr.imwrite(envmap_output_path, env_map_default.astype(np.float32))
|
116 |
+
else:
|
117 |
+
env_map_default = skimage.img_as_ubyte(env_map_default)
|
118 |
+
skimage.io.imsave(envmap_output_path, env_map_default)
|
119 |
+
return None
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
def main():
|
125 |
+
|
126 |
+
# running time measuring
|
127 |
+
start_time = time.time()
|
128 |
+
|
129 |
+
# load arguments
|
130 |
+
args = create_argparser().parse_args()
|
131 |
+
|
132 |
+
# make output directory if not exist
|
133 |
+
os.makedirs(args.envmap_dir, exist_ok=True)
|
134 |
+
|
135 |
+
# get all file in the directory
|
136 |
+
files = sorted(os.listdir(args.ball_dir))
|
137 |
+
|
138 |
+
# create partial function for pararell processing
|
139 |
+
process_func = partial(process_image, args)
|
140 |
+
|
141 |
+
# pararell processing
|
142 |
+
with Pool(args.threads) as p:
|
143 |
+
list(tqdm(p.imap(process_func, files), total=len(files)))
|
144 |
+
|
145 |
+
# print total time
|
146 |
+
print("TOTAL TIME: ", time.time() - start_time)
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
main()
|
152 |
+
|
environment.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: diffusionlight
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.11.6
|
7 |
+
- cudatoolkit=11.8
|
8 |
+
- openexr==3.2.1
|
example/bed.png
ADDED
![]() |
Git LFS Details
|
exposure2hdr.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# covnert exposure bracket to HDR output
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
from functools import partial
|
5 |
+
from multiprocessing import Pool
|
6 |
+
from tqdm import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import skimage
|
9 |
+
import ezexr
|
10 |
+
from relighting.tonemapper import TonemapHDR
|
11 |
+
|
12 |
+
def create_argparser():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--input_dir", type=str, required=True, help='directory that contain the image') #dataset name or directory
|
15 |
+
parser.add_argument("--output_dir", type=str, required=True, help='directory that contain the image') #dataset name or directory
|
16 |
+
parser.add_argument("--endwith", type=str, default=".png" ,help='file ending to filter out unwant image')
|
17 |
+
parser.add_argument("--ev_string", type=str, default="_ev" ,help='string that use for search ev value')
|
18 |
+
parser.add_argument("--EV", type=str, default="0, -2.5, -5" ,help='avalible ev value')
|
19 |
+
parser.add_argument("--gamma", default=2.4, help="Gamma value", type=float)
|
20 |
+
parser.add_argument('--preview_output', dest='preview_output', action='store_true')
|
21 |
+
parser.set_defaults(preview_output=False)
|
22 |
+
return parser
|
23 |
+
|
24 |
+
def parse_filename(ev_string, endwith,filename):
|
25 |
+
a = filename.split(ev_string)
|
26 |
+
name = ev_string.join(a[:-1])
|
27 |
+
ev = a[-1].replace(endwith, "")
|
28 |
+
ev = int(ev) / 10
|
29 |
+
return {
|
30 |
+
'name': name,
|
31 |
+
'ev': ev,
|
32 |
+
'filename': filename
|
33 |
+
}
|
34 |
+
|
35 |
+
def process_image(args, info):
|
36 |
+
|
37 |
+
#output directory
|
38 |
+
hdrdir = args.output_dir
|
39 |
+
os.makedirs(hdrdir, exist_ok=True)
|
40 |
+
|
41 |
+
scaler = np.array([0.212671, 0.715160, 0.072169])
|
42 |
+
name = info['name']
|
43 |
+
# ev value for each file
|
44 |
+
evs = [e for e in sorted(info['ev'], reverse = True)]
|
45 |
+
|
46 |
+
# filename
|
47 |
+
files = [info['ev'][e] for e in evs]
|
48 |
+
|
49 |
+
# inital first image
|
50 |
+
image0 = skimage.io.imread(os.path.join(args.input_dir, files[0]))[...,:3]
|
51 |
+
image0 = skimage.img_as_float(image0)
|
52 |
+
image0_linear = np.power(image0, args.gamma)
|
53 |
+
|
54 |
+
# read luminace for every image
|
55 |
+
luminances = []
|
56 |
+
for i in range(len(evs)):
|
57 |
+
# load image
|
58 |
+
path = os.path.join(args.input_dir, files[i])
|
59 |
+
image = skimage.io.imread(path)[...,:3]
|
60 |
+
image = skimage.img_as_float(image)
|
61 |
+
|
62 |
+
# apply gama correction
|
63 |
+
linear_img = np.power(image, args.gamma)
|
64 |
+
|
65 |
+
# convert the brighness
|
66 |
+
linear_img *= 1 / (2 ** evs[i])
|
67 |
+
|
68 |
+
# compute luminace
|
69 |
+
lumi = linear_img @ scaler
|
70 |
+
luminances.append(lumi)
|
71 |
+
|
72 |
+
# start from darkest image
|
73 |
+
out_luminace = luminances[len(evs) - 1]
|
74 |
+
for i in range(len(evs) - 1, 0, -1):
|
75 |
+
# compute mask
|
76 |
+
maxval = 1 / (2 ** evs[i-1])
|
77 |
+
p1 = np.clip((luminances[i-1] - 0.9 * maxval) / (0.1 * maxval), 0, 1)
|
78 |
+
p2 = out_luminace > luminances[i-1]
|
79 |
+
mask = (p1 * p2).astype(np.float32)
|
80 |
+
out_luminace = luminances[i-1] * (1-mask) + out_luminace * mask
|
81 |
+
|
82 |
+
hdr_rgb = image0_linear * (out_luminace / (luminances[0] + 1e-10))[:, :, np.newaxis]
|
83 |
+
|
84 |
+
# tone map for visualization
|
85 |
+
hdr2ldr = TonemapHDR(gamma=args.gamma, percentile=99, max_mapping=0.9)
|
86 |
+
|
87 |
+
|
88 |
+
ldr_rgb, _, _ = hdr2ldr(hdr_rgb)
|
89 |
+
|
90 |
+
ezexr.imwrite(os.path.join(hdrdir, name+".exr"), hdr_rgb.astype(np.float32))
|
91 |
+
if args.preview_output:
|
92 |
+
preview_dir = os.path.join(args.output_dir, "preview")
|
93 |
+
os.makedirs(preview_dir, exist_ok=True)
|
94 |
+
bracket = []
|
95 |
+
for s in 2 ** np.linspace(0, evs[-1], 10): #evs[-1] is -5
|
96 |
+
lumi = np.clip((s * hdr_rgb) ** (1/args.gamma), 0, 1)
|
97 |
+
bracket.append(lumi)
|
98 |
+
bracket = np.concatenate(bracket, axis=1)
|
99 |
+
skimage.io.imsave(os.path.join(preview_dir, name+".png"), skimage.img_as_ubyte(bracket))
|
100 |
+
return None
|
101 |
+
|
102 |
+
def main():
|
103 |
+
# load arguments
|
104 |
+
args = create_argparser().parse_args()
|
105 |
+
|
106 |
+
files = sorted(os.listdir(args.input_dir))
|
107 |
+
|
108 |
+
#filter file out with file ending
|
109 |
+
files = [f for f in files if f.endswith(args.endwith)]
|
110 |
+
evs = [float(e.strip()) for e in args.EV.split(",")]
|
111 |
+
|
112 |
+
# parse into useful data
|
113 |
+
files = [parse_filename(args.ev_string, args.endwith, f) for f in files]
|
114 |
+
|
115 |
+
# filter out unused ev
|
116 |
+
files = [f for f in files if f['ev'] in evs]
|
117 |
+
|
118 |
+
info = {}
|
119 |
+
for f in files:
|
120 |
+
if not f['name'] in info:
|
121 |
+
info[f['name']] = {}
|
122 |
+
info[f['name']][f['ev']] = f['filename']
|
123 |
+
|
124 |
+
infolist = []
|
125 |
+
for k in info:
|
126 |
+
if len(info[k]) != len(evs):
|
127 |
+
print("WARNING: missing ev in ", k)
|
128 |
+
continue
|
129 |
+
# convert to list data
|
130 |
+
infolist.append({'name': k, 'ev': info[k]})
|
131 |
+
|
132 |
+
fn = partial(process_image, args)
|
133 |
+
with Pool(8) as p:
|
134 |
+
r = list(tqdm(p.imap(fn, infolist), total=len(infolist)))
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == "__main__":
|
139 |
+
main()
|
inpaint.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# inpaint the ball on an image
|
2 |
+
# this one is design for general image that does not require special location to place
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import argparse
|
7 |
+
import numpy as np
|
8 |
+
import torch.distributed as dist
|
9 |
+
import os
|
10 |
+
from PIL import Image
|
11 |
+
from tqdm.auto import tqdm
|
12 |
+
import json
|
13 |
+
|
14 |
+
|
15 |
+
from relighting.inpainter import BallInpainter
|
16 |
+
|
17 |
+
from relighting.mask_utils import MaskGenerator
|
18 |
+
from relighting.ball_processor import (
|
19 |
+
get_ideal_normal_ball,
|
20 |
+
crop_ball
|
21 |
+
)
|
22 |
+
from relighting.dataset import GeneralLoader
|
23 |
+
from relighting.utils import name2hash
|
24 |
+
import relighting.dist_utils as dist_util
|
25 |
+
import time
|
26 |
+
|
27 |
+
|
28 |
+
# cross import from inpaint_multi-illum.py
|
29 |
+
from relighting.argument import (
|
30 |
+
SD_MODELS,
|
31 |
+
CONTROLNET_MODELS,
|
32 |
+
VAE_MODELS
|
33 |
+
)
|
34 |
+
|
35 |
+
def create_argparser():
|
36 |
+
parser = argparse.ArgumentParser()
|
37 |
+
parser.add_argument("--dataset", type=str, required=True ,help='directory that contain the image') #dataset name or directory
|
38 |
+
parser.add_argument("--ball_size", type=int, default=256, help="size of the ball in pixel")
|
39 |
+
parser.add_argument("--ball_dilate", type=int, default=20, help="How much pixel to dilate the ball to make a sharper edge")
|
40 |
+
parser.add_argument("--prompt", type=str, default="a perfect mirrored reflective chrome ball sphere")
|
41 |
+
parser.add_argument("--prompt_dark", type=str, default="a perfect black dark mirrored reflective chrome ball sphere")
|
42 |
+
parser.add_argument("--negative_prompt", type=str, default="matte, diffuse, flat, dull")
|
43 |
+
parser.add_argument("--model_option", default="sdxl", help='selecting fancy model option (sd15_old, sd15_new, sd21, sdxl, sdxl_turbo)') # [sd15_old, sd15_new, or sd21]
|
44 |
+
parser.add_argument("--output_dir", required=True, type=str, help="output directory")
|
45 |
+
parser.add_argument("--img_height", type=int, default=1024, help="Dataset Image Height")
|
46 |
+
parser.add_argument("--img_width", type=int, default=1024, help="Dataset Image Width")
|
47 |
+
# some good seed 0, 37, 71, 125, 140, 196, 307, 434, 485, 575 | 9021, 9166, 9560, 9814, but default auto is for fairness
|
48 |
+
parser.add_argument("--seed", default="auto", type=str, help="Seed: right now we use single seed instead to reduce the time, (Auto will use hash file name to generate seed)")
|
49 |
+
parser.add_argument("--denoising_step", default=30, type=int, help="number of denoising step of diffusion model")
|
50 |
+
parser.add_argument("--control_scale", default=0.5, type=float, help="controlnet conditioning scale")
|
51 |
+
parser.add_argument("--guidance_scale", default=5.0, type=float, help="guidance scale (also known as CFG)")
|
52 |
+
|
53 |
+
parser.add_argument('--no_controlnet', dest='use_controlnet', action='store_false', help='by default we using controlnet, we have option to disable to see the different')
|
54 |
+
parser.set_defaults(use_controlnet=True)
|
55 |
+
|
56 |
+
parser.add_argument('--no_force_square', dest='force_square', action='store_false', help='SDXL is trained for square image, we prefered the square input. but you use this option to disable reshape')
|
57 |
+
parser.set_defaults(force_square=True)
|
58 |
+
|
59 |
+
parser.add_argument('--no_random_loader', dest='random_loader', action='store_false', help="by default, we random how dataset load. This make us able to peak into the trend of result without waiting entire dataset. but can disable if prefereed")
|
60 |
+
parser.set_defaults(random_loader=True)
|
61 |
+
|
62 |
+
parser.add_argument('--cpu', dest='is_cpu', action='store_true', help="using CPU inference instead of GPU inference")
|
63 |
+
parser.set_defaults(is_cpu=False)
|
64 |
+
|
65 |
+
parser.add_argument('--offload', dest='offload', action='store_false', help="to enable diffusers cpu offload")
|
66 |
+
parser.set_defaults(offload=False)
|
67 |
+
|
68 |
+
parser.add_argument("--limit_input", default=0, type=int, help="limit number of image to process to n image (0 = no limit), useful for run smallset")
|
69 |
+
|
70 |
+
|
71 |
+
# LoRA stuff
|
72 |
+
parser.add_argument('--no_lora', dest='use_lora', action='store_false', help='by default we using lora, we have option to disable to see the different')
|
73 |
+
parser.set_defaults(use_lora=True)
|
74 |
+
|
75 |
+
parser.add_argument("--lora_path", default="models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500", type=str, help="LoRA Checkpoint path")
|
76 |
+
parser.add_argument("--lora_scale", default=0.75, type=float, help="LoRA scale factor")
|
77 |
+
|
78 |
+
# speed optimization stuff
|
79 |
+
parser.add_argument('--no_torch_compile', dest='use_torch_compile', action='store_false', help='by default we using torch compile for faster processing speed. disable it if your environemnt is lower than pytorch2.0')
|
80 |
+
parser.set_defaults(use_torch_compile=True)
|
81 |
+
|
82 |
+
# algorithm + iterative stuff
|
83 |
+
parser.add_argument("--algorithm", type=str, default="iterative", choices=["iterative", "normal"], help="Selecting between iterative or normal (single pass inpaint) algorithm")
|
84 |
+
|
85 |
+
parser.add_argument("--agg_mode", default="median", type=str)
|
86 |
+
parser.add_argument("--strength", default=0.8, type=float)
|
87 |
+
parser.add_argument("--num_iteration", default=2, type=int)
|
88 |
+
parser.add_argument("--ball_per_iteration", default=30, type=int)
|
89 |
+
parser.add_argument('--no_save_intermediate', dest='save_intermediate', action='store_false')
|
90 |
+
parser.set_defaults(save_intermediate=True)
|
91 |
+
parser.add_argument("--cache_dir", default="./temp_inpaint_iterative", type=str, help="cache directory for iterative inpaint")
|
92 |
+
|
93 |
+
# pararelle processing
|
94 |
+
parser.add_argument("--idx", default=0, type=int, help="index of the current process, useful for running on multiple node")
|
95 |
+
parser.add_argument("--total", default=1, type=int, help="total number of process")
|
96 |
+
|
97 |
+
# for HDR stuff
|
98 |
+
parser.add_argument("--max_negative_ev", default=-5, type=int, help="maximum negative EV for lora")
|
99 |
+
parser.add_argument("--ev", default="0,-2.5,-5", type=str, help="EV: list of EV to generate")
|
100 |
+
|
101 |
+
return parser
|
102 |
+
|
103 |
+
def get_ball_location(image_data, args):
|
104 |
+
if 'boundary' in image_data:
|
105 |
+
# support predefined boundary if need
|
106 |
+
x = image_data["boundary"]["x"]
|
107 |
+
y = image_data["boundary"]["y"]
|
108 |
+
r = image_data["boundary"]["size"]
|
109 |
+
|
110 |
+
# support ball dilation
|
111 |
+
half_dilate = args.ball_dilate // 2
|
112 |
+
|
113 |
+
# check if not left out-of-bound
|
114 |
+
if x - half_dilate < 0: x += half_dilate
|
115 |
+
if y - half_dilate < 0: y += half_dilate
|
116 |
+
|
117 |
+
# check if not right out-of-bound
|
118 |
+
if x + r + half_dilate > args.img_width: x -= half_dilate
|
119 |
+
if y + r + half_dilate > args.img_height: y -= half_dilate
|
120 |
+
|
121 |
+
else:
|
122 |
+
# we use top-left corner notation
|
123 |
+
x, y, r = ((args.img_width // 2) - (args.ball_size // 2), (args.img_height // 2) - (args.ball_size // 2), args.ball_size)
|
124 |
+
return x, y, r
|
125 |
+
|
126 |
+
def interpolate_embedding(pipe, args):
|
127 |
+
print("interpolate embedding...")
|
128 |
+
|
129 |
+
# get list of all EVs
|
130 |
+
ev_list = [float(x) for x in args.ev.split(",")]
|
131 |
+
interpolants = [ev / args.max_negative_ev for ev in ev_list]
|
132 |
+
|
133 |
+
print("EV : ", ev_list)
|
134 |
+
print("EV : ", interpolants)
|
135 |
+
|
136 |
+
# calculate prompt embeddings
|
137 |
+
prompt_normal = args.prompt
|
138 |
+
prompt_dark = args.prompt_dark
|
139 |
+
prompt_embeds_normal, _, pooled_prompt_embeds_normal, _ = pipe.pipeline.encode_prompt(prompt_normal)
|
140 |
+
prompt_embeds_dark, _, pooled_prompt_embeds_dark, _ = pipe.pipeline.encode_prompt(prompt_dark)
|
141 |
+
|
142 |
+
# interpolate embeddings
|
143 |
+
interpolate_embeds = []
|
144 |
+
for t in interpolants:
|
145 |
+
int_prompt_embeds = prompt_embeds_normal + t * (prompt_embeds_dark - prompt_embeds_normal)
|
146 |
+
int_pooled_prompt_embeds = pooled_prompt_embeds_normal + t * (pooled_prompt_embeds_dark - pooled_prompt_embeds_normal)
|
147 |
+
|
148 |
+
interpolate_embeds.append((int_prompt_embeds, int_pooled_prompt_embeds))
|
149 |
+
|
150 |
+
return dict(zip(ev_list, interpolate_embeds))
|
151 |
+
|
152 |
+
def main():
|
153 |
+
# load arguments
|
154 |
+
args = create_argparser().parse_args()
|
155 |
+
|
156 |
+
# get local rank
|
157 |
+
if args.is_cpu:
|
158 |
+
device = torch.device("cpu")
|
159 |
+
torch_dtype = torch.float32
|
160 |
+
else:
|
161 |
+
device = dist_util.dev()
|
162 |
+
torch_dtype = torch.float16
|
163 |
+
|
164 |
+
# so, we need ball_dilate >= 16 (2*vae_scale_factor) to make our mask shape = (272, 272)
|
165 |
+
assert args.ball_dilate % 2 == 0 # ball dilation should be symmetric
|
166 |
+
|
167 |
+
# create controlnet pipeline
|
168 |
+
if args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and args.use_controlnet:
|
169 |
+
model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option]
|
170 |
+
pipe = BallInpainter.from_sdxl(
|
171 |
+
model=model,
|
172 |
+
controlnet=controlnet,
|
173 |
+
device=device,
|
174 |
+
torch_dtype = torch_dtype,
|
175 |
+
offload = args.offload
|
176 |
+
)
|
177 |
+
elif args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and not args.use_controlnet:
|
178 |
+
model = SD_MODELS[args.model_option]
|
179 |
+
pipe = BallInpainter.from_sdxl(
|
180 |
+
model=model,
|
181 |
+
controlnet=None,
|
182 |
+
device=device,
|
183 |
+
torch_dtype = torch_dtype,
|
184 |
+
offload = args.offload
|
185 |
+
)
|
186 |
+
elif args.use_controlnet:
|
187 |
+
model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option]
|
188 |
+
pipe = BallInpainter.from_sd(
|
189 |
+
model=model,
|
190 |
+
controlnet=controlnet,
|
191 |
+
device=device,
|
192 |
+
torch_dtype = torch_dtype,
|
193 |
+
offload = args.offload
|
194 |
+
)
|
195 |
+
else:
|
196 |
+
model = SD_MODELS[args.model_option]
|
197 |
+
pipe = BallInpainter.from_sd(
|
198 |
+
model=model,
|
199 |
+
controlnet=None,
|
200 |
+
device=device,
|
201 |
+
torch_dtype = torch_dtype,
|
202 |
+
offload = args.offload
|
203 |
+
)
|
204 |
+
|
205 |
+
if args.model_option in ["sdxl_turbo"]:
|
206 |
+
# Guidance scale is not supported in sdxl_turbo
|
207 |
+
args.guidance_scale = 0.0
|
208 |
+
|
209 |
+
if args.lora_scale > 0 and args.lora_path is None:
|
210 |
+
raise ValueError("lora scale is not 0 but lora path is not set")
|
211 |
+
|
212 |
+
if (args.lora_path is not None) and (args.use_lora):
|
213 |
+
print(f"using lora path {args.lora_path}")
|
214 |
+
print(f"using lora scale {args.lora_scale}")
|
215 |
+
pipe.pipeline.load_lora_weights(args.lora_path)
|
216 |
+
pipe.pipeline.fuse_lora(lora_scale=args.lora_scale) # fuse lora weight w' = w + \alpha \Delta w
|
217 |
+
enabled_lora = True
|
218 |
+
else:
|
219 |
+
enabled_lora = False
|
220 |
+
|
221 |
+
if args.use_torch_compile:
|
222 |
+
try:
|
223 |
+
print("compiling unet model")
|
224 |
+
start_time = time.time()
|
225 |
+
pipe.pipeline.unet = torch.compile(pipe.pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
226 |
+
print("Model compilation time: ", time.time() - start_time)
|
227 |
+
except:
|
228 |
+
pass
|
229 |
+
|
230 |
+
# default height for sdxl is 1024, if not set, we set default height.
|
231 |
+
if args.model_option == "sdxl" and args.img_height == 0 and args.img_width == 0:
|
232 |
+
args.img_height = 1024
|
233 |
+
args.img_width = 1024
|
234 |
+
|
235 |
+
# load dataset
|
236 |
+
dataset = GeneralLoader(
|
237 |
+
root=args.dataset,
|
238 |
+
resolution=(args.img_width, args.img_height),
|
239 |
+
force_square=args.force_square,
|
240 |
+
return_dict=True,
|
241 |
+
random_shuffle=args.random_loader,
|
242 |
+
process_id=args.idx,
|
243 |
+
process_total=args.total,
|
244 |
+
limit_input=args.limit_input,
|
245 |
+
)
|
246 |
+
|
247 |
+
# interpolate embedding
|
248 |
+
embedding_dict = interpolate_embedding(pipe, args)
|
249 |
+
|
250 |
+
# prepare mask and normal ball
|
251 |
+
mask_generator = MaskGenerator()
|
252 |
+
normal_ball, mask_ball = get_ideal_normal_ball(size=args.ball_size+args.ball_dilate)
|
253 |
+
_, mask_ball_for_crop = get_ideal_normal_ball(size=args.ball_size)
|
254 |
+
|
255 |
+
|
256 |
+
# make output directory if not exist
|
257 |
+
raw_output_dir = os.path.join(args.output_dir, "raw")
|
258 |
+
control_output_dir = os.path.join(args.output_dir, "control")
|
259 |
+
square_output_dir = os.path.join(args.output_dir, "square")
|
260 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
261 |
+
os.makedirs(raw_output_dir, exist_ok=True)
|
262 |
+
os.makedirs(control_output_dir, exist_ok=True)
|
263 |
+
os.makedirs(square_output_dir, exist_ok=True)
|
264 |
+
|
265 |
+
# create split seed
|
266 |
+
# please DO NOT manual replace this line, use --seed option instead
|
267 |
+
seeds = args.seed.split(",")
|
268 |
+
|
269 |
+
for image_data in tqdm(dataset):
|
270 |
+
input_image = image_data["image"]
|
271 |
+
image_path = image_data["path"]
|
272 |
+
|
273 |
+
for ev, (prompt_embeds, pooled_prompt_embeds) in embedding_dict.items():
|
274 |
+
# create output file name (we always use png to prevent quality loss)
|
275 |
+
ev_str = str(ev).replace(".", "") if ev != 0 else "-00"
|
276 |
+
outname = os.path.basename(image_path).split(".")[0] + f"_ev{ev_str}"
|
277 |
+
|
278 |
+
# we use top-left corner notation (which is different from aj.aek's center point notation)
|
279 |
+
x, y, r = get_ball_location(image_data, args)
|
280 |
+
|
281 |
+
# create inpaint mask
|
282 |
+
mask = mask_generator.generate_single(
|
283 |
+
input_image, mask_ball,
|
284 |
+
x - (args.ball_dilate // 2),
|
285 |
+
y - (args.ball_dilate // 2),
|
286 |
+
r + args.ball_dilate
|
287 |
+
)
|
288 |
+
|
289 |
+
seeds = tqdm(seeds, desc="seeds") if len(seeds) > 10 else seeds
|
290 |
+
|
291 |
+
#replacely create image with differnt seed
|
292 |
+
for seed in seeds:
|
293 |
+
start_time = time.time()
|
294 |
+
# set seed, if seed auto we use file name as seed
|
295 |
+
if seed == "auto":
|
296 |
+
filename = os.path.basename(image_path).split(".")[0]
|
297 |
+
seed = name2hash(filename)
|
298 |
+
outpng = f"{outname}.png"
|
299 |
+
cache_name = f"{outname}"
|
300 |
+
else:
|
301 |
+
seed = int(seed)
|
302 |
+
outpng = f"{outname}_seed{seed}.png"
|
303 |
+
cache_name = f"{outname}_seed{seed}"
|
304 |
+
# skip if file exist, useful for resuming
|
305 |
+
if os.path.exists(os.path.join(square_output_dir, outpng)):
|
306 |
+
continue
|
307 |
+
generator = torch.Generator().manual_seed(seed)
|
308 |
+
kwargs = {
|
309 |
+
"prompt_embeds": prompt_embeds,
|
310 |
+
"pooled_prompt_embeds": pooled_prompt_embeds,
|
311 |
+
'negative_prompt': args.negative_prompt,
|
312 |
+
'num_inference_steps': args.denoising_step,
|
313 |
+
'generator': generator,
|
314 |
+
'image': input_image,
|
315 |
+
'mask_image': mask,
|
316 |
+
'strength': 1.0,
|
317 |
+
'current_seed': seed, # we still need seed in the pipeline!
|
318 |
+
'controlnet_conditioning_scale': args.control_scale,
|
319 |
+
'height': args.img_height,
|
320 |
+
'width': args.img_width,
|
321 |
+
'normal_ball': normal_ball,
|
322 |
+
'mask_ball': mask_ball,
|
323 |
+
'x': x,
|
324 |
+
'y': y,
|
325 |
+
'r': r,
|
326 |
+
'guidance_scale': args.guidance_scale,
|
327 |
+
}
|
328 |
+
|
329 |
+
if enabled_lora:
|
330 |
+
kwargs["cross_attention_kwargs"] = {"scale": args.lora_scale}
|
331 |
+
|
332 |
+
if args.algorithm == "normal":
|
333 |
+
output_image = pipe.inpaint(**kwargs).images[0]
|
334 |
+
elif args.algorithm == "iterative":
|
335 |
+
# This is still buggy
|
336 |
+
print("using inpainting iterative, this is going to take a while...")
|
337 |
+
kwargs.update({
|
338 |
+
"strength": args.strength,
|
339 |
+
"num_iteration": args.num_iteration,
|
340 |
+
"ball_per_iteration": args.ball_per_iteration,
|
341 |
+
"agg_mode": args.agg_mode,
|
342 |
+
"save_intermediate": args.save_intermediate,
|
343 |
+
"cache_dir": os.path.join(args.cache_dir, cache_name),
|
344 |
+
})
|
345 |
+
output_image = pipe.inpaint_iterative(**kwargs)
|
346 |
+
else:
|
347 |
+
raise NotImplementedError(f"Unknown algorithm {args.algorithm}")
|
348 |
+
|
349 |
+
|
350 |
+
square_image = output_image.crop((x, y, x+r, y+r))
|
351 |
+
|
352 |
+
# return the most recent control_image for sanity check
|
353 |
+
control_image = pipe.get_cache_control_image()
|
354 |
+
if control_image is not None:
|
355 |
+
control_image.save(os.path.join(control_output_dir, outpng))
|
356 |
+
|
357 |
+
# save image
|
358 |
+
output_image.save(os.path.join(raw_output_dir, outpng))
|
359 |
+
square_image.save(os.path.join(square_output_dir, outpng))
|
360 |
+
|
361 |
+
|
362 |
+
if __name__ == "__main__":
|
363 |
+
main()
|
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/optimizer.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b96eb34accf4ce5f33dff729e12669c46e3132973ee2cf0ac2d4f2c993a2af4d
|
3 |
+
size 47392882
|
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/pytorch_lora_weights.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bf601e30c08ce4f1eea6e09e1780bd8ba2588986eaee8379672350707dcddaa
|
3 |
+
size 23396024
|
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/random_states_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec103b2e80b357d29e2ec8355d49f0c289331f0ece810d5bafd464d33e5f4c76
|
3 |
+
size 14280
|
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/scheduler.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e9244a743d4761f975ab2a14d0b7a509a85554dd77bef6490330160a2a639fae
|
3 |
+
size 1000
|
relighting/argument.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from diffusers import DDIMScheduler, DDPMScheduler, UniPCMultistepScheduler
|
3 |
+
|
4 |
+
def get_control_signal_type(controlnet):
|
5 |
+
if "normal" in controlnet:
|
6 |
+
return "normal"
|
7 |
+
elif "depth" in controlnet:
|
8 |
+
return "depth"
|
9 |
+
else:
|
10 |
+
raise NotImplementedError
|
11 |
+
|
12 |
+
SD_MODELS = {
|
13 |
+
"sd15_old": "runwayml/stable-diffusion-inpainting",
|
14 |
+
"sd15_new": "runwayml/stable-diffusion-inpainting",
|
15 |
+
"sd21": "stabilityai/stable-diffusion-2-inpainting",
|
16 |
+
"sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
|
17 |
+
"sdxl_fast": "stabilityai/stable-diffusion-xl-base-1.0",
|
18 |
+
"sdxl_turbo": "stabilityai/sdxl-turbo",
|
19 |
+
"sd15_depth": "runwayml/stable-diffusion-inpainting",
|
20 |
+
}
|
21 |
+
|
22 |
+
VAE_MODELS = {
|
23 |
+
"sdxl": "madebyollin/sdxl-vae-fp16-fix",
|
24 |
+
"sdxl_fast": "madebyollin/sdxl-vae-fp16-fix",
|
25 |
+
}
|
26 |
+
|
27 |
+
CONTROLNET_MODELS = {
|
28 |
+
"sd15_old": "fusing/stable-diffusion-v1-5-controlnet-normal",
|
29 |
+
"sd15_new": "lllyasviel/control_v11p_sd15_normalbae",
|
30 |
+
"sd21": "thibaud/controlnet-sd21-normalbae-diffusers",
|
31 |
+
"sdxl": "diffusers/controlnet-depth-sdxl-1.0",
|
32 |
+
"sdxl_fast": "diffusers/controlnet-depth-sdxl-1.0-small",
|
33 |
+
"sdxl_turbo": "diffusers/controlnet-depth-sdxl-1.0-small",
|
34 |
+
"sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
|
35 |
+
}
|
36 |
+
|
37 |
+
SAMPLERS = {
|
38 |
+
"ddim": DDIMScheduler,
|
39 |
+
"ddpm": DDPMScheduler,
|
40 |
+
"unipc": UniPCMultistepScheduler,
|
41 |
+
}
|
42 |
+
|
43 |
+
DEPTH_ESTIMATOR = "Intel/dpt-hybrid-midas"
|
relighting/ball_processor.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from scipy.special import sph_harm
|
5 |
+
|
6 |
+
def crop_ball(image, mask_ball, x, y, size, apply_mask=True, bg_color = (0, 0, 0)):
|
7 |
+
if isinstance(image, Image.Image):
|
8 |
+
result = np.array(image)
|
9 |
+
else:
|
10 |
+
result = image.copy()
|
11 |
+
|
12 |
+
result = result[y:y+size, x:x+size]
|
13 |
+
if apply_mask:
|
14 |
+
result[~mask_ball] = bg_color
|
15 |
+
return result
|
16 |
+
|
17 |
+
def get_ideal_normal_ball(size, flip_x=True):
|
18 |
+
"""
|
19 |
+
Generate normal ball for specific size
|
20 |
+
Normal map is x "left", y up, z into the screen
|
21 |
+
(we flip X to match sobel operator)
|
22 |
+
@params
|
23 |
+
- size (int) - single value of height and width
|
24 |
+
@return:
|
25 |
+
- normal_map (np.array) - normal map [size, size, 3]
|
26 |
+
- mask (np.array) - mask that make a valid normal map [size,size]
|
27 |
+
"""
|
28 |
+
# we flip x to match sobel operator
|
29 |
+
x = torch.linspace(1, -1, size)
|
30 |
+
y = torch.linspace(1, -1, size)
|
31 |
+
x = x.flip(dims=(-1,)) if not flip_x else x
|
32 |
+
|
33 |
+
y, x = torch.meshgrid(y, x)
|
34 |
+
z = (1 - x**2 - y**2)
|
35 |
+
mask = z >= 0
|
36 |
+
|
37 |
+
# clean up invalid value outsize the mask
|
38 |
+
x = x * mask
|
39 |
+
y = y * mask
|
40 |
+
z = z * mask
|
41 |
+
|
42 |
+
# get real z value
|
43 |
+
z = torch.sqrt(z)
|
44 |
+
|
45 |
+
# clean up normal map value outside mask
|
46 |
+
normal_map = torch.cat([x[..., None], y[..., None], z[..., None]], dim=-1)
|
47 |
+
normal_map = normal_map.numpy()
|
48 |
+
mask = mask.numpy()
|
49 |
+
return normal_map, mask
|
50 |
+
|
51 |
+
def get_predicted_normal_ball(size, precomputed_path=None):
|
52 |
+
if precomputed_path is not None:
|
53 |
+
normal_map = Image.open(precomputed_path).resize((size, size))
|
54 |
+
normal_map = np.array(normal_map).astype(np.uint8)
|
55 |
+
_, mask = get_ideal_normal_ball(size)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
normal_map = (normal_map - 127.5) / 127.5 # normalize for compatibility with inpainting pipeline
|
60 |
+
return normal_map, mask
|
relighting/dataset.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import skimage
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from natsort import natsorted
|
8 |
+
from PIL import Image
|
9 |
+
from relighting.image_processor import pil_square_image
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
import random
|
12 |
+
import itertools
|
13 |
+
from abc import ABC, abstractmethod
|
14 |
+
|
15 |
+
class Dataset(ABC):
|
16 |
+
def __init__(self,
|
17 |
+
resolution=(1024, 1024),
|
18 |
+
force_square=True,
|
19 |
+
return_image_path=False,
|
20 |
+
return_dict=False,
|
21 |
+
):
|
22 |
+
"""
|
23 |
+
Resoution is (WIDTH, HEIGHT)
|
24 |
+
"""
|
25 |
+
self.resolution = resolution
|
26 |
+
self.force_square = force_square
|
27 |
+
self.return_image_path = return_image_path
|
28 |
+
self.return_dict = return_dict
|
29 |
+
self.scene_data = []
|
30 |
+
self.meta_data = []
|
31 |
+
self.boundary_info = []
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def _load_data_path(self):
|
35 |
+
pass
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.scene_data)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
image = Image.open(self.scene_data[idx])
|
42 |
+
if self.force_square:
|
43 |
+
image = pil_square_image(image, self.resolution)
|
44 |
+
else:
|
45 |
+
image = image.resize(self.resolution)
|
46 |
+
|
47 |
+
if self.return_dict:
|
48 |
+
d = {
|
49 |
+
"image": image,
|
50 |
+
"path": self.scene_data[idx]
|
51 |
+
}
|
52 |
+
if len(self.boundary_info) > 0:
|
53 |
+
d["boundary"] = self.boundary_info[idx]
|
54 |
+
|
55 |
+
return d
|
56 |
+
elif self.return_image_path:
|
57 |
+
return image, self.scene_data[idx]
|
58 |
+
else:
|
59 |
+
return image
|
60 |
+
|
61 |
+
class GeneralLoader(Dataset):
|
62 |
+
def __init__(self,
|
63 |
+
root=None,
|
64 |
+
num_samples=None,
|
65 |
+
res_threshold=((1024, 1024)),
|
66 |
+
apply_threshold=False,
|
67 |
+
random_shuffle=False,
|
68 |
+
process_id = 0,
|
69 |
+
process_total = 1,
|
70 |
+
limit_input = 0,
|
71 |
+
**kwargs,
|
72 |
+
):
|
73 |
+
super().__init__(**kwargs)
|
74 |
+
self.root = root
|
75 |
+
self.res_threshold = res_threshold
|
76 |
+
self.apply_threshold = apply_threshold
|
77 |
+
self.has_meta = False
|
78 |
+
|
79 |
+
if self.root is not None:
|
80 |
+
if not os.path.exists(self.root):
|
81 |
+
raise Exception(f"Dataset {self.root} does not exist.")
|
82 |
+
|
83 |
+
paths = natsorted(
|
84 |
+
list(glob.glob(os.path.join(self.root, "*.png"))) + \
|
85 |
+
list(glob.glob(os.path.join(self.root, "*.jpg")))
|
86 |
+
)
|
87 |
+
self.scene_data = self._load_data_path(paths, num_samples=num_samples)
|
88 |
+
|
89 |
+
if random_shuffle:
|
90 |
+
SEED = 0
|
91 |
+
random.Random(SEED).shuffle(self.scene_data)
|
92 |
+
random.Random(SEED).shuffle(self.boundary_info)
|
93 |
+
|
94 |
+
if limit_input > 0:
|
95 |
+
self.scene_data = self.scene_data[:limit_input]
|
96 |
+
self.boundary_info = self.boundary_info[:limit_input]
|
97 |
+
|
98 |
+
# please keep this one the last, so, we will filter out scene_data and boundary info
|
99 |
+
if process_total > 1:
|
100 |
+
self.scene_data = self.scene_data[process_id::process_total]
|
101 |
+
self.boundary_info = self.boundary_info[process_id::process_total]
|
102 |
+
print(f"Process {process_id} has {len(self.scene_data)} samples")
|
103 |
+
|
104 |
+
def _load_data_path(self, paths, num_samples=None):
|
105 |
+
if os.path.exists(os.path.splitext(paths[0])[0] + ".json") or os.path.exists(os.path.splitext(paths[-1])[0] + ".json"):
|
106 |
+
self.has_meta = True
|
107 |
+
|
108 |
+
if self.has_meta:
|
109 |
+
# read metadata
|
110 |
+
TARGET_KEY = "chrome_mask256"
|
111 |
+
for path in paths:
|
112 |
+
with open(os.path.splitext(path)[0] + ".json") as f:
|
113 |
+
meta = json.load(f)
|
114 |
+
self.meta_data.append(meta)
|
115 |
+
boundary = {
|
116 |
+
"x": meta[TARGET_KEY]["x"],
|
117 |
+
"y": meta[TARGET_KEY]["y"],
|
118 |
+
"size": meta[TARGET_KEY]["w"],
|
119 |
+
}
|
120 |
+
self.boundary_info.append(boundary)
|
121 |
+
|
122 |
+
|
123 |
+
scene_data = paths
|
124 |
+
if self.apply_threshold:
|
125 |
+
scene_data = []
|
126 |
+
for path in tqdm(paths):
|
127 |
+
img = Image.open(path)
|
128 |
+
if (img.size[0] >= self.res_threshold[0]) and (img.size[1] >= self.res_threshold[1]):
|
129 |
+
scene_data.append(path)
|
130 |
+
|
131 |
+
if num_samples is not None:
|
132 |
+
max_idx = min(num_samples, len(scene_data))
|
133 |
+
scene_data = scene_data[:max_idx]
|
134 |
+
|
135 |
+
return scene_data
|
136 |
+
|
137 |
+
@classmethod
|
138 |
+
def from_image_paths(cls, paths, *args, **kwargs):
|
139 |
+
dataset = cls(*args, **kwargs)
|
140 |
+
dataset.scene_data = dataset._load_data_path(paths)
|
141 |
+
return dataset
|
142 |
+
|
143 |
+
class ALPLoader(Dataset):
|
144 |
+
def __init__(self,
|
145 |
+
root=None,
|
146 |
+
num_samples=None,
|
147 |
+
res_threshold=((1024, 1024)),
|
148 |
+
apply_threshold=False,
|
149 |
+
**kwargs,
|
150 |
+
):
|
151 |
+
super().__init__(**kwargs)
|
152 |
+
self.root = root
|
153 |
+
self.res_threshold = res_threshold
|
154 |
+
self.apply_threshold = apply_threshold
|
155 |
+
self.has_meta = False
|
156 |
+
|
157 |
+
if self.root is not None:
|
158 |
+
if not os.path.exists(self.root):
|
159 |
+
raise Exception(f"Dataset {self.root} does not exist.")
|
160 |
+
|
161 |
+
dirs = natsorted(list(glob.glob(os.path.join(self.root, "*"))))
|
162 |
+
self.scene_data = self._load_data_path(dirs)
|
163 |
+
|
164 |
+
def _load_data_path(self, dirs):
|
165 |
+
self.scene_names = [Path(dir).name for dir in dirs]
|
166 |
+
|
167 |
+
scene_data = []
|
168 |
+
for dir in dirs:
|
169 |
+
pseudo_probe_dirs = natsorted(list(glob.glob(os.path.join(dir, "*"))))
|
170 |
+
pseudo_probe_dirs = [dir for dir in pseudo_probe_dirs if "gt" not in dir]
|
171 |
+
data = [os.path.join(dir, "images", "0.png") for dir in pseudo_probe_dirs]
|
172 |
+
scene_data.append(data)
|
173 |
+
|
174 |
+
scene_data = list(itertools.chain(*scene_data))
|
175 |
+
return scene_data
|
176 |
+
|
177 |
+
class MultiIlluminationLoader(Dataset):
|
178 |
+
def __init__(self,
|
179 |
+
root,
|
180 |
+
mask_probe=True,
|
181 |
+
mask_boundingbox=False,
|
182 |
+
**kwargs,
|
183 |
+
):
|
184 |
+
"""
|
185 |
+
@params resolution (tuple): (width, height) - resolution of the image
|
186 |
+
@params force_square: will add black border to make the image square while keeping the aspect ratio
|
187 |
+
@params mask_probe: mask the probe with the mask in the dataset
|
188 |
+
|
189 |
+
"""
|
190 |
+
super().__init__(**kwargs)
|
191 |
+
self.root = root
|
192 |
+
self.mask_probe = mask_probe
|
193 |
+
self.mask_boundingbox = mask_boundingbox
|
194 |
+
|
195 |
+
if self.root is not None:
|
196 |
+
dirs = natsorted(list(glob.glob(os.path.join(self.root, "*"))))
|
197 |
+
self.scene_data = self._load_data_path(dirs)
|
198 |
+
|
199 |
+
def _load_data_path(self, dirs):
|
200 |
+
self.scene_names = [Path(dir).name for dir in dirs]
|
201 |
+
|
202 |
+
data = {}
|
203 |
+
for dir in dirs:
|
204 |
+
chrome_probes = natsorted(list(glob.glob(os.path.join(dir, "probes", "*chrome*.jpg"))))
|
205 |
+
gray_probes = natsorted(list(glob.glob(os.path.join(dir, "probes", "*gray*.jpg"))))
|
206 |
+
scenes = natsorted(list(glob.glob(os.path.join(dir, "dir_*.jpg"))))
|
207 |
+
|
208 |
+
with open(os.path.join(dir, "meta.json")) as f:
|
209 |
+
meta_data = json.load(f)
|
210 |
+
|
211 |
+
bbox_chrome = meta_data["chrome"]["bounding_box"]
|
212 |
+
bbox_gray = meta_data["gray"]["bounding_box"]
|
213 |
+
|
214 |
+
mask_chrome = os.path.join(dir, "mask_chrome.png")
|
215 |
+
mask_gray = os.path.join(dir, "mask_gray.png")
|
216 |
+
|
217 |
+
scene_name = Path(dir).name
|
218 |
+
data[scene_name] = {
|
219 |
+
"scenes": scenes,
|
220 |
+
"chrome_probes": chrome_probes,
|
221 |
+
"gray_probes": gray_probes,
|
222 |
+
"bbox_chrome": bbox_chrome,
|
223 |
+
"bbox_gray": bbox_gray,
|
224 |
+
"mask_chrome": mask_chrome,
|
225 |
+
"mask_gray": mask_gray,
|
226 |
+
}
|
227 |
+
return data
|
228 |
+
|
229 |
+
def _mask_probe(self, image, mask):
|
230 |
+
"""
|
231 |
+
mask probe with a png file in dataset
|
232 |
+
"""
|
233 |
+
image_anticheat = skimage.img_as_float(np.array(image))
|
234 |
+
mask_np = skimage.img_as_float(np.array(mask))[..., None]
|
235 |
+
image_anticheat = ((1.0 - mask_np) * image_anticheat) + (0.5 * mask_np)
|
236 |
+
image_anticheat = Image.fromarray(skimage.img_as_ubyte(image_anticheat))
|
237 |
+
return image_anticheat
|
238 |
+
|
239 |
+
def _mask_boundingbox(self, image, bbox):
|
240 |
+
"""
|
241 |
+
mask image with the bounding box for anti-cheat
|
242 |
+
"""
|
243 |
+
bbox = {k:int(np.round(v/4.0)) for k,v in bbox.items()}
|
244 |
+
x, y, w, h = bbox["x"], bbox["y"], bbox["w"], bbox["h"]
|
245 |
+
image_anticheat = skimage.img_as_float(np.array(image))
|
246 |
+
image_anticheat[y:y+h, x:x+w] = 0.5
|
247 |
+
image_anticheat = Image.fromarray(skimage.img_as_ubyte(image_anticheat))
|
248 |
+
return image_anticheat
|
249 |
+
|
250 |
+
def __getitem__(self, scene_name):
|
251 |
+
data = self.scene_data[scene_name]
|
252 |
+
|
253 |
+
mask_chrome = Image.open(data["mask_chrome"])
|
254 |
+
mask_gray = Image.open(data["mask_gray"])
|
255 |
+
images = []
|
256 |
+
for path in data["scenes"]:
|
257 |
+
image = Image.open(path)
|
258 |
+
if self.mask_probe:
|
259 |
+
image = self._mask_probe(image, mask_chrome)
|
260 |
+
image = self._mask_probe(image, mask_gray)
|
261 |
+
if self.mask_boundingbox:
|
262 |
+
image = self._mask_boundingbox(image, data["bbox_chrome"])
|
263 |
+
image = self._mask_boundingbox(image, data["bbox_gray"])
|
264 |
+
|
265 |
+
if self.force_square:
|
266 |
+
image = pil_square_image(image, self.resolution)
|
267 |
+
else:
|
268 |
+
image = image.resize(self.resolution)
|
269 |
+
images.append(image)
|
270 |
+
|
271 |
+
chrome_probes = [Image.open(path) for path in data["chrome_probes"]]
|
272 |
+
gray_probes = [Image.open(path) for path in data["gray_probes"]]
|
273 |
+
bbox_chrome = data["bbox_chrome"]
|
274 |
+
bbox_gray = data["bbox_gray"]
|
275 |
+
|
276 |
+
return images, chrome_probes, gray_probes, bbox_chrome, bbox_gray
|
277 |
+
|
278 |
+
|
279 |
+
def calculate_ball_info(self, scene_name):
|
280 |
+
# TODO: remove hard-coded parameters
|
281 |
+
ball_data = []
|
282 |
+
for mtype in ['bbox_chrome', 'bbox_gray']:
|
283 |
+
info = self.scene_data[scene_name][mtype]
|
284 |
+
|
285 |
+
# x-y is top-left corner of the bounding box
|
286 |
+
# meta file is for 4000x6000 image but dataset is 1000x1500
|
287 |
+
x = info['x'] / 4
|
288 |
+
y = info['y'] / 4
|
289 |
+
w = info['w'] / 4
|
290 |
+
h = info['h'] / 4
|
291 |
+
|
292 |
+
|
293 |
+
# we scale data to 512x512 image
|
294 |
+
if self.force_square:
|
295 |
+
h_ratio = (512.0 * 2.0 / 3.0) / 1000.0 #384 because we have black border on the top
|
296 |
+
w_ratio = 512.0 / 1500.0
|
297 |
+
else:
|
298 |
+
h_ratio = self.resolution[0] / 1000.0
|
299 |
+
w_ratio = self.resolution[1] / 1500.0
|
300 |
+
|
301 |
+
x = x * w_ratio
|
302 |
+
y = y * h_ratio
|
303 |
+
w = w * w_ratio
|
304 |
+
h = h * h_ratio
|
305 |
+
|
306 |
+
if self.force_square:
|
307 |
+
# y need to shift due to top black border
|
308 |
+
top_border_height = 512.0 * (1/6)
|
309 |
+
y = y + top_border_height
|
310 |
+
|
311 |
+
|
312 |
+
# Sphere is not circle due to the camera perspective, Need future fix for this
|
313 |
+
# For now, we use the minimum of width and height
|
314 |
+
w = int(np.round(w))
|
315 |
+
h = int(np.round(h))
|
316 |
+
if w > h:
|
317 |
+
r = h
|
318 |
+
x = x + (w - h) / 2.0
|
319 |
+
else:
|
320 |
+
r = w
|
321 |
+
y = y + (h - w) / 2.0
|
322 |
+
|
323 |
+
x = int(np.round(x))
|
324 |
+
y = int(np.round(y))
|
325 |
+
|
326 |
+
ball_data.append((x, y, r))
|
327 |
+
|
328 |
+
return ball_data
|
329 |
+
|
330 |
+
def calculate_bbox_info(self, scene_name):
|
331 |
+
# TODO: remove hard-coded parameters
|
332 |
+
bbox_data = []
|
333 |
+
for mtype in ['bbox_chrome', 'bbox_gray']:
|
334 |
+
info = self.scene_data[scene_name][mtype]
|
335 |
+
|
336 |
+
# x-y is top-left corner of the bounding box
|
337 |
+
# meta file is for 4000x6000 image but dataset is 1000x1500
|
338 |
+
x = info['x'] / 4
|
339 |
+
y = info['y'] / 4
|
340 |
+
w = info['w'] / 4
|
341 |
+
h = info['h'] / 4
|
342 |
+
|
343 |
+
|
344 |
+
# we scale data to 512x512 image
|
345 |
+
if self.force_square:
|
346 |
+
h_ratio = (512.0 * 2.0 / 3.0) / 1000.0 #384 because we have black border on the top
|
347 |
+
w_ratio = 512.0 / 1500.0
|
348 |
+
else:
|
349 |
+
h_ratio = self.resolution[0] / 1000.0
|
350 |
+
w_ratio = self.resolution[1] / 1500.0
|
351 |
+
|
352 |
+
x = x * w_ratio
|
353 |
+
y = y * h_ratio
|
354 |
+
w = w * w_ratio
|
355 |
+
h = h * h_ratio
|
356 |
+
|
357 |
+
if self.force_square:
|
358 |
+
# y need to shift due to top black border
|
359 |
+
top_border_height = 512.0 * (1/6)
|
360 |
+
y = y + top_border_height
|
361 |
+
|
362 |
+
|
363 |
+
w = int(np.round(w))
|
364 |
+
h = int(np.round(h))
|
365 |
+
x = int(np.round(x))
|
366 |
+
y = int(np.round(y))
|
367 |
+
|
368 |
+
bbox_data.append((x, y, w, h))
|
369 |
+
|
370 |
+
return bbox_data
|
371 |
+
|
372 |
+
"""
|
373 |
+
DO NOT remove this!
|
374 |
+
This is for evaluating results from Multi-Illumination generated from the old version
|
375 |
+
"""
|
376 |
+
def calculate_ball_info_legacy(self, scene_name):
|
377 |
+
# TODO: remove hard-coded parameters
|
378 |
+
ball_data = []
|
379 |
+
for mtype in ['bbox_chrome', 'bbox_gray']:
|
380 |
+
info = self.scene_data[scene_name][mtype]
|
381 |
+
|
382 |
+
# x-y is top-left corner of the bounding box
|
383 |
+
# meta file is for 4000x6000 image but dataset is 1000x1500
|
384 |
+
x = info['x'] / 4
|
385 |
+
y = info['y'] / 4
|
386 |
+
w = info['w'] / 4
|
387 |
+
h = info['h'] / 4
|
388 |
+
|
389 |
+
# we scale data to 512x512 image
|
390 |
+
h_ratio = 384.0 / 1000.0 #384 because we have black border on the top
|
391 |
+
w_ratio = 512.0 / 1500.0
|
392 |
+
x = x * w_ratio
|
393 |
+
y = y * h_ratio
|
394 |
+
w = w * w_ratio
|
395 |
+
h = h * h_ratio
|
396 |
+
|
397 |
+
# y need to shift due to top black border
|
398 |
+
top_border_height = 512.0 * (1/8)
|
399 |
+
|
400 |
+
y = y + top_border_height
|
401 |
+
|
402 |
+
# Sphere is not circle due to the camera perspective, Need future fix for this
|
403 |
+
# For now, we use the minimum of width and height
|
404 |
+
r = np.max(np.array([w, h]))
|
405 |
+
|
406 |
+
x = int(np.round(x))
|
407 |
+
y = int(np.round(y))
|
408 |
+
r = int(np.round(r))
|
409 |
+
|
410 |
+
ball_data.append((y, x, r))
|
411 |
+
|
412 |
+
return ball_data
|
relighting/dist_utils.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers for distributed training.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import socket
|
8 |
+
|
9 |
+
try:
|
10 |
+
import blobfile as bf
|
11 |
+
except:
|
12 |
+
pass
|
13 |
+
|
14 |
+
try:
|
15 |
+
from mpi4py import MPI
|
16 |
+
except:
|
17 |
+
pass
|
18 |
+
|
19 |
+
import torch as th
|
20 |
+
import torch.distributed as dist
|
21 |
+
import builtins
|
22 |
+
import datetime
|
23 |
+
|
24 |
+
# Change this to reflect your cluster layout.
|
25 |
+
# The GPU for a given rank is (rank % GPUS_PER_NODE).
|
26 |
+
GPUS_PER_NODE = 8
|
27 |
+
|
28 |
+
SETUP_RETRY_COUNT = 3
|
29 |
+
def synchronize():
|
30 |
+
if not dist.is_available():
|
31 |
+
return
|
32 |
+
|
33 |
+
if not dist.is_initialized():
|
34 |
+
return
|
35 |
+
|
36 |
+
world_size = dist.get_world_size()
|
37 |
+
|
38 |
+
if world_size == 1:
|
39 |
+
return
|
40 |
+
|
41 |
+
dist.barrier()
|
42 |
+
|
43 |
+
def is_dist_avail_and_initialized():
|
44 |
+
if not dist.is_available():
|
45 |
+
return False
|
46 |
+
if not dist.is_initialized():
|
47 |
+
return False
|
48 |
+
return True
|
49 |
+
def get_world_size():
|
50 |
+
if not is_dist_avail_and_initialized():
|
51 |
+
return 1
|
52 |
+
return dist.get_world_size()
|
53 |
+
|
54 |
+
def setup_for_distributed(is_master):
|
55 |
+
"""
|
56 |
+
This function disables printing when not in master process
|
57 |
+
"""
|
58 |
+
builtin_print = builtins.print
|
59 |
+
|
60 |
+
def print(*args, **kwargs):
|
61 |
+
force = kwargs.pop('force', False)
|
62 |
+
force = force or (get_world_size() > 8)
|
63 |
+
if is_master or force:
|
64 |
+
now = datetime.datetime.now().time()
|
65 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
66 |
+
builtin_print(*args, **kwargs)
|
67 |
+
|
68 |
+
builtins.print = print
|
69 |
+
|
70 |
+
def setup_dist_multinode(args):
|
71 |
+
"""
|
72 |
+
Setup a distributed process group.
|
73 |
+
"""
|
74 |
+
if not dist.is_available() or not dist.is_initialized():
|
75 |
+
th.distributed.init_process_group(backend="nccl", init_method='env://')
|
76 |
+
world_size = dist.get_world_size()
|
77 |
+
local_rank = int(os.getenv('LOCAL_RANK'))
|
78 |
+
print("rank",local_rank)
|
79 |
+
device = local_rank
|
80 |
+
th.cuda.set_device(device)
|
81 |
+
setup_for_distributed(device == 0)
|
82 |
+
|
83 |
+
synchronize()
|
84 |
+
else:
|
85 |
+
print("ddp failed!")
|
86 |
+
exit()
|
87 |
+
|
88 |
+
def setup_dist(global_seed):
|
89 |
+
"""
|
90 |
+
Setup a distributed process group.
|
91 |
+
"""
|
92 |
+
if dist.is_initialized():
|
93 |
+
return
|
94 |
+
th.cuda.set_device(int(os.environ["LOCAL_RANK"]))
|
95 |
+
th.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400))
|
96 |
+
|
97 |
+
# fix seed
|
98 |
+
rank = dist.get_rank()
|
99 |
+
device = rank % th.cuda.device_count()
|
100 |
+
seed = global_seed * dist.get_world_size() + rank
|
101 |
+
th.manual_seed(seed)
|
102 |
+
th.cuda.set_device(device)
|
103 |
+
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
|
104 |
+
synchronize()
|
105 |
+
|
106 |
+
def dev():
|
107 |
+
"""
|
108 |
+
Get the device to use for torch.distributed.
|
109 |
+
"""
|
110 |
+
if th.cuda.is_available():
|
111 |
+
return th.device(f"cuda")
|
112 |
+
return th.device("cpu")
|
113 |
+
|
114 |
+
|
115 |
+
def load_state_dict(path, **kwargs):
|
116 |
+
"""
|
117 |
+
Load a PyTorch file without redundant fetches across MPI ranks.
|
118 |
+
"""
|
119 |
+
chunk_size = 2 ** 30 # MPI has a relatively small size limit
|
120 |
+
if MPI.COMM_WORLD.Get_rank() == 0:
|
121 |
+
with bf.BlobFile(path, "rb") as f:
|
122 |
+
data = f.read()
|
123 |
+
num_chunks = len(data) // chunk_size
|
124 |
+
if len(data) % chunk_size:
|
125 |
+
num_chunks += 1
|
126 |
+
MPI.COMM_WORLD.bcast(num_chunks)
|
127 |
+
for i in range(0, len(data), chunk_size):
|
128 |
+
MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
|
129 |
+
else:
|
130 |
+
num_chunks = MPI.COMM_WORLD.bcast(None)
|
131 |
+
data = bytes()
|
132 |
+
for _ in range(num_chunks):
|
133 |
+
data += MPI.COMM_WORLD.bcast(None)
|
134 |
+
|
135 |
+
return th.load(io.BytesIO(data), **kwargs)
|
136 |
+
|
137 |
+
|
138 |
+
def sync_params(params):
|
139 |
+
"""
|
140 |
+
Synchronize a sequence of Tensors across ranks from rank 0.
|
141 |
+
"""
|
142 |
+
for p in params:
|
143 |
+
with th.no_grad():
|
144 |
+
dist.broadcast(p, 0)
|
145 |
+
|
146 |
+
|
147 |
+
def _find_free_port():
|
148 |
+
try:
|
149 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
150 |
+
s.bind(("", 0))
|
151 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
152 |
+
return s.getsockname()[1]
|
153 |
+
finally:
|
154 |
+
s.close()
|
relighting/image_processor.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageChops
|
4 |
+
import skimage
|
5 |
+
try:
|
6 |
+
import cv2
|
7 |
+
except:
|
8 |
+
pass
|
9 |
+
|
10 |
+
def fill_image(image, mask_ball, x, y, size, color=(255,255,255)):
|
11 |
+
if isinstance(image, Image.Image):
|
12 |
+
result = np.array(image)
|
13 |
+
else:
|
14 |
+
result = image.copy()
|
15 |
+
|
16 |
+
result[y:y+size, x:x+size][mask_ball] = color
|
17 |
+
|
18 |
+
if isinstance(image, Image.Image):
|
19 |
+
result = Image.fromarray(result)
|
20 |
+
|
21 |
+
return result
|
22 |
+
|
23 |
+
def pil_square_image(image, desired_size = (512,512), interpolation=Image.LANCZOS):
|
24 |
+
"""
|
25 |
+
Make top-bottom border
|
26 |
+
"""
|
27 |
+
# Don't resize if already desired size (Avoid aliasing problem)
|
28 |
+
if image.size == desired_size:
|
29 |
+
return image
|
30 |
+
|
31 |
+
# Calculate the scale factor
|
32 |
+
scale_factor = min(desired_size[0] / image.width, desired_size[1] / image.height)
|
33 |
+
|
34 |
+
# Resize the image
|
35 |
+
resized_image = image.resize((int(image.width * scale_factor), int(image.height * scale_factor)), interpolation)
|
36 |
+
|
37 |
+
# Create a new blank image with the desired size and black border
|
38 |
+
new_image = Image.new("RGB", desired_size, color=(0, 0, 0))
|
39 |
+
|
40 |
+
# Paste the resized image onto the new image, centered
|
41 |
+
new_image.paste(resized_image, ((desired_size[0] - resized_image.width) // 2, (desired_size[1] - resized_image.height) // 2))
|
42 |
+
|
43 |
+
return new_image
|
44 |
+
|
45 |
+
# https://stackoverflow.com/questions/19271692/removing-borders-from-an-image-in-python
|
46 |
+
def remove_borders(image):
|
47 |
+
bg = Image.new(image.mode, image.size, image.getpixel((0,0)))
|
48 |
+
diff = ImageChops.difference(image, bg)
|
49 |
+
diff = ImageChops.add(diff, diff, 2.0, -100)
|
50 |
+
bbox = diff.getbbox()
|
51 |
+
if bbox:
|
52 |
+
return image.crop(bbox)
|
53 |
+
|
54 |
+
# Taken from https://huggingface.co/lllyasviel/sd-controlnet-normal
|
55 |
+
def estimate_scene_normal(image, depth_estimator):
|
56 |
+
# can be improve speed do not going back and float between numpy and torch
|
57 |
+
normal_image = depth_estimator(image)['predicted_depth'][0]
|
58 |
+
|
59 |
+
normal_image = normal_image.numpy()
|
60 |
+
|
61 |
+
# upsizing image depth to match input
|
62 |
+
hw = np.array(image).shape[:2]
|
63 |
+
normal_image = skimage.transform.resize(normal_image, hw, preserve_range=True)
|
64 |
+
|
65 |
+
image_depth = normal_image.copy()
|
66 |
+
image_depth -= np.min(image_depth)
|
67 |
+
image_depth /= np.max(image_depth)
|
68 |
+
|
69 |
+
bg_threhold = 0.4
|
70 |
+
|
71 |
+
x = cv2.Sobel(normal_image, cv2.CV_32F, 1, 0, ksize=3)
|
72 |
+
x[image_depth < bg_threhold] = 0
|
73 |
+
|
74 |
+
y = cv2.Sobel(normal_image, cv2.CV_32F, 0, 1, ksize=3)
|
75 |
+
y[image_depth < bg_threhold] = 0
|
76 |
+
|
77 |
+
z = np.ones_like(x) * np.pi * 2.0
|
78 |
+
|
79 |
+
normal_image = np.stack([x, y, z], axis=2)
|
80 |
+
normal_image /= np.sum(normal_image ** 2.0, axis=2, keepdims=True) ** 0.5
|
81 |
+
|
82 |
+
# rescale back to image size
|
83 |
+
return normal_image
|
84 |
+
|
85 |
+
def estimate_scene_depth(image, depth_estimator):
|
86 |
+
#image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
87 |
+
#with torch.no_grad(), torch.autocast("cuda"):
|
88 |
+
# depth_map = depth_estimator(image).predicted_depth
|
89 |
+
|
90 |
+
depth_map = depth_estimator(image)['predicted_depth']
|
91 |
+
W, H = image.size
|
92 |
+
depth_map = torch.nn.functional.interpolate(
|
93 |
+
depth_map.unsqueeze(1),
|
94 |
+
size=(H, W),
|
95 |
+
mode="bicubic",
|
96 |
+
align_corners=False,
|
97 |
+
)
|
98 |
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
99 |
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
100 |
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
101 |
+
image = torch.cat([depth_map] * 3, dim=1)
|
102 |
+
|
103 |
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
104 |
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
105 |
+
return image
|
106 |
+
|
107 |
+
def fill_depth_circular(depth_image, x, y, r):
|
108 |
+
depth_image = np.array(depth_image)
|
109 |
+
|
110 |
+
for i in range(depth_image.shape[0]):
|
111 |
+
for j in range(depth_image.shape[1]):
|
112 |
+
xy = (i - x - r//2)**2 + (j - y - r//2)**2
|
113 |
+
# if xy <= rr**2:
|
114 |
+
# depth_image[j, i, :] = 255
|
115 |
+
# depth_image[j, i, :] = int(minv + (maxv - minv) * z)
|
116 |
+
if xy <= (r // 2)**2:
|
117 |
+
depth_image[j, i, :] = 255
|
118 |
+
|
119 |
+
depth_image = Image.fromarray(depth_image)
|
120 |
+
return depth_image
|
121 |
+
|
122 |
+
|
123 |
+
def merge_normal_map(normal_map, normal_ball, mask_ball, x, y):
|
124 |
+
"""
|
125 |
+
Merge a ball to normal map using mask
|
126 |
+
@params
|
127 |
+
normal_amp (np.array) - normal map of the scene [height, width, 3]
|
128 |
+
normal_ball (np.array) - normal map of the ball [ball_height, ball_width, 3]
|
129 |
+
mask_ball (np.array) - mask of the ball [ball_height, ball_width]
|
130 |
+
x (int) - x position of the ball (top-left)
|
131 |
+
y (int) - y position of the ball (top-left)
|
132 |
+
@return
|
133 |
+
normal_mapthe merge normal map [height, width, 3]
|
134 |
+
"""
|
135 |
+
result = normal_map.copy()
|
136 |
+
|
137 |
+
mask_ball = mask_ball[..., None]
|
138 |
+
ball = (normal_ball * mask_ball) # alpha blending the ball
|
139 |
+
unball = (normal_map[y:y+normal_ball.shape[0], x:x+normal_ball.shape[1]] * (1 - mask_ball)) # alpha blending the normal map
|
140 |
+
result[y:y+normal_ball.shape[0], x:x+normal_ball.shape[1]] = ball+unball # add them together
|
141 |
+
return result
|
relighting/inpainter.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import ControlNetModel, AutoencoderKL
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from transformers import pipeline as transformers_pipeline
|
8 |
+
|
9 |
+
from relighting.pipeline import CustomStableDiffusionControlNetInpaintPipeline
|
10 |
+
from relighting.pipeline_inpaintonly import CustomStableDiffusionInpaintPipeline, CustomStableDiffusionXLInpaintPipeline
|
11 |
+
from relighting.argument import SAMPLERS, VAE_MODELS, DEPTH_ESTIMATOR, get_control_signal_type
|
12 |
+
from relighting.image_processor import (
|
13 |
+
estimate_scene_depth,
|
14 |
+
estimate_scene_normal,
|
15 |
+
merge_normal_map,
|
16 |
+
fill_depth_circular
|
17 |
+
)
|
18 |
+
from relighting.ball_processor import get_ideal_normal_ball, crop_ball
|
19 |
+
import pickle
|
20 |
+
|
21 |
+
from relighting.pipeline_xl import CustomStableDiffusionXLControlNetInpaintPipeline
|
22 |
+
|
23 |
+
class NoWaterMark:
|
24 |
+
def apply_watermark(self, *args, **kwargs):
|
25 |
+
return args[0]
|
26 |
+
|
27 |
+
class ControlSignalGenerator():
|
28 |
+
def __init__(self, sd_arch, control_signal_type, device):
|
29 |
+
self.sd_arch = sd_arch
|
30 |
+
self.control_signal_type = control_signal_type
|
31 |
+
self.device = device
|
32 |
+
|
33 |
+
def process_sd_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None):
|
34 |
+
if getattr(self, 'depth_estimator', None) is None:
|
35 |
+
self.depth_estimator = transformers_pipeline("depth-estimation", device=self.device.index)
|
36 |
+
|
37 |
+
control_image = self.depth_estimator(input_image)['depth']
|
38 |
+
control_image = np.array(control_image)
|
39 |
+
control_image = control_image[:, :, None]
|
40 |
+
control_image = np.concatenate([control_image, control_image, control_image], axis=2)
|
41 |
+
control_image = Image.fromarray(control_image)
|
42 |
+
|
43 |
+
control_image = fill_depth_circular(control_image, x, y, r)
|
44 |
+
return control_image
|
45 |
+
|
46 |
+
def process_sdxl_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None):
|
47 |
+
if getattr(self, 'depth_estimator', None) is None:
|
48 |
+
self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index)
|
49 |
+
|
50 |
+
control_image = estimate_scene_depth(input_image, depth_estimator=self.depth_estimator)
|
51 |
+
xs = [x] if not isinstance(x, list) else x
|
52 |
+
ys = [y] if not isinstance(y, list) else y
|
53 |
+
rs = [r] if not isinstance(r, list) else r
|
54 |
+
|
55 |
+
for x, y, r in zip(xs, ys, rs):
|
56 |
+
#print(f"depth at {x}, {y}, {r}")
|
57 |
+
control_image = fill_depth_circular(control_image, x, y, r)
|
58 |
+
return control_image
|
59 |
+
|
60 |
+
def process_sd_normal(self, input_image, normal_ball, mask_ball, x, y, r=None, normal_ball_path=None):
|
61 |
+
if getattr(self, 'depth_estimator', None) is None:
|
62 |
+
self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index)
|
63 |
+
|
64 |
+
normal_scene = estimate_scene_normal(input_image, depth_estimator=self.depth_estimator)
|
65 |
+
normal_image = merge_normal_map(normal_scene, normal_ball, mask_ball, x, y)
|
66 |
+
normal_image = (normal_image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
|
67 |
+
control_image = Image.fromarray(normal_image)
|
68 |
+
return control_image
|
69 |
+
|
70 |
+
def __call__(self, *args, **kwargs):
|
71 |
+
process_fn = getattr(self, f"process_{self.sd_arch}_{self.control_signal_type}", None)
|
72 |
+
if process_fn is None:
|
73 |
+
raise ValueError
|
74 |
+
else:
|
75 |
+
return process_fn(*args, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
class BallInpainter():
|
79 |
+
def __init__(self, pipeline, sd_arch, control_generator, disable_water_mask=True):
|
80 |
+
self.pipeline = pipeline
|
81 |
+
self.sd_arch = sd_arch
|
82 |
+
self.control_generator = control_generator
|
83 |
+
self.median = {}
|
84 |
+
if disable_water_mask:
|
85 |
+
self._disable_water_mask()
|
86 |
+
|
87 |
+
def _disable_water_mask(self):
|
88 |
+
if hasattr(self.pipeline, "watermark"):
|
89 |
+
self.pipeline.watermark = NoWaterMark()
|
90 |
+
print("Disabled watermasking")
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def from_sd(cls,
|
94 |
+
model,
|
95 |
+
controlnet=None,
|
96 |
+
device=0,
|
97 |
+
sampler="unipc",
|
98 |
+
torch_dtype=torch.float16,
|
99 |
+
disable_water_mask=True,
|
100 |
+
offload=False
|
101 |
+
):
|
102 |
+
if controlnet is not None:
|
103 |
+
control_signal_type = get_control_signal_type(controlnet)
|
104 |
+
controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch.float16)
|
105 |
+
pipe = CustomStableDiffusionControlNetInpaintPipeline.from_pretrained(
|
106 |
+
model,
|
107 |
+
controlnet=controlnet,
|
108 |
+
torch_dtype=torch_dtype,
|
109 |
+
).to(device)
|
110 |
+
control_generator = ControlSignalGenerator("sd", control_signal_type, device=device)
|
111 |
+
else:
|
112 |
+
pipe = CustomStableDiffusionInpaintPipeline.from_pretrained(
|
113 |
+
model,
|
114 |
+
torch_dtype=torch_dtype,
|
115 |
+
).to(device)
|
116 |
+
control_generator = None
|
117 |
+
|
118 |
+
try:
|
119 |
+
if torch_dtype==torch.float16 and device != torch.device("cpu"):
|
120 |
+
pipe.enable_xformers_memory_efficient_attention()
|
121 |
+
except:
|
122 |
+
pass
|
123 |
+
pipe.set_progress_bar_config(disable=True)
|
124 |
+
|
125 |
+
pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config)
|
126 |
+
|
127 |
+
return BallInpainter(pipe, "sd", control_generator, disable_water_mask)
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_sdxl(cls,
|
131 |
+
model,
|
132 |
+
controlnet=None,
|
133 |
+
device=0,
|
134 |
+
sampler="unipc",
|
135 |
+
torch_dtype=torch.float16,
|
136 |
+
disable_water_mask=True,
|
137 |
+
use_fixed_vae=True,
|
138 |
+
offload=False
|
139 |
+
):
|
140 |
+
vae = VAE_MODELS["sdxl"]
|
141 |
+
vae = AutoencoderKL.from_pretrained(vae, torch_dtype=torch_dtype).to(device) if use_fixed_vae else None
|
142 |
+
extra_kwargs = {"vae": vae} if vae is not None else {}
|
143 |
+
|
144 |
+
if controlnet is not None:
|
145 |
+
control_signal_type = get_control_signal_type(controlnet)
|
146 |
+
controlnet = ControlNetModel.from_pretrained(
|
147 |
+
controlnet,
|
148 |
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
149 |
+
use_safetensors=True,
|
150 |
+
torch_dtype=torch_dtype,
|
151 |
+
).to(device)
|
152 |
+
pipe = CustomStableDiffusionXLControlNetInpaintPipeline.from_pretrained(
|
153 |
+
model,
|
154 |
+
controlnet=controlnet,
|
155 |
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
156 |
+
use_safetensors=True,
|
157 |
+
torch_dtype=torch_dtype,
|
158 |
+
**extra_kwargs,
|
159 |
+
).to(device)
|
160 |
+
control_generator = ControlSignalGenerator("sdxl", control_signal_type, device=device)
|
161 |
+
else:
|
162 |
+
pipe = CustomStableDiffusionXLInpaintPipeline.from_pretrained(
|
163 |
+
model,
|
164 |
+
variant="fp16" if torch_dtype == torch.float16 else None,
|
165 |
+
use_safetensors=True,
|
166 |
+
torch_dtype=torch_dtype,
|
167 |
+
**extra_kwargs,
|
168 |
+
).to(device)
|
169 |
+
control_generator = None
|
170 |
+
|
171 |
+
try:
|
172 |
+
if torch_dtype==torch.float16 and device != torch.device("cpu"):
|
173 |
+
pipe.enable_xformers_memory_efficient_attention()
|
174 |
+
except:
|
175 |
+
pass
|
176 |
+
|
177 |
+
if offload and device != torch.device("cpu"):
|
178 |
+
pipe.enable_model_cpu_offload()
|
179 |
+
pipe.set_progress_bar_config(disable=True)
|
180 |
+
pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config)
|
181 |
+
|
182 |
+
return BallInpainter(pipe, "sdxl", control_generator, disable_water_mask)
|
183 |
+
|
184 |
+
# TODO: this method should be replaced by inpaint(), but we'll leave it here for now
|
185 |
+
# otherwise, the existing experiment code will break down
|
186 |
+
def __call__(self, *args, **kwargs):
|
187 |
+
return self.pipeline(*args, **kwargs)
|
188 |
+
|
189 |
+
def _default_height_width(self, height=None, width=None):
|
190 |
+
if (height is not None) and (width is not None):
|
191 |
+
return height, width
|
192 |
+
if self.sd_arch == "sd":
|
193 |
+
return (512, 512)
|
194 |
+
elif self.sd_arch == "sdxl":
|
195 |
+
return (1024, 1024)
|
196 |
+
else:
|
197 |
+
raise NotImplementedError
|
198 |
+
|
199 |
+
# this method is for sanity check only
|
200 |
+
def get_cache_control_image(self):
|
201 |
+
control_image = getattr(self, "cache_control_image", None)
|
202 |
+
return control_image
|
203 |
+
|
204 |
+
def prepare_control_signal(self, image, controlnet_conditioning_scale, extra_kwargs):
|
205 |
+
if self.control_generator is not None:
|
206 |
+
control_image = self.control_generator(image, **extra_kwargs)
|
207 |
+
controlnet_kwargs = {
|
208 |
+
"control_image": control_image,
|
209 |
+
"controlnet_conditioning_scale": controlnet_conditioning_scale
|
210 |
+
}
|
211 |
+
self.cache_control_image = control_image
|
212 |
+
else:
|
213 |
+
controlnet_kwargs = {}
|
214 |
+
|
215 |
+
return controlnet_kwargs
|
216 |
+
|
217 |
+
def get_cache_median(self, it):
|
218 |
+
if it in self.median: return self.median[it]
|
219 |
+
else: return None
|
220 |
+
|
221 |
+
def reset_median(self):
|
222 |
+
self.median = {}
|
223 |
+
print("Reset median")
|
224 |
+
|
225 |
+
def load_median(self, path):
|
226 |
+
if os.path.exists(path):
|
227 |
+
with open(path, "rb") as f:
|
228 |
+
self.median = pickle.load(f)
|
229 |
+
print(f"Loaded median from {path}")
|
230 |
+
else:
|
231 |
+
print(f"Median not found at {path}!")
|
232 |
+
|
233 |
+
def inpaint_iterative(
|
234 |
+
self,
|
235 |
+
prompt=None,
|
236 |
+
negative_prompt="",
|
237 |
+
num_inference_steps=30,
|
238 |
+
generator=None, # TODO: remove this
|
239 |
+
image=None,
|
240 |
+
mask_image=None,
|
241 |
+
height=None,
|
242 |
+
width=None,
|
243 |
+
controlnet_conditioning_scale=0.5,
|
244 |
+
num_images_per_prompt=1,
|
245 |
+
current_seed=0,
|
246 |
+
cross_attention_kwargs={},
|
247 |
+
strength=0.8,
|
248 |
+
num_iteration=2,
|
249 |
+
ball_per_iteration=30,
|
250 |
+
agg_mode="median",
|
251 |
+
save_intermediate=True,
|
252 |
+
cache_dir="./temp_inpaint_iterative",
|
253 |
+
disable_progress=False,
|
254 |
+
prompt_embeds=None,
|
255 |
+
pooled_prompt_embeds=None,
|
256 |
+
use_cache_median=False,
|
257 |
+
guidance_scale=5.0, # In the paper, we use guidance scale to 5.0 (same as pipeline_xl.py)
|
258 |
+
**extra_kwargs,
|
259 |
+
):
|
260 |
+
def computeMedian(ball_images):
|
261 |
+
all = np.stack(ball_images, axis=0)
|
262 |
+
median = np.median(all, axis=0)
|
263 |
+
idx_median = np.argsort(all, axis=0)[all.shape[0]//2]
|
264 |
+
# print(all.shape)
|
265 |
+
# print(idx_median.shape)
|
266 |
+
return median, idx_median
|
267 |
+
|
268 |
+
def generate_balls(avg_image, current_strength, ball_per_iteration, current_iteration):
|
269 |
+
print(f"Inpainting balls for {current_iteration} iteration...")
|
270 |
+
controlnet_kwargs = self.prepare_control_signal(
|
271 |
+
image=avg_image,
|
272 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
273 |
+
extra_kwargs=extra_kwargs,
|
274 |
+
)
|
275 |
+
|
276 |
+
ball_images = []
|
277 |
+
for i in tqdm(range(ball_per_iteration), disable=disable_progress):
|
278 |
+
seed = current_seed + i
|
279 |
+
new_generator = torch.Generator().manual_seed(seed)
|
280 |
+
output_image = self.pipeline(
|
281 |
+
prompt=prompt,
|
282 |
+
negative_prompt=negative_prompt,
|
283 |
+
num_inference_steps=num_inference_steps,
|
284 |
+
generator=new_generator,
|
285 |
+
image=avg_image,
|
286 |
+
mask_image=mask_image,
|
287 |
+
height=height,
|
288 |
+
width=width,
|
289 |
+
num_images_per_prompt=num_images_per_prompt,
|
290 |
+
strength=current_strength,
|
291 |
+
newx=x,
|
292 |
+
newy=y,
|
293 |
+
newr=r,
|
294 |
+
current_seed=seed,
|
295 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
296 |
+
prompt_embeds=prompt_embeds,
|
297 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
298 |
+
guidance_scale=guidance_scale,
|
299 |
+
**controlnet_kwargs
|
300 |
+
).images[0]
|
301 |
+
|
302 |
+
ball_image = crop_ball(output_image, mask_ball_for_crop, x, y, r)
|
303 |
+
ball_images.append(ball_image)
|
304 |
+
|
305 |
+
if save_intermediate:
|
306 |
+
os.makedirs(os.path.join(cache_dir, str(current_iteration)), mode=0o777, exist_ok=True)
|
307 |
+
output_image.save(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png"))
|
308 |
+
Image.fromarray(ball_image).save(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png"))
|
309 |
+
# chmod 777
|
310 |
+
os.chmod(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png"), 0o0777)
|
311 |
+
os.chmod(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png"), 0o0777)
|
312 |
+
|
313 |
+
|
314 |
+
return ball_images
|
315 |
+
|
316 |
+
if save_intermediate:
|
317 |
+
os.makedirs(cache_dir, exist_ok=True)
|
318 |
+
|
319 |
+
height, width = self._default_height_width(height, width)
|
320 |
+
|
321 |
+
x = extra_kwargs["x"]
|
322 |
+
y = extra_kwargs["y"]
|
323 |
+
r = 256 if "r" not in extra_kwargs else extra_kwargs["r"]
|
324 |
+
_, mask_ball_for_crop = get_ideal_normal_ball(size=r)
|
325 |
+
|
326 |
+
# generate initial average ball
|
327 |
+
avg_image = image
|
328 |
+
ball_images = generate_balls(
|
329 |
+
avg_image,
|
330 |
+
current_strength=1.0,
|
331 |
+
ball_per_iteration=ball_per_iteration,
|
332 |
+
current_iteration=0,
|
333 |
+
)
|
334 |
+
|
335 |
+
# ball refinement loop
|
336 |
+
image = np.array(image)
|
337 |
+
for it in range(1, num_iteration+1):
|
338 |
+
if use_cache_median and (self.get_cache_median(it) is not None):
|
339 |
+
print("Use existing median")
|
340 |
+
all = np.stack(ball_images, axis=0)
|
341 |
+
idx_median = self.get_cache_median(it)
|
342 |
+
avg_ball = all[idx_median,
|
343 |
+
np.arange(idx_median.shape[0])[:, np.newaxis, np.newaxis],
|
344 |
+
np.arange(idx_median.shape[1])[np.newaxis, :, np.newaxis],
|
345 |
+
np.arange(idx_median.shape[2])[np.newaxis, np.newaxis, :]
|
346 |
+
]
|
347 |
+
else:
|
348 |
+
avg_ball, idx_median = computeMedian(ball_images)
|
349 |
+
print("Add new median")
|
350 |
+
self.median[it] = idx_median
|
351 |
+
|
352 |
+
avg_image = merge_normal_map(image, avg_ball, mask_ball_for_crop, x, y)
|
353 |
+
avg_image = Image.fromarray(avg_image.astype(np.uint8))
|
354 |
+
if save_intermediate:
|
355 |
+
avg_image.save(os.path.join(cache_dir, f"average_{it}.png"))
|
356 |
+
# chmod777
|
357 |
+
os.chmod(os.path.join(cache_dir, f"average_{it}.png"), 0o0777)
|
358 |
+
|
359 |
+
ball_images = generate_balls(
|
360 |
+
avg_image,
|
361 |
+
current_strength=strength,
|
362 |
+
ball_per_iteration=ball_per_iteration if it < num_iteration else 1,
|
363 |
+
current_iteration=it,
|
364 |
+
)
|
365 |
+
|
366 |
+
# TODO: add algorithm for select the best ball
|
367 |
+
best_ball = ball_images[0]
|
368 |
+
output_image = merge_normal_map(image, best_ball, mask_ball_for_crop, x, y)
|
369 |
+
return Image.fromarray(output_image.astype(np.uint8))
|
370 |
+
|
371 |
+
def inpaint(
|
372 |
+
self,
|
373 |
+
prompt=None,
|
374 |
+
negative_prompt=None,
|
375 |
+
num_inference_steps=30,
|
376 |
+
generator=None,
|
377 |
+
image=None,
|
378 |
+
mask_image=None,
|
379 |
+
height=None,
|
380 |
+
width=None,
|
381 |
+
controlnet_conditioning_scale=0.5,
|
382 |
+
num_images_per_prompt=1,
|
383 |
+
strength=1.0,
|
384 |
+
current_seed=0,
|
385 |
+
cross_attention_kwargs={},
|
386 |
+
prompt_embeds=None,
|
387 |
+
pooled_prompt_embeds=None,
|
388 |
+
guidance_scale=5.0, # (same as pipeline_xl.py)
|
389 |
+
**extra_kwargs,
|
390 |
+
):
|
391 |
+
height, width = self._default_height_width(height, width)
|
392 |
+
|
393 |
+
controlnet_kwargs = self.prepare_control_signal(
|
394 |
+
image=image,
|
395 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
396 |
+
extra_kwargs=extra_kwargs,
|
397 |
+
)
|
398 |
+
|
399 |
+
if generator is None:
|
400 |
+
generator = torch.Generator().manual_seed(0)
|
401 |
+
|
402 |
+
output_image = self.pipeline(
|
403 |
+
prompt=prompt,
|
404 |
+
negative_prompt=negative_prompt,
|
405 |
+
num_inference_steps=num_inference_steps,
|
406 |
+
generator=generator,
|
407 |
+
image=image,
|
408 |
+
mask_image=mask_image,
|
409 |
+
height=height,
|
410 |
+
width=width,
|
411 |
+
num_images_per_prompt=num_images_per_prompt,
|
412 |
+
strength=strength,
|
413 |
+
newx = extra_kwargs["x"],
|
414 |
+
newy = extra_kwargs["y"],
|
415 |
+
newr = getattr(extra_kwargs, "r", 256), # default to ball_size = 256
|
416 |
+
current_seed=current_seed,
|
417 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
418 |
+
prompt_embeds=prompt_embeds,
|
419 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
420 |
+
guidance_scale=guidance_scale,
|
421 |
+
**controlnet_kwargs
|
422 |
+
)
|
423 |
+
|
424 |
+
return output_image
|
relighting/mask_utils.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
import cv2
|
3 |
+
except:
|
4 |
+
pass
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
from relighting.ball_processor import get_ideal_normal_ball
|
8 |
+
|
9 |
+
def create_grid(image_size, n_ball, size):
|
10 |
+
height, width = image_size
|
11 |
+
nx, ny = n_ball
|
12 |
+
if nx * ny == 1:
|
13 |
+
grid = np.array([[(height-size)//2, (width-size)//2]])
|
14 |
+
else:
|
15 |
+
height_ = np.linspace(0, height-size, nx).astype(int)
|
16 |
+
weight_ = np.linspace(0, width-size, ny).astype(int)
|
17 |
+
hh, ww = np.meshgrid(height_, weight_)
|
18 |
+
grid = np.stack([hh,ww], axis = -1).reshape(-1,2)
|
19 |
+
|
20 |
+
return grid
|
21 |
+
|
22 |
+
class MaskGenerator():
|
23 |
+
def __init__(self, cache_mask=True):
|
24 |
+
self.cache_mask = cache_mask
|
25 |
+
self.all_masks = []
|
26 |
+
|
27 |
+
def clear_cache(self):
|
28 |
+
self.all_masks = []
|
29 |
+
|
30 |
+
def retrieve_masks(self):
|
31 |
+
return self.all_masks
|
32 |
+
|
33 |
+
def generate_grid(self, image, mask_ball, n_ball=16, size=128):
|
34 |
+
ball_positions = create_grid(image.size, n_ball, size)
|
35 |
+
# _, mask_ball = get_normal_ball(size)
|
36 |
+
|
37 |
+
masks = []
|
38 |
+
mask_template = np.zeros(image.size)
|
39 |
+
for x, y in ball_positions:
|
40 |
+
mask = mask_template.copy()
|
41 |
+
mask[y:y+size, x:x+size] = 255 * mask_ball
|
42 |
+
mask = Image.fromarray(mask.astype(np.uint8), "L")
|
43 |
+
masks.append(mask)
|
44 |
+
|
45 |
+
# if self.cache_mask:
|
46 |
+
# self.all_masks.append((x, y, size))
|
47 |
+
|
48 |
+
return masks, ball_positions
|
49 |
+
|
50 |
+
def generate_single(self, image, mask_ball, x, y, size):
|
51 |
+
w,h = image.size # numpy as (h,w) but PIL is (w,h)
|
52 |
+
mask = np.zeros((h,w))
|
53 |
+
mask[y:y+size, x:x+size] = 255 * mask_ball
|
54 |
+
mask = Image.fromarray(mask.astype(np.uint8), "L")
|
55 |
+
|
56 |
+
return mask
|
57 |
+
|
58 |
+
def generate_best(self, image, mask_ball, size):
|
59 |
+
w,h = image.size # numpy as (h,w) but PIL is (w,h)
|
60 |
+
mask = np.zeros((h,w))
|
61 |
+
|
62 |
+
(y, x), _ = find_best_location(np.array(image), ball_size=size)
|
63 |
+
mask[y:y+size, x:x+size] = 255 * mask_ball
|
64 |
+
mask = Image.fromarray(mask.astype(np.uint8), "L")
|
65 |
+
|
66 |
+
return mask, (x, y)
|
67 |
+
|
68 |
+
|
69 |
+
def get_only_high_freqency(image: np.array):
|
70 |
+
"""
|
71 |
+
Get only height freqency image by subtract low freqency (using gaussian blur)
|
72 |
+
@params image: np.array - image in RGB format [h,w,3]
|
73 |
+
@return high_frequency: np.array - high freqnecy image in grayscale format [h,w]
|
74 |
+
"""
|
75 |
+
|
76 |
+
# Convert to grayscale
|
77 |
+
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
78 |
+
|
79 |
+
# Subtract low freqency from high freqency
|
80 |
+
kernel_size = 11 # Adjust this according to your image size
|
81 |
+
high_frequency = gray - cv2.GaussianBlur(gray,(kernel_size, kernel_size), 0)
|
82 |
+
|
83 |
+
return high_frequency
|
84 |
+
|
85 |
+
def find_best_location(image, ball_size=128):
|
86 |
+
"""
|
87 |
+
Find the best location to place the ball (Eg. empty location)
|
88 |
+
@params image: np.array - image in RGB format [h,w,3]
|
89 |
+
@return min_pos: tuple - top left position of the best location (the location is in "Y,X" format)
|
90 |
+
@return min_val: float - the sum value contain in the window
|
91 |
+
"""
|
92 |
+
local_variance = get_only_high_freqency(image)
|
93 |
+
qsum = quicksum2d(local_variance)
|
94 |
+
|
95 |
+
min_val = None
|
96 |
+
min_pos = None
|
97 |
+
k = ball_size
|
98 |
+
for i in range(k-1, qsum.shape[0]):
|
99 |
+
for j in range(k-1, qsum.shape[1]):
|
100 |
+
A = 0 if i-k < 0 else qsum[i-k, j]
|
101 |
+
B = 0 if j-k < 0 else qsum[i, j-k]
|
102 |
+
C = 0 if (i-k < 0) or (j-k < 0) else qsum[i-k, j-k]
|
103 |
+
sum = qsum[i, j] - A - B + C
|
104 |
+
if (min_val is None) or (sum < min_val):
|
105 |
+
min_val = sum
|
106 |
+
min_pos = (i-k+1, j-k+1) # get top left position
|
107 |
+
|
108 |
+
return min_pos, min_val
|
109 |
+
|
110 |
+
def quicksum2d(x: np.array):
|
111 |
+
"""
|
112 |
+
Quick sum algorithm to find the window that have smallest sum with O(n^2) complexity
|
113 |
+
@params x: np.array - image in grayscale [h,w]
|
114 |
+
@return q: np.array - quick sum of the image for future seach in find_best_location [h,w]
|
115 |
+
"""
|
116 |
+
qsum = np.zeros(x.shape)
|
117 |
+
for i in range(x.shape[0]):
|
118 |
+
for j in range(x.shape[1]):
|
119 |
+
A = 0 if i-1 < 0 else qsum[i-1, j]
|
120 |
+
B = 0 if j-1 < 0 else qsum[i, j-1]
|
121 |
+
C = 0 if (i-1 < 0) or (j-1 < 0) else qsum[i-1, j-1]
|
122 |
+
qsum[i, j] = A + B - C + x[i, j]
|
123 |
+
|
124 |
+
return qsum
|
relighting/pipeline.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Dict, Any, Callable, Optional, Tuple
|
3 |
+
|
4 |
+
from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
|
5 |
+
from diffusers.models import ControlNetModel
|
6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
7 |
+
from diffusers import StableDiffusionControlNetInpaintPipeline
|
8 |
+
from diffusers.image_processor import PipelineImageInput
|
9 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
10 |
+
from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents
|
11 |
+
|
12 |
+
class CustomStableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetInpaintPipeline):
|
13 |
+
@torch.no_grad()
|
14 |
+
def __call__(
|
15 |
+
self,
|
16 |
+
prompt: Union[str, List[str]] = None,
|
17 |
+
image: PipelineImageInput = None,
|
18 |
+
mask_image: PipelineImageInput = None,
|
19 |
+
control_image: PipelineImageInput = None,
|
20 |
+
height: Optional[int] = None,
|
21 |
+
width: Optional[int] = None,
|
22 |
+
strength: float = 1.0,
|
23 |
+
num_inference_steps: int = 50,
|
24 |
+
guidance_scale: float = 7.5,
|
25 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
26 |
+
num_images_per_prompt: Optional[int] = 1,
|
27 |
+
eta: float = 0.0,
|
28 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
29 |
+
latents: Optional[torch.FloatTensor] = None,
|
30 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
31 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
32 |
+
output_type: Optional[str] = "pil",
|
33 |
+
return_dict: bool = True,
|
34 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
35 |
+
callback_steps: int = 1,
|
36 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
37 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
|
38 |
+
guess_mode: bool = False,
|
39 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
40 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
41 |
+
newx: int = 0,
|
42 |
+
newy: int = 0,
|
43 |
+
newr: int = 256,
|
44 |
+
current_seed=0,
|
45 |
+
use_noise_moving=True,
|
46 |
+
):
|
47 |
+
# OVERWRITE METHODS
|
48 |
+
self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionControlNetInpaintPipeline)
|
49 |
+
self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionControlNetInpaintPipeline)
|
50 |
+
|
51 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
52 |
+
|
53 |
+
# align format for control guidance
|
54 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
55 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
56 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
57 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
58 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
59 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
60 |
+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
61 |
+
control_guidance_end
|
62 |
+
]
|
63 |
+
|
64 |
+
# 1. Check inputs. Raise error if not correct
|
65 |
+
self.check_inputs(
|
66 |
+
prompt,
|
67 |
+
control_image,
|
68 |
+
height,
|
69 |
+
width,
|
70 |
+
callback_steps,
|
71 |
+
negative_prompt,
|
72 |
+
prompt_embeds,
|
73 |
+
negative_prompt_embeds,
|
74 |
+
controlnet_conditioning_scale,
|
75 |
+
control_guidance_start,
|
76 |
+
control_guidance_end,
|
77 |
+
)
|
78 |
+
|
79 |
+
# 2. Define call parameters
|
80 |
+
if prompt is not None and isinstance(prompt, str):
|
81 |
+
batch_size = 1
|
82 |
+
elif prompt is not None and isinstance(prompt, list):
|
83 |
+
batch_size = len(prompt)
|
84 |
+
else:
|
85 |
+
batch_size = prompt_embeds.shape[0]
|
86 |
+
|
87 |
+
device = self._execution_device
|
88 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
89 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
90 |
+
# corresponds to doing no classifier free guidance.
|
91 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
92 |
+
|
93 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
94 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
95 |
+
|
96 |
+
global_pool_conditions = (
|
97 |
+
controlnet.config.global_pool_conditions
|
98 |
+
if isinstance(controlnet, ControlNetModel)
|
99 |
+
else controlnet.nets[0].config.global_pool_conditions
|
100 |
+
)
|
101 |
+
guess_mode = guess_mode or global_pool_conditions
|
102 |
+
|
103 |
+
# 3. Encode input prompt
|
104 |
+
text_encoder_lora_scale = (
|
105 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
106 |
+
)
|
107 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
108 |
+
prompt,
|
109 |
+
device,
|
110 |
+
num_images_per_prompt,
|
111 |
+
do_classifier_free_guidance,
|
112 |
+
negative_prompt,
|
113 |
+
prompt_embeds=prompt_embeds,
|
114 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
115 |
+
lora_scale=text_encoder_lora_scale,
|
116 |
+
)
|
117 |
+
# For classifier free guidance, we need to do two forward passes.
|
118 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
119 |
+
# to avoid doing two forward passes
|
120 |
+
if do_classifier_free_guidance:
|
121 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
122 |
+
|
123 |
+
# 4. Prepare image
|
124 |
+
if isinstance(controlnet, ControlNetModel):
|
125 |
+
control_image = self.prepare_control_image(
|
126 |
+
image=control_image,
|
127 |
+
width=width,
|
128 |
+
height=height,
|
129 |
+
batch_size=batch_size * num_images_per_prompt,
|
130 |
+
num_images_per_prompt=num_images_per_prompt,
|
131 |
+
device=device,
|
132 |
+
dtype=controlnet.dtype,
|
133 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
134 |
+
guess_mode=guess_mode,
|
135 |
+
)
|
136 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
137 |
+
control_images = []
|
138 |
+
|
139 |
+
for control_image_ in control_image:
|
140 |
+
control_image_ = self.prepare_control_image(
|
141 |
+
image=control_image_,
|
142 |
+
width=width,
|
143 |
+
height=height,
|
144 |
+
batch_size=batch_size * num_images_per_prompt,
|
145 |
+
num_images_per_prompt=num_images_per_prompt,
|
146 |
+
device=device,
|
147 |
+
dtype=controlnet.dtype,
|
148 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
149 |
+
guess_mode=guess_mode,
|
150 |
+
)
|
151 |
+
|
152 |
+
control_images.append(control_image_)
|
153 |
+
|
154 |
+
control_image = control_images
|
155 |
+
else:
|
156 |
+
assert False
|
157 |
+
|
158 |
+
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
159 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
160 |
+
init_image = init_image.to(dtype=torch.float32)
|
161 |
+
|
162 |
+
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
163 |
+
|
164 |
+
masked_image = init_image * (mask < 0.5)
|
165 |
+
_, _, height, width = init_image.shape
|
166 |
+
|
167 |
+
# 5. Prepare timesteps
|
168 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
169 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
170 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
171 |
+
)
|
172 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
173 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
174 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
175 |
+
is_strength_max = strength == 1.0
|
176 |
+
|
177 |
+
# 6. Prepare latent variables
|
178 |
+
num_channels_latents = self.vae.config.latent_channels
|
179 |
+
num_channels_unet = self.unet.config.in_channels
|
180 |
+
return_image_latents = num_channels_unet == 4
|
181 |
+
|
182 |
+
# EDITED HERE
|
183 |
+
latents_outputs = self.prepare_latents(
|
184 |
+
batch_size * num_images_per_prompt,
|
185 |
+
num_channels_latents,
|
186 |
+
height,
|
187 |
+
width,
|
188 |
+
prompt_embeds.dtype,
|
189 |
+
device,
|
190 |
+
generator,
|
191 |
+
latents,
|
192 |
+
image=init_image,
|
193 |
+
timestep=latent_timestep,
|
194 |
+
is_strength_max=is_strength_max,
|
195 |
+
return_noise=True,
|
196 |
+
return_image_latents=return_image_latents,
|
197 |
+
newx=newx,
|
198 |
+
newy=newy,
|
199 |
+
newr=newr,
|
200 |
+
current_seed=current_seed,
|
201 |
+
use_noise_moving=use_noise_moving,
|
202 |
+
)
|
203 |
+
|
204 |
+
if return_image_latents:
|
205 |
+
latents, noise, image_latents = latents_outputs
|
206 |
+
else:
|
207 |
+
latents, noise = latents_outputs
|
208 |
+
|
209 |
+
# 7. Prepare mask latent variables
|
210 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
211 |
+
mask,
|
212 |
+
masked_image,
|
213 |
+
batch_size * num_images_per_prompt,
|
214 |
+
height,
|
215 |
+
width,
|
216 |
+
prompt_embeds.dtype,
|
217 |
+
device,
|
218 |
+
generator,
|
219 |
+
do_classifier_free_guidance,
|
220 |
+
)
|
221 |
+
|
222 |
+
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
223 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
224 |
+
|
225 |
+
# 7.1 Create tensor stating which controlnets to keep
|
226 |
+
controlnet_keep = []
|
227 |
+
for i in range(len(timesteps)):
|
228 |
+
keeps = [
|
229 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
230 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
231 |
+
]
|
232 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
233 |
+
|
234 |
+
# 8. Denoising loop
|
235 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
236 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
237 |
+
for i, t in enumerate(timesteps):
|
238 |
+
# expand the latents if we are doing classifier free guidance
|
239 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
240 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
241 |
+
|
242 |
+
# controlnet(s) inference
|
243 |
+
if guess_mode and do_classifier_free_guidance:
|
244 |
+
# Infer ControlNet only for the conditional batch.
|
245 |
+
control_model_input = latents
|
246 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
247 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
248 |
+
else:
|
249 |
+
control_model_input = latent_model_input
|
250 |
+
controlnet_prompt_embeds = prompt_embeds
|
251 |
+
|
252 |
+
if isinstance(controlnet_keep[i], list):
|
253 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
254 |
+
else:
|
255 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
256 |
+
if isinstance(controlnet_cond_scale, list):
|
257 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
258 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
259 |
+
|
260 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
261 |
+
control_model_input,
|
262 |
+
t,
|
263 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
264 |
+
controlnet_cond=control_image,
|
265 |
+
conditioning_scale=cond_scale,
|
266 |
+
guess_mode=guess_mode,
|
267 |
+
return_dict=False,
|
268 |
+
)
|
269 |
+
|
270 |
+
if guess_mode and do_classifier_free_guidance:
|
271 |
+
# Infered ControlNet only for the conditional batch.
|
272 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
273 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
274 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
275 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
276 |
+
|
277 |
+
# predict the noise residual
|
278 |
+
if num_channels_unet == 9:
|
279 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
280 |
+
|
281 |
+
noise_pred = self.unet(
|
282 |
+
latent_model_input,
|
283 |
+
t,
|
284 |
+
encoder_hidden_states=prompt_embeds,
|
285 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
286 |
+
down_block_additional_residuals=down_block_res_samples,
|
287 |
+
mid_block_additional_residual=mid_block_res_sample,
|
288 |
+
return_dict=False,
|
289 |
+
)[0]
|
290 |
+
|
291 |
+
# perform guidance
|
292 |
+
if do_classifier_free_guidance:
|
293 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
294 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
295 |
+
|
296 |
+
# compute the previous noisy sample x_t -> x_t-1
|
297 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
298 |
+
|
299 |
+
if num_channels_unet == 4:
|
300 |
+
init_latents_proper = image_latents[:1]
|
301 |
+
init_mask = mask[:1]
|
302 |
+
|
303 |
+
if i < len(timesteps) - 1:
|
304 |
+
noise_timestep = timesteps[i + 1]
|
305 |
+
init_latents_proper = self.scheduler.add_noise(
|
306 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
307 |
+
)
|
308 |
+
|
309 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
310 |
+
|
311 |
+
# call the callback, if provided
|
312 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
313 |
+
progress_bar.update()
|
314 |
+
if callback is not None and i % callback_steps == 0:
|
315 |
+
callback(i, t, latents)
|
316 |
+
|
317 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
318 |
+
# manually for max memory savings
|
319 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
320 |
+
self.unet.to("cpu")
|
321 |
+
self.controlnet.to("cpu")
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
|
324 |
+
if not output_type == "latent":
|
325 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
326 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
327 |
+
else:
|
328 |
+
image = latents
|
329 |
+
has_nsfw_concept = None
|
330 |
+
|
331 |
+
if has_nsfw_concept is None:
|
332 |
+
do_denormalize = [True] * image.shape[0]
|
333 |
+
else:
|
334 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
335 |
+
|
336 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
337 |
+
|
338 |
+
# Offload all models
|
339 |
+
self.maybe_free_model_hooks()
|
340 |
+
|
341 |
+
if not return_dict:
|
342 |
+
return (image, has_nsfw_concept)
|
343 |
+
|
344 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
relighting/pipeline_inpaintonly.py
ADDED
@@ -0,0 +1,613 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Dict, Any, Callable, Optional, Tuple
|
3 |
+
|
4 |
+
from diffusers.image_processor import PipelineImageInput
|
5 |
+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
|
6 |
+
from diffusers.models import AsymmetricAutoencoderKL
|
7 |
+
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
8 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
9 |
+
from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents, rescale_noise_cfg
|
10 |
+
|
11 |
+
class CustomStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
12 |
+
@torch.no_grad()
|
13 |
+
def __call__(
|
14 |
+
self,
|
15 |
+
prompt: Union[str, List[str]] = None,
|
16 |
+
image: PipelineImageInput = None,
|
17 |
+
mask_image: PipelineImageInput = None,
|
18 |
+
masked_image_latents: torch.FloatTensor = None,
|
19 |
+
height: Optional[int] = None,
|
20 |
+
width: Optional[int] = None,
|
21 |
+
strength: float = 1.0,
|
22 |
+
num_inference_steps: int = 50,
|
23 |
+
guidance_scale: float = 7.5,
|
24 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
25 |
+
num_images_per_prompt: Optional[int] = 1,
|
26 |
+
eta: float = 0.0,
|
27 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
28 |
+
latents: Optional[torch.FloatTensor] = None,
|
29 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
30 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
31 |
+
output_type: Optional[str] = "pil",
|
32 |
+
return_dict: bool = True,
|
33 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
34 |
+
callback_steps: int = 1,
|
35 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
36 |
+
newx: int = 0,
|
37 |
+
newy: int = 0,
|
38 |
+
newr: int = 256,
|
39 |
+
current_seed=0,
|
40 |
+
use_noise_moving=True,
|
41 |
+
):
|
42 |
+
# OVERWRITE METHODS
|
43 |
+
self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionInpaintPipeline)
|
44 |
+
self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionInpaintPipeline)
|
45 |
+
|
46 |
+
# 0. Default height and width to unet
|
47 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
48 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
49 |
+
|
50 |
+
# 1. Check inputs
|
51 |
+
self.check_inputs(
|
52 |
+
prompt,
|
53 |
+
height,
|
54 |
+
width,
|
55 |
+
strength,
|
56 |
+
callback_steps,
|
57 |
+
negative_prompt,
|
58 |
+
prompt_embeds,
|
59 |
+
negative_prompt_embeds,
|
60 |
+
)
|
61 |
+
|
62 |
+
# 2. Define call parameters
|
63 |
+
if prompt is not None and isinstance(prompt, str):
|
64 |
+
batch_size = 1
|
65 |
+
elif prompt is not None and isinstance(prompt, list):
|
66 |
+
batch_size = len(prompt)
|
67 |
+
else:
|
68 |
+
batch_size = prompt_embeds.shape[0]
|
69 |
+
|
70 |
+
device = self._execution_device
|
71 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
72 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
73 |
+
# corresponds to doing no classifier free guidance.
|
74 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
75 |
+
|
76 |
+
# 3. Encode input prompt
|
77 |
+
text_encoder_lora_scale = (
|
78 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
79 |
+
)
|
80 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
81 |
+
prompt,
|
82 |
+
device,
|
83 |
+
num_images_per_prompt,
|
84 |
+
do_classifier_free_guidance,
|
85 |
+
negative_prompt,
|
86 |
+
prompt_embeds=prompt_embeds,
|
87 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
88 |
+
lora_scale=text_encoder_lora_scale,
|
89 |
+
)
|
90 |
+
# For classifier free guidance, we need to do two forward passes.
|
91 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
92 |
+
# to avoid doing two forward passes
|
93 |
+
if do_classifier_free_guidance:
|
94 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
95 |
+
|
96 |
+
# 4. set timesteps
|
97 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
98 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
99 |
+
num_inference_steps=num_inference_steps, strength=strength, device=device
|
100 |
+
)
|
101 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
102 |
+
if num_inference_steps < 1:
|
103 |
+
raise ValueError(
|
104 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
105 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
106 |
+
)
|
107 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
108 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
109 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
110 |
+
is_strength_max = strength == 1.0
|
111 |
+
|
112 |
+
# 5. Preprocess mask and image
|
113 |
+
|
114 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
115 |
+
init_image = init_image.to(dtype=torch.float32)
|
116 |
+
|
117 |
+
# 6. Prepare latent variables
|
118 |
+
num_channels_latents = self.vae.config.latent_channels
|
119 |
+
num_channels_unet = self.unet.config.in_channels
|
120 |
+
return_image_latents = num_channels_unet == 4
|
121 |
+
|
122 |
+
latents_outputs = self.prepare_latents(
|
123 |
+
batch_size * num_images_per_prompt,
|
124 |
+
num_channels_latents,
|
125 |
+
height,
|
126 |
+
width,
|
127 |
+
prompt_embeds.dtype,
|
128 |
+
device,
|
129 |
+
generator,
|
130 |
+
latents,
|
131 |
+
image=init_image,
|
132 |
+
timestep=latent_timestep,
|
133 |
+
is_strength_max=is_strength_max,
|
134 |
+
return_noise=True,
|
135 |
+
return_image_latents=return_image_latents,
|
136 |
+
newx=newx,
|
137 |
+
newy=newy,
|
138 |
+
newr=newr,
|
139 |
+
current_seed=current_seed,
|
140 |
+
use_noise_moving=use_noise_moving,
|
141 |
+
)
|
142 |
+
|
143 |
+
if return_image_latents:
|
144 |
+
latents, noise, image_latents = latents_outputs
|
145 |
+
else:
|
146 |
+
latents, noise = latents_outputs
|
147 |
+
|
148 |
+
# 7. Prepare mask latent variables
|
149 |
+
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
150 |
+
|
151 |
+
if masked_image_latents is None:
|
152 |
+
masked_image = init_image * (mask_condition < 0.5)
|
153 |
+
else:
|
154 |
+
masked_image = masked_image_latents
|
155 |
+
|
156 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
157 |
+
mask_condition,
|
158 |
+
masked_image,
|
159 |
+
batch_size * num_images_per_prompt,
|
160 |
+
height,
|
161 |
+
width,
|
162 |
+
prompt_embeds.dtype,
|
163 |
+
device,
|
164 |
+
generator,
|
165 |
+
do_classifier_free_guidance,
|
166 |
+
)
|
167 |
+
|
168 |
+
# 8. Check that sizes of mask, masked image and latents match
|
169 |
+
if num_channels_unet == 9:
|
170 |
+
# default case for runwayml/stable-diffusion-inpainting
|
171 |
+
num_channels_mask = mask.shape[1]
|
172 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
173 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
174 |
+
raise ValueError(
|
175 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
176 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
177 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
178 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
179 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
180 |
+
)
|
181 |
+
elif num_channels_unet != 4:
|
182 |
+
raise ValueError(
|
183 |
+
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
184 |
+
)
|
185 |
+
|
186 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
187 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
188 |
+
|
189 |
+
# 10. Denoising loop
|
190 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
191 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
192 |
+
for i, t in enumerate(timesteps):
|
193 |
+
# expand the latents if we are doing classifier free guidance
|
194 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
195 |
+
|
196 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
197 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
198 |
+
|
199 |
+
if num_channels_unet == 9:
|
200 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
201 |
+
|
202 |
+
# predict the noise residual
|
203 |
+
noise_pred = self.unet(
|
204 |
+
latent_model_input,
|
205 |
+
t,
|
206 |
+
encoder_hidden_states=prompt_embeds,
|
207 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
208 |
+
return_dict=False,
|
209 |
+
)[0]
|
210 |
+
|
211 |
+
# perform guidance
|
212 |
+
if do_classifier_free_guidance:
|
213 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
214 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
215 |
+
|
216 |
+
# compute the previous noisy sample x_t -> x_t-1
|
217 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
218 |
+
|
219 |
+
if num_channels_unet == 4:
|
220 |
+
init_latents_proper = image_latents[:1]
|
221 |
+
init_mask = mask[:1]
|
222 |
+
|
223 |
+
if i < len(timesteps) - 1:
|
224 |
+
noise_timestep = timesteps[i + 1]
|
225 |
+
init_latents_proper = self.scheduler.add_noise(
|
226 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
227 |
+
)
|
228 |
+
|
229 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
230 |
+
|
231 |
+
# call the callback, if provided
|
232 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
233 |
+
progress_bar.update()
|
234 |
+
if callback is not None and i % callback_steps == 0:
|
235 |
+
callback(i, t, latents)
|
236 |
+
|
237 |
+
if not output_type == "latent":
|
238 |
+
condition_kwargs = {}
|
239 |
+
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
240 |
+
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
241 |
+
init_image_condition = init_image.clone()
|
242 |
+
init_image = self._encode_vae_image(init_image, generator=generator)
|
243 |
+
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
244 |
+
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
245 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
|
246 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
247 |
+
else:
|
248 |
+
image = latents
|
249 |
+
has_nsfw_concept = None
|
250 |
+
|
251 |
+
if has_nsfw_concept is None:
|
252 |
+
do_denormalize = [True] * image.shape[0]
|
253 |
+
else:
|
254 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
255 |
+
|
256 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
257 |
+
|
258 |
+
# Offload all models
|
259 |
+
self.maybe_free_model_hooks()
|
260 |
+
|
261 |
+
if not return_dict:
|
262 |
+
return (image, has_nsfw_concept)
|
263 |
+
|
264 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
265 |
+
|
266 |
+
class CustomStableDiffusionXLInpaintPipeline(StableDiffusionXLInpaintPipeline):
|
267 |
+
@torch.no_grad()
|
268 |
+
def __call__(
|
269 |
+
self,
|
270 |
+
prompt: Union[str, List[str]] = None,
|
271 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
272 |
+
image: PipelineImageInput = None,
|
273 |
+
mask_image: PipelineImageInput = None,
|
274 |
+
masked_image_latents: torch.FloatTensor = None,
|
275 |
+
height: Optional[int] = None,
|
276 |
+
width: Optional[int] = None,
|
277 |
+
strength: float = 0.9999,
|
278 |
+
num_inference_steps: int = 50,
|
279 |
+
denoising_start: Optional[float] = None,
|
280 |
+
denoising_end: Optional[float] = None,
|
281 |
+
guidance_scale: float = 7.5,
|
282 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
283 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
284 |
+
num_images_per_prompt: Optional[int] = 1,
|
285 |
+
eta: float = 0.0,
|
286 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
287 |
+
latents: Optional[torch.FloatTensor] = None,
|
288 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
289 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
290 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
291 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
292 |
+
output_type: Optional[str] = "pil",
|
293 |
+
return_dict: bool = True,
|
294 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
295 |
+
callback_steps: int = 1,
|
296 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
297 |
+
guidance_rescale: float = 0.0,
|
298 |
+
original_size: Tuple[int, int] = None,
|
299 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
300 |
+
target_size: Tuple[int, int] = None,
|
301 |
+
negative_original_size: Optional[Tuple[int, int]] = None,
|
302 |
+
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
303 |
+
negative_target_size: Optional[Tuple[int, int]] = None,
|
304 |
+
aesthetic_score: float = 6.0,
|
305 |
+
negative_aesthetic_score: float = 2.5,
|
306 |
+
newx: int = 0,
|
307 |
+
newy: int = 0,
|
308 |
+
newr: int = 256,
|
309 |
+
current_seed=0,
|
310 |
+
use_noise_moving=True,
|
311 |
+
):
|
312 |
+
# OVERWRITE METHODS
|
313 |
+
self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionXLInpaintPipeline)
|
314 |
+
self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionXLInpaintPipeline)
|
315 |
+
|
316 |
+
# 0. Default height and width to unet
|
317 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
318 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
319 |
+
|
320 |
+
# 1. Check inputs
|
321 |
+
self.check_inputs(
|
322 |
+
prompt,
|
323 |
+
prompt_2,
|
324 |
+
height,
|
325 |
+
width,
|
326 |
+
strength,
|
327 |
+
callback_steps,
|
328 |
+
negative_prompt,
|
329 |
+
negative_prompt_2,
|
330 |
+
prompt_embeds,
|
331 |
+
negative_prompt_embeds,
|
332 |
+
)
|
333 |
+
|
334 |
+
# 2. Define call parameters
|
335 |
+
if prompt is not None and isinstance(prompt, str):
|
336 |
+
batch_size = 1
|
337 |
+
elif prompt is not None and isinstance(prompt, list):
|
338 |
+
batch_size = len(prompt)
|
339 |
+
else:
|
340 |
+
batch_size = prompt_embeds.shape[0]
|
341 |
+
|
342 |
+
device = self._execution_device
|
343 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
344 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
345 |
+
# corresponds to doing no classifier free guidance.
|
346 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
347 |
+
|
348 |
+
# 3. Encode input prompt
|
349 |
+
text_encoder_lora_scale = (
|
350 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
351 |
+
)
|
352 |
+
|
353 |
+
(
|
354 |
+
prompt_embeds,
|
355 |
+
negative_prompt_embeds,
|
356 |
+
pooled_prompt_embeds,
|
357 |
+
negative_pooled_prompt_embeds,
|
358 |
+
) = self.encode_prompt(
|
359 |
+
prompt=prompt,
|
360 |
+
prompt_2=prompt_2,
|
361 |
+
device=device,
|
362 |
+
num_images_per_prompt=num_images_per_prompt,
|
363 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
364 |
+
negative_prompt=negative_prompt,
|
365 |
+
negative_prompt_2=negative_prompt_2,
|
366 |
+
prompt_embeds=prompt_embeds,
|
367 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
368 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
369 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
370 |
+
lora_scale=text_encoder_lora_scale,
|
371 |
+
)
|
372 |
+
|
373 |
+
# 4. set timesteps
|
374 |
+
def denoising_value_valid(dnv):
|
375 |
+
return isinstance(denoising_end, float) and 0 < dnv < 1
|
376 |
+
|
377 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
378 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
379 |
+
num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
|
380 |
+
)
|
381 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
382 |
+
if num_inference_steps < 1:
|
383 |
+
raise ValueError(
|
384 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
385 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
386 |
+
)
|
387 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
388 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
389 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
390 |
+
is_strength_max = strength == 1.0
|
391 |
+
|
392 |
+
# 5. Preprocess mask and image
|
393 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
394 |
+
init_image = init_image.to(dtype=torch.float32)
|
395 |
+
|
396 |
+
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
397 |
+
|
398 |
+
if masked_image_latents is not None:
|
399 |
+
masked_image = masked_image_latents
|
400 |
+
elif init_image.shape[1] == 4:
|
401 |
+
# if images are in latent space, we can't mask it
|
402 |
+
masked_image = None
|
403 |
+
else:
|
404 |
+
masked_image = init_image * (mask < 0.5)
|
405 |
+
|
406 |
+
# 6. Prepare latent variables
|
407 |
+
num_channels_latents = self.vae.config.latent_channels
|
408 |
+
num_channels_unet = self.unet.config.in_channels
|
409 |
+
return_image_latents = num_channels_unet == 4
|
410 |
+
|
411 |
+
# add_noise = True if denoising_start is None else False
|
412 |
+
latents_outputs = self.prepare_latents(
|
413 |
+
batch_size * num_images_per_prompt,
|
414 |
+
num_channels_latents,
|
415 |
+
height,
|
416 |
+
width,
|
417 |
+
prompt_embeds.dtype,
|
418 |
+
device,
|
419 |
+
generator,
|
420 |
+
latents,
|
421 |
+
image=init_image,
|
422 |
+
timestep=latent_timestep,
|
423 |
+
is_strength_max=is_strength_max,
|
424 |
+
return_noise=True,
|
425 |
+
return_image_latents=return_image_latents,
|
426 |
+
newx=newx,
|
427 |
+
newy=newy,
|
428 |
+
newr=newr,
|
429 |
+
current_seed=current_seed,
|
430 |
+
use_noise_moving=use_noise_moving,
|
431 |
+
)
|
432 |
+
|
433 |
+
if return_image_latents:
|
434 |
+
latents, noise, image_latents = latents_outputs
|
435 |
+
else:
|
436 |
+
latents, noise = latents_outputs
|
437 |
+
|
438 |
+
# 7. Prepare mask latent variables
|
439 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
440 |
+
mask,
|
441 |
+
masked_image,
|
442 |
+
batch_size * num_images_per_prompt,
|
443 |
+
height,
|
444 |
+
width,
|
445 |
+
prompt_embeds.dtype,
|
446 |
+
device,
|
447 |
+
generator,
|
448 |
+
do_classifier_free_guidance,
|
449 |
+
)
|
450 |
+
|
451 |
+
# 8. Check that sizes of mask, masked image and latents match
|
452 |
+
if num_channels_unet == 9:
|
453 |
+
# default case for runwayml/stable-diffusion-inpainting
|
454 |
+
num_channels_mask = mask.shape[1]
|
455 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
456 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
457 |
+
raise ValueError(
|
458 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
459 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
460 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
461 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
462 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
463 |
+
)
|
464 |
+
elif num_channels_unet != 4:
|
465 |
+
raise ValueError(
|
466 |
+
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
467 |
+
)
|
468 |
+
# 8.1 Prepare extra step kwargs.
|
469 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
470 |
+
|
471 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
472 |
+
height, width = latents.shape[-2:]
|
473 |
+
height = height * self.vae_scale_factor
|
474 |
+
width = width * self.vae_scale_factor
|
475 |
+
|
476 |
+
original_size = original_size or (height, width)
|
477 |
+
target_size = target_size or (height, width)
|
478 |
+
|
479 |
+
# 10. Prepare added time ids & embeddings
|
480 |
+
if negative_original_size is None:
|
481 |
+
negative_original_size = original_size
|
482 |
+
if negative_target_size is None:
|
483 |
+
negative_target_size = target_size
|
484 |
+
|
485 |
+
add_text_embeds = pooled_prompt_embeds
|
486 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
487 |
+
original_size,
|
488 |
+
crops_coords_top_left,
|
489 |
+
target_size,
|
490 |
+
aesthetic_score,
|
491 |
+
negative_aesthetic_score,
|
492 |
+
negative_original_size,
|
493 |
+
negative_crops_coords_top_left,
|
494 |
+
negative_target_size,
|
495 |
+
dtype=prompt_embeds.dtype,
|
496 |
+
)
|
497 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
498 |
+
|
499 |
+
if do_classifier_free_guidance:
|
500 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
501 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
502 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
503 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
504 |
+
|
505 |
+
prompt_embeds = prompt_embeds.to(device)
|
506 |
+
add_text_embeds = add_text_embeds.to(device)
|
507 |
+
add_time_ids = add_time_ids.to(device)
|
508 |
+
|
509 |
+
# 11. Denoising loop
|
510 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
511 |
+
|
512 |
+
if (
|
513 |
+
denoising_end is not None
|
514 |
+
and denoising_start is not None
|
515 |
+
and denoising_value_valid(denoising_end)
|
516 |
+
and denoising_value_valid(denoising_start)
|
517 |
+
and denoising_start >= denoising_end
|
518 |
+
):
|
519 |
+
raise ValueError(
|
520 |
+
f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
|
521 |
+
+ f" {denoising_end} when using type float."
|
522 |
+
)
|
523 |
+
elif denoising_end is not None and denoising_value_valid(denoising_end):
|
524 |
+
discrete_timestep_cutoff = int(
|
525 |
+
round(
|
526 |
+
self.scheduler.config.num_train_timesteps
|
527 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
528 |
+
)
|
529 |
+
)
|
530 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
531 |
+
timesteps = timesteps[:num_inference_steps]
|
532 |
+
|
533 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
534 |
+
for i, t in enumerate(timesteps):
|
535 |
+
# expand the latents if we are doing classifier free guidance
|
536 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
537 |
+
|
538 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
539 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
540 |
+
|
541 |
+
if num_channels_unet == 9:
|
542 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
543 |
+
|
544 |
+
# predict the noise residual
|
545 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
546 |
+
noise_pred = self.unet(
|
547 |
+
latent_model_input,
|
548 |
+
t,
|
549 |
+
encoder_hidden_states=prompt_embeds,
|
550 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
551 |
+
added_cond_kwargs=added_cond_kwargs,
|
552 |
+
return_dict=False,
|
553 |
+
)[0]
|
554 |
+
|
555 |
+
# perform guidance
|
556 |
+
if do_classifier_free_guidance:
|
557 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
558 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
559 |
+
|
560 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
561 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
562 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
563 |
+
|
564 |
+
# compute the previous noisy sample x_t -> x_t-1
|
565 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
566 |
+
|
567 |
+
if num_channels_unet == 4:
|
568 |
+
init_latents_proper = image_latents[:1]
|
569 |
+
init_mask = mask[:1]
|
570 |
+
|
571 |
+
if i < len(timesteps) - 1:
|
572 |
+
noise_timestep = timesteps[i + 1]
|
573 |
+
init_latents_proper = self.scheduler.add_noise(
|
574 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
575 |
+
)
|
576 |
+
|
577 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
578 |
+
|
579 |
+
# call the callback, if provided
|
580 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
581 |
+
progress_bar.update()
|
582 |
+
if callback is not None and i % callback_steps == 0:
|
583 |
+
callback(i, t, latents)
|
584 |
+
|
585 |
+
if not output_type == "latent":
|
586 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
587 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
588 |
+
|
589 |
+
if needs_upcasting:
|
590 |
+
self.upcast_vae()
|
591 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
592 |
+
|
593 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
594 |
+
|
595 |
+
# cast back to fp16 if needed
|
596 |
+
if needs_upcasting:
|
597 |
+
self.vae.to(dtype=torch.float16)
|
598 |
+
else:
|
599 |
+
return StableDiffusionXLPipelineOutput(images=latents)
|
600 |
+
|
601 |
+
# apply watermark if available
|
602 |
+
if self.watermark is not None:
|
603 |
+
image = self.watermark.apply_watermark(image)
|
604 |
+
|
605 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
606 |
+
|
607 |
+
# Offload all models
|
608 |
+
self.maybe_free_model_hooks()
|
609 |
+
|
610 |
+
if not return_dict:
|
611 |
+
return (image,)
|
612 |
+
|
613 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
relighting/pipeline_utils.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import itertools
|
4 |
+
from diffusers.utils.torch_utils import randn_tensor
|
5 |
+
|
6 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
7 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
8 |
+
"""
|
9 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
10 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
11 |
+
"""
|
12 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
13 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
14 |
+
# rescale the results from guidance (fixes overexposure)
|
15 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
16 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
17 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
18 |
+
return noise_cfg
|
19 |
+
|
20 |
+
def expand_noise(noise, shape, seed, device, dtype):
|
21 |
+
new_generator = torch.Generator().manual_seed(seed)
|
22 |
+
corner_shape = (shape[0], shape[1], shape[2] // 2, shape[3] // 2)
|
23 |
+
vert_border_shape = (shape[0], shape[1], shape[2], shape[3] // 2)
|
24 |
+
hori_border_shape = (shape[0], shape[1], shape[2] // 2, shape[3])
|
25 |
+
|
26 |
+
corners = [randn_tensor(corner_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(4)]
|
27 |
+
vert_borders = [randn_tensor(vert_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
|
28 |
+
hori_borders = [randn_tensor(hori_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
|
29 |
+
|
30 |
+
# combine
|
31 |
+
big_shape = (shape[0], shape[1], shape[2] * 2, shape[3] * 2)
|
32 |
+
noise_template = randn_tensor(big_shape, generator=new_generator, device=device, dtype=dtype)
|
33 |
+
|
34 |
+
ticks = [(0, 0.25), (0.25, 0.75), (0.75, 1.0)]
|
35 |
+
grid = list(itertools.product(ticks, ticks))
|
36 |
+
noise_list = [
|
37 |
+
corners[0], hori_borders[0], corners[1],
|
38 |
+
vert_borders[0], noise, vert_borders[1],
|
39 |
+
corners[2], hori_borders[1], corners[3],
|
40 |
+
]
|
41 |
+
for current_noise, ((x1, x2), (y1, y2)) in zip(noise_list, grid):
|
42 |
+
top_left = (int(x1 * big_shape[2]), int(y1 * big_shape[3]))
|
43 |
+
bottom_right = (int(x2 * big_shape[2]), int(y2 * big_shape[3]))
|
44 |
+
noise_template[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = current_noise
|
45 |
+
|
46 |
+
return noise_template
|
47 |
+
|
48 |
+
def custom_prepare_latents(
|
49 |
+
self,
|
50 |
+
batch_size,
|
51 |
+
num_channels_latents,
|
52 |
+
height,
|
53 |
+
width,
|
54 |
+
dtype,
|
55 |
+
device,
|
56 |
+
generator,
|
57 |
+
latents=None,
|
58 |
+
image=None,
|
59 |
+
timestep=None,
|
60 |
+
is_strength_max=True,
|
61 |
+
use_noise_moving=True,
|
62 |
+
return_noise=False,
|
63 |
+
return_image_latents=False,
|
64 |
+
newx=0,
|
65 |
+
newy=0,
|
66 |
+
newr=256,
|
67 |
+
current_seed=None,
|
68 |
+
):
|
69 |
+
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
70 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
71 |
+
raise ValueError(
|
72 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
73 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
74 |
+
)
|
75 |
+
|
76 |
+
if (image is None or timestep is None) and not is_strength_max:
|
77 |
+
raise ValueError(
|
78 |
+
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
|
79 |
+
"However, either the image or the noise timestep has not been provided."
|
80 |
+
)
|
81 |
+
|
82 |
+
if image.shape[1] == 4:
|
83 |
+
image_latents = image.to(device=device, dtype=dtype)
|
84 |
+
elif return_image_latents or (latents is None and not is_strength_max):
|
85 |
+
image = image.to(device=device, dtype=dtype)
|
86 |
+
image_latents = self._encode_vae_image(image=image, generator=generator)
|
87 |
+
|
88 |
+
if latents is None and use_noise_moving:
|
89 |
+
# random big noise map
|
90 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
91 |
+
noise = expand_noise(noise, shape, seed=current_seed, device=device, dtype=dtype)
|
92 |
+
|
93 |
+
# ensure noise is the same regardless of inpainting location (top-left corner notation)
|
94 |
+
newys = [newy] if not isinstance(newy, list) else newy
|
95 |
+
newxs = [newx] if not isinstance(newx, list) else newx
|
96 |
+
big_noise = noise.clone()
|
97 |
+
prev_noise = None
|
98 |
+
for newy, newx in zip(newys, newxs):
|
99 |
+
# find patch location within big noise map
|
100 |
+
sy = big_noise.shape[2] // 4 + ((512 - 128) - newy) // self.vae_scale_factor
|
101 |
+
sx = big_noise.shape[3] // 4 + ((512 - 128) - newx) // self.vae_scale_factor
|
102 |
+
|
103 |
+
if prev_noise is not None:
|
104 |
+
new_noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
|
105 |
+
|
106 |
+
ball_mask = torch.zeros(shape, device=device, dtype=bool)
|
107 |
+
top_left = (newy // self.vae_scale_factor, newx // self.vae_scale_factor)
|
108 |
+
bottom_right = (top_left[0] + newr // self.vae_scale_factor, top_left[1] + newr // self.vae_scale_factor) # fixed ball size r = 256
|
109 |
+
ball_mask[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = True
|
110 |
+
|
111 |
+
noise = prev_noise.clone()
|
112 |
+
noise[ball_mask] = new_noise[ball_mask]
|
113 |
+
else:
|
114 |
+
noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
|
115 |
+
|
116 |
+
prev_noise = noise.clone()
|
117 |
+
|
118 |
+
# if strength is 1. then initialise the latents to noise, else initial to image + noise
|
119 |
+
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
|
120 |
+
# if pure noise then scale the initial latents by the Scheduler's init sigma
|
121 |
+
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
|
122 |
+
elif latents is None:
|
123 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
124 |
+
latents = image_latents.to(device)
|
125 |
+
else:
|
126 |
+
noise = latents.to(device)
|
127 |
+
latents = noise * self.scheduler.init_noise_sigma
|
128 |
+
|
129 |
+
outputs = (latents,)
|
130 |
+
|
131 |
+
if return_noise:
|
132 |
+
outputs += (noise,)
|
133 |
+
|
134 |
+
if return_image_latents:
|
135 |
+
outputs += (image_latents,)
|
136 |
+
|
137 |
+
return outputs
|
138 |
+
|
139 |
+
def custom_prepare_mask_latents(
|
140 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
141 |
+
):
|
142 |
+
# resize the mask to latents shape as we concatenate the mask to the latents
|
143 |
+
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
144 |
+
# and half precision
|
145 |
+
mask = torch.nn.functional.interpolate(
|
146 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
|
147 |
+
mode="bilinear", align_corners=False #PURE: We add this to avoid sharp border of the ball
|
148 |
+
)
|
149 |
+
mask = mask.to(device=device, dtype=dtype)
|
150 |
+
|
151 |
+
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
152 |
+
if mask.shape[0] < batch_size:
|
153 |
+
if not batch_size % mask.shape[0] == 0:
|
154 |
+
raise ValueError(
|
155 |
+
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
156 |
+
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
157 |
+
" of masks that you pass is divisible by the total requested batch size."
|
158 |
+
)
|
159 |
+
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
160 |
+
|
161 |
+
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
162 |
+
|
163 |
+
masked_image_latents = None
|
164 |
+
if masked_image is not None:
|
165 |
+
masked_image = masked_image.to(device=device, dtype=dtype)
|
166 |
+
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
|
167 |
+
if masked_image_latents.shape[0] < batch_size:
|
168 |
+
if not batch_size % masked_image_latents.shape[0] == 0:
|
169 |
+
raise ValueError(
|
170 |
+
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
171 |
+
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
172 |
+
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
173 |
+
)
|
174 |
+
masked_image_latents = masked_image_latents.repeat(
|
175 |
+
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
176 |
+
)
|
177 |
+
|
178 |
+
masked_image_latents = (
|
179 |
+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
180 |
+
)
|
181 |
+
|
182 |
+
# aligning device to prevent device errors when concating it with the latent model input
|
183 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
184 |
+
|
185 |
+
return mask, masked_image_latents
|
relighting/pipeline_xl.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import List, Union, Dict, Any, Callable, Optional, Tuple
|
3 |
+
|
4 |
+
from diffusers.utils.torch_utils import is_compiled_module
|
5 |
+
from diffusers.models import ControlNetModel
|
6 |
+
from diffusers.pipelines.controlnet import MultiControlNetModel
|
7 |
+
from diffusers import StableDiffusionXLControlNetInpaintPipeline
|
8 |
+
from diffusers.image_processor import PipelineImageInput
|
9 |
+
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
10 |
+
from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents, rescale_noise_cfg
|
11 |
+
|
12 |
+
class CustomStableDiffusionXLControlNetInpaintPipeline(StableDiffusionXLControlNetInpaintPipeline):
|
13 |
+
@torch.no_grad()
|
14 |
+
def __call__(
|
15 |
+
self,
|
16 |
+
prompt: Union[str, List[str]] = None,
|
17 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
18 |
+
image: PipelineImageInput = None,
|
19 |
+
mask_image: PipelineImageInput = None,
|
20 |
+
control_image: Union[
|
21 |
+
PipelineImageInput,
|
22 |
+
List[PipelineImageInput],
|
23 |
+
] = None,
|
24 |
+
height: Optional[int] = None,
|
25 |
+
width: Optional[int] = None,
|
26 |
+
strength: float = 0.9999,
|
27 |
+
num_inference_steps: int = 50,
|
28 |
+
denoising_start: Optional[float] = None,
|
29 |
+
denoising_end: Optional[float] = None,
|
30 |
+
guidance_scale: float = 5.0,
|
31 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
32 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
33 |
+
num_images_per_prompt: Optional[int] = 1,
|
34 |
+
eta: float = 0.0,
|
35 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
36 |
+
latents: Optional[torch.FloatTensor] = None,
|
37 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
38 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
39 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
40 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
41 |
+
output_type: Optional[str] = "pil",
|
42 |
+
return_dict: bool = True,
|
43 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
44 |
+
callback_steps: int = 1,
|
45 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
46 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
47 |
+
guess_mode: bool = False,
|
48 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
49 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
50 |
+
guidance_rescale: float = 0.0,
|
51 |
+
original_size: Tuple[int, int] = None,
|
52 |
+
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
53 |
+
target_size: Tuple[int, int] = None,
|
54 |
+
aesthetic_score: float = 6.0,
|
55 |
+
negative_aesthetic_score: float = 2.5,
|
56 |
+
newx: int = 0,
|
57 |
+
newy: int = 0,
|
58 |
+
newr: int = 256,
|
59 |
+
current_seed=0,
|
60 |
+
use_noise_moving=True,
|
61 |
+
):
|
62 |
+
# OVERWRITE METHODS
|
63 |
+
self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionXLControlNetInpaintPipeline)
|
64 |
+
self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionXLControlNetInpaintPipeline)
|
65 |
+
|
66 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
67 |
+
|
68 |
+
# align format for control guidance
|
69 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
70 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
71 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
72 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
73 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
74 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
75 |
+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
76 |
+
control_guidance_end
|
77 |
+
]
|
78 |
+
|
79 |
+
# # 0.0 Default height and width to unet
|
80 |
+
# height = height or self.unet.config.sample_size * self.vae_scale_factor
|
81 |
+
# width = width or self.unet.config.sample_size * self.vae_scale_factor
|
82 |
+
|
83 |
+
# 0.1 align format for control guidance
|
84 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
85 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
86 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
87 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
88 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
89 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
90 |
+
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
|
91 |
+
control_guidance_end
|
92 |
+
]
|
93 |
+
|
94 |
+
# 1. Check inputs
|
95 |
+
self.check_inputs(
|
96 |
+
prompt,
|
97 |
+
prompt_2,
|
98 |
+
control_image,
|
99 |
+
strength,
|
100 |
+
num_inference_steps,
|
101 |
+
callback_steps,
|
102 |
+
negative_prompt,
|
103 |
+
negative_prompt_2,
|
104 |
+
prompt_embeds,
|
105 |
+
negative_prompt_embeds,
|
106 |
+
pooled_prompt_embeds,
|
107 |
+
negative_pooled_prompt_embeds,
|
108 |
+
controlnet_conditioning_scale,
|
109 |
+
control_guidance_start,
|
110 |
+
control_guidance_end,
|
111 |
+
)
|
112 |
+
|
113 |
+
# 2. Define call parameters
|
114 |
+
if prompt is not None and isinstance(prompt, str):
|
115 |
+
batch_size = 1
|
116 |
+
elif prompt is not None and isinstance(prompt, list):
|
117 |
+
batch_size = len(prompt)
|
118 |
+
else:
|
119 |
+
batch_size = prompt_embeds.shape[0]
|
120 |
+
|
121 |
+
device = self._execution_device
|
122 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
123 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
124 |
+
# corresponds to doing no classifier free guidance.
|
125 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
126 |
+
|
127 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
128 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
129 |
+
|
130 |
+
# 3. Encode input prompt
|
131 |
+
text_encoder_lora_scale = (
|
132 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
133 |
+
)
|
134 |
+
|
135 |
+
(
|
136 |
+
prompt_embeds,
|
137 |
+
negative_prompt_embeds,
|
138 |
+
pooled_prompt_embeds,
|
139 |
+
negative_pooled_prompt_embeds,
|
140 |
+
) = self.encode_prompt(
|
141 |
+
prompt=prompt,
|
142 |
+
prompt_2=prompt_2,
|
143 |
+
device=device,
|
144 |
+
num_images_per_prompt=num_images_per_prompt,
|
145 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
146 |
+
negative_prompt=negative_prompt,
|
147 |
+
negative_prompt_2=negative_prompt_2,
|
148 |
+
prompt_embeds=prompt_embeds,
|
149 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
150 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
151 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
152 |
+
lora_scale=text_encoder_lora_scale,
|
153 |
+
)
|
154 |
+
|
155 |
+
# 4. set timesteps
|
156 |
+
def denoising_value_valid(dnv):
|
157 |
+
return isinstance(denoising_end, float) and 0 < dnv < 1
|
158 |
+
|
159 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
160 |
+
timesteps, num_inference_steps = self.get_timesteps(
|
161 |
+
num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
|
162 |
+
)
|
163 |
+
# check that number of inference steps is not < 1 - as this doesn't make sense
|
164 |
+
if num_inference_steps < 1:
|
165 |
+
raise ValueError(
|
166 |
+
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
167 |
+
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
168 |
+
)
|
169 |
+
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
170 |
+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
171 |
+
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
172 |
+
is_strength_max = strength == 1.0
|
173 |
+
|
174 |
+
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
|
175 |
+
# 5.1 Prepare init image
|
176 |
+
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
177 |
+
init_image = init_image.to(dtype=torch.float32)
|
178 |
+
|
179 |
+
# 5.2 Prepare control images
|
180 |
+
if isinstance(controlnet, ControlNetModel):
|
181 |
+
control_image = self.prepare_control_image(
|
182 |
+
image=control_image,
|
183 |
+
width=width,
|
184 |
+
height=height,
|
185 |
+
batch_size=batch_size * num_images_per_prompt,
|
186 |
+
num_images_per_prompt=num_images_per_prompt,
|
187 |
+
device=device,
|
188 |
+
dtype=controlnet.dtype,
|
189 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
190 |
+
guess_mode=guess_mode,
|
191 |
+
)
|
192 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
193 |
+
control_images = []
|
194 |
+
|
195 |
+
for control_image_ in control_image:
|
196 |
+
control_image_ = self.prepare_control_image(
|
197 |
+
image=control_image_,
|
198 |
+
width=width,
|
199 |
+
height=height,
|
200 |
+
batch_size=batch_size * num_images_per_prompt,
|
201 |
+
num_images_per_prompt=num_images_per_prompt,
|
202 |
+
device=device,
|
203 |
+
dtype=controlnet.dtype,
|
204 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
205 |
+
guess_mode=guess_mode,
|
206 |
+
)
|
207 |
+
|
208 |
+
control_images.append(control_image_)
|
209 |
+
|
210 |
+
control_image = control_images
|
211 |
+
else:
|
212 |
+
raise ValueError(f"{controlnet.__class__} is not supported.")
|
213 |
+
|
214 |
+
# 5.3 Prepare mask
|
215 |
+
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
216 |
+
|
217 |
+
masked_image = init_image * (mask < 0.5)
|
218 |
+
_, _, height, width = init_image.shape
|
219 |
+
|
220 |
+
# 6. Prepare latent variables
|
221 |
+
num_channels_latents = self.vae.config.latent_channels
|
222 |
+
num_channels_unet = self.unet.config.in_channels
|
223 |
+
return_image_latents = num_channels_unet == 4
|
224 |
+
|
225 |
+
add_noise = True if denoising_start is None else False
|
226 |
+
latents_outputs = self.prepare_latents(
|
227 |
+
batch_size * num_images_per_prompt,
|
228 |
+
num_channels_latents,
|
229 |
+
height,
|
230 |
+
width,
|
231 |
+
prompt_embeds.dtype,
|
232 |
+
device,
|
233 |
+
generator,
|
234 |
+
latents,
|
235 |
+
image=init_image,
|
236 |
+
timestep=latent_timestep,
|
237 |
+
is_strength_max=is_strength_max,
|
238 |
+
return_noise=True,
|
239 |
+
return_image_latents=return_image_latents,
|
240 |
+
newx=newx,
|
241 |
+
newy=newy,
|
242 |
+
newr=newr,
|
243 |
+
current_seed=current_seed,
|
244 |
+
use_noise_moving=use_noise_moving,
|
245 |
+
)
|
246 |
+
|
247 |
+
if return_image_latents:
|
248 |
+
latents, noise, image_latents = latents_outputs
|
249 |
+
else:
|
250 |
+
latents, noise = latents_outputs
|
251 |
+
|
252 |
+
# 7. Prepare mask latent variables
|
253 |
+
mask, masked_image_latents = self.prepare_mask_latents(
|
254 |
+
mask,
|
255 |
+
masked_image,
|
256 |
+
batch_size * num_images_per_prompt,
|
257 |
+
height,
|
258 |
+
width,
|
259 |
+
prompt_embeds.dtype,
|
260 |
+
device,
|
261 |
+
generator,
|
262 |
+
do_classifier_free_guidance,
|
263 |
+
)
|
264 |
+
|
265 |
+
# 8. Check that sizes of mask, masked image and latents match
|
266 |
+
if num_channels_unet == 9:
|
267 |
+
# default case for runwayml/stable-diffusion-inpainting
|
268 |
+
num_channels_mask = mask.shape[1]
|
269 |
+
num_channels_masked_image = masked_image_latents.shape[1]
|
270 |
+
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
271 |
+
raise ValueError(
|
272 |
+
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
273 |
+
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
274 |
+
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
275 |
+
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
276 |
+
" `pipeline.unet` or your `mask_image` or `image` input."
|
277 |
+
)
|
278 |
+
elif num_channels_unet != 4:
|
279 |
+
raise ValueError(
|
280 |
+
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
281 |
+
)
|
282 |
+
# 8.1 Prepare extra step kwargs.
|
283 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
284 |
+
|
285 |
+
# 8.2 Create tensor stating which controlnets to keep
|
286 |
+
controlnet_keep = []
|
287 |
+
for i in range(len(timesteps)):
|
288 |
+
keeps = [
|
289 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
290 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
291 |
+
]
|
292 |
+
if isinstance(self.controlnet, MultiControlNetModel):
|
293 |
+
controlnet_keep.append(keeps)
|
294 |
+
else:
|
295 |
+
controlnet_keep.append(keeps[0])
|
296 |
+
|
297 |
+
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
298 |
+
height, width = latents.shape[-2:]
|
299 |
+
height = height * self.vae_scale_factor
|
300 |
+
width = width * self.vae_scale_factor
|
301 |
+
|
302 |
+
original_size = original_size or (height, width)
|
303 |
+
target_size = target_size or (height, width)
|
304 |
+
|
305 |
+
# 10. Prepare added time ids & embeddings
|
306 |
+
add_text_embeds = pooled_prompt_embeds
|
307 |
+
add_time_ids, add_neg_time_ids = self._get_add_time_ids(
|
308 |
+
original_size,
|
309 |
+
crops_coords_top_left,
|
310 |
+
target_size,
|
311 |
+
aesthetic_score,
|
312 |
+
negative_aesthetic_score,
|
313 |
+
dtype=prompt_embeds.dtype,
|
314 |
+
)
|
315 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
316 |
+
|
317 |
+
if do_classifier_free_guidance:
|
318 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
319 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
320 |
+
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
|
321 |
+
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
|
322 |
+
|
323 |
+
prompt_embeds = prompt_embeds.to(device)
|
324 |
+
add_text_embeds = add_text_embeds.to(device)
|
325 |
+
add_time_ids = add_time_ids.to(device)
|
326 |
+
|
327 |
+
# 11. Denoising loop
|
328 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
329 |
+
|
330 |
+
if (
|
331 |
+
denoising_end is not None
|
332 |
+
and denoising_start is not None
|
333 |
+
and denoising_value_valid(denoising_end)
|
334 |
+
and denoising_value_valid(denoising_start)
|
335 |
+
and denoising_start >= denoising_end
|
336 |
+
):
|
337 |
+
raise ValueError(
|
338 |
+
f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
|
339 |
+
+ f" {denoising_end} when using type float."
|
340 |
+
)
|
341 |
+
elif denoising_end is not None and denoising_value_valid(denoising_end):
|
342 |
+
discrete_timestep_cutoff = int(
|
343 |
+
round(
|
344 |
+
self.scheduler.config.num_train_timesteps
|
345 |
+
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
346 |
+
)
|
347 |
+
)
|
348 |
+
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
349 |
+
timesteps = timesteps[:num_inference_steps]
|
350 |
+
|
351 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
352 |
+
for i, t in enumerate(timesteps):
|
353 |
+
# expand the latents if we are doing classifier free guidance
|
354 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
355 |
+
|
356 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
357 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
358 |
+
|
359 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
360 |
+
|
361 |
+
# controlnet(s) inference
|
362 |
+
if guess_mode and do_classifier_free_guidance:
|
363 |
+
# Infer ControlNet only for the conditional batch.
|
364 |
+
control_model_input = latents
|
365 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
366 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
367 |
+
controlnet_added_cond_kwargs = {
|
368 |
+
"text_embeds": add_text_embeds.chunk(2)[1],
|
369 |
+
"time_ids": add_time_ids.chunk(2)[1],
|
370 |
+
}
|
371 |
+
else:
|
372 |
+
control_model_input = latent_model_input
|
373 |
+
controlnet_prompt_embeds = prompt_embeds
|
374 |
+
controlnet_added_cond_kwargs = added_cond_kwargs
|
375 |
+
|
376 |
+
if isinstance(controlnet_keep[i], list):
|
377 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
378 |
+
else:
|
379 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
380 |
+
if isinstance(controlnet_cond_scale, list):
|
381 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
382 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
383 |
+
|
384 |
+
# # Resize control_image to match the size of the input to the controlnet
|
385 |
+
# if control_image.shape[-2:] != control_model_input.shape[-2:]:
|
386 |
+
# control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False)
|
387 |
+
|
388 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
389 |
+
control_model_input,
|
390 |
+
t,
|
391 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
392 |
+
controlnet_cond=control_image,
|
393 |
+
conditioning_scale=cond_scale,
|
394 |
+
guess_mode=guess_mode,
|
395 |
+
added_cond_kwargs=controlnet_added_cond_kwargs,
|
396 |
+
return_dict=False,
|
397 |
+
)
|
398 |
+
|
399 |
+
if guess_mode and do_classifier_free_guidance:
|
400 |
+
# Infered ControlNet only for the conditional batch.
|
401 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
402 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
403 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
404 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
405 |
+
|
406 |
+
if num_channels_unet == 9:
|
407 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
408 |
+
|
409 |
+
# predict the noise residual
|
410 |
+
noise_pred = self.unet(
|
411 |
+
latent_model_input,
|
412 |
+
t,
|
413 |
+
encoder_hidden_states=prompt_embeds,
|
414 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
415 |
+
down_block_additional_residuals=down_block_res_samples,
|
416 |
+
mid_block_additional_residual=mid_block_res_sample,
|
417 |
+
added_cond_kwargs=added_cond_kwargs,
|
418 |
+
return_dict=False,
|
419 |
+
)[0]
|
420 |
+
|
421 |
+
# perform guidance
|
422 |
+
if do_classifier_free_guidance:
|
423 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
424 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
425 |
+
|
426 |
+
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
427 |
+
print("rescale: ", guidance_rescale)
|
428 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
429 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
430 |
+
|
431 |
+
# compute the previous noisy sample x_t -> x_t-1
|
432 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
433 |
+
|
434 |
+
if num_channels_unet == 4:
|
435 |
+
init_latents_proper = image_latents[:1]
|
436 |
+
init_mask = mask[:1]
|
437 |
+
|
438 |
+
if i < len(timesteps) - 1:
|
439 |
+
noise_timestep = timesteps[i + 1]
|
440 |
+
init_latents_proper = self.scheduler.add_noise(
|
441 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
442 |
+
)
|
443 |
+
|
444 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
445 |
+
|
446 |
+
# call the callback, if provided
|
447 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
448 |
+
progress_bar.update()
|
449 |
+
if callback is not None and i % callback_steps == 0:
|
450 |
+
callback(i, t, latents)
|
451 |
+
|
452 |
+
# make sure the VAE is in float32 mode, as it overflows in float16
|
453 |
+
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
454 |
+
self.upcast_vae()
|
455 |
+
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
456 |
+
|
457 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
458 |
+
# manually for max memory savings
|
459 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
460 |
+
self.unet.to("cpu")
|
461 |
+
self.controlnet.to("cpu")
|
462 |
+
torch.cuda.empty_cache()
|
463 |
+
|
464 |
+
if not output_type == "latent":
|
465 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
466 |
+
else:
|
467 |
+
return StableDiffusionXLPipelineOutput(images=latents)
|
468 |
+
|
469 |
+
# apply watermark if available
|
470 |
+
if self.watermark is not None:
|
471 |
+
image = self.watermark.apply_watermark(image)
|
472 |
+
|
473 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
474 |
+
|
475 |
+
# Offload last model to CPU
|
476 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
477 |
+
self.final_offload_hook.offload()
|
478 |
+
|
479 |
+
if not return_dict:
|
480 |
+
return (image,)
|
481 |
+
|
482 |
+
return StableDiffusionXLPipelineOutput(images=image)
|
relighting/tonemapper.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
class TonemapHDR(object):
|
4 |
+
"""
|
5 |
+
Tonemap HDR image globally. First, we find alpha that maps the (max(numpy_img) * percentile) to max_mapping.
|
6 |
+
Then, we calculate I_out = alpha * I_in ^ (1/gamma)
|
7 |
+
input : nd.array batch of images : [H, W, C]
|
8 |
+
output : nd.array batch of images : [H, W, C]
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, gamma=2.4, percentile=50, max_mapping=0.5):
|
12 |
+
self.gamma = gamma
|
13 |
+
self.percentile = percentile
|
14 |
+
self.max_mapping = max_mapping # the value to which alpha will map the (max(numpy_img) * percentile) to
|
15 |
+
|
16 |
+
def __call__(self, numpy_img, clip=True, alpha=None, gamma=True):
|
17 |
+
if gamma:
|
18 |
+
power_numpy_img = np.power(numpy_img, 1 / self.gamma)
|
19 |
+
else:
|
20 |
+
power_numpy_img = numpy_img
|
21 |
+
non_zero = power_numpy_img > 0
|
22 |
+
if non_zero.any():
|
23 |
+
r_percentile = np.percentile(power_numpy_img[non_zero], self.percentile)
|
24 |
+
else:
|
25 |
+
r_percentile = np.percentile(power_numpy_img, self.percentile)
|
26 |
+
if alpha is None:
|
27 |
+
alpha = self.max_mapping / (r_percentile + 1e-10)
|
28 |
+
tonemapped_img = np.multiply(alpha, power_numpy_img)
|
29 |
+
|
30 |
+
if clip:
|
31 |
+
tonemapped_img_clip = np.clip(tonemapped_img, 0, 1)
|
32 |
+
|
33 |
+
return tonemapped_img_clip.astype('float32'), alpha, tonemapped_img
|
relighting/utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from PIL import Image
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
def str2bool(v):
|
8 |
+
"""
|
9 |
+
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
|
10 |
+
"""
|
11 |
+
if isinstance(v, bool):
|
12 |
+
return v
|
13 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
14 |
+
return True
|
15 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
16 |
+
return False
|
17 |
+
else:
|
18 |
+
raise argparse.ArgumentTypeError("boolean value expected")
|
19 |
+
|
20 |
+
def add_dict_to_argparser(parser, default_dict):
|
21 |
+
for k, v in default_dict.items():
|
22 |
+
v_type = type(v)
|
23 |
+
if v is None:
|
24 |
+
v_type = str
|
25 |
+
elif isinstance(v, bool):
|
26 |
+
v_type = str2bool
|
27 |
+
parser.add_argument(f"--{k}", default=v, type=v_type)
|
28 |
+
|
29 |
+
def args_to_dict(args, keys):
|
30 |
+
return {k: getattr(args, k) for k in keys}
|
31 |
+
|
32 |
+
def save_result(
|
33 |
+
image, image_path,
|
34 |
+
mask=None, mask_path=None,
|
35 |
+
normal=None, normal_path=None,
|
36 |
+
):
|
37 |
+
assert isinstance(image, Image.Image)
|
38 |
+
os.makedirs(Path(image_path).parent, exist_ok=True)
|
39 |
+
image.save(image_path)
|
40 |
+
|
41 |
+
if (mask is not None) and (mask_path is not None):
|
42 |
+
assert isinstance(mask, Image.Image)
|
43 |
+
os.makedirs(Path(mask_path).parent, exist_ok=True)
|
44 |
+
mask.save(mask_path)
|
45 |
+
|
46 |
+
if (normal is not None) and (normal_path is not None):
|
47 |
+
assert isinstance(normal, Image.Image)
|
48 |
+
os.makedirs(Path(normal_path).parent, exist_ok=True)
|
49 |
+
normal.save(normal_path)
|
50 |
+
|
51 |
+
def name2hash(name: str):
|
52 |
+
"""
|
53 |
+
@see https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
|
54 |
+
"""
|
55 |
+
hash_number = int(hashlib.sha1(name.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
|
56 |
+
return hash_number
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# utility
|
2 |
+
tqdm==4.66.1
|
3 |
+
scikit-image==0.21.0
|
4 |
+
imageio==2.31.1
|
5 |
+
Pillow==10.2.0
|
6 |
+
numpy==1.24.1
|
7 |
+
natsort==8.4.0
|
8 |
+
|
9 |
+
# EXR handling
|
10 |
+
skylibs==0.7.4
|
11 |
+
OpenEXR==1.3.9
|
12 |
+
|
13 |
+
# We use pytorch pip instead because conda is mess up
|
14 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
15 |
+
torch==2.0.1+cu118
|
16 |
+
torchvision==0.15.2+cu118
|
17 |
+
torchaudio==2.0.2+cu118
|
18 |
+
|
19 |
+
# Diffusers dependencies
|
20 |
+
accelerate==0.21.0
|
21 |
+
datasets==2.13.1
|
22 |
+
diffusers==0.21.0
|
23 |
+
transformers==4.36.0
|
24 |
+
xformers==0.0.20
|
25 |
+
huggingface_hub==0.19.4
|