YaohuiW commited on
Commit
c42db24
·
verified ·
1 Parent(s): 49ee0a3

Upload 19 files

Browse files
gradio_tabs/__init__.py ADDED
File without changes
gradio_tabs/animation.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import torchvision
5
+ from PIL import Image
6
+ import numpy as np
7
+ import imageio
8
+
9
+ extensions_dir = "./torch_extension/"
10
+ os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
11
+
12
+ from networks.generator import Generator
13
+
14
+ device = torch.device("cuda")
15
+ ckpt_path = './models/lia-x.pt'
16
+ gen = Generator(size=512, motion_dim=40, scale=2).to(device)
17
+ gen.load_state_dict(torch.load(ckpt_path, weights_only=False))
18
+ gen.eval()
19
+
20
+ output_dir = "./res_gradio"
21
+ os.makedirs(output_dir, exist_ok=True)
22
+
23
+ # lables
24
+ labels_k = [
25
+ 'yaw1',
26
+ 'yaw2',
27
+ 'pitch',
28
+ 'roll1',
29
+ 'roll2',
30
+ 'neck',
31
+
32
+ 'pout',
33
+ 'open->close',
34
+ '"O" Mouth',
35
+ 'apple cheek',
36
+
37
+ 'close->open',
38
+ 'eyebrows',
39
+ 'eyeballs1',
40
+ 'eyeballs2',
41
+
42
+ ]
43
+
44
+ labels_v = [
45
+ 37, 39, 28, 15, 33, 31,
46
+ 6, 25, 16, 19,
47
+ 13, 24, 17, 26
48
+ ]
49
+
50
+
51
+ def load_image(img, size):
52
+ # img = Image.open(filename).convert('RGB')
53
+ if not isinstance(img, np.ndarray):
54
+ img = Image.open(img).convert('RGB')
55
+ img = img.resize((size, size))
56
+ img = np.asarray(img)
57
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
58
+
59
+ return img / 255.0
60
+
61
+
62
+ def img_preprocessing(img_path, size):
63
+ img = load_image(img_path, size) # [0, 1]
64
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
65
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
66
+
67
+ return imgs_norm
68
+
69
+
70
+ def resize(img, size):
71
+ transform = torchvision.transforms.Compose([
72
+ torchvision.transforms.Resize(size, antialias=True),
73
+ torchvision.transforms.CenterCrop(size)
74
+ ])
75
+
76
+ return transform(img)
77
+
78
+
79
+ def vid_preprocessing(vid_path, size):
80
+ vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
81
+ vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
82
+ fps = vid_dict[2]['video_fps']
83
+ vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
84
+
85
+ vid_norm = torch.cat([
86
+ resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
87
+ ], dim=1)
88
+
89
+ return vid_norm, fps
90
+
91
+
92
+ def img_denorm(img):
93
+ img = img.clamp(-1, 1).cpu()
94
+ img = (img - img.min()) / (img.max() - img.min())
95
+
96
+ return img
97
+
98
+
99
+ def vid_denorm(vid):
100
+ vid = vid.clamp(-1, 1).cpu()
101
+ vid = (vid - vid.min()) / (vid.max() - vid.min())
102
+
103
+ return vid
104
+
105
+
106
+ def img_postprocessing(image, output_path=output_dir + "/output_img.png"):
107
+
108
+ image = image.permute(0, 2, 3, 1)
109
+ edited_image = img_denorm(image)
110
+ img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
111
+ imageio.imwrite(output_path, img_output, quality=6)
112
+
113
+ return output_path
114
+
115
+
116
+ def vid_postprocessing(video, fps, output_path=output_dir + "/output_vid.mp4"):
117
+ # video: BCTHW
118
+
119
+ vid = video.permute(0, 2, 3, 4, 1) # B T H W C
120
+ vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')
121
+ imageio.mimwrite(output_path, vid_np, fps=fps, codec='libx264', quality=10)
122
+
123
+ return output_path
124
+
125
+
126
+ @torch.no_grad()
127
+ def edit_media(image, *selected_s):
128
+
129
+ image_tensor = img_preprocessing(image, 512)
130
+ image_tensor = image_tensor.to(device)
131
+
132
+ edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
133
+
134
+ # de-norm
135
+ edited_image = img_postprocessing(edited_image_tensor)
136
+
137
+ return edited_image
138
+
139
+
140
+ @torch.no_grad()
141
+ def animate_media(image, video, *selected_s):
142
+
143
+ image_tensor = img_preprocessing(image, 512)
144
+ vid_target_tensor, fps = vid_preprocessing(video, 512)
145
+ image_tensor = image_tensor.to(device)
146
+ video_target_tensor = vid_target_tensor.to(device)
147
+
148
+ animated_video = gen.animate(image_tensor, video_target_tensor, labels_v, selected_s)
149
+
150
+ # postprocessing
151
+ animated_video = vid_postprocessing(animated_video, fps)
152
+
153
+ return animated_video
154
+
155
+
156
+ def clear_media():
157
+ return None, None, *([0] * len(labels_k))
158
+
159
+
160
+ image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
161
+ video_output = gr.Video(label="Output Video", width=512)
162
+
163
+
164
+ def animation():
165
+ with gr.Tab("Animation & Image Editing"):
166
+
167
+ inputs_s = []
168
+
169
+ with gr.Row():
170
+ with gr.Column(scale=1):
171
+ with gr.Row():
172
+ with gr.Accordion(open=True, label="Source Image"):
173
+ image_input = gr.Image(type="filepath", width=512) # , height=550)
174
+ gr.Examples(
175
+ examples=[
176
+ ["./data/source/macron.png"],
177
+ ["./data/source/einstein.png"],
178
+ ["./data/source/taylor.png"],
179
+ ["./data/source/portrait1.png"],
180
+ ["./data/source/portrait2.png"],
181
+ ["./data/source/portrait3.png"],
182
+ ],
183
+ inputs=[image_input],
184
+ cache_examples=False,
185
+ visible=True,
186
+ )
187
+
188
+ with gr.Accordion(open=True, label="Driving Video"):
189
+ video_input = gr.Video(width=512) # , height=550)
190
+ gr.Examples(
191
+ examples=[
192
+ ["./data/driving/driving1.mp4"],
193
+ ["./data/driving/driving2.mp4"],
194
+ ["./data/driving/driving4.mp4"],
195
+ #["./data/driving/driving5.mp4"],
196
+ ["./data/driving/driving6.mp4"],
197
+ #["./data/driving/driving7.mp4"],
198
+ ["./data/driving/driving8.mov"],
199
+ ],
200
+ inputs=[video_input],
201
+ cache_examples=False,
202
+ visible=True,
203
+ )
204
+
205
+ with gr.Row():
206
+ with gr.Column(scale=1):
207
+ with gr.Row(): # Buttons now within a single Row
208
+ edit_btn = gr.Button("Edit")
209
+ clear_btn = gr.Button("Clear")
210
+ with gr.Row():
211
+ animate_btn = gr.Button("Animate")
212
+
213
+
214
+
215
+ with gr.Column(scale=1):
216
+
217
+ with gr.Row():
218
+ with gr.Accordion(open=True, label="Edited Source Image"):
219
+ image_output.render()
220
+
221
+ with gr.Accordion(open=True, label="Animated Video"):
222
+ video_output.render()
223
+
224
+ with gr.Accordion("Control Panel", open=True):
225
+ with gr.Tab("Head"):
226
+ with gr.Row():
227
+ for k in labels_k[:3]:
228
+ slider = gr.Slider(minimum=-1.0, maximum=0.5, value=0, label=k)
229
+ inputs_s.append(slider)
230
+ with gr.Row():
231
+ for k in labels_k[3:6]:
232
+ slider = gr.Slider(minimum=-0.5, maximum=0.5, value=0, label=k)
233
+ inputs_s.append(slider)
234
+
235
+ with gr.Tab("Mouth"):
236
+ with gr.Row():
237
+ for k in labels_k[6:8]:
238
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
239
+ inputs_s.append(slider)
240
+ with gr.Row():
241
+ for k in labels_k[8:10]:
242
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
243
+ inputs_s.append(slider)
244
+
245
+ with gr.Tab("Eyes"):
246
+ with gr.Row():
247
+ for k in labels_k[10:12]:
248
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
249
+ inputs_s.append(slider)
250
+ with gr.Row():
251
+ for k in labels_k[12:14]:
252
+ slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
253
+ inputs_s.append(slider)
254
+
255
+
256
+ edit_btn.click(
257
+ fn=edit_media,
258
+ inputs=[image_input] + inputs_s,
259
+ outputs=[image_output],
260
+ show_progress=True
261
+ )
262
+
263
+ animate_btn.click(
264
+ fn=animate_media,
265
+ inputs=[image_input, video_input] + inputs_s, # [image_input, video_input] + inputs_s,
266
+ outputs=[video_output],
267
+ )
268
+
269
+ clear_btn.click(
270
+ fn=clear_media,
271
+ outputs=[image_output, video_output] + inputs_s
272
+ )
273
+
274
+ gr.Examples(
275
+ examples=[
276
+ ['./data/source/macron.png', './data/driving/driving1.mp4', 0.14,0,-0.26,-0.29,-0.11,0,-0.13,-0.18,0,0,0,0,-0.02,0.07],
277
+ ['./data/source/portrait1.png', './data/driving/driving2.mp4', -0.1, 0, 0, 0.17, 0.16, 0, 0.01, 0, 0.17,0.17, 0, 0, 0, 0],
278
+ ['./data/source/macron.png', './data/driving/driving4.mp4', -0.24, -0.17, -0.15, 0, 0, 0, 0, -0.16,
279
+ 0.08, 0, 0, 0, 0, 0],
280
+ ['./data/source/portrait2.png', './data/driving/driving3.mp4', 0.33, 0.38, -0.22, 0.25, -0.23, 0, -0.16,
281
+ 0, 0.06, 0, 0, 0, 0, 0],
282
+ ['./data/source/portrait2.png', './data/driving/driving6.mp4', -0.27, -0.25, 0, 0, 0, 0, 0, 0, 0, 0, 0,
283
+ 0, 0, 0],
284
+ ['./data/source/portrait2.png','./data/driving/driving1.mp4',-0.03,0.21,-0.41,-0.29,-0.11,0,0,-0.23,0,0,0,0,-0.02,0.07],
285
+ ['./data/source/portrait3.png','./data/driving/driving1.mp4',-0.03,0.21,-0.31,-0.12,-0.11,0,-0.05,-0.16,0,0,0,0,-0.02,0.07],
286
+ ['./data/source/portrait1.png','./data/driving/driving1.mp4',-0.03,0.21,-0.31,-0.12,-0.11,0,-0.1,-0.12,0,0.11,0,0,-0.02,0.07],
287
+ ['./data/source/einstein.png','./data/driving/driving2.mp4',-0.31,0,0,0.16,0.08,0,-0.07,0,0.13,0,0,0,0,0],
288
+ ['./data/source/einstein.png', './data/driving/driving4.mp4',0,0,0,0,0,0,0,-0.14,0.1,0,0,0,0,0],
289
+ ['./data/source/portrait1.png', './data/driving/driving4.mp4',0,0,0,0,0,0,0,-0.1,0.19,0,0,0,0,0],
290
+ ['./data/source/macron.png', './data/driving/driving6.mp4',-0.37,-0.34,0,0,0,0,0,0,0,0,0,0,0,0],
291
+ ],
292
+ inputs=[image_input, video_input] + inputs_s
293
+ )
294
+
295
+
gradio_tabs/vid_edit.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import torchvision
5
+ from PIL import Image
6
+ import numpy as np
7
+ import imageio
8
+ from einops import rearrange
9
+
10
+ extensions_dir = "./torch_extension/"
11
+ os.environ["TORCH_EXTENSIONS_DIR"] = extensions_dir
12
+
13
+ from networks.generator import Generator
14
+
15
+ device = torch.device("cuda")
16
+ ckpt_path = './models/lia-x.pt'
17
+ gen = Generator(size=512, motion_dim=40, scale=2).to(device)
18
+ gen.load_state_dict(torch.load(ckpt_path, weights_only=False))
19
+ gen.eval()
20
+
21
+ output_dir = "./res_gradio"
22
+ os.makedirs(output_dir, exist_ok=True)
23
+
24
+ # lables
25
+ labels_k = [
26
+ 'yaw1',
27
+ 'yaw2',
28
+ 'pitch',
29
+ 'roll1',
30
+ 'roll2',
31
+ 'neck',
32
+
33
+ 'pout',
34
+ 'open->close',
35
+ '"O" mouth',
36
+ 'apple cheek',
37
+
38
+ 'close->open',
39
+ 'eyebrows',
40
+ 'eyeballs1',
41
+ 'eyeballs2',
42
+
43
+ ]
44
+
45
+ labels_v = [
46
+ 37, 39, 28, 15, 33, 31,
47
+ 6, 25, 16, 19,
48
+ 13, 24, 17, 26
49
+ ]
50
+
51
+
52
+ def load_image(img, size):
53
+ # img = Image.open(filename).convert('RGB')
54
+ if not isinstance(img, np.ndarray):
55
+ img = Image.open(img).convert('RGB')
56
+ img = img.resize((size, size))
57
+ img = np.asarray(img)
58
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
59
+
60
+ return img / 255.0
61
+
62
+
63
+ def img_preprocessing(img_path, size):
64
+ img = load_image(img_path, size) # [0, 1]
65
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
66
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
67
+
68
+ return imgs_norm
69
+
70
+
71
+ def resize(img, size):
72
+ transform = torchvision.transforms.Compose([
73
+ torchvision.transforms.Resize(size, antialias=True),
74
+ torchvision.transforms.CenterCrop(size)
75
+ ])
76
+
77
+ return transform(img)
78
+
79
+
80
+ def vid_preprocessing(vid_path, size):
81
+ vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
82
+ vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
83
+ fps = vid_dict[2]['video_fps']
84
+ vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
85
+
86
+ vid_norm = torch.cat([
87
+ resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
88
+ ], dim=1)
89
+
90
+ return vid_norm, fps
91
+
92
+
93
+ def img_denorm(img):
94
+ img = img.clamp(-1, 1).cpu()
95
+ img = (img - img.min()) / (img.max() - img.min())
96
+
97
+ return img
98
+
99
+
100
+ def vid_denorm(vid):
101
+ vid = vid.clamp(-1, 1).cpu()
102
+ vid = (vid - vid.min()) / (vid.max() - vid.min())
103
+
104
+ return vid
105
+
106
+
107
+ def img_postprocessing(image, output_path=output_dir + "/output_img.png"):
108
+ image = image.permute(0, 2, 3, 1)
109
+ edited_image = img_denorm(image)
110
+ img_output = (edited_image[0].numpy() * 255).astype(np.uint8)
111
+ imageio.imwrite(output_path, img_output, quality=6)
112
+
113
+ return output_path
114
+
115
+
116
+ def vid_all_save(vid_d, vid_a, fps, output_path=output_dir + "/output_vid.mp4", output_all_path=output_dir + "/output_all_vid.mp4"):
117
+
118
+ vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
119
+ vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
120
+ vid_all = torch.cat([vid_d, vid_a], dim=3)
121
+
122
+ vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
123
+ vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
124
+
125
+ imageio.mimwrite(output_path, vid_a_np, fps=fps, codec='libx264', quality=8)
126
+ imageio.mimwrite(output_all_path, vid_all_np, fps=fps, codec='libx264', quality=8)
127
+
128
+ return output_path, output_all_path
129
+
130
+
131
+ @torch.no_grad()
132
+ def edit_img(video, *selected_s):
133
+
134
+ vid_target_tensor, fps = vid_preprocessing(video, 512)
135
+ video_target_tensor = vid_target_tensor.to(device)
136
+ image_tensor = video_target_tensor[:,0,:,:,:]
137
+
138
+ edited_image_tensor = gen.edit_img(image_tensor, labels_v, selected_s)
139
+
140
+ # de-norm
141
+ edited_image = img_postprocessing(edited_image_tensor)
142
+
143
+ return edited_image
144
+
145
+
146
+ @torch.no_grad()
147
+ def edit_vid(video, *selected_s):
148
+
149
+ video_target_tensor, fps = vid_preprocessing(video, 512)
150
+ video_target_tensor = video_target_tensor.to(device)
151
+
152
+ edited_video_tensor = gen.edit_vid(video_target_tensor, labels_v, selected_s)
153
+
154
+ # de-norm
155
+ animated_video, animated_all_video = vid_all_save(video_target_tensor, edited_video_tensor, fps)
156
+
157
+ return animated_video, animated_all_video
158
+
159
+
160
+ def clear_media():
161
+ return None, None, None, *([0] * len(labels_k))
162
+
163
+
164
+ image_output = gr.Image(label="Image", type='numpy', interactive=False, width=512)
165
+ video_output = gr.Video(label="Video", width=512)
166
+ video_all_output = gr.Video(label="Videos")
167
+
168
+
169
+ def vid_edit():
170
+ with gr.Tab("Video Editing"):
171
+
172
+ inputs_c = []
173
+ inputs_s = []
174
+
175
+ with gr.Row():
176
+ with gr.Column(scale=1):
177
+ with gr.Row():
178
+ with gr.Accordion(open=True, label="Video"):
179
+ video_input = gr.Video(width=512) # , height=550)
180
+ gr.Examples(
181
+ examples=[
182
+ ["./data/driving/driving1.mp4"],
183
+ ["./data/driving/driving2.mp4"],
184
+ ["./data/driving/driving4.mp4"],
185
+ #["./data/driving/driving5.mp4"],
186
+ #["./data/driving/driving6.mp4"],
187
+ #["./data/driving/driving7.mp4"],
188
+ ["./data/driving/driving3.mp4"],
189
+ ["./data/driving/driving8.mov"],
190
+ ["./data/driving/driving9.mov"],
191
+ ],
192
+ inputs=[video_input],
193
+ cache_examples=False,
194
+ visible=True,
195
+ )
196
+
197
+ # with gr.Row():
198
+ # with gr.Column(scale=1):
199
+ # with gr.Row(): # Buttons now within a single Row
200
+ # edit_btn = gr.Button("Edit")
201
+ # clear_btn = gr.Button("Clear")
202
+ # with gr.Row():
203
+ # animate_btn = gr.Button("Generate")
204
+
205
+ with gr.Column(scale=2):
206
+
207
+ with gr.Row():
208
+ with gr.Accordion(open=True, label="Edited First Frame"):
209
+ image_output.render()
210
+
211
+ with gr.Accordion(open=True, label="Edited Video"):
212
+ video_output.render()
213
+
214
+ with gr.Row():
215
+ with gr.Accordion(open=True, label="Original & Edited Videos"):
216
+ video_all_output.render()
217
+
218
+ with gr.Column(scale=1):
219
+ with gr.Accordion("Control Panel", open=True):
220
+ with gr.Tab("Head"):
221
+ with gr.Row():
222
+ for k in labels_k[:3]:
223
+ slider = gr.Slider(minimum=-1.0, maximum=0.5, value=0, label=k)
224
+ inputs_s.append(slider)
225
+ with gr.Row():
226
+ for k in labels_k[3:6]:
227
+ slider = gr.Slider(minimum=-0.5, maximum=0.5, value=0, label=k)
228
+ inputs_s.append(slider)
229
+
230
+ with gr.Tab("Mouth"):
231
+ with gr.Row():
232
+ for k in labels_k[6:8]:
233
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
234
+ inputs_s.append(slider)
235
+ with gr.Row():
236
+ for k in labels_k[8:10]:
237
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
238
+ inputs_s.append(slider)
239
+
240
+ with gr.Tab("Eyes"):
241
+ with gr.Row():
242
+ for k in labels_k[10:12]:
243
+ slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
244
+ inputs_s.append(slider)
245
+ with gr.Row():
246
+ for k in labels_k[12:14]:
247
+ slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
248
+ inputs_s.append(slider)
249
+
250
+ with gr.Row():
251
+ with gr.Column(scale=1):
252
+ with gr.Row(): # Buttons now within a single Row
253
+ edit_btn = gr.Button("Edit")
254
+ clear_btn = gr.Button("Clear")
255
+ with gr.Row():
256
+ animate_btn = gr.Button("Generate")
257
+
258
+ edit_btn.click(
259
+ fn=edit_img,
260
+ inputs=[video_input] + inputs_s,
261
+ outputs=[image_output],
262
+ show_progress=True
263
+ )
264
+
265
+ animate_btn.click(
266
+ fn=edit_vid,
267
+ inputs=[video_input] + inputs_s, # [image_input, video_input] + inputs_s,
268
+ outputs=[video_output, video_all_output],
269
+ )
270
+
271
+ clear_btn.click(
272
+ fn=clear_media,
273
+ outputs=[image_output, video_output, video_all_output] + inputs_s
274
+ )
275
+
276
+ gr.Examples(
277
+ examples=[
278
+ ['./data/driving/driving1.mp4', 0.5, 0.5, 0, 0, 0, 0, 0,
279
+ 0, 0, 0, 0, 0, 0, 0],
280
+ ['./data/driving/driving2.mp4', 0.5, 0.5, 0, 0, 0, 0, 0, 0, 0,
281
+ 0, 0, 0, 0, 0],
282
+ ['./data/driving/driving1.mp4', 0, 0, 0, 0, 0, 0, 0,
283
+ 0, 0, 0, 0, -0.3, 0, 0],
284
+ ['./data/driving/driving3.mp4', -0.6, 0, 0, 0, 0, 0, 0,
285
+ 0, 0, 0, 0, 0, 0, 0],
286
+ ['./data/driving/driving9.mov', 0, 0, 0, 0, 0, 0, 0,
287
+ 0, 0, 0, 0, 0, -0.1, 0.07],
288
+ ],
289
+ inputs=[video_input] + inputs_s
290
+ )
291
+
292
+
293
+
networks/__init__.py ADDED
File without changes
networks/decoder.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from .ops import (ConstantInput, ConvLayer, StyledConv, ToFlow, ToRGB, Direction)
6
+
7
+
8
+ class FlowResBlock(nn.Module):
9
+ def __init__(self, in_channel, out_channel, style_dim):
10
+ super().__init__()
11
+
12
+ self.norm = nn.GroupNorm(32, out_channel)
13
+
14
+ self.conv1 = StyledConv(in_channel, out_channel, 3, style_dim, False)
15
+ self.conv2 = StyledConv(out_channel, out_channel, 3, style_dim, False)
16
+
17
+ self.gamma = nn.Parameter(1e-5 * torch.ones([1, out_channel, 1, 1]))
18
+
19
+ def forward(self, x, style):
20
+ h = x
21
+ h = self.conv1(h, style)
22
+ skip = h
23
+
24
+ h = self.norm(h)
25
+ h = self.conv2(h, style)
26
+ h = self.gamma * h
27
+
28
+ return h + skip
29
+
30
+
31
+ class ResBlock(nn.Module):
32
+ def __init__(self, in_channel, out_channel):
33
+ super().__init__()
34
+
35
+ self.conv1 = ConvLayer(in_channel, out_channel, 3, upsample=False)
36
+ self.conv2 = ConvLayer(out_channel, out_channel, 3, upsample=False)
37
+
38
+ if in_channel != out_channel:
39
+ self.skip = ConvLayer(in_channel, out_channel, 1, upsample=False, activate=False, bias=False)
40
+ else:
41
+ self.skip = torch.nn.Identity()
42
+
43
+ def forward(self, x):
44
+
45
+ h = x
46
+ h = self.conv1(h)
47
+ h = self.conv2(h)
48
+ skip = self.skip(x)
49
+
50
+ return (h + skip) / math.sqrt(2)
51
+
52
+
53
+ class Decoder(nn.Module):
54
+ def __init__(self, style_dim, motion_dim, scale=1):
55
+ super().__init__()
56
+
57
+ channels = [512*scale, 256 * scale, 128 * scale, 64 * scale]
58
+
59
+ self.direction = Direction(style_dim, motion_dim)
60
+
61
+ self.input = ConstantInput(channels[0], size=4) # 4
62
+
63
+ # block1, 4
64
+ self.conv1 = StyledConv(channels[0], channels[0], 3, style_dim, False)
65
+
66
+ # for 512
67
+ self.conv_512_1 = StyledConv(channels[0], channels[0], 3, style_dim, True)
68
+ self.conv_512_2 = nn.ModuleList([
69
+ FlowResBlock(channels[0], channels[0], style_dim),
70
+ FlowResBlock(channels[0], channels[0], style_dim),
71
+ FlowResBlock(channels[0], channels[0], style_dim),
72
+ FlowResBlock(channels[0], channels[0], style_dim),
73
+ ])
74
+ self.conv_512_2_rgb = nn.ModuleList([
75
+ ResBlock(channels[0], channels[0]),
76
+ ResBlock(channels[0], channels[0]),
77
+ ResBlock(channels[0], channels[0]),
78
+ ResBlock(channels[0], channels[0]),
79
+ ])
80
+ self.rgb_512 = ToRGB(channels[0])
81
+ self.flow_512 = ToFlow(channels[0], style_dim) # 16
82
+
83
+ # block2, 8
84
+ self.conv2_1 = StyledConv(channels[0], channels[0], 3, style_dim, True)
85
+ self.conv2_2 = nn.ModuleList([
86
+ FlowResBlock(channels[0], channels[0], style_dim),
87
+ FlowResBlock(channels[0], channels[0], style_dim),
88
+ FlowResBlock(channels[0], channels[0], style_dim),
89
+ FlowResBlock(channels[0], channels[0], style_dim),
90
+ ])
91
+ self.conv2_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True)
92
+ self.conv2_2_rgb = nn.ModuleList([
93
+ ResBlock(channels[0], channels[0]),
94
+ ResBlock(channels[0], channels[0]),
95
+ ResBlock(channels[0], channels[0]),
96
+ ResBlock(channels[0], channels[0]),
97
+ ])
98
+ self.rgb2 = ToRGB(channels[0])
99
+ self.flow2 = ToFlow(channels[0], style_dim) # 16
100
+
101
+ # block3, 16
102
+ self.conv3_1 = StyledConv(channels[0], channels[0], 3, style_dim, True)
103
+ self.conv3_2 = nn.ModuleList([
104
+ FlowResBlock(channels[0], channels[0], style_dim),
105
+ FlowResBlock(channels[0], channels[0], style_dim),
106
+ FlowResBlock(channels[0], channels[0], style_dim),
107
+ FlowResBlock(channels[0], channels[0], style_dim),
108
+ ])
109
+ self.conv3_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True)
110
+ self.conv3_2_rgb = nn.ModuleList([
111
+ ResBlock(channels[0], channels[0]),
112
+ ResBlock(channels[0], channels[0]),
113
+ ResBlock(channels[0], channels[0]),
114
+ ResBlock(channels[0], channels[0]),
115
+ ])
116
+ self.rgb3 = ToRGB(channels[0])
117
+ self.flow3 = ToFlow(channels[0], style_dim) # 32
118
+
119
+ # block4, 32
120
+ self.conv4_1 = StyledConv(channels[0], channels[0], 3, style_dim, True)
121
+ self.conv4_2 = nn.ModuleList([
122
+ FlowResBlock(channels[0], channels[0], style_dim),
123
+ FlowResBlock(channels[0], channels[0], style_dim),
124
+ FlowResBlock(channels[0], channels[0], style_dim),
125
+ FlowResBlock(channels[0], channels[0], style_dim),
126
+ ])
127
+ self.conv4_2_up = ConvLayer(channels[0], channels[0], 3, upsample=True)
128
+ self.conv4_2_rgb = nn.ModuleList([
129
+ ResBlock(channels[0], channels[0]),
130
+ ResBlock(channels[0], channels[0]),
131
+ ResBlock(channels[0], channels[0]),
132
+ ResBlock(channels[0], channels[0]),
133
+ ])
134
+ self.rgb4 = ToRGB(channels[0])
135
+ self.flow4 = ToFlow(channels[0], style_dim) # 64
136
+
137
+ # block5, 64
138
+ self.conv5_1 = StyledConv(channels[0], channels[1], 3, style_dim, True)
139
+ self.conv5_2 = nn.ModuleList([
140
+ FlowResBlock(channels[1], channels[1], style_dim),
141
+ FlowResBlock(channels[1], channels[1], style_dim),
142
+ FlowResBlock(channels[1], channels[1], style_dim),
143
+ FlowResBlock(channels[1], channels[1], style_dim),
144
+ ])
145
+ self.conv5_2_up = ConvLayer(channels[0], channels[1], 3, upsample=True)
146
+ self.conv5_2_rgb = nn.ModuleList([
147
+ ResBlock(channels[1], channels[1]),
148
+ ResBlock(channels[1], channels[1]),
149
+ ResBlock(channels[1], channels[1]),
150
+ ResBlock(channels[1], channels[1]),
151
+ ])
152
+ self.rgb5 = ToRGB(channels[1])
153
+ self.flow5 = ToFlow(channels[1], style_dim) # 128
154
+
155
+ # block6, 128
156
+ self.conv6_1 = StyledConv(channels[1], channels[2], 3, style_dim, True)
157
+ self.conv6_2 = nn.ModuleList([
158
+ FlowResBlock(channels[2], channels[2], style_dim),
159
+ FlowResBlock(channels[2], channels[2], style_dim),
160
+ FlowResBlock(channels[2], channels[2], style_dim),
161
+ FlowResBlock(channels[2], channels[2], style_dim),
162
+ ])
163
+ self.conv6_2_up = ConvLayer(channels[1], channels[2], 3, upsample=True)
164
+ self.conv6_2_rgb = nn.ModuleList([
165
+ ResBlock(channels[2], channels[2]),
166
+ ResBlock(channels[2], channels[2]),
167
+ ResBlock(channels[2], channels[2]),
168
+ ResBlock(channels[2], channels[2]),
169
+ ])
170
+ self.rgb6 = ToRGB(channels[2])
171
+ self.flow6 = ToFlow(channels[2], style_dim) # 128
172
+
173
+ # block7, 256
174
+ self.conv7_1 = StyledConv(channels[2], channels[3], 3, style_dim, True)
175
+ self.conv7_2 = nn.ModuleList([
176
+ FlowResBlock(channels[3], channels[3], style_dim),
177
+ FlowResBlock(channels[3], channels[3], style_dim),
178
+ FlowResBlock(channels[3], channels[3], style_dim),
179
+ FlowResBlock(channels[3], channels[3], style_dim),
180
+ ])
181
+ self.conv7_2_up = ConvLayer(channels[2], channels[3], 3, upsample=True)
182
+ self.conv7_2_rgb = nn.ModuleList([
183
+ ResBlock(channels[3], channels[3]),
184
+ ResBlock(channels[3], channels[3]),
185
+ ResBlock(channels[3], channels[3]),
186
+ ResBlock(channels[3], channels[3]),
187
+ ])
188
+ self.rgb7 = ToRGB(channels[3])
189
+ self.flow7 = ToFlow(channels[3], style_dim) # 128
190
+
191
+ def navigation(self, z_s2r, alpha):
192
+
193
+ if alpha is not None:
194
+ # generating moving directions
195
+ if len(alpha) > 1:
196
+ z_r2t = self.direction(alpha[0]) # target
197
+ z_r2s = self.direction(alpha[1]) # source
198
+ z_start = self.direction(alpha[2]) # start
199
+ z_s2t = z_s2r + (z_r2t - z_start) + z_r2s
200
+ else:
201
+ z_r2t = self.direction(alpha[0])
202
+ z_s2t = z_s2r + z_r2t # wa + directions
203
+ else:
204
+ z_s2t = z_s2r
205
+
206
+ return z_s2t
207
+
208
+ def apply_flow(self, h, mask, flow, feat):
209
+
210
+ feat_warp = F.grid_sample(feat, flow) * mask
211
+ h = feat_warp + (1 - mask) * h
212
+
213
+ return feat_warp, h
214
+
215
+ def forward(self, z_s2r, alpha, feats):
216
+ # z_s2r: bs x style_dim
217
+ # alpha: bs x style_dim
218
+
219
+ z_s2t = self.navigation(z_s2r, alpha)
220
+
221
+ h = self.input(z_s2t)
222
+ h = self.conv1(h, z_s2t)
223
+
224
+ #for 512
225
+ h = self.conv_512_1(h, z_s2t)
226
+ for conv in self.conv_512_2:
227
+ h = conv(h, z_s2t)
228
+ h_warp_512, h, h_flow_512 = self.flow_512(h, z_s2t, feats[0])
229
+ for conv in self.conv_512_2_rgb:
230
+ h_warp_512 = conv(h_warp_512)
231
+ rgb_512 = self.rgb_512(h_warp_512)
232
+
233
+ h = self.conv2_1(h, z_s2t)
234
+ for conv in self.conv2_2:
235
+ h = conv(h, z_s2t)
236
+ h_warp2, h, h_flow2 = self.flow2(h, z_s2t, feats[1], h_flow_512)
237
+ h_warp2 = h_warp2 + self.conv2_2_up(h_warp_512)
238
+ for conv in self.conv2_2_rgb:
239
+ h_warp2 = conv(h_warp2)
240
+ rgb2 = self.rgb2(h_warp2, rgb_512)
241
+
242
+ h = self.conv3_1(h, z_s2t)
243
+ for conv in self.conv3_2:
244
+ h = conv(h, z_s2t)
245
+ h_warp3, h, h_flow3 = self.flow3(h, z_s2t, feats[2], h_flow2)
246
+ h_warp3 = h_warp3 + self.conv3_2_up(h_warp2)
247
+ for conv in self.conv3_2_rgb:
248
+ h_warp3 = conv(h_warp3)
249
+ rgb3 = self.rgb3(h_warp3, rgb2)
250
+
251
+ h = self.conv4_1(h, z_s2t)
252
+ for conv in self.conv4_2:
253
+ h = conv(h, z_s2t)
254
+ h_warp4, h, h_flow4 = self.flow4(h, z_s2t, feats[3], h_flow3)
255
+ h_warp4 = h_warp4 + self.conv4_2_up(h_warp3)
256
+ for conv in self.conv4_2_rgb:
257
+ h_warp4 = conv(h_warp4)
258
+ rgb4 = self.rgb4(h_warp4, rgb3)
259
+
260
+ h = self.conv5_1(h, z_s2t)
261
+ for conv in self.conv5_2:
262
+ h = conv(h, z_s2t)
263
+ h_warp5, h, h_flow5 = self.flow5(h, z_s2t, feats[4], h_flow4)
264
+ h_warp5 = h_warp5 + self.conv5_2_up(h_warp4)
265
+ for conv in self.conv5_2_rgb:
266
+ h_warp5 = conv(h_warp5)
267
+ rgb5 = self.rgb5(h_warp5, rgb4)
268
+
269
+ h = self.conv6_1(h, z_s2t)
270
+ for conv in self.conv6_2:
271
+ h = conv(h, z_s2t)
272
+ h_warp6, h, h_flow6 = self.flow6(h, z_s2t, feats[5], h_flow5)
273
+ h_warp6 = h_warp6 + self.conv6_2_up(h_warp5)
274
+ for conv in self.conv6_2_rgb:
275
+ h_warp6 = conv(h_warp6)
276
+ rgb6 = self.rgb6(h_warp6, rgb5)
277
+
278
+ h = self.conv7_1(h, z_s2t)
279
+ for conv in self.conv7_2:
280
+ h = conv(h, z_s2t)
281
+ h_warp7, h, h_flow7 = self.flow7(h, z_s2t, feats[6], h_flow6)
282
+ h_warp7 = h_warp7 + self.conv7_2_up(h_warp6)
283
+ for conv in self.conv7_2_rgb:
284
+ h_warp7 = conv(h_warp7)
285
+ out = self.rgb7(h_warp7, rgb6)
286
+
287
+ return out
networks/encoder.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+ from .ops import (EqualConv2d, EqualLinear, ConvLayer)
6
+
7
+
8
+ class ResBlock(nn.Module):
9
+ def __init__(self, in_channel, out_channel):
10
+ super().__init__()
11
+
12
+ self.conv1 = ConvLayer(in_channel, out_channel, 3)
13
+ self.conv2 = ConvLayer(out_channel, out_channel, 3, downsample=True)
14
+
15
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
16
+
17
+
18
+ def forward(self, x):
19
+
20
+ h = x
21
+
22
+ h = self.conv1(h)
23
+ h = self.conv2(h)
24
+
25
+ skip = self.skip(x)
26
+ h = (h + skip) / math.sqrt(2)
27
+
28
+ return h
29
+
30
+
31
+ class Encoder2R(nn.Module):
32
+ def __init__(self, latent_dim=512, scale=1):
33
+ super(Encoder2R, self).__init__()
34
+
35
+ channels = [64*scale, 128*scale, 256*scale, 512*scale]
36
+
37
+ # version1
38
+ self.block1 = ConvLayer(3, channels[0], 1) # 256, 3 -> 64
39
+ self.block2 = nn.Sequential(
40
+ ResBlock(channels[0], channels[1])
41
+ ) # 64 -> 128
42
+ self.block3 = nn.Sequential(
43
+ ResBlock(channels[1], channels[2])
44
+ ) # 128 -> 256
45
+ self.block4 = nn.Sequential(
46
+ ResBlock(channels[2], channels[3])
47
+ ) # 256 -> 512
48
+ self.block5 = nn.Sequential(
49
+ ResBlock(channels[3], channels[3])
50
+ ) # 512 -> 512
51
+ self.block6 = nn.Sequential(
52
+ ResBlock(channels[3], channels[3])
53
+ ) # 512 -> 512
54
+ self.block7 = nn.Sequential(
55
+ ResBlock(channels[3], channels[3])
56
+ ) # 512 -> 512
57
+
58
+ self.block_512 = ResBlock(channels[3], channels[3])
59
+ self.block8 = EqualConv2d(channels[3], latent_dim, 4, padding=0, bias=False)
60
+
61
+ def forward(self, x):
62
+
63
+ res = []
64
+ h = x
65
+ h = self.block1(h) # 256
66
+ res.append(h)
67
+ h = self.block2(h) # 128
68
+ res.append(h)
69
+ h = self.block3(h) # 64
70
+ res.append(h)
71
+ h = self.block4(h) # 32
72
+ res.append(h)
73
+ h = self.block5(h) # 16
74
+ res.append(h)
75
+ h = self.block6(h) # 8
76
+ res.append(h)
77
+ h = self.block7(h) # 4
78
+ res.append(h)
79
+ h = self.block_512(h)
80
+ h = self.block8(h) # 1
81
+
82
+ return h.squeeze(-1).squeeze(-1), res[::-1]
83
+
84
+
85
+ class Encoder(nn.Module):
86
+ def __init__(self, dim=512, dim_motion=20, scale=1):
87
+ super(Encoder, self).__init__()
88
+
89
+ # 2R netmork
90
+ self.enc_2r = Encoder2R(dim, scale)
91
+
92
+ # R2T
93
+ self.enc_r2t = nn.Sequential(
94
+ EqualLinear(dim, dim_motion)
95
+ )
96
+
97
+ def enc_motion(self, x):
98
+
99
+ z_t2r, _ = self.enc_2r(x)
100
+ alpha_r2t = self.enc_r2t(z_t2r)
101
+
102
+ return alpha_r2t
103
+
104
+
105
+ def enc_transfer_img(self, z_s2r, d_l, s_l):
106
+
107
+ alpha_r2s = self.enc_r2t(z_s2r)
108
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(s_l).unsqueeze(0).to('cuda')
109
+ alpha = [alpha_r2s]
110
+
111
+ return alpha
112
+
113
+ def enc_transfer_vid(self, alpha_r2s, input_target, alpha_start):
114
+
115
+ z_t2r, _ = self.enc_2r(input_target)
116
+ alpha_r2t = self.enc_r2t(z_t2r)
117
+ alpha = [alpha_r2t, alpha_r2s, alpha_start]
118
+
119
+ return alpha
120
+
121
+
122
+ def forward(self, input_source, input_target, alpha_start=None):
123
+
124
+ if input_target is not None:
125
+
126
+ z_s2r, feats = self.enc_2r(input_source)
127
+ z_t2r, _ = self.enc_2r(input_target)
128
+
129
+ alpha_r2t = self.enc_r2t(z_t2r)
130
+
131
+ if alpha_start is not None:
132
+ alpha_r2s = self.enc_r2t(z_s2r)
133
+ alpha = [alpha_r2t, alpha_r2s, alpha_start]
134
+ else:
135
+ alpha = [alpha_r2t]
136
+
137
+ return z_s2r, alpha, feats
138
+ else:
139
+ z_s2r, feats = self.enc_2r(input_source)
140
+
141
+ return z_s2r, None, feats
networks/generator.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from networks.encoder import Encoder
4
+ from networks.decoder import Decoder
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+
8
+
9
+ class Generator(nn.Module):
10
+ def __init__(self, size, style_dim=512, motion_dim=40, scale=1):
11
+ super(Generator, self).__init__()
12
+
13
+ style_dim = style_dim * scale
14
+
15
+ # encoder
16
+ self.enc = Encoder(style_dim, motion_dim, scale)
17
+ self.dec = Decoder(style_dim, motion_dim, scale)
18
+
19
+ def get_alpha(self, x):
20
+ return self.enc.enc_motion(x)
21
+
22
+ def edit_img(self, img_source, d_l, v_l):
23
+
24
+ z_s2r, feat_rgb = self.enc.enc_2r(img_source)
25
+ alpha_r2s = self.enc.enc_r2t(z_s2r)
26
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
27
+ img_recon = self.dec(z_s2r, [alpha_r2s], feat_rgb)
28
+
29
+ return img_recon
30
+
31
+ def animate(self, img_source, vid_target, d_l, v_l):
32
+
33
+ alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
34
+
35
+ vid_target_recon = []
36
+ z_s2r, feat_rgb = self.enc.enc_2r(img_source)
37
+ alpha_r2s = self.enc.enc_r2t(z_s2r)
38
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
39
+
40
+ for i in tqdm(range(vid_target.size(1))):
41
+ img_target = vid_target[:, i, :, :, :]
42
+ alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start)
43
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
44
+ vid_target_recon.append(img_recon.unsqueeze(2))
45
+ vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
46
+
47
+ return vid_target_recon
48
+
49
+ def edit_vid(self, vid_target, d_l, v_l):
50
+
51
+ img_source = vid_target[:, 0, :, :, :]
52
+ alpha_start = self.get_alpha(vid_target[:, 0, :, :, :])
53
+
54
+ vid_target_recon = []
55
+ z_s2r, feat_rgb = self.enc.enc_2r(img_source)
56
+ alpha_r2s = self.enc.enc_r2t(z_s2r)
57
+ alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(v_l).unsqueeze(0).to('cuda')
58
+
59
+ for i in tqdm(range(vid_target.size(1))):
60
+ img_target = vid_target[:, i, :, :, :]
61
+ alpha = self.enc.enc_transfer_vid(alpha_r2s, img_target, alpha_start)
62
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
63
+ vid_target_recon.append(img_recon.unsqueeze(2))
64
+ vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
65
+
66
+ return vid_target_recon
67
+
68
+
69
+ def interpolate_img(self, img_source, d_l, v_l):
70
+
71
+ vid_target_recon = []
72
+
73
+ step = 16
74
+ v_start = np.array([0.] * len(v_l))
75
+ v_end = np.array(v_l)
76
+ stride = (v_end - v_start) / step
77
+
78
+ z_s2r, feat_rgb = self.enc.enc_2r(img_source)
79
+
80
+ v_tmp = v_start
81
+ for i in range(step):
82
+ v_tmp = v_tmp + stride
83
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
84
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
85
+ vid_target_recon.append(img_recon.unsqueeze(2))
86
+
87
+ for i in range(step):
88
+ v_tmp = v_tmp - stride
89
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
90
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
91
+ vid_target_recon.append(img_recon.unsqueeze(2))
92
+
93
+ if (v_l[6]!=0) or (v_l[7]!=0) or (v_l[8]!=0) or (v_l[9]!=0):
94
+ for i in range(step):
95
+ v_tmp = v_tmp + stride
96
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
97
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
98
+ vid_target_recon.append(img_recon.unsqueeze(2))
99
+
100
+ for i in range(step):
101
+ v_tmp = v_tmp - stride
102
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
103
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
104
+ vid_target_recon.append(img_recon.unsqueeze(2))
105
+ else:
106
+ for i in range(step):
107
+ v_tmp = v_tmp - stride
108
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
109
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
110
+ vid_target_recon.append(img_recon.unsqueeze(2))
111
+
112
+ for i in range(step):
113
+ v_tmp = v_tmp + stride
114
+ alpha = self.enc.enc_transfer_img(z_s2r, d_l, v_tmp)
115
+ img_recon = self.dec(z_s2r, alpha, feat_rgb)
116
+ vid_target_recon.append(img_recon.unsqueeze(2))
117
+
118
+ vid_target_recon = torch.cat(vid_target_recon, dim=2) # BCTHW
119
+
120
+ return vid_target_recon
121
+
networks/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
networks/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
networks/op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias,
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
networks/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
networks/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
networks/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
networks/op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel,
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel,
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
networks/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
networks/op/upfirdn2d_new.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.autograd import Function
6
+ from torch.utils.cpp_extension import load
7
+ #from util import is_custom_kernel_supported as is_custom_kernel_supported
8
+
9
+ def is_custom_kernel_supported():
10
+ version_str = str(torch.version.cuda).split(".")
11
+ major = version_str[0]
12
+ minor = version_str[1]
13
+ return int(major) >= 10 and int(minor) >= 1
14
+
15
+
16
+ if is_custom_kernel_supported():
17
+ print("Loading custom kernel...")
18
+ module_path = os.path.dirname(__file__)
19
+ upfirdn2d_op = load(
20
+ 'upfirdn2d',
21
+ sources=[
22
+ os.path.join(module_path, 'upfirdn2d.cpp'),
23
+ os.path.join(module_path, 'upfirdn2d_kernel.cu'),
24
+ ],
25
+ verbose=True
26
+ )
27
+
28
+ use_custom_kernel = is_custom_kernel_supported()
29
+
30
+
31
+ class UpFirDn2dBackward(Function):
32
+ @staticmethod
33
+ def forward(
34
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
35
+ ):
36
+
37
+ up_x, up_y = up
38
+ down_x, down_y = down
39
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
40
+
41
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
42
+
43
+ grad_input = upfirdn2d_op.upfirdn2d(
44
+ grad_output,
45
+ grad_kernel,
46
+ down_x,
47
+ down_y,
48
+ up_x,
49
+ up_y,
50
+ g_pad_x0,
51
+ g_pad_x1,
52
+ g_pad_y0,
53
+ g_pad_y1,
54
+ )
55
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
56
+
57
+ ctx.save_for_backward(kernel)
58
+
59
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
60
+
61
+ ctx.up_x = up_x
62
+ ctx.up_y = up_y
63
+ ctx.down_x = down_x
64
+ ctx.down_y = down_y
65
+ ctx.pad_x0 = pad_x0
66
+ ctx.pad_x1 = pad_x1
67
+ ctx.pad_y0 = pad_y0
68
+ ctx.pad_y1 = pad_y1
69
+ ctx.in_size = in_size
70
+ ctx.out_size = out_size
71
+
72
+ return grad_input
73
+
74
+ @staticmethod
75
+ def backward(ctx, gradgrad_input):
76
+ kernel, = ctx.saved_tensors
77
+
78
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
79
+
80
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
81
+ gradgrad_input,
82
+ kernel,
83
+ ctx.up_x,
84
+ ctx.up_y,
85
+ ctx.down_x,
86
+ ctx.down_y,
87
+ ctx.pad_x0,
88
+ ctx.pad_x1,
89
+ ctx.pad_y0,
90
+ ctx.pad_y1,
91
+ )
92
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
93
+ gradgrad_out = gradgrad_out.view(
94
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
95
+ )
96
+
97
+ return gradgrad_out, None, None, None, None, None, None, None, None
98
+
99
+
100
+ class UpFirDn2d(Function):
101
+ @staticmethod
102
+ def forward(ctx, input, kernel, up, down, pad):
103
+ up_x, up_y = up
104
+ down_x, down_y = down
105
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
106
+
107
+ kernel_h, kernel_w = kernel.shape
108
+ batch, channel, in_h, in_w = input.shape
109
+ ctx.in_size = input.shape
110
+
111
+ input = input.reshape(-1, in_h, in_w, 1)
112
+
113
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
114
+
115
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
116
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
117
+ ctx.out_size = (out_h, out_w)
118
+
119
+ ctx.up = (up_x, up_y)
120
+ ctx.down = (down_x, down_y)
121
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
122
+
123
+ g_pad_x0 = kernel_w - pad_x0 - 1
124
+ g_pad_y0 = kernel_h - pad_y0 - 1
125
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
126
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
127
+
128
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
129
+
130
+ out = upfirdn2d_op.upfirdn2d(
131
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
132
+ )
133
+ # out = out.view(major, out_h, out_w, minor)
134
+ out = out.view(-1, channel, out_h, out_w)
135
+
136
+ return out
137
+
138
+ @staticmethod
139
+ def backward(ctx, grad_output):
140
+ kernel, grad_kernel = ctx.saved_tensors
141
+
142
+ grad_input = UpFirDn2dBackward.apply(
143
+ grad_output,
144
+ kernel,
145
+ grad_kernel,
146
+ ctx.up,
147
+ ctx.down,
148
+ ctx.pad,
149
+ ctx.g_pad,
150
+ ctx.in_size,
151
+ ctx.out_size,
152
+ )
153
+
154
+ return grad_input, None, None, None, None
155
+
156
+
157
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
158
+ global use_custom_kernel
159
+ if use_custom_kernel:
160
+ out = UpFirDn2d.apply(
161
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
162
+ )
163
+ else:
164
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
165
+
166
+ return out
167
+
168
+
169
+ def upfirdn2d_native(
170
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
171
+ ):
172
+ bs, ch, in_h, in_w = input.shape
173
+ minor = 1
174
+ kernel_h, kernel_w = kernel.shape
175
+
176
+ #assert kernel_h == 1 and kernel_w == 1
177
+
178
+ #print("original shape ", input.shape, up_x, down_x, pad_x0, pad_x1)
179
+
180
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
181
+ if up_x > 1 or up_y > 1:
182
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
183
+
184
+ #print("after padding ", out.shape)
185
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
186
+
187
+ #print("after reshaping ", out.shape)
188
+
189
+ if pad_x0 > 0 or pad_x1 > 0 or pad_y0 > 0 or pad_y1 > 0:
190
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
191
+
192
+ #print("after second padding ", out.shape)
193
+ out = out[
194
+ :,
195
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
196
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
197
+ :,
198
+ ]
199
+
200
+ #print("after trimming ", out.shape)
201
+
202
+ out = out.permute(0, 3, 1, 2)
203
+ out = out.reshape(
204
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
205
+ )
206
+
207
+ #print("after reshaping", out.shape)
208
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
209
+ out = F.conv2d(out, w)
210
+
211
+ #print("after conv ", out.shape)
212
+ out = out.reshape(
213
+ -1,
214
+ minor,
215
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
216
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
217
+ )
218
+
219
+ out = out.permute(0, 2, 3, 1)
220
+
221
+ #print("after permuting ", out.shape)
222
+
223
+ out = out[:, ::down_y, ::down_x, :]
224
+
225
+ out = out.view(bs, ch, out.size(1), out.size(2))
226
+
227
+ #print("final shape ", out.shape)
228
+
229
+ return out
230
+
networks/ops.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from .op import (FusedLeakyReLU, fused_leaky_relu, upfirdn2d)
6
+ import numpy as np
7
+
8
+
9
+ def make_kernel(k):
10
+ k = torch.tensor(k, dtype=torch.float32)
11
+
12
+ if k.ndim == 1:
13
+ k = k[None, :] * k[:, None]
14
+
15
+ k /= k.sum()
16
+
17
+ return k
18
+
19
+
20
+ class Blur(nn.Module):
21
+ def __init__(self, kernel, pad, upsample_factor=1):
22
+ super().__init__()
23
+
24
+ kernel = make_kernel(kernel)
25
+
26
+ if upsample_factor > 1:
27
+ kernel = kernel * (upsample_factor ** 2)
28
+
29
+ self.register_buffer('kernel', kernel)
30
+
31
+ self.pad = pad
32
+
33
+ def forward(self, input):
34
+ return upfirdn2d(input, self.kernel, pad=self.pad)
35
+
36
+
37
+ class ScaledLeakyReLU(nn.Module):
38
+ def __init__(self, negative_slope=0.2):
39
+ super().__init__()
40
+
41
+ self.negative_slope = negative_slope
42
+
43
+ def forward(self, input):
44
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
45
+
46
+
47
+ class EqualConv2d(nn.Module):
48
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
49
+ super().__init__()
50
+
51
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
52
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
53
+
54
+ self.stride = stride
55
+ self.padding = padding
56
+
57
+ if bias:
58
+ self.bias = nn.Parameter(torch.zeros(out_channel))
59
+ else:
60
+ self.bias = None
61
+
62
+ def forward(self, input):
63
+
64
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
65
+
66
+ def __repr__(self):
67
+ return (
68
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
69
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
70
+ )
71
+
72
+
73
+ class EqualLinear(nn.Module):
74
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
75
+ super().__init__()
76
+
77
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
78
+
79
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_dim])
80
+ if bias:
81
+ self.bias = nn.Parameter(torch.from_numpy(bias_init / lr_mul))
82
+ #self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
83
+ else:
84
+ self.bias = None
85
+
86
+ self.activation = activation
87
+
88
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
89
+ self.lr_mul = lr_mul
90
+
91
+ def forward(self, input):
92
+
93
+ if self.activation:
94
+ out = F.linear(input, self.weight * self.scale)
95
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
96
+ else:
97
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
98
+
99
+ return out
100
+
101
+ def __repr__(self):
102
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
103
+
104
+
105
+ class ConvLayer(nn.Sequential):
106
+ def __init__(
107
+ self,
108
+ in_channel,
109
+ out_channel,
110
+ kernel_size,
111
+ downsample=False,
112
+ upsample=False,
113
+ blur_kernel=[1, 3, 3, 1],
114
+ bias=True,
115
+ activate=True,
116
+ ):
117
+ layers = []
118
+
119
+ if downsample:
120
+ factor = 2
121
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
122
+ pad0 = (p + 1) // 2
123
+ pad1 = p // 2
124
+
125
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
126
+
127
+ stride = 2
128
+ self.padding = 0
129
+
130
+ elif upsample:
131
+ layers.append(Upsample(blur_kernel))
132
+
133
+ stride = 1
134
+ self.padding = kernel_size // 2
135
+ else:
136
+ stride = 1
137
+ self.padding = kernel_size // 2
138
+
139
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
140
+ bias=bias and not activate))
141
+
142
+ if activate:
143
+ if bias:
144
+ layers.append(FusedLeakyReLU(out_channel))
145
+ else:
146
+ layers.append(ScaledLeakyReLU(0.2))
147
+
148
+ super().__init__(*layers)
149
+
150
+
151
+ class ResBlock(nn.Module):
152
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
153
+ super().__init__()
154
+
155
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
156
+ self.conv2 = ConvLayer(in_channel, out_channel, 3)
157
+ self.skip = nn.Identity()
158
+
159
+ def forward(self, input):
160
+ out = self.conv1(input)
161
+ out = self.conv2(out)
162
+
163
+ skip = self.skip(input)
164
+ out = (out + skip) / math.sqrt(2)
165
+
166
+ return out
167
+
168
+
169
+ class ResDownBlock(nn.Module):
170
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
171
+ super().__init__()
172
+
173
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
174
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
175
+
176
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
177
+
178
+ def forward(self, input):
179
+ out = self.conv1(input)
180
+ out = self.conv2(out)
181
+
182
+ skip = self.skip(input)
183
+ out = (out + skip) / math.sqrt(2)
184
+
185
+ return out
186
+
187
+
188
+ class ResUpBlock(nn.Module):
189
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
190
+ super().__init__()
191
+
192
+ self.conv1 = ConvLayer(in_channel, out_channel, 3, upsample=True)
193
+ self.conv2 = ConvLayer(out_channel, out_channel, 3, upsample=False)
194
+
195
+ if in_channel != out_channel:
196
+ self.skip = ConvLayer(in_channel, out_channel, 1, upsample=True, activate=False, bias=False)
197
+ else:
198
+ self.skip = torch.nn.Identity()
199
+
200
+ def forward(self, x):
201
+ out = self.conv1(x)
202
+ out = self.conv2(out)
203
+
204
+ skip = self.skip(x)
205
+ out = (out + skip) / math.sqrt(2)
206
+
207
+ return out
208
+
209
+
210
+ class Upsample(nn.Module):
211
+ def __init__(self, kernel, factor=2):
212
+ super().__init__()
213
+
214
+ self.factor = factor
215
+ kernel = make_kernel(kernel) * (factor ** 2)
216
+ self.register_buffer('kernel', kernel)
217
+
218
+ p = kernel.shape[0] - factor
219
+
220
+ pad0 = (p + 1) // 2 + factor - 1
221
+ pad1 = p // 2
222
+
223
+ self.pad = (pad0, pad1)
224
+
225
+ def forward(self, input):
226
+ return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
227
+
228
+
229
+ class Downsample(nn.Module):
230
+ def __init__(self, kernel, factor=2):
231
+ super().__init__()
232
+
233
+ self.factor = factor
234
+ kernel = make_kernel(kernel)
235
+ self.register_buffer('kernel', kernel)
236
+
237
+ p = kernel.shape[0] - factor
238
+
239
+ pad0 = (p + 1) // 2
240
+ pad1 = p // 2
241
+
242
+ self.pad = (pad0, pad1)
243
+
244
+ def forward(self, input):
245
+ return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
246
+
247
+
248
+ class ModulatedConv2d(nn.Module):
249
+ def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False,
250
+ downsample=False, blur_kernel=[1, 3, 3, 1], ):
251
+ super().__init__()
252
+
253
+ self.eps = 1e-8
254
+ self.kernel_size = kernel_size
255
+ self.in_channel = in_channel
256
+ self.out_channel = out_channel
257
+ self.upsample = upsample
258
+ self.downsample = downsample
259
+
260
+ if upsample:
261
+ factor = 2
262
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
263
+ pad0 = (p + 1) // 2 + factor - 1
264
+ pad1 = p // 2 + 1
265
+
266
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
267
+
268
+ if downsample:
269
+ factor = 2
270
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
271
+ pad0 = (p + 1) // 2
272
+ pad1 = p // 2
273
+
274
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
275
+
276
+ fan_in = in_channel * kernel_size ** 2
277
+ self.scale = 1 / math.sqrt(fan_in)
278
+ self.padding = kernel_size // 2
279
+
280
+ self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))
281
+
282
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
283
+ self.demodulate = demodulate
284
+
285
+ def __repr__(self):
286
+ return (
287
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
288
+ f'upsample={self.upsample}, downsample={self.downsample})'
289
+ )
290
+
291
+ def forward(self, input, style):
292
+ batch, in_channel, height, width = input.shape
293
+
294
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
295
+ weight = self.scale * self.weight * style
296
+
297
+ if self.demodulate:
298
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
299
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
300
+
301
+ weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
302
+
303
+ if self.upsample:
304
+ input = input.view(1, batch * in_channel, height, width)
305
+ weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)
306
+ weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size,
307
+ self.kernel_size)
308
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
309
+ _, _, height, width = out.shape
310
+ out = out.view(batch, self.out_channel, height, width)
311
+ out = self.blur(out)
312
+ elif self.downsample:
313
+ input = self.blur(input)
314
+ _, _, height, width = input.shape
315
+ input = input.view(1, batch * in_channel, height, width)
316
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
317
+ _, _, height, width = out.shape
318
+ out = out.view(batch, self.out_channel, height, width)
319
+ else:
320
+ input = input.view(1, batch * in_channel, height, width)
321
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
322
+ _, _, height, width = out.shape
323
+ out = out.view(batch, self.out_channel, height, width)
324
+
325
+ return out
326
+
327
+
328
+ class ConstantInput(nn.Module):
329
+ def __init__(self, channel, size=4):
330
+ super().__init__()
331
+
332
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
333
+
334
+ def forward(self, input):
335
+ batch = input.shape[0]
336
+ out = self.input.repeat(batch, 1, 1, 1)
337
+
338
+ return out
339
+
340
+ class StyledConv(nn.Module):
341
+ def __init__(self, in_channel, out_channel, kernel_size, style_dim, upsample=False, demodulate=True):
342
+ super().__init__()
343
+
344
+ self.conv = ModulatedConv2d(
345
+ in_channel,
346
+ out_channel,
347
+ kernel_size,
348
+ style_dim,
349
+ upsample=upsample,
350
+ blur_kernel=[1,3,3,1],
351
+ demodulate=demodulate,
352
+ )
353
+
354
+ self.activate = FusedLeakyReLU(out_channel)
355
+
356
+ def forward(self, input, style):
357
+ out = self.conv(input, style)
358
+ out = self.activate(out)
359
+
360
+ return out
361
+
362
+ class ToRGB(nn.Module):
363
+ def __init__(self, in_channel, upsample=True, blur_kernel=[1, 3, 3, 1]):
364
+ super().__init__()
365
+
366
+ self.upsample = upsample
367
+
368
+ if upsample:
369
+ self.up = Upsample(blur_kernel)
370
+
371
+ self.conv = ConvLayer(in_channel, 3, 1)
372
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
373
+
374
+ def forward(self, input, skip=None):
375
+ out = self.conv(input)
376
+ out = out + self.bias
377
+
378
+ if skip is not None:
379
+ skip = self.up(skip)
380
+ out = out + skip
381
+
382
+ return out
383
+
384
+ class ToFlow(nn.Module):
385
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
386
+ super().__init__()
387
+
388
+ self.upsample = upsample
389
+ if upsample:
390
+ self.up = Upsample(blur_kernel)
391
+
392
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
393
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
394
+
395
+ def forward(self, h, style, feat, skip=None):
396
+
397
+ out = self.conv(h, style)
398
+ out = out + self.bias
399
+
400
+ if skip is not None:
401
+ if self.upsample:
402
+ skip = self.up(skip)
403
+ out = out + skip
404
+
405
+ xs = torch.linspace(-1, 1, out.size(2)).to(h.device)
406
+ xs = torch.meshgrid(xs, xs, indexing='xy')
407
+ xs = torch.stack(xs, 2)
408
+ xs = xs.unsqueeze(0).repeat(out.size(0), 1, 1, 1)
409
+
410
+ sampler = torch.tanh(out[:, 0:2, :, :])
411
+ mask = torch.sigmoid(out[:, 2:3, :, :])
412
+ flow = sampler.permute(0, 2, 3, 1) + xs
413
+
414
+ feat_warp = F.grid_sample(feat, flow, align_corners=True) * mask
415
+ h = feat_warp + (1 - mask) * h
416
+
417
+ #return h, out
418
+ return feat_warp, h, out
419
+
420
+
421
+ class Direction(nn.Module):
422
+ def __init__(self, style_dim, motion_dim):
423
+ super(Direction, self).__init__()
424
+
425
+ self.weight = nn.Parameter(torch.randn(style_dim, motion_dim))
426
+
427
+ def forward(self, input):
428
+ # input: (bs*t) x 512
429
+
430
+ weight = self.weight + 1e-8
431
+ Q, R = torch.linalg.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4]
432
+
433
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
434
+ out = torch.matmul(input_diag, Q.T)
435
+ out = torch.sum(out, dim=1)
436
+
437
+ return out
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+
454
+
455
+
456
+
457
+
458
+
459
+
460
+
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+
471
+
472
+
473
+
474
+
475
+
476
+
477
+
478
+
479
+
480
+
481
+
482
+
483
+
484
+
485
+
486
+
487
+
488
+
489
+
490
+
utils/__init__.py ADDED
File without changes
utils/data_processing.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from PIL import Image
5
+ import numpy as np
6
+ import imageio
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ def load_image(img, size):
11
+ # img = Image.open(filename).convert('RGB')
12
+ if not isinstance(img, np.ndarray):
13
+ img = Image.open(img).convert('RGB')
14
+ img = img.resize((size, size))
15
+ img = np.asarray(img)
16
+ img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
17
+
18
+ return img / 255.0
19
+
20
+
21
+ def img_preprocessing(img_path, size):
22
+ img = load_image(img_path, size) # [0, 1]
23
+ img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
24
+ imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
25
+
26
+ return imgs_norm
27
+
28
+
29
+ def resize(img, size):
30
+ transform = torchvision.transforms.Compose([
31
+ torchvision.transforms.Resize(size, antialias=True),
32
+ torchvision.transforms.CenterCrop(size)
33
+ ])
34
+
35
+ return transform(img)
36
+
37
+
38
+ def vid_preprocessing(vid_path, size):
39
+ vid_dict = torchvision.io.read_video(vid_path, pts_unit='sec')
40
+ vid = vid_dict[0].permute(0, 3, 1, 2).unsqueeze(0) # btchw
41
+ fps = vid_dict[2]['video_fps']
42
+ vid_norm = (vid / 255.0 - 0.5) * 2.0 # [-1, 1]
43
+
44
+ vid_norm = torch.cat([
45
+ resize(vid_norm[:, i, :, :, :], size).unsqueeze(1) for i in range(vid.size(1))
46
+ ], dim=1)
47
+
48
+ return vid_norm, fps
49
+
50
+
51
+ def img_denorm(img):
52
+ img = img.clamp(-1, 1).cpu()
53
+ img = (img - img.min()) / (img.max() - img.min())
54
+
55
+ return img
56
+
57
+
58
+ def vid_denorm(vid):
59
+ vid = vid.clamp(-1, 1).cpu()
60
+ vid = (vid - vid.min()) / (vid.max() - vid.min())
61
+
62
+ return vid
63
+
64
+
65
+ def save_img_edit(save_dir, img, img_e):
66
+ # img: BCHW
67
+ # img_e: BCHW
68
+
69
+ output_img_path = os.path.join(save_dir, "img_edit.png")
70
+ output_img_all_path = os.path.join(save_dir, "img_all.png")
71
+
72
+ img = rearrange(img, 'b c h w -> b h w c')
73
+ img_e = rearrange(img_e, 'b c h w -> b h w c')
74
+ img_all = torch.cat([img, img_e], dim=2)
75
+
76
+ img_e_np = (img_denorm(img_e[0]).numpy() * 255).astype('uint8')
77
+ img_all_np = (img_denorm(img_all[0]).numpy() * 255).astype('uint8')
78
+
79
+ imageio.imwrite(output_img_path, img_e_np, quality=8)
80
+ imageio.imwrite(output_img_all_path, img_all_np, quality=8)
81
+
82
+ return
83
+
84
+
85
+ def save_vid_edit(save_dir, vid_d, vid_a, fps):
86
+ # img_s: BCHW
87
+ # vid_d: BTCHW
88
+ # vid_a: BCTHW
89
+
90
+ output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4")
91
+ output_vid_all_path = os.path.join(save_dir, "vid_all.mp4")
92
+
93
+ vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
94
+ vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
95
+ vid_all = torch.cat([vid_d, vid_a], dim=3)
96
+
97
+ vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
98
+ vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
99
+
100
+ imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8)
101
+ imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8)
102
+
103
+ return
104
+
105
+
106
+ def save_animation(save_dir, img_s, vid_d, vid_a, fps):
107
+ # img_s: BCHW
108
+ # vid_d: BTCHW
109
+ # vid_a: BCTHW
110
+
111
+ output_vid_a_path = os.path.join(save_dir, "vid_animation.mp4")
112
+ output_img_e_path = os.path.join(save_dir, "img_edit.png")
113
+ output_vid_all_path = os.path.join(save_dir, "vid_all.mp4")
114
+
115
+ vid_d = rearrange(vid_d, 'b t c h w -> b t h w c')
116
+ vid_a = rearrange(vid_a, 'b c t h w -> b t h w c')
117
+ img_s = repeat(rearrange(img_s, 'b c h w -> b h w c'), 'b h w c -> b t h w c', t=vid_d.size(1))
118
+ vid_all = torch.cat([img_s, vid_d, vid_a], dim=3)
119
+
120
+ vid_a_np = (vid_denorm(vid_a[0]).numpy() * 255).astype('uint8')
121
+ img_e_np = vid_a_np[0]
122
+ vid_all_np = (vid_denorm(vid_all[0]).numpy() * 255).astype('uint8')
123
+
124
+ imageio.mimwrite(output_vid_a_path, vid_a_np, fps=fps, codec='libx264', quality=8)
125
+ imageio.mimwrite(output_vid_all_path, vid_all_np, fps=fps, codec='libx264', quality=8)
126
+ imageio.imwrite(output_img_e_path, img_e_np, quality=8)
127
+
128
+ return
129
+
130
+
131
+ def save_linear_manipulation(save_dir, vid, fps):
132
+ # vid: BCTHW
133
+
134
+ output_vid_path = os.path.join(save_dir, "vid_interpolation.mp4")
135
+
136
+ vid = rearrange(vid, 'b c t h w -> b t h w c')
137
+ vid_np = (vid_denorm(vid[0]).numpy() * 255).astype('uint8')
138
+
139
+ imageio.mimwrite(output_vid_path, vid_np, fps=fps, codec='libx264', quality=8)
140
+
141
+ return