sakinlesh commited on
Commit
dd06d6b
·
verified ·
1 Parent(s): 09edcbe

Upload 25 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/bed.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Pakkapon Phongthawee
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,101 @@
1
- ---
2
- title: Deneme
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.13.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: deneme deneme
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffusionLight: Light Probes for Free by Painting a Chrome Ball
2
+
3
+ ### [Project Page](https://diffusionlight.github.io/) | [Paper](https://arxiv.org/abs/2312.09168) | [Colab](https://colab.research.google.com/drive/15pC4qb9mEtRYsW3utXkk-jnaeVxUy-0S?usp=sharing&sandboxMode=true) | [HuggingFace](https://huggingface.co/DiffusionLight/DiffusionLight)
4
+
5
+ [![Open DiffusionLight in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/15pC4qb9mEtRYsW3utXkk-jnaeVxUy-0S?usp=sharing&sandboxMode=true)
6
+
7
+ ![](https://diffusionlight.github.io/assets/images/thumbnail.jpg)
8
+
9
+ We present a simple yet effective technique to estimate lighting in a single input image. Current techniques rely heavily on HDR panorama datasets to train neural networks to regress an input with limited field-of-view to a full environment map. However, these approaches often struggle with real-world, uncontrolled settings due to the limited diversity and size of their datasets. To address this problem, we leverage diffusion models trained on billions of standard images to render a chrome ball into the input image. Despite its simplicity, this task remains challenging: the diffusion models often insert incorrect or inconsistent objects and cannot readily generate images in HDR format. Our research uncovers a surprising relationship between the appearance of chrome balls and the initial diffusion noise map, which we utilize to consistently generate high-quality chrome balls. We further fine-tune an LDR diffusion model (Stable Diffusion XL) with LoRA, enabling it to perform exposure bracketing for HDR light estimation. Our method produces convincing light estimates across diverse settings and demonstrates superior generalization to in-the-wild scenarios.
10
+
11
+ ## Table of contents
12
+ -----
13
+ * [TL;DR](#Getting-started)
14
+ * [Installation](#Installation)
15
+ * [Prediction](#Prediction)
16
+ * [Evaluation](#Evaluation)
17
+ * [Citation](#Citation)
18
+ ------
19
+
20
+ ## Getting started
21
+
22
+ ```shell
23
+ conda env create -f environment.yml
24
+ conda activate diffusionlight
25
+ pip install -r requirements.txt
26
+ python inpaint.py --dataset example --output_dir output
27
+ python ball2envmap.py --ball_dir output/square --envmap_dir output/envmap
28
+ python exposure2hdr.py --input_dir output/envmap --output_dir output/hdr
29
+ ```
30
+
31
+ ## Installation
32
+
33
+ To setup the Python environment, you need to run the following commands in both Conda and pip:
34
+
35
+ ```shell
36
+ conda env create -f environment.yml
37
+ conda activate diffusionlight
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ Note that Conda is optional. However, if you choose not to use Conda, you must manually install CUDA-toolkit and OpenEXR.
42
+
43
+ ## Prediction
44
+
45
+ ### 0. Preparing the image
46
+
47
+ Please resize the input image to 1024x1024. If the image is not square, we recommend padding it with a black border.
48
+
49
+ ### 1. Inpainting the chrome ball
50
+
51
+ First, we predict the chrome ball in different exposure values (EV) using the following command:
52
+
53
+ ```shell
54
+ python inpaint.py --dataset <input_directory> --output_dir <output_directory>
55
+ ```
56
+
57
+ This command outputs three subdirectories: `control`, `raw`, and `square`
58
+
59
+ The contents of each directory are:
60
+
61
+ - `control`: Conditioned depth map
62
+ - `raw`: Inpainted image with a chrome ball in the center
63
+ - `square`: Square-cropped chrome ball (used for the next step)
64
+
65
+
66
+ ### 2. Projecting a ball into an environment map
67
+
68
+ Next, we project the chrome ball from the previous step to the LDR environment map using the following command:
69
+
70
+ ```shell
71
+ python ball2envmap.py --ball_dir <output_directory>/square --envmap_dir <output_directory>/envmap
72
+ ```
73
+
74
+ ### 3. Compose HDR image
75
+
76
+ Finally, we compose an HDR image from multiple LDR environment maps using our custom exposure bracketing:
77
+
78
+ ```shell
79
+ python exposure2hdr.py --input_dir <output_directory>/envmap --output_dir <output_directory>/hdr
80
+ ```
81
+
82
+ The predicted light estimation will be located at `<output_directory>/hdr` and can be used for downstream tasks such as object insertion. We will also use it to compare with other methods.
83
+
84
+ ## Evaluation
85
+ We use the evaluation code from [StyleLight](https://style-light.github.io/) and [Editable Indoor LightEstimation](https://lvsn.github.io/EditableIndoorLight/). You can use their code to measure our score.
86
+
87
+ Additionally, we provide a *slightly* modified version of the evaluation code at [DiffusionLight-evaluation](https://github.com/DiffusionLight/DiffusionLight-evaluation) including the test input.
88
+
89
+ ## Citation
90
+
91
+ ```
92
+ @inproceedings{Phongthawee2023DiffusionLight,
93
+ author = {Phongthawee, Pakkapon and Chinchuthakun, Worameth and Sinsunthithet, Nontaphat and Raj, Amit and Jampani, Varun and Khungurn, Pramook and Suwajanakorn, Supasorn},
94
+ title = {DiffusionLight: Light Probes for Free by Painting a Chrome Ball},
95
+ booktitle = {ArXiv},
96
+ year = {2023},
97
+ }
98
+ ```
99
+
100
+ ## Visit us 🦉
101
+ [![Vision & Learning Laboratory](https://i.imgur.com/hQhkKhG.png)](https://vistec.ist/vision) [![VISTEC - Vidyasirimedhi Institute of Science and Technology](https://i.imgur.com/4wh8HQd.png)](https://vistec.ist/)
ball2envmap.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # convert the ball to environment map, lat, long format
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import skimage
6
+ import time
7
+ import torch
8
+ import argparse
9
+ from multiprocessing import Pool
10
+ from functools import partial
11
+ from tqdm.auto import tqdm
12
+ import os
13
+
14
+ try:
15
+ import ezexr
16
+ except:
17
+ pass
18
+
19
+ def create_argparser():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--ball_dir", type=str, required=True ,help='directory that contain the image')
22
+ parser.add_argument("--envmap_dir", type=str, required=True ,help='directory to output environment map') #dataset name or directory
23
+ parser.add_argument("--envmap_height", type=int, default=256, help="size of the environment map height in pixel (height)")
24
+ parser.add_argument("--scale", type=int, default=4, help="scale factor")
25
+ parser.add_argument("--threads", type=int, default=8, help="num thread for pararell processing")
26
+ return parser
27
+
28
+ def create_envmap_grid(size: int):
29
+ """
30
+ BLENDER CONVENSION
31
+ Create the grid of environment map that contain the position in sperical coordinate
32
+ Top left is (0,0) and bottom right is (pi/2, 2pi)
33
+ """
34
+
35
+ theta = torch.linspace(0, np.pi * 2, size * 2)
36
+ phi = torch.linspace(0, np.pi, size)
37
+
38
+ #use indexing 'xy' torch match vision's homework 3
39
+ theta, phi = torch.meshgrid(theta, phi ,indexing='xy')
40
+
41
+
42
+ theta_phi = torch.cat([theta[..., None], phi[..., None]], dim=-1)
43
+ theta_phi = theta_phi.numpy()
44
+ return theta_phi
45
+
46
+ def get_normal_vector(incoming_vector: np.ndarray, reflect_vector: np.ndarray):
47
+ """
48
+ BLENDER CONVENSION
49
+ incoming_vector: the vector from the point to the camera
50
+ reflect_vector: the vector from the point to the light source
51
+ """
52
+ #N = 2(R ⋅ I)R - I
53
+ N = (incoming_vector + reflect_vector) / np.linalg.norm(incoming_vector + reflect_vector, axis=-1, keepdims=True)
54
+ return N
55
+
56
+ def get_cartesian_from_spherical(theta: np.array, phi: np.array, r = 1.0):
57
+ """
58
+ BLENDER CONVENSION
59
+ theta: vertical angle
60
+ phi: horizontal angle
61
+ r: radius
62
+ """
63
+ x = r * np.sin(theta) * np.cos(phi)
64
+ y = r * np.sin(theta) * np.sin(phi)
65
+ z = r * np.cos(theta)
66
+ return np.concatenate([x[...,None],y[...,None],z[...,None]], axis=-1)
67
+
68
+
69
+ def process_image(args: argparse.Namespace, file_name: str):
70
+ I = np.array([1,0, 0])
71
+
72
+ # check if exist, skip!
73
+ envmap_output_path = os.path.join(args.envmap_dir, file_name)
74
+ if os.path.exists(envmap_output_path):
75
+ return None
76
+
77
+ # read ball image
78
+ ball_path = os.path.join(args.ball_dir, file_name)
79
+ if file_name.endswith(".exr"):
80
+ ball_image = ezexr.imread(ball_path)
81
+ else:
82
+ try:
83
+ ball_image = skimage.io.imread(ball_path)
84
+ ball_image = skimage.img_as_float(ball_image)
85
+ except:
86
+ return None
87
+
88
+ # compute normal map that create from reflect vector
89
+ env_grid = create_envmap_grid(args.envmap_height * args.scale)
90
+ reflect_vec = get_cartesian_from_spherical(env_grid[...,1], env_grid[...,0])
91
+ normal = get_normal_vector(I[None,None], reflect_vec)
92
+
93
+ # turn from normal map to position to lookup [Range: 0,1]
94
+ pos = (normal + 1.0) / 2
95
+ pos = 1.0 - pos
96
+ pos = pos[...,1:]
97
+
98
+ env_map = None
99
+
100
+ # using pytorch method for bilinear interpolation
101
+ with torch.no_grad():
102
+ # convert position to pytorch grid look up
103
+ grid = torch.from_numpy(pos)[None].float()
104
+ grid = grid * 2 - 1 # convert to range [-1,1]
105
+
106
+ # convert ball to support pytorch
107
+ ball_image = torch.from_numpy(ball_image[None]).float()
108
+ ball_image = ball_image.permute(0,3,1,2) # [1,3,H,W]
109
+
110
+ env_map = torch.nn.functional.grid_sample(ball_image, grid, mode='bilinear', padding_mode='border', align_corners=True)
111
+ env_map = env_map[0].permute(1,2,0).numpy()
112
+
113
+ env_map_default = skimage.transform.resize(env_map, (args.envmap_height, args.envmap_height*2), anti_aliasing=True)
114
+ if file_name.endswith(".exr"):
115
+ ezexr.imwrite(envmap_output_path, env_map_default.astype(np.float32))
116
+ else:
117
+ env_map_default = skimage.img_as_ubyte(env_map_default)
118
+ skimage.io.imsave(envmap_output_path, env_map_default)
119
+ return None
120
+
121
+
122
+
123
+
124
+ def main():
125
+
126
+ # running time measuring
127
+ start_time = time.time()
128
+
129
+ # load arguments
130
+ args = create_argparser().parse_args()
131
+
132
+ # make output directory if not exist
133
+ os.makedirs(args.envmap_dir, exist_ok=True)
134
+
135
+ # get all file in the directory
136
+ files = sorted(os.listdir(args.ball_dir))
137
+
138
+ # create partial function for pararell processing
139
+ process_func = partial(process_image, args)
140
+
141
+ # pararell processing
142
+ with Pool(args.threads) as p:
143
+ list(tqdm(p.imap(process_func, files), total=len(files)))
144
+
145
+ # print total time
146
+ print("TOTAL TIME: ", time.time() - start_time)
147
+
148
+
149
+
150
+ if __name__ == "__main__":
151
+ main()
152
+
environment.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ name: diffusionlight
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.11.6
7
+ - cudatoolkit=11.8
8
+ - openexr==3.2.1
example/bed.png ADDED

Git LFS Details

  • SHA256: 5d832f4c8a4954d7d05611c3b5ed39f86517953dd4d9bfed1753ad2402bbb090
  • Pointer size: 132 Bytes
  • Size of remote file: 1.79 MB
exposure2hdr.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # covnert exposure bracket to HDR output
2
+ import argparse
3
+ import os
4
+ from functools import partial
5
+ from multiprocessing import Pool
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import skimage
9
+ import ezexr
10
+ from relighting.tonemapper import TonemapHDR
11
+
12
+ def create_argparser():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", type=str, required=True, help='directory that contain the image') #dataset name or directory
15
+ parser.add_argument("--output_dir", type=str, required=True, help='directory that contain the image') #dataset name or directory
16
+ parser.add_argument("--endwith", type=str, default=".png" ,help='file ending to filter out unwant image')
17
+ parser.add_argument("--ev_string", type=str, default="_ev" ,help='string that use for search ev value')
18
+ parser.add_argument("--EV", type=str, default="0, -2.5, -5" ,help='avalible ev value')
19
+ parser.add_argument("--gamma", default=2.4, help="Gamma value", type=float)
20
+ parser.add_argument('--preview_output', dest='preview_output', action='store_true')
21
+ parser.set_defaults(preview_output=False)
22
+ return parser
23
+
24
+ def parse_filename(ev_string, endwith,filename):
25
+ a = filename.split(ev_string)
26
+ name = ev_string.join(a[:-1])
27
+ ev = a[-1].replace(endwith, "")
28
+ ev = int(ev) / 10
29
+ return {
30
+ 'name': name,
31
+ 'ev': ev,
32
+ 'filename': filename
33
+ }
34
+
35
+ def process_image(args, info):
36
+
37
+ #output directory
38
+ hdrdir = args.output_dir
39
+ os.makedirs(hdrdir, exist_ok=True)
40
+
41
+ scaler = np.array([0.212671, 0.715160, 0.072169])
42
+ name = info['name']
43
+ # ev value for each file
44
+ evs = [e for e in sorted(info['ev'], reverse = True)]
45
+
46
+ # filename
47
+ files = [info['ev'][e] for e in evs]
48
+
49
+ # inital first image
50
+ image0 = skimage.io.imread(os.path.join(args.input_dir, files[0]))[...,:3]
51
+ image0 = skimage.img_as_float(image0)
52
+ image0_linear = np.power(image0, args.gamma)
53
+
54
+ # read luminace for every image
55
+ luminances = []
56
+ for i in range(len(evs)):
57
+ # load image
58
+ path = os.path.join(args.input_dir, files[i])
59
+ image = skimage.io.imread(path)[...,:3]
60
+ image = skimage.img_as_float(image)
61
+
62
+ # apply gama correction
63
+ linear_img = np.power(image, args.gamma)
64
+
65
+ # convert the brighness
66
+ linear_img *= 1 / (2 ** evs[i])
67
+
68
+ # compute luminace
69
+ lumi = linear_img @ scaler
70
+ luminances.append(lumi)
71
+
72
+ # start from darkest image
73
+ out_luminace = luminances[len(evs) - 1]
74
+ for i in range(len(evs) - 1, 0, -1):
75
+ # compute mask
76
+ maxval = 1 / (2 ** evs[i-1])
77
+ p1 = np.clip((luminances[i-1] - 0.9 * maxval) / (0.1 * maxval), 0, 1)
78
+ p2 = out_luminace > luminances[i-1]
79
+ mask = (p1 * p2).astype(np.float32)
80
+ out_luminace = luminances[i-1] * (1-mask) + out_luminace * mask
81
+
82
+ hdr_rgb = image0_linear * (out_luminace / (luminances[0] + 1e-10))[:, :, np.newaxis]
83
+
84
+ # tone map for visualization
85
+ hdr2ldr = TonemapHDR(gamma=args.gamma, percentile=99, max_mapping=0.9)
86
+
87
+
88
+ ldr_rgb, _, _ = hdr2ldr(hdr_rgb)
89
+
90
+ ezexr.imwrite(os.path.join(hdrdir, name+".exr"), hdr_rgb.astype(np.float32))
91
+ if args.preview_output:
92
+ preview_dir = os.path.join(args.output_dir, "preview")
93
+ os.makedirs(preview_dir, exist_ok=True)
94
+ bracket = []
95
+ for s in 2 ** np.linspace(0, evs[-1], 10): #evs[-1] is -5
96
+ lumi = np.clip((s * hdr_rgb) ** (1/args.gamma), 0, 1)
97
+ bracket.append(lumi)
98
+ bracket = np.concatenate(bracket, axis=1)
99
+ skimage.io.imsave(os.path.join(preview_dir, name+".png"), skimage.img_as_ubyte(bracket))
100
+ return None
101
+
102
+ def main():
103
+ # load arguments
104
+ args = create_argparser().parse_args()
105
+
106
+ files = sorted(os.listdir(args.input_dir))
107
+
108
+ #filter file out with file ending
109
+ files = [f for f in files if f.endswith(args.endwith)]
110
+ evs = [float(e.strip()) for e in args.EV.split(",")]
111
+
112
+ # parse into useful data
113
+ files = [parse_filename(args.ev_string, args.endwith, f) for f in files]
114
+
115
+ # filter out unused ev
116
+ files = [f for f in files if f['ev'] in evs]
117
+
118
+ info = {}
119
+ for f in files:
120
+ if not f['name'] in info:
121
+ info[f['name']] = {}
122
+ info[f['name']][f['ev']] = f['filename']
123
+
124
+ infolist = []
125
+ for k in info:
126
+ if len(info[k]) != len(evs):
127
+ print("WARNING: missing ev in ", k)
128
+ continue
129
+ # convert to list data
130
+ infolist.append({'name': k, 'ev': info[k]})
131
+
132
+ fn = partial(process_image, args)
133
+ with Pool(8) as p:
134
+ r = list(tqdm(p.imap(fn, infolist), total=len(infolist)))
135
+
136
+
137
+
138
+ if __name__ == "__main__":
139
+ main()
inpaint.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inpaint the ball on an image
2
+ # this one is design for general image that does not require special location to place
3
+
4
+
5
+ import torch
6
+ import argparse
7
+ import numpy as np
8
+ import torch.distributed as dist
9
+ import os
10
+ from PIL import Image
11
+ from tqdm.auto import tqdm
12
+ import json
13
+
14
+
15
+ from relighting.inpainter import BallInpainter
16
+
17
+ from relighting.mask_utils import MaskGenerator
18
+ from relighting.ball_processor import (
19
+ get_ideal_normal_ball,
20
+ crop_ball
21
+ )
22
+ from relighting.dataset import GeneralLoader
23
+ from relighting.utils import name2hash
24
+ import relighting.dist_utils as dist_util
25
+ import time
26
+
27
+
28
+ # cross import from inpaint_multi-illum.py
29
+ from relighting.argument import (
30
+ SD_MODELS,
31
+ CONTROLNET_MODELS,
32
+ VAE_MODELS
33
+ )
34
+
35
+ def create_argparser():
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument("--dataset", type=str, required=True ,help='directory that contain the image') #dataset name or directory
38
+ parser.add_argument("--ball_size", type=int, default=256, help="size of the ball in pixel")
39
+ parser.add_argument("--ball_dilate", type=int, default=20, help="How much pixel to dilate the ball to make a sharper edge")
40
+ parser.add_argument("--prompt", type=str, default="a perfect mirrored reflective chrome ball sphere")
41
+ parser.add_argument("--prompt_dark", type=str, default="a perfect black dark mirrored reflective chrome ball sphere")
42
+ parser.add_argument("--negative_prompt", type=str, default="matte, diffuse, flat, dull")
43
+ parser.add_argument("--model_option", default="sdxl", help='selecting fancy model option (sd15_old, sd15_new, sd21, sdxl, sdxl_turbo)') # [sd15_old, sd15_new, or sd21]
44
+ parser.add_argument("--output_dir", required=True, type=str, help="output directory")
45
+ parser.add_argument("--img_height", type=int, default=1024, help="Dataset Image Height")
46
+ parser.add_argument("--img_width", type=int, default=1024, help="Dataset Image Width")
47
+ # some good seed 0, 37, 71, 125, 140, 196, 307, 434, 485, 575 | 9021, 9166, 9560, 9814, but default auto is for fairness
48
+ parser.add_argument("--seed", default="auto", type=str, help="Seed: right now we use single seed instead to reduce the time, (Auto will use hash file name to generate seed)")
49
+ parser.add_argument("--denoising_step", default=30, type=int, help="number of denoising step of diffusion model")
50
+ parser.add_argument("--control_scale", default=0.5, type=float, help="controlnet conditioning scale")
51
+ parser.add_argument("--guidance_scale", default=5.0, type=float, help="guidance scale (also known as CFG)")
52
+
53
+ parser.add_argument('--no_controlnet', dest='use_controlnet', action='store_false', help='by default we using controlnet, we have option to disable to see the different')
54
+ parser.set_defaults(use_controlnet=True)
55
+
56
+ parser.add_argument('--no_force_square', dest='force_square', action='store_false', help='SDXL is trained for square image, we prefered the square input. but you use this option to disable reshape')
57
+ parser.set_defaults(force_square=True)
58
+
59
+ parser.add_argument('--no_random_loader', dest='random_loader', action='store_false', help="by default, we random how dataset load. This make us able to peak into the trend of result without waiting entire dataset. but can disable if prefereed")
60
+ parser.set_defaults(random_loader=True)
61
+
62
+ parser.add_argument('--cpu', dest='is_cpu', action='store_true', help="using CPU inference instead of GPU inference")
63
+ parser.set_defaults(is_cpu=False)
64
+
65
+ parser.add_argument('--offload', dest='offload', action='store_false', help="to enable diffusers cpu offload")
66
+ parser.set_defaults(offload=False)
67
+
68
+ parser.add_argument("--limit_input", default=0, type=int, help="limit number of image to process to n image (0 = no limit), useful for run smallset")
69
+
70
+
71
+ # LoRA stuff
72
+ parser.add_argument('--no_lora', dest='use_lora', action='store_false', help='by default we using lora, we have option to disable to see the different')
73
+ parser.set_defaults(use_lora=True)
74
+
75
+ parser.add_argument("--lora_path", default="models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500", type=str, help="LoRA Checkpoint path")
76
+ parser.add_argument("--lora_scale", default=0.75, type=float, help="LoRA scale factor")
77
+
78
+ # speed optimization stuff
79
+ parser.add_argument('--no_torch_compile', dest='use_torch_compile', action='store_false', help='by default we using torch compile for faster processing speed. disable it if your environemnt is lower than pytorch2.0')
80
+ parser.set_defaults(use_torch_compile=True)
81
+
82
+ # algorithm + iterative stuff
83
+ parser.add_argument("--algorithm", type=str, default="iterative", choices=["iterative", "normal"], help="Selecting between iterative or normal (single pass inpaint) algorithm")
84
+
85
+ parser.add_argument("--agg_mode", default="median", type=str)
86
+ parser.add_argument("--strength", default=0.8, type=float)
87
+ parser.add_argument("--num_iteration", default=2, type=int)
88
+ parser.add_argument("--ball_per_iteration", default=30, type=int)
89
+ parser.add_argument('--no_save_intermediate', dest='save_intermediate', action='store_false')
90
+ parser.set_defaults(save_intermediate=True)
91
+ parser.add_argument("--cache_dir", default="./temp_inpaint_iterative", type=str, help="cache directory for iterative inpaint")
92
+
93
+ # pararelle processing
94
+ parser.add_argument("--idx", default=0, type=int, help="index of the current process, useful for running on multiple node")
95
+ parser.add_argument("--total", default=1, type=int, help="total number of process")
96
+
97
+ # for HDR stuff
98
+ parser.add_argument("--max_negative_ev", default=-5, type=int, help="maximum negative EV for lora")
99
+ parser.add_argument("--ev", default="0,-2.5,-5", type=str, help="EV: list of EV to generate")
100
+
101
+ return parser
102
+
103
+ def get_ball_location(image_data, args):
104
+ if 'boundary' in image_data:
105
+ # support predefined boundary if need
106
+ x = image_data["boundary"]["x"]
107
+ y = image_data["boundary"]["y"]
108
+ r = image_data["boundary"]["size"]
109
+
110
+ # support ball dilation
111
+ half_dilate = args.ball_dilate // 2
112
+
113
+ # check if not left out-of-bound
114
+ if x - half_dilate < 0: x += half_dilate
115
+ if y - half_dilate < 0: y += half_dilate
116
+
117
+ # check if not right out-of-bound
118
+ if x + r + half_dilate > args.img_width: x -= half_dilate
119
+ if y + r + half_dilate > args.img_height: y -= half_dilate
120
+
121
+ else:
122
+ # we use top-left corner notation
123
+ x, y, r = ((args.img_width // 2) - (args.ball_size // 2), (args.img_height // 2) - (args.ball_size // 2), args.ball_size)
124
+ return x, y, r
125
+
126
+ def interpolate_embedding(pipe, args):
127
+ print("interpolate embedding...")
128
+
129
+ # get list of all EVs
130
+ ev_list = [float(x) for x in args.ev.split(",")]
131
+ interpolants = [ev / args.max_negative_ev for ev in ev_list]
132
+
133
+ print("EV : ", ev_list)
134
+ print("EV : ", interpolants)
135
+
136
+ # calculate prompt embeddings
137
+ prompt_normal = args.prompt
138
+ prompt_dark = args.prompt_dark
139
+ prompt_embeds_normal, _, pooled_prompt_embeds_normal, _ = pipe.pipeline.encode_prompt(prompt_normal)
140
+ prompt_embeds_dark, _, pooled_prompt_embeds_dark, _ = pipe.pipeline.encode_prompt(prompt_dark)
141
+
142
+ # interpolate embeddings
143
+ interpolate_embeds = []
144
+ for t in interpolants:
145
+ int_prompt_embeds = prompt_embeds_normal + t * (prompt_embeds_dark - prompt_embeds_normal)
146
+ int_pooled_prompt_embeds = pooled_prompt_embeds_normal + t * (pooled_prompt_embeds_dark - pooled_prompt_embeds_normal)
147
+
148
+ interpolate_embeds.append((int_prompt_embeds, int_pooled_prompt_embeds))
149
+
150
+ return dict(zip(ev_list, interpolate_embeds))
151
+
152
+ def main():
153
+ # load arguments
154
+ args = create_argparser().parse_args()
155
+
156
+ # get local rank
157
+ if args.is_cpu:
158
+ device = torch.device("cpu")
159
+ torch_dtype = torch.float32
160
+ else:
161
+ device = dist_util.dev()
162
+ torch_dtype = torch.float16
163
+
164
+ # so, we need ball_dilate >= 16 (2*vae_scale_factor) to make our mask shape = (272, 272)
165
+ assert args.ball_dilate % 2 == 0 # ball dilation should be symmetric
166
+
167
+ # create controlnet pipeline
168
+ if args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and args.use_controlnet:
169
+ model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option]
170
+ pipe = BallInpainter.from_sdxl(
171
+ model=model,
172
+ controlnet=controlnet,
173
+ device=device,
174
+ torch_dtype = torch_dtype,
175
+ offload = args.offload
176
+ )
177
+ elif args.model_option in ["sdxl", "sdxl_fast", "sdxl_turbo"] and not args.use_controlnet:
178
+ model = SD_MODELS[args.model_option]
179
+ pipe = BallInpainter.from_sdxl(
180
+ model=model,
181
+ controlnet=None,
182
+ device=device,
183
+ torch_dtype = torch_dtype,
184
+ offload = args.offload
185
+ )
186
+ elif args.use_controlnet:
187
+ model, controlnet = SD_MODELS[args.model_option], CONTROLNET_MODELS[args.model_option]
188
+ pipe = BallInpainter.from_sd(
189
+ model=model,
190
+ controlnet=controlnet,
191
+ device=device,
192
+ torch_dtype = torch_dtype,
193
+ offload = args.offload
194
+ )
195
+ else:
196
+ model = SD_MODELS[args.model_option]
197
+ pipe = BallInpainter.from_sd(
198
+ model=model,
199
+ controlnet=None,
200
+ device=device,
201
+ torch_dtype = torch_dtype,
202
+ offload = args.offload
203
+ )
204
+
205
+ if args.model_option in ["sdxl_turbo"]:
206
+ # Guidance scale is not supported in sdxl_turbo
207
+ args.guidance_scale = 0.0
208
+
209
+ if args.lora_scale > 0 and args.lora_path is None:
210
+ raise ValueError("lora scale is not 0 but lora path is not set")
211
+
212
+ if (args.lora_path is not None) and (args.use_lora):
213
+ print(f"using lora path {args.lora_path}")
214
+ print(f"using lora scale {args.lora_scale}")
215
+ pipe.pipeline.load_lora_weights(args.lora_path)
216
+ pipe.pipeline.fuse_lora(lora_scale=args.lora_scale) # fuse lora weight w' = w + \alpha \Delta w
217
+ enabled_lora = True
218
+ else:
219
+ enabled_lora = False
220
+
221
+ if args.use_torch_compile:
222
+ try:
223
+ print("compiling unet model")
224
+ start_time = time.time()
225
+ pipe.pipeline.unet = torch.compile(pipe.pipeline.unet, mode="reduce-overhead", fullgraph=True)
226
+ print("Model compilation time: ", time.time() - start_time)
227
+ except:
228
+ pass
229
+
230
+ # default height for sdxl is 1024, if not set, we set default height.
231
+ if args.model_option == "sdxl" and args.img_height == 0 and args.img_width == 0:
232
+ args.img_height = 1024
233
+ args.img_width = 1024
234
+
235
+ # load dataset
236
+ dataset = GeneralLoader(
237
+ root=args.dataset,
238
+ resolution=(args.img_width, args.img_height),
239
+ force_square=args.force_square,
240
+ return_dict=True,
241
+ random_shuffle=args.random_loader,
242
+ process_id=args.idx,
243
+ process_total=args.total,
244
+ limit_input=args.limit_input,
245
+ )
246
+
247
+ # interpolate embedding
248
+ embedding_dict = interpolate_embedding(pipe, args)
249
+
250
+ # prepare mask and normal ball
251
+ mask_generator = MaskGenerator()
252
+ normal_ball, mask_ball = get_ideal_normal_ball(size=args.ball_size+args.ball_dilate)
253
+ _, mask_ball_for_crop = get_ideal_normal_ball(size=args.ball_size)
254
+
255
+
256
+ # make output directory if not exist
257
+ raw_output_dir = os.path.join(args.output_dir, "raw")
258
+ control_output_dir = os.path.join(args.output_dir, "control")
259
+ square_output_dir = os.path.join(args.output_dir, "square")
260
+ os.makedirs(args.output_dir, exist_ok=True)
261
+ os.makedirs(raw_output_dir, exist_ok=True)
262
+ os.makedirs(control_output_dir, exist_ok=True)
263
+ os.makedirs(square_output_dir, exist_ok=True)
264
+
265
+ # create split seed
266
+ # please DO NOT manual replace this line, use --seed option instead
267
+ seeds = args.seed.split(",")
268
+
269
+ for image_data in tqdm(dataset):
270
+ input_image = image_data["image"]
271
+ image_path = image_data["path"]
272
+
273
+ for ev, (prompt_embeds, pooled_prompt_embeds) in embedding_dict.items():
274
+ # create output file name (we always use png to prevent quality loss)
275
+ ev_str = str(ev).replace(".", "") if ev != 0 else "-00"
276
+ outname = os.path.basename(image_path).split(".")[0] + f"_ev{ev_str}"
277
+
278
+ # we use top-left corner notation (which is different from aj.aek's center point notation)
279
+ x, y, r = get_ball_location(image_data, args)
280
+
281
+ # create inpaint mask
282
+ mask = mask_generator.generate_single(
283
+ input_image, mask_ball,
284
+ x - (args.ball_dilate // 2),
285
+ y - (args.ball_dilate // 2),
286
+ r + args.ball_dilate
287
+ )
288
+
289
+ seeds = tqdm(seeds, desc="seeds") if len(seeds) > 10 else seeds
290
+
291
+ #replacely create image with differnt seed
292
+ for seed in seeds:
293
+ start_time = time.time()
294
+ # set seed, if seed auto we use file name as seed
295
+ if seed == "auto":
296
+ filename = os.path.basename(image_path).split(".")[0]
297
+ seed = name2hash(filename)
298
+ outpng = f"{outname}.png"
299
+ cache_name = f"{outname}"
300
+ else:
301
+ seed = int(seed)
302
+ outpng = f"{outname}_seed{seed}.png"
303
+ cache_name = f"{outname}_seed{seed}"
304
+ # skip if file exist, useful for resuming
305
+ if os.path.exists(os.path.join(square_output_dir, outpng)):
306
+ continue
307
+ generator = torch.Generator().manual_seed(seed)
308
+ kwargs = {
309
+ "prompt_embeds": prompt_embeds,
310
+ "pooled_prompt_embeds": pooled_prompt_embeds,
311
+ 'negative_prompt': args.negative_prompt,
312
+ 'num_inference_steps': args.denoising_step,
313
+ 'generator': generator,
314
+ 'image': input_image,
315
+ 'mask_image': mask,
316
+ 'strength': 1.0,
317
+ 'current_seed': seed, # we still need seed in the pipeline!
318
+ 'controlnet_conditioning_scale': args.control_scale,
319
+ 'height': args.img_height,
320
+ 'width': args.img_width,
321
+ 'normal_ball': normal_ball,
322
+ 'mask_ball': mask_ball,
323
+ 'x': x,
324
+ 'y': y,
325
+ 'r': r,
326
+ 'guidance_scale': args.guidance_scale,
327
+ }
328
+
329
+ if enabled_lora:
330
+ kwargs["cross_attention_kwargs"] = {"scale": args.lora_scale}
331
+
332
+ if args.algorithm == "normal":
333
+ output_image = pipe.inpaint(**kwargs).images[0]
334
+ elif args.algorithm == "iterative":
335
+ # This is still buggy
336
+ print("using inpainting iterative, this is going to take a while...")
337
+ kwargs.update({
338
+ "strength": args.strength,
339
+ "num_iteration": args.num_iteration,
340
+ "ball_per_iteration": args.ball_per_iteration,
341
+ "agg_mode": args.agg_mode,
342
+ "save_intermediate": args.save_intermediate,
343
+ "cache_dir": os.path.join(args.cache_dir, cache_name),
344
+ })
345
+ output_image = pipe.inpaint_iterative(**kwargs)
346
+ else:
347
+ raise NotImplementedError(f"Unknown algorithm {args.algorithm}")
348
+
349
+
350
+ square_image = output_image.crop((x, y, x+r, y+r))
351
+
352
+ # return the most recent control_image for sanity check
353
+ control_image = pipe.get_cache_control_image()
354
+ if control_image is not None:
355
+ control_image.save(os.path.join(control_output_dir, outpng))
356
+
357
+ # save image
358
+ output_image.save(os.path.join(raw_output_dir, outpng))
359
+ square_image.save(os.path.join(square_output_dir, outpng))
360
+
361
+
362
+ if __name__ == "__main__":
363
+ main()
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b96eb34accf4ce5f33dff729e12669c46e3132973ee2cf0ac2d4f2c993a2af4d
3
+ size 47392882
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bf601e30c08ce4f1eea6e09e1780bd8ba2588986eaee8379672350707dcddaa
3
+ size 23396024
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec103b2e80b357d29e2ec8355d49f0c289331f0ece810d5bafd464d33e5f4c76
3
+ size 14280
models/ThisIsTheFinal-lora-hdr-continuous-largeT@900/0_-5/checkpoint-2500/scheduler.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9244a743d4761f975ab2a14d0b7a509a85554dd77bef6490330160a2a639fae
3
+ size 1000
relighting/argument.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from diffusers import DDIMScheduler, DDPMScheduler, UniPCMultistepScheduler
3
+
4
+ def get_control_signal_type(controlnet):
5
+ if "normal" in controlnet:
6
+ return "normal"
7
+ elif "depth" in controlnet:
8
+ return "depth"
9
+ else:
10
+ raise NotImplementedError
11
+
12
+ SD_MODELS = {
13
+ "sd15_old": "runwayml/stable-diffusion-inpainting",
14
+ "sd15_new": "runwayml/stable-diffusion-inpainting",
15
+ "sd21": "stabilityai/stable-diffusion-2-inpainting",
16
+ "sdxl": "stabilityai/stable-diffusion-xl-base-1.0",
17
+ "sdxl_fast": "stabilityai/stable-diffusion-xl-base-1.0",
18
+ "sdxl_turbo": "stabilityai/sdxl-turbo",
19
+ "sd15_depth": "runwayml/stable-diffusion-inpainting",
20
+ }
21
+
22
+ VAE_MODELS = {
23
+ "sdxl": "madebyollin/sdxl-vae-fp16-fix",
24
+ "sdxl_fast": "madebyollin/sdxl-vae-fp16-fix",
25
+ }
26
+
27
+ CONTROLNET_MODELS = {
28
+ "sd15_old": "fusing/stable-diffusion-v1-5-controlnet-normal",
29
+ "sd15_new": "lllyasviel/control_v11p_sd15_normalbae",
30
+ "sd21": "thibaud/controlnet-sd21-normalbae-diffusers",
31
+ "sdxl": "diffusers/controlnet-depth-sdxl-1.0",
32
+ "sdxl_fast": "diffusers/controlnet-depth-sdxl-1.0-small",
33
+ "sdxl_turbo": "diffusers/controlnet-depth-sdxl-1.0-small",
34
+ "sd15_depth": "lllyasviel/control_v11f1p_sd15_depth",
35
+ }
36
+
37
+ SAMPLERS = {
38
+ "ddim": DDIMScheduler,
39
+ "ddpm": DDPMScheduler,
40
+ "unipc": UniPCMultistepScheduler,
41
+ }
42
+
43
+ DEPTH_ESTIMATOR = "Intel/dpt-hybrid-midas"
relighting/ball_processor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from scipy.special import sph_harm
5
+
6
+ def crop_ball(image, mask_ball, x, y, size, apply_mask=True, bg_color = (0, 0, 0)):
7
+ if isinstance(image, Image.Image):
8
+ result = np.array(image)
9
+ else:
10
+ result = image.copy()
11
+
12
+ result = result[y:y+size, x:x+size]
13
+ if apply_mask:
14
+ result[~mask_ball] = bg_color
15
+ return result
16
+
17
+ def get_ideal_normal_ball(size, flip_x=True):
18
+ """
19
+ Generate normal ball for specific size
20
+ Normal map is x "left", y up, z into the screen
21
+ (we flip X to match sobel operator)
22
+ @params
23
+ - size (int) - single value of height and width
24
+ @return:
25
+ - normal_map (np.array) - normal map [size, size, 3]
26
+ - mask (np.array) - mask that make a valid normal map [size,size]
27
+ """
28
+ # we flip x to match sobel operator
29
+ x = torch.linspace(1, -1, size)
30
+ y = torch.linspace(1, -1, size)
31
+ x = x.flip(dims=(-1,)) if not flip_x else x
32
+
33
+ y, x = torch.meshgrid(y, x)
34
+ z = (1 - x**2 - y**2)
35
+ mask = z >= 0
36
+
37
+ # clean up invalid value outsize the mask
38
+ x = x * mask
39
+ y = y * mask
40
+ z = z * mask
41
+
42
+ # get real z value
43
+ z = torch.sqrt(z)
44
+
45
+ # clean up normal map value outside mask
46
+ normal_map = torch.cat([x[..., None], y[..., None], z[..., None]], dim=-1)
47
+ normal_map = normal_map.numpy()
48
+ mask = mask.numpy()
49
+ return normal_map, mask
50
+
51
+ def get_predicted_normal_ball(size, precomputed_path=None):
52
+ if precomputed_path is not None:
53
+ normal_map = Image.open(precomputed_path).resize((size, size))
54
+ normal_map = np.array(normal_map).astype(np.uint8)
55
+ _, mask = get_ideal_normal_ball(size)
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ normal_map = (normal_map - 127.5) / 127.5 # normalize for compatibility with inpainting pipeline
60
+ return normal_map, mask
relighting/dataset.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import json
3
+ import os
4
+ import skimage
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from natsort import natsorted
8
+ from PIL import Image
9
+ from relighting.image_processor import pil_square_image
10
+ from tqdm.auto import tqdm
11
+ import random
12
+ import itertools
13
+ from abc import ABC, abstractmethod
14
+
15
+ class Dataset(ABC):
16
+ def __init__(self,
17
+ resolution=(1024, 1024),
18
+ force_square=True,
19
+ return_image_path=False,
20
+ return_dict=False,
21
+ ):
22
+ """
23
+ Resoution is (WIDTH, HEIGHT)
24
+ """
25
+ self.resolution = resolution
26
+ self.force_square = force_square
27
+ self.return_image_path = return_image_path
28
+ self.return_dict = return_dict
29
+ self.scene_data = []
30
+ self.meta_data = []
31
+ self.boundary_info = []
32
+
33
+ @abstractmethod
34
+ def _load_data_path(self):
35
+ pass
36
+
37
+ def __len__(self):
38
+ return len(self.scene_data)
39
+
40
+ def __getitem__(self, idx):
41
+ image = Image.open(self.scene_data[idx])
42
+ if self.force_square:
43
+ image = pil_square_image(image, self.resolution)
44
+ else:
45
+ image = image.resize(self.resolution)
46
+
47
+ if self.return_dict:
48
+ d = {
49
+ "image": image,
50
+ "path": self.scene_data[idx]
51
+ }
52
+ if len(self.boundary_info) > 0:
53
+ d["boundary"] = self.boundary_info[idx]
54
+
55
+ return d
56
+ elif self.return_image_path:
57
+ return image, self.scene_data[idx]
58
+ else:
59
+ return image
60
+
61
+ class GeneralLoader(Dataset):
62
+ def __init__(self,
63
+ root=None,
64
+ num_samples=None,
65
+ res_threshold=((1024, 1024)),
66
+ apply_threshold=False,
67
+ random_shuffle=False,
68
+ process_id = 0,
69
+ process_total = 1,
70
+ limit_input = 0,
71
+ **kwargs,
72
+ ):
73
+ super().__init__(**kwargs)
74
+ self.root = root
75
+ self.res_threshold = res_threshold
76
+ self.apply_threshold = apply_threshold
77
+ self.has_meta = False
78
+
79
+ if self.root is not None:
80
+ if not os.path.exists(self.root):
81
+ raise Exception(f"Dataset {self.root} does not exist.")
82
+
83
+ paths = natsorted(
84
+ list(glob.glob(os.path.join(self.root, "*.png"))) + \
85
+ list(glob.glob(os.path.join(self.root, "*.jpg")))
86
+ )
87
+ self.scene_data = self._load_data_path(paths, num_samples=num_samples)
88
+
89
+ if random_shuffle:
90
+ SEED = 0
91
+ random.Random(SEED).shuffle(self.scene_data)
92
+ random.Random(SEED).shuffle(self.boundary_info)
93
+
94
+ if limit_input > 0:
95
+ self.scene_data = self.scene_data[:limit_input]
96
+ self.boundary_info = self.boundary_info[:limit_input]
97
+
98
+ # please keep this one the last, so, we will filter out scene_data and boundary info
99
+ if process_total > 1:
100
+ self.scene_data = self.scene_data[process_id::process_total]
101
+ self.boundary_info = self.boundary_info[process_id::process_total]
102
+ print(f"Process {process_id} has {len(self.scene_data)} samples")
103
+
104
+ def _load_data_path(self, paths, num_samples=None):
105
+ if os.path.exists(os.path.splitext(paths[0])[0] + ".json") or os.path.exists(os.path.splitext(paths[-1])[0] + ".json"):
106
+ self.has_meta = True
107
+
108
+ if self.has_meta:
109
+ # read metadata
110
+ TARGET_KEY = "chrome_mask256"
111
+ for path in paths:
112
+ with open(os.path.splitext(path)[0] + ".json") as f:
113
+ meta = json.load(f)
114
+ self.meta_data.append(meta)
115
+ boundary = {
116
+ "x": meta[TARGET_KEY]["x"],
117
+ "y": meta[TARGET_KEY]["y"],
118
+ "size": meta[TARGET_KEY]["w"],
119
+ }
120
+ self.boundary_info.append(boundary)
121
+
122
+
123
+ scene_data = paths
124
+ if self.apply_threshold:
125
+ scene_data = []
126
+ for path in tqdm(paths):
127
+ img = Image.open(path)
128
+ if (img.size[0] >= self.res_threshold[0]) and (img.size[1] >= self.res_threshold[1]):
129
+ scene_data.append(path)
130
+
131
+ if num_samples is not None:
132
+ max_idx = min(num_samples, len(scene_data))
133
+ scene_data = scene_data[:max_idx]
134
+
135
+ return scene_data
136
+
137
+ @classmethod
138
+ def from_image_paths(cls, paths, *args, **kwargs):
139
+ dataset = cls(*args, **kwargs)
140
+ dataset.scene_data = dataset._load_data_path(paths)
141
+ return dataset
142
+
143
+ class ALPLoader(Dataset):
144
+ def __init__(self,
145
+ root=None,
146
+ num_samples=None,
147
+ res_threshold=((1024, 1024)),
148
+ apply_threshold=False,
149
+ **kwargs,
150
+ ):
151
+ super().__init__(**kwargs)
152
+ self.root = root
153
+ self.res_threshold = res_threshold
154
+ self.apply_threshold = apply_threshold
155
+ self.has_meta = False
156
+
157
+ if self.root is not None:
158
+ if not os.path.exists(self.root):
159
+ raise Exception(f"Dataset {self.root} does not exist.")
160
+
161
+ dirs = natsorted(list(glob.glob(os.path.join(self.root, "*"))))
162
+ self.scene_data = self._load_data_path(dirs)
163
+
164
+ def _load_data_path(self, dirs):
165
+ self.scene_names = [Path(dir).name for dir in dirs]
166
+
167
+ scene_data = []
168
+ for dir in dirs:
169
+ pseudo_probe_dirs = natsorted(list(glob.glob(os.path.join(dir, "*"))))
170
+ pseudo_probe_dirs = [dir for dir in pseudo_probe_dirs if "gt" not in dir]
171
+ data = [os.path.join(dir, "images", "0.png") for dir in pseudo_probe_dirs]
172
+ scene_data.append(data)
173
+
174
+ scene_data = list(itertools.chain(*scene_data))
175
+ return scene_data
176
+
177
+ class MultiIlluminationLoader(Dataset):
178
+ def __init__(self,
179
+ root,
180
+ mask_probe=True,
181
+ mask_boundingbox=False,
182
+ **kwargs,
183
+ ):
184
+ """
185
+ @params resolution (tuple): (width, height) - resolution of the image
186
+ @params force_square: will add black border to make the image square while keeping the aspect ratio
187
+ @params mask_probe: mask the probe with the mask in the dataset
188
+
189
+ """
190
+ super().__init__(**kwargs)
191
+ self.root = root
192
+ self.mask_probe = mask_probe
193
+ self.mask_boundingbox = mask_boundingbox
194
+
195
+ if self.root is not None:
196
+ dirs = natsorted(list(glob.glob(os.path.join(self.root, "*"))))
197
+ self.scene_data = self._load_data_path(dirs)
198
+
199
+ def _load_data_path(self, dirs):
200
+ self.scene_names = [Path(dir).name for dir in dirs]
201
+
202
+ data = {}
203
+ for dir in dirs:
204
+ chrome_probes = natsorted(list(glob.glob(os.path.join(dir, "probes", "*chrome*.jpg"))))
205
+ gray_probes = natsorted(list(glob.glob(os.path.join(dir, "probes", "*gray*.jpg"))))
206
+ scenes = natsorted(list(glob.glob(os.path.join(dir, "dir_*.jpg"))))
207
+
208
+ with open(os.path.join(dir, "meta.json")) as f:
209
+ meta_data = json.load(f)
210
+
211
+ bbox_chrome = meta_data["chrome"]["bounding_box"]
212
+ bbox_gray = meta_data["gray"]["bounding_box"]
213
+
214
+ mask_chrome = os.path.join(dir, "mask_chrome.png")
215
+ mask_gray = os.path.join(dir, "mask_gray.png")
216
+
217
+ scene_name = Path(dir).name
218
+ data[scene_name] = {
219
+ "scenes": scenes,
220
+ "chrome_probes": chrome_probes,
221
+ "gray_probes": gray_probes,
222
+ "bbox_chrome": bbox_chrome,
223
+ "bbox_gray": bbox_gray,
224
+ "mask_chrome": mask_chrome,
225
+ "mask_gray": mask_gray,
226
+ }
227
+ return data
228
+
229
+ def _mask_probe(self, image, mask):
230
+ """
231
+ mask probe with a png file in dataset
232
+ """
233
+ image_anticheat = skimage.img_as_float(np.array(image))
234
+ mask_np = skimage.img_as_float(np.array(mask))[..., None]
235
+ image_anticheat = ((1.0 - mask_np) * image_anticheat) + (0.5 * mask_np)
236
+ image_anticheat = Image.fromarray(skimage.img_as_ubyte(image_anticheat))
237
+ return image_anticheat
238
+
239
+ def _mask_boundingbox(self, image, bbox):
240
+ """
241
+ mask image with the bounding box for anti-cheat
242
+ """
243
+ bbox = {k:int(np.round(v/4.0)) for k,v in bbox.items()}
244
+ x, y, w, h = bbox["x"], bbox["y"], bbox["w"], bbox["h"]
245
+ image_anticheat = skimage.img_as_float(np.array(image))
246
+ image_anticheat[y:y+h, x:x+w] = 0.5
247
+ image_anticheat = Image.fromarray(skimage.img_as_ubyte(image_anticheat))
248
+ return image_anticheat
249
+
250
+ def __getitem__(self, scene_name):
251
+ data = self.scene_data[scene_name]
252
+
253
+ mask_chrome = Image.open(data["mask_chrome"])
254
+ mask_gray = Image.open(data["mask_gray"])
255
+ images = []
256
+ for path in data["scenes"]:
257
+ image = Image.open(path)
258
+ if self.mask_probe:
259
+ image = self._mask_probe(image, mask_chrome)
260
+ image = self._mask_probe(image, mask_gray)
261
+ if self.mask_boundingbox:
262
+ image = self._mask_boundingbox(image, data["bbox_chrome"])
263
+ image = self._mask_boundingbox(image, data["bbox_gray"])
264
+
265
+ if self.force_square:
266
+ image = pil_square_image(image, self.resolution)
267
+ else:
268
+ image = image.resize(self.resolution)
269
+ images.append(image)
270
+
271
+ chrome_probes = [Image.open(path) for path in data["chrome_probes"]]
272
+ gray_probes = [Image.open(path) for path in data["gray_probes"]]
273
+ bbox_chrome = data["bbox_chrome"]
274
+ bbox_gray = data["bbox_gray"]
275
+
276
+ return images, chrome_probes, gray_probes, bbox_chrome, bbox_gray
277
+
278
+
279
+ def calculate_ball_info(self, scene_name):
280
+ # TODO: remove hard-coded parameters
281
+ ball_data = []
282
+ for mtype in ['bbox_chrome', 'bbox_gray']:
283
+ info = self.scene_data[scene_name][mtype]
284
+
285
+ # x-y is top-left corner of the bounding box
286
+ # meta file is for 4000x6000 image but dataset is 1000x1500
287
+ x = info['x'] / 4
288
+ y = info['y'] / 4
289
+ w = info['w'] / 4
290
+ h = info['h'] / 4
291
+
292
+
293
+ # we scale data to 512x512 image
294
+ if self.force_square:
295
+ h_ratio = (512.0 * 2.0 / 3.0) / 1000.0 #384 because we have black border on the top
296
+ w_ratio = 512.0 / 1500.0
297
+ else:
298
+ h_ratio = self.resolution[0] / 1000.0
299
+ w_ratio = self.resolution[1] / 1500.0
300
+
301
+ x = x * w_ratio
302
+ y = y * h_ratio
303
+ w = w * w_ratio
304
+ h = h * h_ratio
305
+
306
+ if self.force_square:
307
+ # y need to shift due to top black border
308
+ top_border_height = 512.0 * (1/6)
309
+ y = y + top_border_height
310
+
311
+
312
+ # Sphere is not circle due to the camera perspective, Need future fix for this
313
+ # For now, we use the minimum of width and height
314
+ w = int(np.round(w))
315
+ h = int(np.round(h))
316
+ if w > h:
317
+ r = h
318
+ x = x + (w - h) / 2.0
319
+ else:
320
+ r = w
321
+ y = y + (h - w) / 2.0
322
+
323
+ x = int(np.round(x))
324
+ y = int(np.round(y))
325
+
326
+ ball_data.append((x, y, r))
327
+
328
+ return ball_data
329
+
330
+ def calculate_bbox_info(self, scene_name):
331
+ # TODO: remove hard-coded parameters
332
+ bbox_data = []
333
+ for mtype in ['bbox_chrome', 'bbox_gray']:
334
+ info = self.scene_data[scene_name][mtype]
335
+
336
+ # x-y is top-left corner of the bounding box
337
+ # meta file is for 4000x6000 image but dataset is 1000x1500
338
+ x = info['x'] / 4
339
+ y = info['y'] / 4
340
+ w = info['w'] / 4
341
+ h = info['h'] / 4
342
+
343
+
344
+ # we scale data to 512x512 image
345
+ if self.force_square:
346
+ h_ratio = (512.0 * 2.0 / 3.0) / 1000.0 #384 because we have black border on the top
347
+ w_ratio = 512.0 / 1500.0
348
+ else:
349
+ h_ratio = self.resolution[0] / 1000.0
350
+ w_ratio = self.resolution[1] / 1500.0
351
+
352
+ x = x * w_ratio
353
+ y = y * h_ratio
354
+ w = w * w_ratio
355
+ h = h * h_ratio
356
+
357
+ if self.force_square:
358
+ # y need to shift due to top black border
359
+ top_border_height = 512.0 * (1/6)
360
+ y = y + top_border_height
361
+
362
+
363
+ w = int(np.round(w))
364
+ h = int(np.round(h))
365
+ x = int(np.round(x))
366
+ y = int(np.round(y))
367
+
368
+ bbox_data.append((x, y, w, h))
369
+
370
+ return bbox_data
371
+
372
+ """
373
+ DO NOT remove this!
374
+ This is for evaluating results from Multi-Illumination generated from the old version
375
+ """
376
+ def calculate_ball_info_legacy(self, scene_name):
377
+ # TODO: remove hard-coded parameters
378
+ ball_data = []
379
+ for mtype in ['bbox_chrome', 'bbox_gray']:
380
+ info = self.scene_data[scene_name][mtype]
381
+
382
+ # x-y is top-left corner of the bounding box
383
+ # meta file is for 4000x6000 image but dataset is 1000x1500
384
+ x = info['x'] / 4
385
+ y = info['y'] / 4
386
+ w = info['w'] / 4
387
+ h = info['h'] / 4
388
+
389
+ # we scale data to 512x512 image
390
+ h_ratio = 384.0 / 1000.0 #384 because we have black border on the top
391
+ w_ratio = 512.0 / 1500.0
392
+ x = x * w_ratio
393
+ y = y * h_ratio
394
+ w = w * w_ratio
395
+ h = h * h_ratio
396
+
397
+ # y need to shift due to top black border
398
+ top_border_height = 512.0 * (1/8)
399
+
400
+ y = y + top_border_height
401
+
402
+ # Sphere is not circle due to the camera perspective, Need future fix for this
403
+ # For now, we use the minimum of width and height
404
+ r = np.max(np.array([w, h]))
405
+
406
+ x = int(np.round(x))
407
+ y = int(np.round(y))
408
+ r = int(np.round(r))
409
+
410
+ ball_data.append((y, x, r))
411
+
412
+ return ball_data
relighting/dist_utils.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import io
6
+ import os
7
+ import socket
8
+
9
+ try:
10
+ import blobfile as bf
11
+ except:
12
+ pass
13
+
14
+ try:
15
+ from mpi4py import MPI
16
+ except:
17
+ pass
18
+
19
+ import torch as th
20
+ import torch.distributed as dist
21
+ import builtins
22
+ import datetime
23
+
24
+ # Change this to reflect your cluster layout.
25
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
26
+ GPUS_PER_NODE = 8
27
+
28
+ SETUP_RETRY_COUNT = 3
29
+ def synchronize():
30
+ if not dist.is_available():
31
+ return
32
+
33
+ if not dist.is_initialized():
34
+ return
35
+
36
+ world_size = dist.get_world_size()
37
+
38
+ if world_size == 1:
39
+ return
40
+
41
+ dist.barrier()
42
+
43
+ def is_dist_avail_and_initialized():
44
+ if not dist.is_available():
45
+ return False
46
+ if not dist.is_initialized():
47
+ return False
48
+ return True
49
+ def get_world_size():
50
+ if not is_dist_avail_and_initialized():
51
+ return 1
52
+ return dist.get_world_size()
53
+
54
+ def setup_for_distributed(is_master):
55
+ """
56
+ This function disables printing when not in master process
57
+ """
58
+ builtin_print = builtins.print
59
+
60
+ def print(*args, **kwargs):
61
+ force = kwargs.pop('force', False)
62
+ force = force or (get_world_size() > 8)
63
+ if is_master or force:
64
+ now = datetime.datetime.now().time()
65
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
66
+ builtin_print(*args, **kwargs)
67
+
68
+ builtins.print = print
69
+
70
+ def setup_dist_multinode(args):
71
+ """
72
+ Setup a distributed process group.
73
+ """
74
+ if not dist.is_available() or not dist.is_initialized():
75
+ th.distributed.init_process_group(backend="nccl", init_method='env://')
76
+ world_size = dist.get_world_size()
77
+ local_rank = int(os.getenv('LOCAL_RANK'))
78
+ print("rank",local_rank)
79
+ device = local_rank
80
+ th.cuda.set_device(device)
81
+ setup_for_distributed(device == 0)
82
+
83
+ synchronize()
84
+ else:
85
+ print("ddp failed!")
86
+ exit()
87
+
88
+ def setup_dist(global_seed):
89
+ """
90
+ Setup a distributed process group.
91
+ """
92
+ if dist.is_initialized():
93
+ return
94
+ th.cuda.set_device(int(os.environ["LOCAL_RANK"]))
95
+ th.distributed.init_process_group(backend="nccl", init_method="env://", timeout=datetime.timedelta(seconds=5400))
96
+
97
+ # fix seed
98
+ rank = dist.get_rank()
99
+ device = rank % th.cuda.device_count()
100
+ seed = global_seed * dist.get_world_size() + rank
101
+ th.manual_seed(seed)
102
+ th.cuda.set_device(device)
103
+ print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
104
+ synchronize()
105
+
106
+ def dev():
107
+ """
108
+ Get the device to use for torch.distributed.
109
+ """
110
+ if th.cuda.is_available():
111
+ return th.device(f"cuda")
112
+ return th.device("cpu")
113
+
114
+
115
+ def load_state_dict(path, **kwargs):
116
+ """
117
+ Load a PyTorch file without redundant fetches across MPI ranks.
118
+ """
119
+ chunk_size = 2 ** 30 # MPI has a relatively small size limit
120
+ if MPI.COMM_WORLD.Get_rank() == 0:
121
+ with bf.BlobFile(path, "rb") as f:
122
+ data = f.read()
123
+ num_chunks = len(data) // chunk_size
124
+ if len(data) % chunk_size:
125
+ num_chunks += 1
126
+ MPI.COMM_WORLD.bcast(num_chunks)
127
+ for i in range(0, len(data), chunk_size):
128
+ MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
129
+ else:
130
+ num_chunks = MPI.COMM_WORLD.bcast(None)
131
+ data = bytes()
132
+ for _ in range(num_chunks):
133
+ data += MPI.COMM_WORLD.bcast(None)
134
+
135
+ return th.load(io.BytesIO(data), **kwargs)
136
+
137
+
138
+ def sync_params(params):
139
+ """
140
+ Synchronize a sequence of Tensors across ranks from rank 0.
141
+ """
142
+ for p in params:
143
+ with th.no_grad():
144
+ dist.broadcast(p, 0)
145
+
146
+
147
+ def _find_free_port():
148
+ try:
149
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150
+ s.bind(("", 0))
151
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
152
+ return s.getsockname()[1]
153
+ finally:
154
+ s.close()
relighting/image_processor.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image, ImageChops
4
+ import skimage
5
+ try:
6
+ import cv2
7
+ except:
8
+ pass
9
+
10
+ def fill_image(image, mask_ball, x, y, size, color=(255,255,255)):
11
+ if isinstance(image, Image.Image):
12
+ result = np.array(image)
13
+ else:
14
+ result = image.copy()
15
+
16
+ result[y:y+size, x:x+size][mask_ball] = color
17
+
18
+ if isinstance(image, Image.Image):
19
+ result = Image.fromarray(result)
20
+
21
+ return result
22
+
23
+ def pil_square_image(image, desired_size = (512,512), interpolation=Image.LANCZOS):
24
+ """
25
+ Make top-bottom border
26
+ """
27
+ # Don't resize if already desired size (Avoid aliasing problem)
28
+ if image.size == desired_size:
29
+ return image
30
+
31
+ # Calculate the scale factor
32
+ scale_factor = min(desired_size[0] / image.width, desired_size[1] / image.height)
33
+
34
+ # Resize the image
35
+ resized_image = image.resize((int(image.width * scale_factor), int(image.height * scale_factor)), interpolation)
36
+
37
+ # Create a new blank image with the desired size and black border
38
+ new_image = Image.new("RGB", desired_size, color=(0, 0, 0))
39
+
40
+ # Paste the resized image onto the new image, centered
41
+ new_image.paste(resized_image, ((desired_size[0] - resized_image.width) // 2, (desired_size[1] - resized_image.height) // 2))
42
+
43
+ return new_image
44
+
45
+ # https://stackoverflow.com/questions/19271692/removing-borders-from-an-image-in-python
46
+ def remove_borders(image):
47
+ bg = Image.new(image.mode, image.size, image.getpixel((0,0)))
48
+ diff = ImageChops.difference(image, bg)
49
+ diff = ImageChops.add(diff, diff, 2.0, -100)
50
+ bbox = diff.getbbox()
51
+ if bbox:
52
+ return image.crop(bbox)
53
+
54
+ # Taken from https://huggingface.co/lllyasviel/sd-controlnet-normal
55
+ def estimate_scene_normal(image, depth_estimator):
56
+ # can be improve speed do not going back and float between numpy and torch
57
+ normal_image = depth_estimator(image)['predicted_depth'][0]
58
+
59
+ normal_image = normal_image.numpy()
60
+
61
+ # upsizing image depth to match input
62
+ hw = np.array(image).shape[:2]
63
+ normal_image = skimage.transform.resize(normal_image, hw, preserve_range=True)
64
+
65
+ image_depth = normal_image.copy()
66
+ image_depth -= np.min(image_depth)
67
+ image_depth /= np.max(image_depth)
68
+
69
+ bg_threhold = 0.4
70
+
71
+ x = cv2.Sobel(normal_image, cv2.CV_32F, 1, 0, ksize=3)
72
+ x[image_depth < bg_threhold] = 0
73
+
74
+ y = cv2.Sobel(normal_image, cv2.CV_32F, 0, 1, ksize=3)
75
+ y[image_depth < bg_threhold] = 0
76
+
77
+ z = np.ones_like(x) * np.pi * 2.0
78
+
79
+ normal_image = np.stack([x, y, z], axis=2)
80
+ normal_image /= np.sum(normal_image ** 2.0, axis=2, keepdims=True) ** 0.5
81
+
82
+ # rescale back to image size
83
+ return normal_image
84
+
85
+ def estimate_scene_depth(image, depth_estimator):
86
+ #image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
87
+ #with torch.no_grad(), torch.autocast("cuda"):
88
+ # depth_map = depth_estimator(image).predicted_depth
89
+
90
+ depth_map = depth_estimator(image)['predicted_depth']
91
+ W, H = image.size
92
+ depth_map = torch.nn.functional.interpolate(
93
+ depth_map.unsqueeze(1),
94
+ size=(H, W),
95
+ mode="bicubic",
96
+ align_corners=False,
97
+ )
98
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
99
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
100
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
101
+ image = torch.cat([depth_map] * 3, dim=1)
102
+
103
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
104
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
105
+ return image
106
+
107
+ def fill_depth_circular(depth_image, x, y, r):
108
+ depth_image = np.array(depth_image)
109
+
110
+ for i in range(depth_image.shape[0]):
111
+ for j in range(depth_image.shape[1]):
112
+ xy = (i - x - r//2)**2 + (j - y - r//2)**2
113
+ # if xy <= rr**2:
114
+ # depth_image[j, i, :] = 255
115
+ # depth_image[j, i, :] = int(minv + (maxv - minv) * z)
116
+ if xy <= (r // 2)**2:
117
+ depth_image[j, i, :] = 255
118
+
119
+ depth_image = Image.fromarray(depth_image)
120
+ return depth_image
121
+
122
+
123
+ def merge_normal_map(normal_map, normal_ball, mask_ball, x, y):
124
+ """
125
+ Merge a ball to normal map using mask
126
+ @params
127
+ normal_amp (np.array) - normal map of the scene [height, width, 3]
128
+ normal_ball (np.array) - normal map of the ball [ball_height, ball_width, 3]
129
+ mask_ball (np.array) - mask of the ball [ball_height, ball_width]
130
+ x (int) - x position of the ball (top-left)
131
+ y (int) - y position of the ball (top-left)
132
+ @return
133
+ normal_mapthe merge normal map [height, width, 3]
134
+ """
135
+ result = normal_map.copy()
136
+
137
+ mask_ball = mask_ball[..., None]
138
+ ball = (normal_ball * mask_ball) # alpha blending the ball
139
+ unball = (normal_map[y:y+normal_ball.shape[0], x:x+normal_ball.shape[1]] * (1 - mask_ball)) # alpha blending the normal map
140
+ result[y:y+normal_ball.shape[0], x:x+normal_ball.shape[1]] = ball+unball # add them together
141
+ return result
relighting/inpainter.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import ControlNetModel, AutoencoderKL
3
+ from PIL import Image
4
+ import numpy as np
5
+ import os
6
+ from tqdm.auto import tqdm
7
+ from transformers import pipeline as transformers_pipeline
8
+
9
+ from relighting.pipeline import CustomStableDiffusionControlNetInpaintPipeline
10
+ from relighting.pipeline_inpaintonly import CustomStableDiffusionInpaintPipeline, CustomStableDiffusionXLInpaintPipeline
11
+ from relighting.argument import SAMPLERS, VAE_MODELS, DEPTH_ESTIMATOR, get_control_signal_type
12
+ from relighting.image_processor import (
13
+ estimate_scene_depth,
14
+ estimate_scene_normal,
15
+ merge_normal_map,
16
+ fill_depth_circular
17
+ )
18
+ from relighting.ball_processor import get_ideal_normal_ball, crop_ball
19
+ import pickle
20
+
21
+ from relighting.pipeline_xl import CustomStableDiffusionXLControlNetInpaintPipeline
22
+
23
+ class NoWaterMark:
24
+ def apply_watermark(self, *args, **kwargs):
25
+ return args[0]
26
+
27
+ class ControlSignalGenerator():
28
+ def __init__(self, sd_arch, control_signal_type, device):
29
+ self.sd_arch = sd_arch
30
+ self.control_signal_type = control_signal_type
31
+ self.device = device
32
+
33
+ def process_sd_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None):
34
+ if getattr(self, 'depth_estimator', None) is None:
35
+ self.depth_estimator = transformers_pipeline("depth-estimation", device=self.device.index)
36
+
37
+ control_image = self.depth_estimator(input_image)['depth']
38
+ control_image = np.array(control_image)
39
+ control_image = control_image[:, :, None]
40
+ control_image = np.concatenate([control_image, control_image, control_image], axis=2)
41
+ control_image = Image.fromarray(control_image)
42
+
43
+ control_image = fill_depth_circular(control_image, x, y, r)
44
+ return control_image
45
+
46
+ def process_sdxl_depth(self, input_image, normal_ball=None, mask_ball=None, x=None, y=None, r=None):
47
+ if getattr(self, 'depth_estimator', None) is None:
48
+ self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index)
49
+
50
+ control_image = estimate_scene_depth(input_image, depth_estimator=self.depth_estimator)
51
+ xs = [x] if not isinstance(x, list) else x
52
+ ys = [y] if not isinstance(y, list) else y
53
+ rs = [r] if not isinstance(r, list) else r
54
+
55
+ for x, y, r in zip(xs, ys, rs):
56
+ #print(f"depth at {x}, {y}, {r}")
57
+ control_image = fill_depth_circular(control_image, x, y, r)
58
+ return control_image
59
+
60
+ def process_sd_normal(self, input_image, normal_ball, mask_ball, x, y, r=None, normal_ball_path=None):
61
+ if getattr(self, 'depth_estimator', None) is None:
62
+ self.depth_estimator = transformers_pipeline("depth-estimation", model=DEPTH_ESTIMATOR, device=self.device.index)
63
+
64
+ normal_scene = estimate_scene_normal(input_image, depth_estimator=self.depth_estimator)
65
+ normal_image = merge_normal_map(normal_scene, normal_ball, mask_ball, x, y)
66
+ normal_image = (normal_image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
67
+ control_image = Image.fromarray(normal_image)
68
+ return control_image
69
+
70
+ def __call__(self, *args, **kwargs):
71
+ process_fn = getattr(self, f"process_{self.sd_arch}_{self.control_signal_type}", None)
72
+ if process_fn is None:
73
+ raise ValueError
74
+ else:
75
+ return process_fn(*args, **kwargs)
76
+
77
+
78
+ class BallInpainter():
79
+ def __init__(self, pipeline, sd_arch, control_generator, disable_water_mask=True):
80
+ self.pipeline = pipeline
81
+ self.sd_arch = sd_arch
82
+ self.control_generator = control_generator
83
+ self.median = {}
84
+ if disable_water_mask:
85
+ self._disable_water_mask()
86
+
87
+ def _disable_water_mask(self):
88
+ if hasattr(self.pipeline, "watermark"):
89
+ self.pipeline.watermark = NoWaterMark()
90
+ print("Disabled watermasking")
91
+
92
+ @classmethod
93
+ def from_sd(cls,
94
+ model,
95
+ controlnet=None,
96
+ device=0,
97
+ sampler="unipc",
98
+ torch_dtype=torch.float16,
99
+ disable_water_mask=True,
100
+ offload=False
101
+ ):
102
+ if controlnet is not None:
103
+ control_signal_type = get_control_signal_type(controlnet)
104
+ controlnet = ControlNetModel.from_pretrained(controlnet, torch_dtype=torch.float16)
105
+ pipe = CustomStableDiffusionControlNetInpaintPipeline.from_pretrained(
106
+ model,
107
+ controlnet=controlnet,
108
+ torch_dtype=torch_dtype,
109
+ ).to(device)
110
+ control_generator = ControlSignalGenerator("sd", control_signal_type, device=device)
111
+ else:
112
+ pipe = CustomStableDiffusionInpaintPipeline.from_pretrained(
113
+ model,
114
+ torch_dtype=torch_dtype,
115
+ ).to(device)
116
+ control_generator = None
117
+
118
+ try:
119
+ if torch_dtype==torch.float16 and device != torch.device("cpu"):
120
+ pipe.enable_xformers_memory_efficient_attention()
121
+ except:
122
+ pass
123
+ pipe.set_progress_bar_config(disable=True)
124
+
125
+ pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config)
126
+
127
+ return BallInpainter(pipe, "sd", control_generator, disable_water_mask)
128
+
129
+ @classmethod
130
+ def from_sdxl(cls,
131
+ model,
132
+ controlnet=None,
133
+ device=0,
134
+ sampler="unipc",
135
+ torch_dtype=torch.float16,
136
+ disable_water_mask=True,
137
+ use_fixed_vae=True,
138
+ offload=False
139
+ ):
140
+ vae = VAE_MODELS["sdxl"]
141
+ vae = AutoencoderKL.from_pretrained(vae, torch_dtype=torch_dtype).to(device) if use_fixed_vae else None
142
+ extra_kwargs = {"vae": vae} if vae is not None else {}
143
+
144
+ if controlnet is not None:
145
+ control_signal_type = get_control_signal_type(controlnet)
146
+ controlnet = ControlNetModel.from_pretrained(
147
+ controlnet,
148
+ variant="fp16" if torch_dtype == torch.float16 else None,
149
+ use_safetensors=True,
150
+ torch_dtype=torch_dtype,
151
+ ).to(device)
152
+ pipe = CustomStableDiffusionXLControlNetInpaintPipeline.from_pretrained(
153
+ model,
154
+ controlnet=controlnet,
155
+ variant="fp16" if torch_dtype == torch.float16 else None,
156
+ use_safetensors=True,
157
+ torch_dtype=torch_dtype,
158
+ **extra_kwargs,
159
+ ).to(device)
160
+ control_generator = ControlSignalGenerator("sdxl", control_signal_type, device=device)
161
+ else:
162
+ pipe = CustomStableDiffusionXLInpaintPipeline.from_pretrained(
163
+ model,
164
+ variant="fp16" if torch_dtype == torch.float16 else None,
165
+ use_safetensors=True,
166
+ torch_dtype=torch_dtype,
167
+ **extra_kwargs,
168
+ ).to(device)
169
+ control_generator = None
170
+
171
+ try:
172
+ if torch_dtype==torch.float16 and device != torch.device("cpu"):
173
+ pipe.enable_xformers_memory_efficient_attention()
174
+ except:
175
+ pass
176
+
177
+ if offload and device != torch.device("cpu"):
178
+ pipe.enable_model_cpu_offload()
179
+ pipe.set_progress_bar_config(disable=True)
180
+ pipe.scheduler = SAMPLERS[sampler].from_config(pipe.scheduler.config)
181
+
182
+ return BallInpainter(pipe, "sdxl", control_generator, disable_water_mask)
183
+
184
+ # TODO: this method should be replaced by inpaint(), but we'll leave it here for now
185
+ # otherwise, the existing experiment code will break down
186
+ def __call__(self, *args, **kwargs):
187
+ return self.pipeline(*args, **kwargs)
188
+
189
+ def _default_height_width(self, height=None, width=None):
190
+ if (height is not None) and (width is not None):
191
+ return height, width
192
+ if self.sd_arch == "sd":
193
+ return (512, 512)
194
+ elif self.sd_arch == "sdxl":
195
+ return (1024, 1024)
196
+ else:
197
+ raise NotImplementedError
198
+
199
+ # this method is for sanity check only
200
+ def get_cache_control_image(self):
201
+ control_image = getattr(self, "cache_control_image", None)
202
+ return control_image
203
+
204
+ def prepare_control_signal(self, image, controlnet_conditioning_scale, extra_kwargs):
205
+ if self.control_generator is not None:
206
+ control_image = self.control_generator(image, **extra_kwargs)
207
+ controlnet_kwargs = {
208
+ "control_image": control_image,
209
+ "controlnet_conditioning_scale": controlnet_conditioning_scale
210
+ }
211
+ self.cache_control_image = control_image
212
+ else:
213
+ controlnet_kwargs = {}
214
+
215
+ return controlnet_kwargs
216
+
217
+ def get_cache_median(self, it):
218
+ if it in self.median: return self.median[it]
219
+ else: return None
220
+
221
+ def reset_median(self):
222
+ self.median = {}
223
+ print("Reset median")
224
+
225
+ def load_median(self, path):
226
+ if os.path.exists(path):
227
+ with open(path, "rb") as f:
228
+ self.median = pickle.load(f)
229
+ print(f"Loaded median from {path}")
230
+ else:
231
+ print(f"Median not found at {path}!")
232
+
233
+ def inpaint_iterative(
234
+ self,
235
+ prompt=None,
236
+ negative_prompt="",
237
+ num_inference_steps=30,
238
+ generator=None, # TODO: remove this
239
+ image=None,
240
+ mask_image=None,
241
+ height=None,
242
+ width=None,
243
+ controlnet_conditioning_scale=0.5,
244
+ num_images_per_prompt=1,
245
+ current_seed=0,
246
+ cross_attention_kwargs={},
247
+ strength=0.8,
248
+ num_iteration=2,
249
+ ball_per_iteration=30,
250
+ agg_mode="median",
251
+ save_intermediate=True,
252
+ cache_dir="./temp_inpaint_iterative",
253
+ disable_progress=False,
254
+ prompt_embeds=None,
255
+ pooled_prompt_embeds=None,
256
+ use_cache_median=False,
257
+ guidance_scale=5.0, # In the paper, we use guidance scale to 5.0 (same as pipeline_xl.py)
258
+ **extra_kwargs,
259
+ ):
260
+ def computeMedian(ball_images):
261
+ all = np.stack(ball_images, axis=0)
262
+ median = np.median(all, axis=0)
263
+ idx_median = np.argsort(all, axis=0)[all.shape[0]//2]
264
+ # print(all.shape)
265
+ # print(idx_median.shape)
266
+ return median, idx_median
267
+
268
+ def generate_balls(avg_image, current_strength, ball_per_iteration, current_iteration):
269
+ print(f"Inpainting balls for {current_iteration} iteration...")
270
+ controlnet_kwargs = self.prepare_control_signal(
271
+ image=avg_image,
272
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
273
+ extra_kwargs=extra_kwargs,
274
+ )
275
+
276
+ ball_images = []
277
+ for i in tqdm(range(ball_per_iteration), disable=disable_progress):
278
+ seed = current_seed + i
279
+ new_generator = torch.Generator().manual_seed(seed)
280
+ output_image = self.pipeline(
281
+ prompt=prompt,
282
+ negative_prompt=negative_prompt,
283
+ num_inference_steps=num_inference_steps,
284
+ generator=new_generator,
285
+ image=avg_image,
286
+ mask_image=mask_image,
287
+ height=height,
288
+ width=width,
289
+ num_images_per_prompt=num_images_per_prompt,
290
+ strength=current_strength,
291
+ newx=x,
292
+ newy=y,
293
+ newr=r,
294
+ current_seed=seed,
295
+ cross_attention_kwargs=cross_attention_kwargs,
296
+ prompt_embeds=prompt_embeds,
297
+ pooled_prompt_embeds=pooled_prompt_embeds,
298
+ guidance_scale=guidance_scale,
299
+ **controlnet_kwargs
300
+ ).images[0]
301
+
302
+ ball_image = crop_ball(output_image, mask_ball_for_crop, x, y, r)
303
+ ball_images.append(ball_image)
304
+
305
+ if save_intermediate:
306
+ os.makedirs(os.path.join(cache_dir, str(current_iteration)), mode=0o777, exist_ok=True)
307
+ output_image.save(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png"))
308
+ Image.fromarray(ball_image).save(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png"))
309
+ # chmod 777
310
+ os.chmod(os.path.join(cache_dir, str(current_iteration), f"raw_{i}.png"), 0o0777)
311
+ os.chmod(os.path.join(cache_dir, str(current_iteration), f"ball_{i}.png"), 0o0777)
312
+
313
+
314
+ return ball_images
315
+
316
+ if save_intermediate:
317
+ os.makedirs(cache_dir, exist_ok=True)
318
+
319
+ height, width = self._default_height_width(height, width)
320
+
321
+ x = extra_kwargs["x"]
322
+ y = extra_kwargs["y"]
323
+ r = 256 if "r" not in extra_kwargs else extra_kwargs["r"]
324
+ _, mask_ball_for_crop = get_ideal_normal_ball(size=r)
325
+
326
+ # generate initial average ball
327
+ avg_image = image
328
+ ball_images = generate_balls(
329
+ avg_image,
330
+ current_strength=1.0,
331
+ ball_per_iteration=ball_per_iteration,
332
+ current_iteration=0,
333
+ )
334
+
335
+ # ball refinement loop
336
+ image = np.array(image)
337
+ for it in range(1, num_iteration+1):
338
+ if use_cache_median and (self.get_cache_median(it) is not None):
339
+ print("Use existing median")
340
+ all = np.stack(ball_images, axis=0)
341
+ idx_median = self.get_cache_median(it)
342
+ avg_ball = all[idx_median,
343
+ np.arange(idx_median.shape[0])[:, np.newaxis, np.newaxis],
344
+ np.arange(idx_median.shape[1])[np.newaxis, :, np.newaxis],
345
+ np.arange(idx_median.shape[2])[np.newaxis, np.newaxis, :]
346
+ ]
347
+ else:
348
+ avg_ball, idx_median = computeMedian(ball_images)
349
+ print("Add new median")
350
+ self.median[it] = idx_median
351
+
352
+ avg_image = merge_normal_map(image, avg_ball, mask_ball_for_crop, x, y)
353
+ avg_image = Image.fromarray(avg_image.astype(np.uint8))
354
+ if save_intermediate:
355
+ avg_image.save(os.path.join(cache_dir, f"average_{it}.png"))
356
+ # chmod777
357
+ os.chmod(os.path.join(cache_dir, f"average_{it}.png"), 0o0777)
358
+
359
+ ball_images = generate_balls(
360
+ avg_image,
361
+ current_strength=strength,
362
+ ball_per_iteration=ball_per_iteration if it < num_iteration else 1,
363
+ current_iteration=it,
364
+ )
365
+
366
+ # TODO: add algorithm for select the best ball
367
+ best_ball = ball_images[0]
368
+ output_image = merge_normal_map(image, best_ball, mask_ball_for_crop, x, y)
369
+ return Image.fromarray(output_image.astype(np.uint8))
370
+
371
+ def inpaint(
372
+ self,
373
+ prompt=None,
374
+ negative_prompt=None,
375
+ num_inference_steps=30,
376
+ generator=None,
377
+ image=None,
378
+ mask_image=None,
379
+ height=None,
380
+ width=None,
381
+ controlnet_conditioning_scale=0.5,
382
+ num_images_per_prompt=1,
383
+ strength=1.0,
384
+ current_seed=0,
385
+ cross_attention_kwargs={},
386
+ prompt_embeds=None,
387
+ pooled_prompt_embeds=None,
388
+ guidance_scale=5.0, # (same as pipeline_xl.py)
389
+ **extra_kwargs,
390
+ ):
391
+ height, width = self._default_height_width(height, width)
392
+
393
+ controlnet_kwargs = self.prepare_control_signal(
394
+ image=image,
395
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
396
+ extra_kwargs=extra_kwargs,
397
+ )
398
+
399
+ if generator is None:
400
+ generator = torch.Generator().manual_seed(0)
401
+
402
+ output_image = self.pipeline(
403
+ prompt=prompt,
404
+ negative_prompt=negative_prompt,
405
+ num_inference_steps=num_inference_steps,
406
+ generator=generator,
407
+ image=image,
408
+ mask_image=mask_image,
409
+ height=height,
410
+ width=width,
411
+ num_images_per_prompt=num_images_per_prompt,
412
+ strength=strength,
413
+ newx = extra_kwargs["x"],
414
+ newy = extra_kwargs["y"],
415
+ newr = getattr(extra_kwargs, "r", 256), # default to ball_size = 256
416
+ current_seed=current_seed,
417
+ cross_attention_kwargs=cross_attention_kwargs,
418
+ prompt_embeds=prompt_embeds,
419
+ pooled_prompt_embeds=pooled_prompt_embeds,
420
+ guidance_scale=guidance_scale,
421
+ **controlnet_kwargs
422
+ )
423
+
424
+ return output_image
relighting/mask_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import cv2
3
+ except:
4
+ pass
5
+ import numpy as np
6
+ from PIL import Image
7
+ from relighting.ball_processor import get_ideal_normal_ball
8
+
9
+ def create_grid(image_size, n_ball, size):
10
+ height, width = image_size
11
+ nx, ny = n_ball
12
+ if nx * ny == 1:
13
+ grid = np.array([[(height-size)//2, (width-size)//2]])
14
+ else:
15
+ height_ = np.linspace(0, height-size, nx).astype(int)
16
+ weight_ = np.linspace(0, width-size, ny).astype(int)
17
+ hh, ww = np.meshgrid(height_, weight_)
18
+ grid = np.stack([hh,ww], axis = -1).reshape(-1,2)
19
+
20
+ return grid
21
+
22
+ class MaskGenerator():
23
+ def __init__(self, cache_mask=True):
24
+ self.cache_mask = cache_mask
25
+ self.all_masks = []
26
+
27
+ def clear_cache(self):
28
+ self.all_masks = []
29
+
30
+ def retrieve_masks(self):
31
+ return self.all_masks
32
+
33
+ def generate_grid(self, image, mask_ball, n_ball=16, size=128):
34
+ ball_positions = create_grid(image.size, n_ball, size)
35
+ # _, mask_ball = get_normal_ball(size)
36
+
37
+ masks = []
38
+ mask_template = np.zeros(image.size)
39
+ for x, y in ball_positions:
40
+ mask = mask_template.copy()
41
+ mask[y:y+size, x:x+size] = 255 * mask_ball
42
+ mask = Image.fromarray(mask.astype(np.uint8), "L")
43
+ masks.append(mask)
44
+
45
+ # if self.cache_mask:
46
+ # self.all_masks.append((x, y, size))
47
+
48
+ return masks, ball_positions
49
+
50
+ def generate_single(self, image, mask_ball, x, y, size):
51
+ w,h = image.size # numpy as (h,w) but PIL is (w,h)
52
+ mask = np.zeros((h,w))
53
+ mask[y:y+size, x:x+size] = 255 * mask_ball
54
+ mask = Image.fromarray(mask.astype(np.uint8), "L")
55
+
56
+ return mask
57
+
58
+ def generate_best(self, image, mask_ball, size):
59
+ w,h = image.size # numpy as (h,w) but PIL is (w,h)
60
+ mask = np.zeros((h,w))
61
+
62
+ (y, x), _ = find_best_location(np.array(image), ball_size=size)
63
+ mask[y:y+size, x:x+size] = 255 * mask_ball
64
+ mask = Image.fromarray(mask.astype(np.uint8), "L")
65
+
66
+ return mask, (x, y)
67
+
68
+
69
+ def get_only_high_freqency(image: np.array):
70
+ """
71
+ Get only height freqency image by subtract low freqency (using gaussian blur)
72
+ @params image: np.array - image in RGB format [h,w,3]
73
+ @return high_frequency: np.array - high freqnecy image in grayscale format [h,w]
74
+ """
75
+
76
+ # Convert to grayscale
77
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
78
+
79
+ # Subtract low freqency from high freqency
80
+ kernel_size = 11 # Adjust this according to your image size
81
+ high_frequency = gray - cv2.GaussianBlur(gray,(kernel_size, kernel_size), 0)
82
+
83
+ return high_frequency
84
+
85
+ def find_best_location(image, ball_size=128):
86
+ """
87
+ Find the best location to place the ball (Eg. empty location)
88
+ @params image: np.array - image in RGB format [h,w,3]
89
+ @return min_pos: tuple - top left position of the best location (the location is in "Y,X" format)
90
+ @return min_val: float - the sum value contain in the window
91
+ """
92
+ local_variance = get_only_high_freqency(image)
93
+ qsum = quicksum2d(local_variance)
94
+
95
+ min_val = None
96
+ min_pos = None
97
+ k = ball_size
98
+ for i in range(k-1, qsum.shape[0]):
99
+ for j in range(k-1, qsum.shape[1]):
100
+ A = 0 if i-k < 0 else qsum[i-k, j]
101
+ B = 0 if j-k < 0 else qsum[i, j-k]
102
+ C = 0 if (i-k < 0) or (j-k < 0) else qsum[i-k, j-k]
103
+ sum = qsum[i, j] - A - B + C
104
+ if (min_val is None) or (sum < min_val):
105
+ min_val = sum
106
+ min_pos = (i-k+1, j-k+1) # get top left position
107
+
108
+ return min_pos, min_val
109
+
110
+ def quicksum2d(x: np.array):
111
+ """
112
+ Quick sum algorithm to find the window that have smallest sum with O(n^2) complexity
113
+ @params x: np.array - image in grayscale [h,w]
114
+ @return q: np.array - quick sum of the image for future seach in find_best_location [h,w]
115
+ """
116
+ qsum = np.zeros(x.shape)
117
+ for i in range(x.shape[0]):
118
+ for j in range(x.shape[1]):
119
+ A = 0 if i-1 < 0 else qsum[i-1, j]
120
+ B = 0 if j-1 < 0 else qsum[i, j-1]
121
+ C = 0 if (i-1 < 0) or (j-1 < 0) else qsum[i-1, j-1]
122
+ qsum[i, j] = A + B - C + x[i, j]
123
+
124
+ return qsum
relighting/pipeline.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Dict, Any, Callable, Optional, Tuple
3
+
4
+ from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
5
+ from diffusers.models import ControlNetModel
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from diffusers import StableDiffusionControlNetInpaintPipeline
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
10
+ from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents
11
+
12
+ class CustomStableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetInpaintPipeline):
13
+ @torch.no_grad()
14
+ def __call__(
15
+ self,
16
+ prompt: Union[str, List[str]] = None,
17
+ image: PipelineImageInput = None,
18
+ mask_image: PipelineImageInput = None,
19
+ control_image: PipelineImageInput = None,
20
+ height: Optional[int] = None,
21
+ width: Optional[int] = None,
22
+ strength: float = 1.0,
23
+ num_inference_steps: int = 50,
24
+ guidance_scale: float = 7.5,
25
+ negative_prompt: Optional[Union[str, List[str]]] = None,
26
+ num_images_per_prompt: Optional[int] = 1,
27
+ eta: float = 0.0,
28
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
29
+ latents: Optional[torch.FloatTensor] = None,
30
+ prompt_embeds: Optional[torch.FloatTensor] = None,
31
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
32
+ output_type: Optional[str] = "pil",
33
+ return_dict: bool = True,
34
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
35
+ callback_steps: int = 1,
36
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
37
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
38
+ guess_mode: bool = False,
39
+ control_guidance_start: Union[float, List[float]] = 0.0,
40
+ control_guidance_end: Union[float, List[float]] = 1.0,
41
+ newx: int = 0,
42
+ newy: int = 0,
43
+ newr: int = 256,
44
+ current_seed=0,
45
+ use_noise_moving=True,
46
+ ):
47
+ # OVERWRITE METHODS
48
+ self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionControlNetInpaintPipeline)
49
+ self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionControlNetInpaintPipeline)
50
+
51
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
52
+
53
+ # align format for control guidance
54
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
55
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
56
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
57
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
58
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
59
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
60
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
61
+ control_guidance_end
62
+ ]
63
+
64
+ # 1. Check inputs. Raise error if not correct
65
+ self.check_inputs(
66
+ prompt,
67
+ control_image,
68
+ height,
69
+ width,
70
+ callback_steps,
71
+ negative_prompt,
72
+ prompt_embeds,
73
+ negative_prompt_embeds,
74
+ controlnet_conditioning_scale,
75
+ control_guidance_start,
76
+ control_guidance_end,
77
+ )
78
+
79
+ # 2. Define call parameters
80
+ if prompt is not None and isinstance(prompt, str):
81
+ batch_size = 1
82
+ elif prompt is not None and isinstance(prompt, list):
83
+ batch_size = len(prompt)
84
+ else:
85
+ batch_size = prompt_embeds.shape[0]
86
+
87
+ device = self._execution_device
88
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
89
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
90
+ # corresponds to doing no classifier free guidance.
91
+ do_classifier_free_guidance = guidance_scale > 1.0
92
+
93
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
94
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
95
+
96
+ global_pool_conditions = (
97
+ controlnet.config.global_pool_conditions
98
+ if isinstance(controlnet, ControlNetModel)
99
+ else controlnet.nets[0].config.global_pool_conditions
100
+ )
101
+ guess_mode = guess_mode or global_pool_conditions
102
+
103
+ # 3. Encode input prompt
104
+ text_encoder_lora_scale = (
105
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
106
+ )
107
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
108
+ prompt,
109
+ device,
110
+ num_images_per_prompt,
111
+ do_classifier_free_guidance,
112
+ negative_prompt,
113
+ prompt_embeds=prompt_embeds,
114
+ negative_prompt_embeds=negative_prompt_embeds,
115
+ lora_scale=text_encoder_lora_scale,
116
+ )
117
+ # For classifier free guidance, we need to do two forward passes.
118
+ # Here we concatenate the unconditional and text embeddings into a single batch
119
+ # to avoid doing two forward passes
120
+ if do_classifier_free_guidance:
121
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
122
+
123
+ # 4. Prepare image
124
+ if isinstance(controlnet, ControlNetModel):
125
+ control_image = self.prepare_control_image(
126
+ image=control_image,
127
+ width=width,
128
+ height=height,
129
+ batch_size=batch_size * num_images_per_prompt,
130
+ num_images_per_prompt=num_images_per_prompt,
131
+ device=device,
132
+ dtype=controlnet.dtype,
133
+ do_classifier_free_guidance=do_classifier_free_guidance,
134
+ guess_mode=guess_mode,
135
+ )
136
+ elif isinstance(controlnet, MultiControlNetModel):
137
+ control_images = []
138
+
139
+ for control_image_ in control_image:
140
+ control_image_ = self.prepare_control_image(
141
+ image=control_image_,
142
+ width=width,
143
+ height=height,
144
+ batch_size=batch_size * num_images_per_prompt,
145
+ num_images_per_prompt=num_images_per_prompt,
146
+ device=device,
147
+ dtype=controlnet.dtype,
148
+ do_classifier_free_guidance=do_classifier_free_guidance,
149
+ guess_mode=guess_mode,
150
+ )
151
+
152
+ control_images.append(control_image_)
153
+
154
+ control_image = control_images
155
+ else:
156
+ assert False
157
+
158
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
159
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
160
+ init_image = init_image.to(dtype=torch.float32)
161
+
162
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
163
+
164
+ masked_image = init_image * (mask < 0.5)
165
+ _, _, height, width = init_image.shape
166
+
167
+ # 5. Prepare timesteps
168
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
169
+ timesteps, num_inference_steps = self.get_timesteps(
170
+ num_inference_steps=num_inference_steps, strength=strength, device=device
171
+ )
172
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
173
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
174
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
175
+ is_strength_max = strength == 1.0
176
+
177
+ # 6. Prepare latent variables
178
+ num_channels_latents = self.vae.config.latent_channels
179
+ num_channels_unet = self.unet.config.in_channels
180
+ return_image_latents = num_channels_unet == 4
181
+
182
+ # EDITED HERE
183
+ latents_outputs = self.prepare_latents(
184
+ batch_size * num_images_per_prompt,
185
+ num_channels_latents,
186
+ height,
187
+ width,
188
+ prompt_embeds.dtype,
189
+ device,
190
+ generator,
191
+ latents,
192
+ image=init_image,
193
+ timestep=latent_timestep,
194
+ is_strength_max=is_strength_max,
195
+ return_noise=True,
196
+ return_image_latents=return_image_latents,
197
+ newx=newx,
198
+ newy=newy,
199
+ newr=newr,
200
+ current_seed=current_seed,
201
+ use_noise_moving=use_noise_moving,
202
+ )
203
+
204
+ if return_image_latents:
205
+ latents, noise, image_latents = latents_outputs
206
+ else:
207
+ latents, noise = latents_outputs
208
+
209
+ # 7. Prepare mask latent variables
210
+ mask, masked_image_latents = self.prepare_mask_latents(
211
+ mask,
212
+ masked_image,
213
+ batch_size * num_images_per_prompt,
214
+ height,
215
+ width,
216
+ prompt_embeds.dtype,
217
+ device,
218
+ generator,
219
+ do_classifier_free_guidance,
220
+ )
221
+
222
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
223
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
224
+
225
+ # 7.1 Create tensor stating which controlnets to keep
226
+ controlnet_keep = []
227
+ for i in range(len(timesteps)):
228
+ keeps = [
229
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
230
+ for s, e in zip(control_guidance_start, control_guidance_end)
231
+ ]
232
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
233
+
234
+ # 8. Denoising loop
235
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
236
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
237
+ for i, t in enumerate(timesteps):
238
+ # expand the latents if we are doing classifier free guidance
239
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
240
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
241
+
242
+ # controlnet(s) inference
243
+ if guess_mode and do_classifier_free_guidance:
244
+ # Infer ControlNet only for the conditional batch.
245
+ control_model_input = latents
246
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
247
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
248
+ else:
249
+ control_model_input = latent_model_input
250
+ controlnet_prompt_embeds = prompt_embeds
251
+
252
+ if isinstance(controlnet_keep[i], list):
253
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
254
+ else:
255
+ controlnet_cond_scale = controlnet_conditioning_scale
256
+ if isinstance(controlnet_cond_scale, list):
257
+ controlnet_cond_scale = controlnet_cond_scale[0]
258
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
259
+
260
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
261
+ control_model_input,
262
+ t,
263
+ encoder_hidden_states=controlnet_prompt_embeds,
264
+ controlnet_cond=control_image,
265
+ conditioning_scale=cond_scale,
266
+ guess_mode=guess_mode,
267
+ return_dict=False,
268
+ )
269
+
270
+ if guess_mode and do_classifier_free_guidance:
271
+ # Infered ControlNet only for the conditional batch.
272
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
273
+ # add 0 to the unconditional batch to keep it unchanged.
274
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
275
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
276
+
277
+ # predict the noise residual
278
+ if num_channels_unet == 9:
279
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
280
+
281
+ noise_pred = self.unet(
282
+ latent_model_input,
283
+ t,
284
+ encoder_hidden_states=prompt_embeds,
285
+ cross_attention_kwargs=cross_attention_kwargs,
286
+ down_block_additional_residuals=down_block_res_samples,
287
+ mid_block_additional_residual=mid_block_res_sample,
288
+ return_dict=False,
289
+ )[0]
290
+
291
+ # perform guidance
292
+ if do_classifier_free_guidance:
293
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
294
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
295
+
296
+ # compute the previous noisy sample x_t -> x_t-1
297
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
298
+
299
+ if num_channels_unet == 4:
300
+ init_latents_proper = image_latents[:1]
301
+ init_mask = mask[:1]
302
+
303
+ if i < len(timesteps) - 1:
304
+ noise_timestep = timesteps[i + 1]
305
+ init_latents_proper = self.scheduler.add_noise(
306
+ init_latents_proper, noise, torch.tensor([noise_timestep])
307
+ )
308
+
309
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
310
+
311
+ # call the callback, if provided
312
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
313
+ progress_bar.update()
314
+ if callback is not None and i % callback_steps == 0:
315
+ callback(i, t, latents)
316
+
317
+ # If we do sequential model offloading, let's offload unet and controlnet
318
+ # manually for max memory savings
319
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
320
+ self.unet.to("cpu")
321
+ self.controlnet.to("cpu")
322
+ torch.cuda.empty_cache()
323
+
324
+ if not output_type == "latent":
325
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
326
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
327
+ else:
328
+ image = latents
329
+ has_nsfw_concept = None
330
+
331
+ if has_nsfw_concept is None:
332
+ do_denormalize = [True] * image.shape[0]
333
+ else:
334
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
335
+
336
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
337
+
338
+ # Offload all models
339
+ self.maybe_free_model_hooks()
340
+
341
+ if not return_dict:
342
+ return (image, has_nsfw_concept)
343
+
344
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
relighting/pipeline_inpaintonly.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Dict, Any, Callable, Optional, Tuple
3
+
4
+ from diffusers.image_processor import PipelineImageInput
5
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionXLInpaintPipeline
6
+ from diffusers.models import AsymmetricAutoencoderKL
7
+ from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
8
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
9
+ from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents, rescale_noise_cfg
10
+
11
+ class CustomStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
12
+ @torch.no_grad()
13
+ def __call__(
14
+ self,
15
+ prompt: Union[str, List[str]] = None,
16
+ image: PipelineImageInput = None,
17
+ mask_image: PipelineImageInput = None,
18
+ masked_image_latents: torch.FloatTensor = None,
19
+ height: Optional[int] = None,
20
+ width: Optional[int] = None,
21
+ strength: float = 1.0,
22
+ num_inference_steps: int = 50,
23
+ guidance_scale: float = 7.5,
24
+ negative_prompt: Optional[Union[str, List[str]]] = None,
25
+ num_images_per_prompt: Optional[int] = 1,
26
+ eta: float = 0.0,
27
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
28
+ latents: Optional[torch.FloatTensor] = None,
29
+ prompt_embeds: Optional[torch.FloatTensor] = None,
30
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
31
+ output_type: Optional[str] = "pil",
32
+ return_dict: bool = True,
33
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
34
+ callback_steps: int = 1,
35
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
36
+ newx: int = 0,
37
+ newy: int = 0,
38
+ newr: int = 256,
39
+ current_seed=0,
40
+ use_noise_moving=True,
41
+ ):
42
+ # OVERWRITE METHODS
43
+ self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionInpaintPipeline)
44
+ self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionInpaintPipeline)
45
+
46
+ # 0. Default height and width to unet
47
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
48
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
49
+
50
+ # 1. Check inputs
51
+ self.check_inputs(
52
+ prompt,
53
+ height,
54
+ width,
55
+ strength,
56
+ callback_steps,
57
+ negative_prompt,
58
+ prompt_embeds,
59
+ negative_prompt_embeds,
60
+ )
61
+
62
+ # 2. Define call parameters
63
+ if prompt is not None and isinstance(prompt, str):
64
+ batch_size = 1
65
+ elif prompt is not None and isinstance(prompt, list):
66
+ batch_size = len(prompt)
67
+ else:
68
+ batch_size = prompt_embeds.shape[0]
69
+
70
+ device = self._execution_device
71
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
72
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
73
+ # corresponds to doing no classifier free guidance.
74
+ do_classifier_free_guidance = guidance_scale > 1.0
75
+
76
+ # 3. Encode input prompt
77
+ text_encoder_lora_scale = (
78
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
79
+ )
80
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
81
+ prompt,
82
+ device,
83
+ num_images_per_prompt,
84
+ do_classifier_free_guidance,
85
+ negative_prompt,
86
+ prompt_embeds=prompt_embeds,
87
+ negative_prompt_embeds=negative_prompt_embeds,
88
+ lora_scale=text_encoder_lora_scale,
89
+ )
90
+ # For classifier free guidance, we need to do two forward passes.
91
+ # Here we concatenate the unconditional and text embeddings into a single batch
92
+ # to avoid doing two forward passes
93
+ if do_classifier_free_guidance:
94
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
95
+
96
+ # 4. set timesteps
97
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
98
+ timesteps, num_inference_steps = self.get_timesteps(
99
+ num_inference_steps=num_inference_steps, strength=strength, device=device
100
+ )
101
+ # check that number of inference steps is not < 1 - as this doesn't make sense
102
+ if num_inference_steps < 1:
103
+ raise ValueError(
104
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
105
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
106
+ )
107
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
108
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
109
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
110
+ is_strength_max = strength == 1.0
111
+
112
+ # 5. Preprocess mask and image
113
+
114
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
115
+ init_image = init_image.to(dtype=torch.float32)
116
+
117
+ # 6. Prepare latent variables
118
+ num_channels_latents = self.vae.config.latent_channels
119
+ num_channels_unet = self.unet.config.in_channels
120
+ return_image_latents = num_channels_unet == 4
121
+
122
+ latents_outputs = self.prepare_latents(
123
+ batch_size * num_images_per_prompt,
124
+ num_channels_latents,
125
+ height,
126
+ width,
127
+ prompt_embeds.dtype,
128
+ device,
129
+ generator,
130
+ latents,
131
+ image=init_image,
132
+ timestep=latent_timestep,
133
+ is_strength_max=is_strength_max,
134
+ return_noise=True,
135
+ return_image_latents=return_image_latents,
136
+ newx=newx,
137
+ newy=newy,
138
+ newr=newr,
139
+ current_seed=current_seed,
140
+ use_noise_moving=use_noise_moving,
141
+ )
142
+
143
+ if return_image_latents:
144
+ latents, noise, image_latents = latents_outputs
145
+ else:
146
+ latents, noise = latents_outputs
147
+
148
+ # 7. Prepare mask latent variables
149
+ mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)
150
+
151
+ if masked_image_latents is None:
152
+ masked_image = init_image * (mask_condition < 0.5)
153
+ else:
154
+ masked_image = masked_image_latents
155
+
156
+ mask, masked_image_latents = self.prepare_mask_latents(
157
+ mask_condition,
158
+ masked_image,
159
+ batch_size * num_images_per_prompt,
160
+ height,
161
+ width,
162
+ prompt_embeds.dtype,
163
+ device,
164
+ generator,
165
+ do_classifier_free_guidance,
166
+ )
167
+
168
+ # 8. Check that sizes of mask, masked image and latents match
169
+ if num_channels_unet == 9:
170
+ # default case for runwayml/stable-diffusion-inpainting
171
+ num_channels_mask = mask.shape[1]
172
+ num_channels_masked_image = masked_image_latents.shape[1]
173
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
174
+ raise ValueError(
175
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
176
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
177
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
178
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
179
+ " `pipeline.unet` or your `mask_image` or `image` input."
180
+ )
181
+ elif num_channels_unet != 4:
182
+ raise ValueError(
183
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
184
+ )
185
+
186
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
187
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
188
+
189
+ # 10. Denoising loop
190
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
191
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
192
+ for i, t in enumerate(timesteps):
193
+ # expand the latents if we are doing classifier free guidance
194
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
195
+
196
+ # concat latents, mask, masked_image_latents in the channel dimension
197
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
198
+
199
+ if num_channels_unet == 9:
200
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
201
+
202
+ # predict the noise residual
203
+ noise_pred = self.unet(
204
+ latent_model_input,
205
+ t,
206
+ encoder_hidden_states=prompt_embeds,
207
+ cross_attention_kwargs=cross_attention_kwargs,
208
+ return_dict=False,
209
+ )[0]
210
+
211
+ # perform guidance
212
+ if do_classifier_free_guidance:
213
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
214
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
215
+
216
+ # compute the previous noisy sample x_t -> x_t-1
217
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
218
+
219
+ if num_channels_unet == 4:
220
+ init_latents_proper = image_latents[:1]
221
+ init_mask = mask[:1]
222
+
223
+ if i < len(timesteps) - 1:
224
+ noise_timestep = timesteps[i + 1]
225
+ init_latents_proper = self.scheduler.add_noise(
226
+ init_latents_proper, noise, torch.tensor([noise_timestep])
227
+ )
228
+
229
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
230
+
231
+ # call the callback, if provided
232
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
233
+ progress_bar.update()
234
+ if callback is not None and i % callback_steps == 0:
235
+ callback(i, t, latents)
236
+
237
+ if not output_type == "latent":
238
+ condition_kwargs = {}
239
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
240
+ init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
241
+ init_image_condition = init_image.clone()
242
+ init_image = self._encode_vae_image(init_image, generator=generator)
243
+ mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
244
+ condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
245
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
246
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
247
+ else:
248
+ image = latents
249
+ has_nsfw_concept = None
250
+
251
+ if has_nsfw_concept is None:
252
+ do_denormalize = [True] * image.shape[0]
253
+ else:
254
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
255
+
256
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
257
+
258
+ # Offload all models
259
+ self.maybe_free_model_hooks()
260
+
261
+ if not return_dict:
262
+ return (image, has_nsfw_concept)
263
+
264
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
265
+
266
+ class CustomStableDiffusionXLInpaintPipeline(StableDiffusionXLInpaintPipeline):
267
+ @torch.no_grad()
268
+ def __call__(
269
+ self,
270
+ prompt: Union[str, List[str]] = None,
271
+ prompt_2: Optional[Union[str, List[str]]] = None,
272
+ image: PipelineImageInput = None,
273
+ mask_image: PipelineImageInput = None,
274
+ masked_image_latents: torch.FloatTensor = None,
275
+ height: Optional[int] = None,
276
+ width: Optional[int] = None,
277
+ strength: float = 0.9999,
278
+ num_inference_steps: int = 50,
279
+ denoising_start: Optional[float] = None,
280
+ denoising_end: Optional[float] = None,
281
+ guidance_scale: float = 7.5,
282
+ negative_prompt: Optional[Union[str, List[str]]] = None,
283
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
284
+ num_images_per_prompt: Optional[int] = 1,
285
+ eta: float = 0.0,
286
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
287
+ latents: Optional[torch.FloatTensor] = None,
288
+ prompt_embeds: Optional[torch.FloatTensor] = None,
289
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
290
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
291
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
292
+ output_type: Optional[str] = "pil",
293
+ return_dict: bool = True,
294
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
295
+ callback_steps: int = 1,
296
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
297
+ guidance_rescale: float = 0.0,
298
+ original_size: Tuple[int, int] = None,
299
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
300
+ target_size: Tuple[int, int] = None,
301
+ negative_original_size: Optional[Tuple[int, int]] = None,
302
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
303
+ negative_target_size: Optional[Tuple[int, int]] = None,
304
+ aesthetic_score: float = 6.0,
305
+ negative_aesthetic_score: float = 2.5,
306
+ newx: int = 0,
307
+ newy: int = 0,
308
+ newr: int = 256,
309
+ current_seed=0,
310
+ use_noise_moving=True,
311
+ ):
312
+ # OVERWRITE METHODS
313
+ self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionXLInpaintPipeline)
314
+ self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionXLInpaintPipeline)
315
+
316
+ # 0. Default height and width to unet
317
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
318
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
319
+
320
+ # 1. Check inputs
321
+ self.check_inputs(
322
+ prompt,
323
+ prompt_2,
324
+ height,
325
+ width,
326
+ strength,
327
+ callback_steps,
328
+ negative_prompt,
329
+ negative_prompt_2,
330
+ prompt_embeds,
331
+ negative_prompt_embeds,
332
+ )
333
+
334
+ # 2. Define call parameters
335
+ if prompt is not None and isinstance(prompt, str):
336
+ batch_size = 1
337
+ elif prompt is not None and isinstance(prompt, list):
338
+ batch_size = len(prompt)
339
+ else:
340
+ batch_size = prompt_embeds.shape[0]
341
+
342
+ device = self._execution_device
343
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
344
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
345
+ # corresponds to doing no classifier free guidance.
346
+ do_classifier_free_guidance = guidance_scale > 1.0
347
+
348
+ # 3. Encode input prompt
349
+ text_encoder_lora_scale = (
350
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
351
+ )
352
+
353
+ (
354
+ prompt_embeds,
355
+ negative_prompt_embeds,
356
+ pooled_prompt_embeds,
357
+ negative_pooled_prompt_embeds,
358
+ ) = self.encode_prompt(
359
+ prompt=prompt,
360
+ prompt_2=prompt_2,
361
+ device=device,
362
+ num_images_per_prompt=num_images_per_prompt,
363
+ do_classifier_free_guidance=do_classifier_free_guidance,
364
+ negative_prompt=negative_prompt,
365
+ negative_prompt_2=negative_prompt_2,
366
+ prompt_embeds=prompt_embeds,
367
+ negative_prompt_embeds=negative_prompt_embeds,
368
+ pooled_prompt_embeds=pooled_prompt_embeds,
369
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
370
+ lora_scale=text_encoder_lora_scale,
371
+ )
372
+
373
+ # 4. set timesteps
374
+ def denoising_value_valid(dnv):
375
+ return isinstance(denoising_end, float) and 0 < dnv < 1
376
+
377
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
378
+ timesteps, num_inference_steps = self.get_timesteps(
379
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
380
+ )
381
+ # check that number of inference steps is not < 1 - as this doesn't make sense
382
+ if num_inference_steps < 1:
383
+ raise ValueError(
384
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
385
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
386
+ )
387
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
388
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
389
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
390
+ is_strength_max = strength == 1.0
391
+
392
+ # 5. Preprocess mask and image
393
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
394
+ init_image = init_image.to(dtype=torch.float32)
395
+
396
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
397
+
398
+ if masked_image_latents is not None:
399
+ masked_image = masked_image_latents
400
+ elif init_image.shape[1] == 4:
401
+ # if images are in latent space, we can't mask it
402
+ masked_image = None
403
+ else:
404
+ masked_image = init_image * (mask < 0.5)
405
+
406
+ # 6. Prepare latent variables
407
+ num_channels_latents = self.vae.config.latent_channels
408
+ num_channels_unet = self.unet.config.in_channels
409
+ return_image_latents = num_channels_unet == 4
410
+
411
+ # add_noise = True if denoising_start is None else False
412
+ latents_outputs = self.prepare_latents(
413
+ batch_size * num_images_per_prompt,
414
+ num_channels_latents,
415
+ height,
416
+ width,
417
+ prompt_embeds.dtype,
418
+ device,
419
+ generator,
420
+ latents,
421
+ image=init_image,
422
+ timestep=latent_timestep,
423
+ is_strength_max=is_strength_max,
424
+ return_noise=True,
425
+ return_image_latents=return_image_latents,
426
+ newx=newx,
427
+ newy=newy,
428
+ newr=newr,
429
+ current_seed=current_seed,
430
+ use_noise_moving=use_noise_moving,
431
+ )
432
+
433
+ if return_image_latents:
434
+ latents, noise, image_latents = latents_outputs
435
+ else:
436
+ latents, noise = latents_outputs
437
+
438
+ # 7. Prepare mask latent variables
439
+ mask, masked_image_latents = self.prepare_mask_latents(
440
+ mask,
441
+ masked_image,
442
+ batch_size * num_images_per_prompt,
443
+ height,
444
+ width,
445
+ prompt_embeds.dtype,
446
+ device,
447
+ generator,
448
+ do_classifier_free_guidance,
449
+ )
450
+
451
+ # 8. Check that sizes of mask, masked image and latents match
452
+ if num_channels_unet == 9:
453
+ # default case for runwayml/stable-diffusion-inpainting
454
+ num_channels_mask = mask.shape[1]
455
+ num_channels_masked_image = masked_image_latents.shape[1]
456
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
457
+ raise ValueError(
458
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
459
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
460
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
461
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
462
+ " `pipeline.unet` or your `mask_image` or `image` input."
463
+ )
464
+ elif num_channels_unet != 4:
465
+ raise ValueError(
466
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
467
+ )
468
+ # 8.1 Prepare extra step kwargs.
469
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
470
+
471
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
472
+ height, width = latents.shape[-2:]
473
+ height = height * self.vae_scale_factor
474
+ width = width * self.vae_scale_factor
475
+
476
+ original_size = original_size or (height, width)
477
+ target_size = target_size or (height, width)
478
+
479
+ # 10. Prepare added time ids & embeddings
480
+ if negative_original_size is None:
481
+ negative_original_size = original_size
482
+ if negative_target_size is None:
483
+ negative_target_size = target_size
484
+
485
+ add_text_embeds = pooled_prompt_embeds
486
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
487
+ original_size,
488
+ crops_coords_top_left,
489
+ target_size,
490
+ aesthetic_score,
491
+ negative_aesthetic_score,
492
+ negative_original_size,
493
+ negative_crops_coords_top_left,
494
+ negative_target_size,
495
+ dtype=prompt_embeds.dtype,
496
+ )
497
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
498
+
499
+ if do_classifier_free_guidance:
500
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
501
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
502
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
503
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
504
+
505
+ prompt_embeds = prompt_embeds.to(device)
506
+ add_text_embeds = add_text_embeds.to(device)
507
+ add_time_ids = add_time_ids.to(device)
508
+
509
+ # 11. Denoising loop
510
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
511
+
512
+ if (
513
+ denoising_end is not None
514
+ and denoising_start is not None
515
+ and denoising_value_valid(denoising_end)
516
+ and denoising_value_valid(denoising_start)
517
+ and denoising_start >= denoising_end
518
+ ):
519
+ raise ValueError(
520
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
521
+ + f" {denoising_end} when using type float."
522
+ )
523
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
524
+ discrete_timestep_cutoff = int(
525
+ round(
526
+ self.scheduler.config.num_train_timesteps
527
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
528
+ )
529
+ )
530
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
531
+ timesteps = timesteps[:num_inference_steps]
532
+
533
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
534
+ for i, t in enumerate(timesteps):
535
+ # expand the latents if we are doing classifier free guidance
536
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
537
+
538
+ # concat latents, mask, masked_image_latents in the channel dimension
539
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
540
+
541
+ if num_channels_unet == 9:
542
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
543
+
544
+ # predict the noise residual
545
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
546
+ noise_pred = self.unet(
547
+ latent_model_input,
548
+ t,
549
+ encoder_hidden_states=prompt_embeds,
550
+ cross_attention_kwargs=cross_attention_kwargs,
551
+ added_cond_kwargs=added_cond_kwargs,
552
+ return_dict=False,
553
+ )[0]
554
+
555
+ # perform guidance
556
+ if do_classifier_free_guidance:
557
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
558
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
559
+
560
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
561
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
562
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
563
+
564
+ # compute the previous noisy sample x_t -> x_t-1
565
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
566
+
567
+ if num_channels_unet == 4:
568
+ init_latents_proper = image_latents[:1]
569
+ init_mask = mask[:1]
570
+
571
+ if i < len(timesteps) - 1:
572
+ noise_timestep = timesteps[i + 1]
573
+ init_latents_proper = self.scheduler.add_noise(
574
+ init_latents_proper, noise, torch.tensor([noise_timestep])
575
+ )
576
+
577
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
578
+
579
+ # call the callback, if provided
580
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
581
+ progress_bar.update()
582
+ if callback is not None and i % callback_steps == 0:
583
+ callback(i, t, latents)
584
+
585
+ if not output_type == "latent":
586
+ # make sure the VAE is in float32 mode, as it overflows in float16
587
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
588
+
589
+ if needs_upcasting:
590
+ self.upcast_vae()
591
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
592
+
593
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
594
+
595
+ # cast back to fp16 if needed
596
+ if needs_upcasting:
597
+ self.vae.to(dtype=torch.float16)
598
+ else:
599
+ return StableDiffusionXLPipelineOutput(images=latents)
600
+
601
+ # apply watermark if available
602
+ if self.watermark is not None:
603
+ image = self.watermark.apply_watermark(image)
604
+
605
+ image = self.image_processor.postprocess(image, output_type=output_type)
606
+
607
+ # Offload all models
608
+ self.maybe_free_model_hooks()
609
+
610
+ if not return_dict:
611
+ return (image,)
612
+
613
+ return StableDiffusionXLPipelineOutput(images=image)
relighting/pipeline_utils.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import itertools
4
+ from diffusers.utils.torch_utils import randn_tensor
5
+
6
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
7
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
8
+ """
9
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
10
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
11
+ """
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ # rescale the results from guidance (fixes overexposure)
15
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
16
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
17
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
18
+ return noise_cfg
19
+
20
+ def expand_noise(noise, shape, seed, device, dtype):
21
+ new_generator = torch.Generator().manual_seed(seed)
22
+ corner_shape = (shape[0], shape[1], shape[2] // 2, shape[3] // 2)
23
+ vert_border_shape = (shape[0], shape[1], shape[2], shape[3] // 2)
24
+ hori_border_shape = (shape[0], shape[1], shape[2] // 2, shape[3])
25
+
26
+ corners = [randn_tensor(corner_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(4)]
27
+ vert_borders = [randn_tensor(vert_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
28
+ hori_borders = [randn_tensor(hori_border_shape, generator=new_generator, device=device, dtype=dtype) for _ in range(2)]
29
+
30
+ # combine
31
+ big_shape = (shape[0], shape[1], shape[2] * 2, shape[3] * 2)
32
+ noise_template = randn_tensor(big_shape, generator=new_generator, device=device, dtype=dtype)
33
+
34
+ ticks = [(0, 0.25), (0.25, 0.75), (0.75, 1.0)]
35
+ grid = list(itertools.product(ticks, ticks))
36
+ noise_list = [
37
+ corners[0], hori_borders[0], corners[1],
38
+ vert_borders[0], noise, vert_borders[1],
39
+ corners[2], hori_borders[1], corners[3],
40
+ ]
41
+ for current_noise, ((x1, x2), (y1, y2)) in zip(noise_list, grid):
42
+ top_left = (int(x1 * big_shape[2]), int(y1 * big_shape[3]))
43
+ bottom_right = (int(x2 * big_shape[2]), int(y2 * big_shape[3]))
44
+ noise_template[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = current_noise
45
+
46
+ return noise_template
47
+
48
+ def custom_prepare_latents(
49
+ self,
50
+ batch_size,
51
+ num_channels_latents,
52
+ height,
53
+ width,
54
+ dtype,
55
+ device,
56
+ generator,
57
+ latents=None,
58
+ image=None,
59
+ timestep=None,
60
+ is_strength_max=True,
61
+ use_noise_moving=True,
62
+ return_noise=False,
63
+ return_image_latents=False,
64
+ newx=0,
65
+ newy=0,
66
+ newr=256,
67
+ current_seed=None,
68
+ ):
69
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
70
+ if isinstance(generator, list) and len(generator) != batch_size:
71
+ raise ValueError(
72
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
73
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
74
+ )
75
+
76
+ if (image is None or timestep is None) and not is_strength_max:
77
+ raise ValueError(
78
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
79
+ "However, either the image or the noise timestep has not been provided."
80
+ )
81
+
82
+ if image.shape[1] == 4:
83
+ image_latents = image.to(device=device, dtype=dtype)
84
+ elif return_image_latents or (latents is None and not is_strength_max):
85
+ image = image.to(device=device, dtype=dtype)
86
+ image_latents = self._encode_vae_image(image=image, generator=generator)
87
+
88
+ if latents is None and use_noise_moving:
89
+ # random big noise map
90
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
91
+ noise = expand_noise(noise, shape, seed=current_seed, device=device, dtype=dtype)
92
+
93
+ # ensure noise is the same regardless of inpainting location (top-left corner notation)
94
+ newys = [newy] if not isinstance(newy, list) else newy
95
+ newxs = [newx] if not isinstance(newx, list) else newx
96
+ big_noise = noise.clone()
97
+ prev_noise = None
98
+ for newy, newx in zip(newys, newxs):
99
+ # find patch location within big noise map
100
+ sy = big_noise.shape[2] // 4 + ((512 - 128) - newy) // self.vae_scale_factor
101
+ sx = big_noise.shape[3] // 4 + ((512 - 128) - newx) // self.vae_scale_factor
102
+
103
+ if prev_noise is not None:
104
+ new_noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
105
+
106
+ ball_mask = torch.zeros(shape, device=device, dtype=bool)
107
+ top_left = (newy // self.vae_scale_factor, newx // self.vae_scale_factor)
108
+ bottom_right = (top_left[0] + newr // self.vae_scale_factor, top_left[1] + newr // self.vae_scale_factor) # fixed ball size r = 256
109
+ ball_mask[:, :, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = True
110
+
111
+ noise = prev_noise.clone()
112
+ noise[ball_mask] = new_noise[ball_mask]
113
+ else:
114
+ noise = big_noise[:, :, sy:sy+shape[2], sx:sx+shape[3]]
115
+
116
+ prev_noise = noise.clone()
117
+
118
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
119
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
120
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
121
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
122
+ elif latents is None:
123
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
124
+ latents = image_latents.to(device)
125
+ else:
126
+ noise = latents.to(device)
127
+ latents = noise * self.scheduler.init_noise_sigma
128
+
129
+ outputs = (latents,)
130
+
131
+ if return_noise:
132
+ outputs += (noise,)
133
+
134
+ if return_image_latents:
135
+ outputs += (image_latents,)
136
+
137
+ return outputs
138
+
139
+ def custom_prepare_mask_latents(
140
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
141
+ ):
142
+ # resize the mask to latents shape as we concatenate the mask to the latents
143
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
144
+ # and half precision
145
+ mask = torch.nn.functional.interpolate(
146
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor),
147
+ mode="bilinear", align_corners=False #PURE: We add this to avoid sharp border of the ball
148
+ )
149
+ mask = mask.to(device=device, dtype=dtype)
150
+
151
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
152
+ if mask.shape[0] < batch_size:
153
+ if not batch_size % mask.shape[0] == 0:
154
+ raise ValueError(
155
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
156
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
157
+ " of masks that you pass is divisible by the total requested batch size."
158
+ )
159
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
160
+
161
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
162
+
163
+ masked_image_latents = None
164
+ if masked_image is not None:
165
+ masked_image = masked_image.to(device=device, dtype=dtype)
166
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
167
+ if masked_image_latents.shape[0] < batch_size:
168
+ if not batch_size % masked_image_latents.shape[0] == 0:
169
+ raise ValueError(
170
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
171
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
172
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
173
+ )
174
+ masked_image_latents = masked_image_latents.repeat(
175
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
176
+ )
177
+
178
+ masked_image_latents = (
179
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
180
+ )
181
+
182
+ # aligning device to prevent device errors when concating it with the latent model input
183
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
184
+
185
+ return mask, masked_image_latents
relighting/pipeline_xl.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Dict, Any, Callable, Optional, Tuple
3
+
4
+ from diffusers.utils.torch_utils import is_compiled_module
5
+ from diffusers.models import ControlNetModel
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from diffusers import StableDiffusionXLControlNetInpaintPipeline
8
+ from diffusers.image_processor import PipelineImageInput
9
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
10
+ from relighting.pipeline_utils import custom_prepare_latents, custom_prepare_mask_latents, rescale_noise_cfg
11
+
12
+ class CustomStableDiffusionXLControlNetInpaintPipeline(StableDiffusionXLControlNetInpaintPipeline):
13
+ @torch.no_grad()
14
+ def __call__(
15
+ self,
16
+ prompt: Union[str, List[str]] = None,
17
+ prompt_2: Optional[Union[str, List[str]]] = None,
18
+ image: PipelineImageInput = None,
19
+ mask_image: PipelineImageInput = None,
20
+ control_image: Union[
21
+ PipelineImageInput,
22
+ List[PipelineImageInput],
23
+ ] = None,
24
+ height: Optional[int] = None,
25
+ width: Optional[int] = None,
26
+ strength: float = 0.9999,
27
+ num_inference_steps: int = 50,
28
+ denoising_start: Optional[float] = None,
29
+ denoising_end: Optional[float] = None,
30
+ guidance_scale: float = 5.0,
31
+ negative_prompt: Optional[Union[str, List[str]]] = None,
32
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
33
+ num_images_per_prompt: Optional[int] = 1,
34
+ eta: float = 0.0,
35
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
36
+ latents: Optional[torch.FloatTensor] = None,
37
+ prompt_embeds: Optional[torch.FloatTensor] = None,
38
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
39
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
40
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
41
+ output_type: Optional[str] = "pil",
42
+ return_dict: bool = True,
43
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
44
+ callback_steps: int = 1,
45
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
46
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
47
+ guess_mode: bool = False,
48
+ control_guidance_start: Union[float, List[float]] = 0.0,
49
+ control_guidance_end: Union[float, List[float]] = 1.0,
50
+ guidance_rescale: float = 0.0,
51
+ original_size: Tuple[int, int] = None,
52
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
53
+ target_size: Tuple[int, int] = None,
54
+ aesthetic_score: float = 6.0,
55
+ negative_aesthetic_score: float = 2.5,
56
+ newx: int = 0,
57
+ newy: int = 0,
58
+ newr: int = 256,
59
+ current_seed=0,
60
+ use_noise_moving=True,
61
+ ):
62
+ # OVERWRITE METHODS
63
+ self.prepare_mask_latents = custom_prepare_mask_latents.__get__(self, CustomStableDiffusionXLControlNetInpaintPipeline)
64
+ self.prepare_latents = custom_prepare_latents.__get__(self, CustomStableDiffusionXLControlNetInpaintPipeline)
65
+
66
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
67
+
68
+ # align format for control guidance
69
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
70
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
71
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
72
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
73
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
74
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
75
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
76
+ control_guidance_end
77
+ ]
78
+
79
+ # # 0.0 Default height and width to unet
80
+ # height = height or self.unet.config.sample_size * self.vae_scale_factor
81
+ # width = width or self.unet.config.sample_size * self.vae_scale_factor
82
+
83
+ # 0.1 align format for control guidance
84
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
85
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
86
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
87
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
88
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
89
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
90
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
91
+ control_guidance_end
92
+ ]
93
+
94
+ # 1. Check inputs
95
+ self.check_inputs(
96
+ prompt,
97
+ prompt_2,
98
+ control_image,
99
+ strength,
100
+ num_inference_steps,
101
+ callback_steps,
102
+ negative_prompt,
103
+ negative_prompt_2,
104
+ prompt_embeds,
105
+ negative_prompt_embeds,
106
+ pooled_prompt_embeds,
107
+ negative_pooled_prompt_embeds,
108
+ controlnet_conditioning_scale,
109
+ control_guidance_start,
110
+ control_guidance_end,
111
+ )
112
+
113
+ # 2. Define call parameters
114
+ if prompt is not None and isinstance(prompt, str):
115
+ batch_size = 1
116
+ elif prompt is not None and isinstance(prompt, list):
117
+ batch_size = len(prompt)
118
+ else:
119
+ batch_size = prompt_embeds.shape[0]
120
+
121
+ device = self._execution_device
122
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
123
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
124
+ # corresponds to doing no classifier free guidance.
125
+ do_classifier_free_guidance = guidance_scale > 1.0
126
+
127
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
128
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
129
+
130
+ # 3. Encode input prompt
131
+ text_encoder_lora_scale = (
132
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
133
+ )
134
+
135
+ (
136
+ prompt_embeds,
137
+ negative_prompt_embeds,
138
+ pooled_prompt_embeds,
139
+ negative_pooled_prompt_embeds,
140
+ ) = self.encode_prompt(
141
+ prompt=prompt,
142
+ prompt_2=prompt_2,
143
+ device=device,
144
+ num_images_per_prompt=num_images_per_prompt,
145
+ do_classifier_free_guidance=do_classifier_free_guidance,
146
+ negative_prompt=negative_prompt,
147
+ negative_prompt_2=negative_prompt_2,
148
+ prompt_embeds=prompt_embeds,
149
+ negative_prompt_embeds=negative_prompt_embeds,
150
+ pooled_prompt_embeds=pooled_prompt_embeds,
151
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
152
+ lora_scale=text_encoder_lora_scale,
153
+ )
154
+
155
+ # 4. set timesteps
156
+ def denoising_value_valid(dnv):
157
+ return isinstance(denoising_end, float) and 0 < dnv < 1
158
+
159
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
160
+ timesteps, num_inference_steps = self.get_timesteps(
161
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
162
+ )
163
+ # check that number of inference steps is not < 1 - as this doesn't make sense
164
+ if num_inference_steps < 1:
165
+ raise ValueError(
166
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
167
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
168
+ )
169
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
170
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
171
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
172
+ is_strength_max = strength == 1.0
173
+
174
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
175
+ # 5.1 Prepare init image
176
+ init_image = self.image_processor.preprocess(image, height=height, width=width)
177
+ init_image = init_image.to(dtype=torch.float32)
178
+
179
+ # 5.2 Prepare control images
180
+ if isinstance(controlnet, ControlNetModel):
181
+ control_image = self.prepare_control_image(
182
+ image=control_image,
183
+ width=width,
184
+ height=height,
185
+ batch_size=batch_size * num_images_per_prompt,
186
+ num_images_per_prompt=num_images_per_prompt,
187
+ device=device,
188
+ dtype=controlnet.dtype,
189
+ do_classifier_free_guidance=do_classifier_free_guidance,
190
+ guess_mode=guess_mode,
191
+ )
192
+ elif isinstance(controlnet, MultiControlNetModel):
193
+ control_images = []
194
+
195
+ for control_image_ in control_image:
196
+ control_image_ = self.prepare_control_image(
197
+ image=control_image_,
198
+ width=width,
199
+ height=height,
200
+ batch_size=batch_size * num_images_per_prompt,
201
+ num_images_per_prompt=num_images_per_prompt,
202
+ device=device,
203
+ dtype=controlnet.dtype,
204
+ do_classifier_free_guidance=do_classifier_free_guidance,
205
+ guess_mode=guess_mode,
206
+ )
207
+
208
+ control_images.append(control_image_)
209
+
210
+ control_image = control_images
211
+ else:
212
+ raise ValueError(f"{controlnet.__class__} is not supported.")
213
+
214
+ # 5.3 Prepare mask
215
+ mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
216
+
217
+ masked_image = init_image * (mask < 0.5)
218
+ _, _, height, width = init_image.shape
219
+
220
+ # 6. Prepare latent variables
221
+ num_channels_latents = self.vae.config.latent_channels
222
+ num_channels_unet = self.unet.config.in_channels
223
+ return_image_latents = num_channels_unet == 4
224
+
225
+ add_noise = True if denoising_start is None else False
226
+ latents_outputs = self.prepare_latents(
227
+ batch_size * num_images_per_prompt,
228
+ num_channels_latents,
229
+ height,
230
+ width,
231
+ prompt_embeds.dtype,
232
+ device,
233
+ generator,
234
+ latents,
235
+ image=init_image,
236
+ timestep=latent_timestep,
237
+ is_strength_max=is_strength_max,
238
+ return_noise=True,
239
+ return_image_latents=return_image_latents,
240
+ newx=newx,
241
+ newy=newy,
242
+ newr=newr,
243
+ current_seed=current_seed,
244
+ use_noise_moving=use_noise_moving,
245
+ )
246
+
247
+ if return_image_latents:
248
+ latents, noise, image_latents = latents_outputs
249
+ else:
250
+ latents, noise = latents_outputs
251
+
252
+ # 7. Prepare mask latent variables
253
+ mask, masked_image_latents = self.prepare_mask_latents(
254
+ mask,
255
+ masked_image,
256
+ batch_size * num_images_per_prompt,
257
+ height,
258
+ width,
259
+ prompt_embeds.dtype,
260
+ device,
261
+ generator,
262
+ do_classifier_free_guidance,
263
+ )
264
+
265
+ # 8. Check that sizes of mask, masked image and latents match
266
+ if num_channels_unet == 9:
267
+ # default case for runwayml/stable-diffusion-inpainting
268
+ num_channels_mask = mask.shape[1]
269
+ num_channels_masked_image = masked_image_latents.shape[1]
270
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
271
+ raise ValueError(
272
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
273
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
274
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
275
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
276
+ " `pipeline.unet` or your `mask_image` or `image` input."
277
+ )
278
+ elif num_channels_unet != 4:
279
+ raise ValueError(
280
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
281
+ )
282
+ # 8.1 Prepare extra step kwargs.
283
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
284
+
285
+ # 8.2 Create tensor stating which controlnets to keep
286
+ controlnet_keep = []
287
+ for i in range(len(timesteps)):
288
+ keeps = [
289
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
290
+ for s, e in zip(control_guidance_start, control_guidance_end)
291
+ ]
292
+ if isinstance(self.controlnet, MultiControlNetModel):
293
+ controlnet_keep.append(keeps)
294
+ else:
295
+ controlnet_keep.append(keeps[0])
296
+
297
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
298
+ height, width = latents.shape[-2:]
299
+ height = height * self.vae_scale_factor
300
+ width = width * self.vae_scale_factor
301
+
302
+ original_size = original_size or (height, width)
303
+ target_size = target_size or (height, width)
304
+
305
+ # 10. Prepare added time ids & embeddings
306
+ add_text_embeds = pooled_prompt_embeds
307
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
308
+ original_size,
309
+ crops_coords_top_left,
310
+ target_size,
311
+ aesthetic_score,
312
+ negative_aesthetic_score,
313
+ dtype=prompt_embeds.dtype,
314
+ )
315
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
316
+
317
+ if do_classifier_free_guidance:
318
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
319
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
320
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
321
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
322
+
323
+ prompt_embeds = prompt_embeds.to(device)
324
+ add_text_embeds = add_text_embeds.to(device)
325
+ add_time_ids = add_time_ids.to(device)
326
+
327
+ # 11. Denoising loop
328
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
329
+
330
+ if (
331
+ denoising_end is not None
332
+ and denoising_start is not None
333
+ and denoising_value_valid(denoising_end)
334
+ and denoising_value_valid(denoising_start)
335
+ and denoising_start >= denoising_end
336
+ ):
337
+ raise ValueError(
338
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
339
+ + f" {denoising_end} when using type float."
340
+ )
341
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
342
+ discrete_timestep_cutoff = int(
343
+ round(
344
+ self.scheduler.config.num_train_timesteps
345
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
346
+ )
347
+ )
348
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
349
+ timesteps = timesteps[:num_inference_steps]
350
+
351
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
352
+ for i, t in enumerate(timesteps):
353
+ # expand the latents if we are doing classifier free guidance
354
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
355
+
356
+ # concat latents, mask, masked_image_latents in the channel dimension
357
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
358
+
359
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
360
+
361
+ # controlnet(s) inference
362
+ if guess_mode and do_classifier_free_guidance:
363
+ # Infer ControlNet only for the conditional batch.
364
+ control_model_input = latents
365
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
366
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
367
+ controlnet_added_cond_kwargs = {
368
+ "text_embeds": add_text_embeds.chunk(2)[1],
369
+ "time_ids": add_time_ids.chunk(2)[1],
370
+ }
371
+ else:
372
+ control_model_input = latent_model_input
373
+ controlnet_prompt_embeds = prompt_embeds
374
+ controlnet_added_cond_kwargs = added_cond_kwargs
375
+
376
+ if isinstance(controlnet_keep[i], list):
377
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
378
+ else:
379
+ controlnet_cond_scale = controlnet_conditioning_scale
380
+ if isinstance(controlnet_cond_scale, list):
381
+ controlnet_cond_scale = controlnet_cond_scale[0]
382
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
383
+
384
+ # # Resize control_image to match the size of the input to the controlnet
385
+ # if control_image.shape[-2:] != control_model_input.shape[-2:]:
386
+ # control_image = F.interpolate(control_image, size=control_model_input.shape[-2:], mode="bilinear", align_corners=False)
387
+
388
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
389
+ control_model_input,
390
+ t,
391
+ encoder_hidden_states=controlnet_prompt_embeds,
392
+ controlnet_cond=control_image,
393
+ conditioning_scale=cond_scale,
394
+ guess_mode=guess_mode,
395
+ added_cond_kwargs=controlnet_added_cond_kwargs,
396
+ return_dict=False,
397
+ )
398
+
399
+ if guess_mode and do_classifier_free_guidance:
400
+ # Infered ControlNet only for the conditional batch.
401
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
402
+ # add 0 to the unconditional batch to keep it unchanged.
403
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
404
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
405
+
406
+ if num_channels_unet == 9:
407
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
408
+
409
+ # predict the noise residual
410
+ noise_pred = self.unet(
411
+ latent_model_input,
412
+ t,
413
+ encoder_hidden_states=prompt_embeds,
414
+ cross_attention_kwargs=cross_attention_kwargs,
415
+ down_block_additional_residuals=down_block_res_samples,
416
+ mid_block_additional_residual=mid_block_res_sample,
417
+ added_cond_kwargs=added_cond_kwargs,
418
+ return_dict=False,
419
+ )[0]
420
+
421
+ # perform guidance
422
+ if do_classifier_free_guidance:
423
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
424
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
425
+
426
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
427
+ print("rescale: ", guidance_rescale)
428
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
429
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
430
+
431
+ # compute the previous noisy sample x_t -> x_t-1
432
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
433
+
434
+ if num_channels_unet == 4:
435
+ init_latents_proper = image_latents[:1]
436
+ init_mask = mask[:1]
437
+
438
+ if i < len(timesteps) - 1:
439
+ noise_timestep = timesteps[i + 1]
440
+ init_latents_proper = self.scheduler.add_noise(
441
+ init_latents_proper, noise, torch.tensor([noise_timestep])
442
+ )
443
+
444
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
445
+
446
+ # call the callback, if provided
447
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
448
+ progress_bar.update()
449
+ if callback is not None and i % callback_steps == 0:
450
+ callback(i, t, latents)
451
+
452
+ # make sure the VAE is in float32 mode, as it overflows in float16
453
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
454
+ self.upcast_vae()
455
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
456
+
457
+ # If we do sequential model offloading, let's offload unet and controlnet
458
+ # manually for max memory savings
459
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
460
+ self.unet.to("cpu")
461
+ self.controlnet.to("cpu")
462
+ torch.cuda.empty_cache()
463
+
464
+ if not output_type == "latent":
465
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
466
+ else:
467
+ return StableDiffusionXLPipelineOutput(images=latents)
468
+
469
+ # apply watermark if available
470
+ if self.watermark is not None:
471
+ image = self.watermark.apply_watermark(image)
472
+
473
+ image = self.image_processor.postprocess(image, output_type=output_type)
474
+
475
+ # Offload last model to CPU
476
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
477
+ self.final_offload_hook.offload()
478
+
479
+ if not return_dict:
480
+ return (image,)
481
+
482
+ return StableDiffusionXLPipelineOutput(images=image)
relighting/tonemapper.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ class TonemapHDR(object):
4
+ """
5
+ Tonemap HDR image globally. First, we find alpha that maps the (max(numpy_img) * percentile) to max_mapping.
6
+ Then, we calculate I_out = alpha * I_in ^ (1/gamma)
7
+ input : nd.array batch of images : [H, W, C]
8
+ output : nd.array batch of images : [H, W, C]
9
+ """
10
+
11
+ def __init__(self, gamma=2.4, percentile=50, max_mapping=0.5):
12
+ self.gamma = gamma
13
+ self.percentile = percentile
14
+ self.max_mapping = max_mapping # the value to which alpha will map the (max(numpy_img) * percentile) to
15
+
16
+ def __call__(self, numpy_img, clip=True, alpha=None, gamma=True):
17
+ if gamma:
18
+ power_numpy_img = np.power(numpy_img, 1 / self.gamma)
19
+ else:
20
+ power_numpy_img = numpy_img
21
+ non_zero = power_numpy_img > 0
22
+ if non_zero.any():
23
+ r_percentile = np.percentile(power_numpy_img[non_zero], self.percentile)
24
+ else:
25
+ r_percentile = np.percentile(power_numpy_img, self.percentile)
26
+ if alpha is None:
27
+ alpha = self.max_mapping / (r_percentile + 1e-10)
28
+ tonemapped_img = np.multiply(alpha, power_numpy_img)
29
+
30
+ if clip:
31
+ tonemapped_img_clip = np.clip(tonemapped_img, 0, 1)
32
+
33
+ return tonemapped_img_clip.astype('float32'), alpha, tonemapped_img
relighting/utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+ from PIL import Image
5
+ import hashlib
6
+
7
+ def str2bool(v):
8
+ """
9
+ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
10
+ """
11
+ if isinstance(v, bool):
12
+ return v
13
+ if v.lower() in ("yes", "true", "t", "y", "1"):
14
+ return True
15
+ elif v.lower() in ("no", "false", "f", "n", "0"):
16
+ return False
17
+ else:
18
+ raise argparse.ArgumentTypeError("boolean value expected")
19
+
20
+ def add_dict_to_argparser(parser, default_dict):
21
+ for k, v in default_dict.items():
22
+ v_type = type(v)
23
+ if v is None:
24
+ v_type = str
25
+ elif isinstance(v, bool):
26
+ v_type = str2bool
27
+ parser.add_argument(f"--{k}", default=v, type=v_type)
28
+
29
+ def args_to_dict(args, keys):
30
+ return {k: getattr(args, k) for k in keys}
31
+
32
+ def save_result(
33
+ image, image_path,
34
+ mask=None, mask_path=None,
35
+ normal=None, normal_path=None,
36
+ ):
37
+ assert isinstance(image, Image.Image)
38
+ os.makedirs(Path(image_path).parent, exist_ok=True)
39
+ image.save(image_path)
40
+
41
+ if (mask is not None) and (mask_path is not None):
42
+ assert isinstance(mask, Image.Image)
43
+ os.makedirs(Path(mask_path).parent, exist_ok=True)
44
+ mask.save(mask_path)
45
+
46
+ if (normal is not None) and (normal_path is not None):
47
+ assert isinstance(normal, Image.Image)
48
+ os.makedirs(Path(normal_path).parent, exist_ok=True)
49
+ normal.save(normal_path)
50
+
51
+ def name2hash(name: str):
52
+ """
53
+ @see https://stackoverflow.com/questions/16008670/how-to-hash-a-string-into-8-digits
54
+ """
55
+ hash_number = int(hashlib.sha1(name.encode("utf-8")).hexdigest(), 16) % (10 ** 8)
56
+ return hash_number
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utility
2
+ tqdm==4.66.1
3
+ scikit-image==0.21.0
4
+ imageio==2.31.1
5
+ Pillow==10.2.0
6
+ numpy==1.24.1
7
+ natsort==8.4.0
8
+
9
+ # EXR handling
10
+ skylibs==0.7.4
11
+ OpenEXR==1.3.9
12
+
13
+ # We use pytorch pip instead because conda is mess up
14
+ --extra-index-url https://download.pytorch.org/whl/cu118
15
+ torch==2.0.1+cu118
16
+ torchvision==0.15.2+cu118
17
+ torchaudio==2.0.2+cu118
18
+
19
+ # Diffusers dependencies
20
+ accelerate==0.21.0
21
+ datasets==2.13.1
22
+ diffusers==0.21.0
23
+ transformers==4.36.0
24
+ xformers==0.0.20
25
+ huggingface_hub==0.19.4