Luisgust commited on
Commit
b125e70
·
verified ·
1 Parent(s): 91ad38b

Create vtoonify/model/stylegan/op_gpu/fused_act.py

Browse files
vtoonify/model/stylegan/op_gpu/fused_act.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(), gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51
+ )
52
+
53
+ return gradgrad_out, None, None, None, None
54
+
55
+
56
+ class FusedLeakyReLUFunction(Function):
57
+ @staticmethod
58
+ def forward(ctx, input, bias, negative_slope, scale):
59
+ empty = input.new_empty(0)
60
+
61
+ ctx.bias = bias is not None
62
+
63
+ if bias is None:
64
+ bias = empty
65
+
66
+ out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67
+ ctx.save_for_backward(out)
68
+ ctx.negative_slope = negative_slope
69
+ ctx.scale = scale
70
+
71
+ return out
72
+
73
+ @staticmethod
74
+ def backward(ctx, grad_output):
75
+ out, = ctx.saved_tensors
76
+
77
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79
+ )
80
+
81
+ if not ctx.bias:
82
+ grad_bias = None
83
+
84
+ return grad_input, grad_bias, None, None
85
+
86
+
87
+ class FusedLeakyReLU(nn.Module):
88
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89
+ super().__init__()
90
+
91
+ if bias:
92
+ self.bias = nn.Parameter(torch.zeros(channel))
93
+
94
+ else:
95
+ self.bias = None
96
+
97
+ self.negative_slope = negative_slope
98
+ self.scale = scale
99
+
100
+ def forward(self, input):
101
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102
+
103
+
104
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105
+ if input.device.type == "cpu":
106
+ if bias is not None:
107
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
108
+ return (
109
+ F.leaky_relu(
110
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111
+ )
112
+ * scale
113
+ )
114
+
115
+ else:
116
+ return F.leaky_relu(input, negative_slope=0.2) * scale
117
+
118
+ else:
119
+ return FusedLeakyReLUFunction.apply(input.contiguous(), bias, negative_slope, scale)