Update vit_patch_generator.py
Browse files- vit_patch_generator.py +0 -19
vit_patch_generator.py
CHANGED
@@ -116,10 +116,6 @@ class ViTPatchGenerator(nn.Module):
|
|
116 |
'pos_embed',
|
117 |
]
|
118 |
|
119 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
120 |
-
if self.abs_pos:
|
121 |
-
self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
|
122 |
-
|
123 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
124 |
if src_embed.shape != targ_embed.shape:
|
125 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
@@ -282,18 +278,3 @@ class ViTPatchLinear(nn.Linear):
|
|
282 |
**factory
|
283 |
)
|
284 |
self.patch_size = patch_size
|
285 |
-
|
286 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
287 |
-
if self.bias is not None:
|
288 |
-
self.bias.data.copy_(state_dict[f'{prefix}bias'])
|
289 |
-
|
290 |
-
chk_weight = state_dict[f'{prefix}weight']
|
291 |
-
if chk_weight.shape != self.weight.shape:
|
292 |
-
src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
|
293 |
-
|
294 |
-
assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
|
295 |
-
|
296 |
-
chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
297 |
-
chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
298 |
-
chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
|
299 |
-
self.weight.data.copy_(chk_weight)
|
|
|
116 |
'pos_embed',
|
117 |
]
|
118 |
|
|
|
|
|
|
|
|
|
119 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
120 |
if src_embed.shape != targ_embed.shape:
|
121 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
|
278 |
**factory
|
279 |
)
|
280 |
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|