azhongai666666 commited on
Commit
2532eca
·
verified ·
1 Parent(s): 6ab9bc6

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +509 -0
model.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import math
6
+
7
+ from torch.nn.init import _calculate_fan_in_and_fan_out
8
+ from timm.models.layers import to_2tuple, trunc_normal_
9
+
10
+ import torchvision.transforms as transforms
11
+ from torchvision import models
12
+
13
+ import gradio as gr
14
+ from PIL import Image
15
+ import numpy as np
16
+ from matplotlib import pyplot as plt
17
+
18
+ class RLN(nn.Module):
19
+ r"""Revised LayerNorm"""
20
+ def __init__(self, dim, eps=1e-5, detach_grad=False):
21
+ super(RLN, self).__init__()
22
+ self.eps = eps
23
+ self.detach_grad = detach_grad
24
+
25
+ self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
26
+ self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))
27
+
28
+ self.meta1 = nn.Conv2d(1, dim, 1)
29
+ self.meta2 = nn.Conv2d(1, dim, 1)
30
+
31
+ trunc_normal_(self.meta1.weight, std=.02)
32
+ nn.init.constant_(self.meta1.bias, 1)
33
+
34
+ trunc_normal_(self.meta2.weight, std=.02)
35
+ nn.init.constant_(self.meta2.bias, 0)
36
+
37
+ def forward(self, input):
38
+ mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
39
+ std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps)
40
+
41
+ normalized_input = (input - mean) / std
42
+
43
+ if self.detach_grad:
44
+ rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
45
+ else:
46
+ rescale, rebias = self.meta1(std), self.meta2(mean)
47
+
48
+ out = normalized_input * self.weight + self.bias
49
+ return out, rescale, rebias
50
+
51
+
52
+ class Mlp(nn.Module):
53
+ def __init__(self, network_depth, in_features, hidden_features=None, out_features=None):
54
+ super().__init__()
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+
58
+ self.network_depth = network_depth
59
+
60
+ self.mlp = nn.Sequential(
61
+ nn.Conv2d(in_features, hidden_features, 1),
62
+ nn.ReLU(True),
63
+ nn.Conv2d(hidden_features, out_features, 1)
64
+ )
65
+
66
+ self.apply(self._init_weights)
67
+
68
+ def _init_weights(self, m):
69
+ if isinstance(m, nn.Conv2d):
70
+ gain = (8 * self.network_depth) ** (-1/4)
71
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
72
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
73
+ trunc_normal_(m.weight, std=std)
74
+ if m.bias is not None:
75
+ nn.init.constant_(m.bias, 0)
76
+
77
+ def forward(self, x):
78
+ return self.mlp(x)
79
+
80
+
81
+ def window_partition(x, window_size):
82
+ B, H, W, C = x.shape
83
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
84
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
85
+ return windows
86
+
87
+
88
+ def window_reverse(windows, window_size, H, W):
89
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
90
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
91
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
92
+ return x
93
+
94
+
95
+ def get_relative_positions(window_size):
96
+ coords_h = torch.arange(window_size)
97
+ coords_w = torch.arange(window_size)
98
+
99
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
100
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
101
+ relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
102
+
103
+ relative_positions = relative_positions.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
104
+ relative_positions_log = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())
105
+
106
+ return relative_positions_log
107
+
108
+
109
+ class WindowAttention(nn.Module):
110
+ def __init__(self, dim, window_size, num_heads):
111
+
112
+ super().__init__()
113
+ self.dim = dim
114
+ self.window_size = window_size # Wh, Ww
115
+ self.num_heads = num_heads
116
+ head_dim = dim // num_heads
117
+ self.scale = head_dim ** -0.5
118
+
119
+ relative_positions = get_relative_positions(self.window_size)
120
+ self.register_buffer("relative_positions", relative_positions)
121
+ self.meta = nn.Sequential(
122
+ nn.Linear(2, 256, bias=True),
123
+ nn.ReLU(True),
124
+ nn.Linear(256, num_heads, bias=True)
125
+ )
126
+
127
+ self.softmax = nn.Softmax(dim=-1)
128
+
129
+ def forward(self, qkv):
130
+ B_, N, _ = qkv.shape
131
+
132
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4)
133
+
134
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
135
+
136
+ q = q * self.scale
137
+ attn = (q @ k.transpose(-2, -1))
138
+
139
+ relative_position_bias = self.meta(self.relative_positions)
140
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
141
+ attn = attn + relative_position_bias.unsqueeze(0)
142
+
143
+ attn = self.softmax(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
146
+ return x
147
+
148
+
149
+ class Attention(nn.Module):
150
+ def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None):
151
+ super().__init__()
152
+ self.dim = dim
153
+ self.head_dim = int(dim // num_heads)
154
+ self.num_heads = num_heads
155
+
156
+ self.window_size = window_size
157
+ self.shift_size = shift_size
158
+
159
+ self.network_depth = network_depth
160
+ self.use_attn = use_attn
161
+ self.conv_type = conv_type
162
+
163
+ if self.conv_type == 'Conv':
164
+ self.conv = nn.Sequential(
165
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'),
166
+ nn.ReLU(True),
167
+ nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect')
168
+ )
169
+
170
+ if self.conv_type == 'DWConv':
171
+ self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect')
172
+
173
+ if self.conv_type == 'DWConv' or self.use_attn:
174
+ self.V = nn.Conv2d(dim, dim, 1)
175
+ self.proj = nn.Conv2d(dim, dim, 1)
176
+
177
+ if self.use_attn:
178
+ self.QK = nn.Conv2d(dim, dim * 2, 1)
179
+ self.attn = WindowAttention(dim, window_size, num_heads)
180
+
181
+ self.apply(self._init_weights)
182
+
183
+ def _init_weights(self, m):
184
+ if isinstance(m, nn.Conv2d):
185
+ w_shape = m.weight.shape
186
+
187
+ if w_shape[0] == self.dim * 2: # QK
188
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
189
+ std = math.sqrt(2.0 / float(fan_in + fan_out))
190
+ trunc_normal_(m.weight, std=std)
191
+ else:
192
+ gain = (8 * self.network_depth) ** (-1/4)
193
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
194
+ std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
195
+ trunc_normal_(m.weight, std=std)
196
+
197
+ if m.bias is not None:
198
+ nn.init.constant_(m.bias, 0)
199
+
200
+ def check_size(self, x, shift=False):
201
+ _, _, h, w = x.size()
202
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
203
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
204
+
205
+ if shift:
206
+ x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size,
207
+ self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect')
208
+ else:
209
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
210
+ return x
211
+
212
+ def forward(self, X):
213
+ B, C, H, W = X.shape
214
+
215
+ if self.conv_type == 'DWConv' or self.use_attn:
216
+ V = self.V(X)
217
+
218
+ if self.use_attn:
219
+ QK = self.QK(X)
220
+ QKV = torch.cat([QK, V], dim=1)
221
+
222
+ # shift
223
+ shifted_QKV = self.check_size(QKV, self.shift_size > 0)
224
+ Ht, Wt = shifted_QKV.shape[2:]
225
+
226
+ # partition windows
227
+ shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
228
+ qkv = window_partition(shifted_QKV, self.window_size) # nW*B, window_size**2, C
229
+
230
+ attn_windows = self.attn(qkv)
231
+
232
+ # merge windows
233
+ shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt) # B H' W' C
234
+
235
+ # reverse cyclic shift
236
+ out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :]
237
+ attn_out = out.permute(0, 3, 1, 2)
238
+
239
+ if self.conv_type in ['Conv', 'DWConv']:
240
+ conv_out = self.conv(V)
241
+ out = self.proj(conv_out + attn_out)
242
+ else:
243
+ out = self.proj(attn_out)
244
+
245
+ else:
246
+ if self.conv_type == 'Conv':
247
+ out = self.conv(X) # no attention and use conv, no projection
248
+ elif self.conv_type == 'DWConv':
249
+ out = self.proj(self.conv(V))
250
+
251
+ return out
252
+
253
+
254
+ class TransformerBlock(nn.Module):
255
+ def __init__(self, network_depth, dim, num_heads, mlp_ratio=4.,
256
+ norm_layer=nn.LayerNorm, mlp_norm=False,
257
+ window_size=8, shift_size=0, use_attn=True, conv_type=None):
258
+ super().__init__()
259
+ self.use_attn = use_attn
260
+ self.mlp_norm = mlp_norm
261
+
262
+ self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
263
+ self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size,
264
+ shift_size=shift_size, use_attn=use_attn, conv_type=conv_type)
265
+
266
+ self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
267
+ self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))
268
+
269
+ def forward(self, x):
270
+ identity = x
271
+ if self.use_attn: x, rescale, rebias = self.norm1(x)
272
+ x = self.attn(x)
273
+ if self.use_attn: x = x * rescale + rebias
274
+ x = identity + x
275
+
276
+ identity = x
277
+ if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x)
278
+ x = self.mlp(x)
279
+ if self.use_attn and self.mlp_norm: x = x * rescale + rebias
280
+ x = identity + x
281
+ return x
282
+
283
+
284
+ class BasicLayer(nn.Module):
285
+ def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
286
+ norm_layer=nn.LayerNorm, window_size=8,
287
+ attn_ratio=0., attn_loc='last', conv_type=None):
288
+
289
+ super().__init__()
290
+ self.dim = dim
291
+ self.depth = depth
292
+
293
+ attn_depth = attn_ratio * depth
294
+
295
+ if attn_loc == 'last':
296
+ use_attns = [i >= depth-attn_depth for i in range(depth)]
297
+ elif attn_loc == 'first':
298
+ use_attns = [i < attn_depth for i in range(depth)]
299
+ elif attn_loc == 'middle':
300
+ use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]
301
+
302
+ # build blocks
303
+ self.blocks = nn.ModuleList([
304
+ TransformerBlock(network_depth=network_depth,
305
+ dim=dim,
306
+ num_heads=num_heads,
307
+ mlp_ratio=mlp_ratio,
308
+ norm_layer=norm_layer,
309
+ window_size=window_size,
310
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
311
+ use_attn=use_attns[i], conv_type=conv_type)
312
+ for i in range(depth)])
313
+
314
+ def forward(self, x):
315
+ for blk in self.blocks:
316
+ x = blk(x)
317
+ return x
318
+
319
+
320
+ class PatchEmbed(nn.Module):
321
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
322
+ super().__init__()
323
+ self.in_chans = in_chans
324
+ self.embed_dim = embed_dim
325
+
326
+ if kernel_size is None:
327
+ kernel_size = patch_size
328
+
329
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
330
+ padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')
331
+
332
+ def forward(self, x):
333
+ x = self.proj(x)
334
+ return x
335
+
336
+
337
+ class PatchUnEmbed(nn.Module):
338
+ def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
339
+ super().__init__()
340
+ self.out_chans = out_chans
341
+ self.embed_dim = embed_dim
342
+
343
+ if kernel_size is None:
344
+ kernel_size = 1
345
+
346
+ self.proj = nn.Sequential(
347
+ nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size,
348
+ padding=kernel_size//2, padding_mode='reflect'),
349
+ nn.PixelShuffle(patch_size)
350
+ )
351
+
352
+ def forward(self, x):
353
+ x = self.proj(x)
354
+ return x
355
+
356
+
357
+ class SKFusion(nn.Module):
358
+ def __init__(self, dim, height=2, reduction=8):
359
+ super(SKFusion, self).__init__()
360
+
361
+ self.height = height
362
+ d = max(int(dim/reduction), 4)
363
+
364
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
365
+ self.mlp = nn.Sequential(
366
+ nn.Conv2d(dim, d, 1, bias=False),
367
+ nn.ReLU(),
368
+ nn.Conv2d(d, dim*height, 1, bias=False)
369
+ )
370
+
371
+ self.softmax = nn.Softmax(dim=1)
372
+
373
+ def forward(self, in_feats):
374
+ B, C, H, W = in_feats[0].shape
375
+
376
+ in_feats = torch.cat(in_feats, dim=1)
377
+ in_feats = in_feats.view(B, self.height, C, H, W)
378
+
379
+ feats_sum = torch.sum(in_feats, dim=1)
380
+ attn = self.mlp(self.avg_pool(feats_sum))
381
+ attn = self.softmax(attn.view(B, self.height, C, 1, 1))
382
+
383
+ out = torch.sum(in_feats*attn, dim=1)
384
+ return out
385
+
386
+
387
+ class DehazeFormer(nn.Module):
388
+ def __init__(self, in_chans=3, out_chans=4, window_size=8,
389
+ embed_dims=[24, 48, 96, 48, 24],
390
+ mlp_ratios=[2., 4., 4., 2., 2.],
391
+ depths=[16, 16, 16, 8, 8],
392
+ num_heads=[2, 4, 6, 1, 1],
393
+ attn_ratio=[1/4, 1/2, 3/4, 0, 0],
394
+ conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
395
+ norm_layer=[RLN, RLN, RLN, RLN, RLN]):
396
+ super(DehazeFormer, self).__init__()
397
+
398
+ # setting
399
+ self.patch_size = 4
400
+ self.window_size = window_size
401
+ self.mlp_ratios = mlp_ratios
402
+
403
+ # split image into non-overlapping patches
404
+ self.patch_embed = PatchEmbed(
405
+ patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)
406
+
407
+ # backbone
408
+ self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
409
+ num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
410
+ norm_layer=norm_layer[0], window_size=window_size,
411
+ attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])
412
+
413
+ self.patch_merge1 = PatchEmbed(
414
+ patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
415
+
416
+ self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
417
+
418
+ self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
419
+ num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
420
+ norm_layer=norm_layer[1], window_size=window_size,
421
+ attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])
422
+
423
+ self.patch_merge2 = PatchEmbed(
424
+ patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
425
+
426
+ self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
427
+
428
+ self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
429
+ num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
430
+ norm_layer=norm_layer[2], window_size=window_size,
431
+ attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])
432
+
433
+ self.patch_split1 = PatchUnEmbed(
434
+ patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])
435
+
436
+ assert embed_dims[1] == embed_dims[3]
437
+ self.fusion1 = SKFusion(embed_dims[3])
438
+
439
+ self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
440
+ num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
441
+ norm_layer=norm_layer[3], window_size=window_size,
442
+ attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])
443
+
444
+ self.patch_split2 = PatchUnEmbed(
445
+ patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])
446
+
447
+ assert embed_dims[0] == embed_dims[4]
448
+ self.fusion2 = SKFusion(embed_dims[4])
449
+
450
+ self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
451
+ num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
452
+ norm_layer=norm_layer[4], window_size=window_size,
453
+ attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])
454
+
455
+ # merge non-overlapping patches into image
456
+ self.patch_unembed = PatchUnEmbed(
457
+ patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)
458
+
459
+
460
+ def check_image_size(self, x):
461
+ # NOTE: for I2I test
462
+ _, _, h, w = x.size()
463
+ mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
464
+ mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
465
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
466
+ return x
467
+
468
+ def forward_features(self, x):
469
+ x = self.patch_embed(x)
470
+ x = self.layer1(x)
471
+ skip1 = x
472
+
473
+ x = self.patch_merge1(x)
474
+ x = self.layer2(x)
475
+ skip2 = x
476
+
477
+ x = self.patch_merge2(x)
478
+ x = self.layer3(x)
479
+ x = self.patch_split1(x)
480
+
481
+ x = self.fusion1([x, self.skip2(skip2)]) + x
482
+ x = self.layer4(x)
483
+ x = self.patch_split2(x)
484
+
485
+ x = self.fusion2([x, self.skip1(skip1)]) + x
486
+ x = self.layer5(x)
487
+ x = self.patch_unembed(x)
488
+ return x
489
+
490
+ def forward(self, x):
491
+ H, W = x.shape[2:]
492
+ x = self.check_image_size(x)
493
+
494
+ feat = self.forward_features(x)
495
+ K, B = torch.split(feat, (1, 3), dim=1)
496
+
497
+ x = K * x - B + x
498
+ x = x[:, :, :H, :W]
499
+ return x
500
+
501
+
502
+ def dehazeformer_t():
503
+ return DehazeFormer(
504
+ embed_dims=[24, 48, 96, 48, 24],
505
+ mlp_ratios=[2., 4., 4., 2., 2.],
506
+ depths=[4, 4, 4, 2, 2],
507
+ num_heads=[2, 4, 6, 1, 1],
508
+ attn_ratio=[0, 1/2, 1, 0, 0],
509
+ conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])