gheinrich commited on
Commit
c3cd55f
·
verified ·
1 Parent(s): d0099e6

Update vit_patch_generator.py

Browse files
Files changed (1) hide show
  1. vit_patch_generator.py +0 -19
vit_patch_generator.py CHANGED
@@ -116,10 +116,6 @@ class ViTPatchGenerator(nn.Module):
116
  'pos_embed',
117
  ]
118
 
119
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
120
- if self.abs_pos:
121
- self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
122
-
123
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
124
  if src_embed.shape != targ_embed.shape:
125
  src_size = int(math.sqrt(src_embed.shape[1]))
@@ -282,18 +278,3 @@ class ViTPatchLinear(nn.Linear):
282
  **factory
283
  )
284
  self.patch_size = patch_size
285
-
286
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
287
- if self.bias is not None:
288
- self.bias.data.copy_(state_dict[f'{prefix}bias'])
289
-
290
- chk_weight = state_dict[f'{prefix}weight']
291
- if chk_weight.shape != self.weight.shape:
292
- src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
293
-
294
- assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
295
-
296
- chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
297
- chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
298
- chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
299
- self.weight.data.copy_(chk_weight)
 
116
  'pos_embed',
117
  ]
118
 
 
 
 
 
119
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
120
  if src_embed.shape != targ_embed.shape:
121
  src_size = int(math.sqrt(src_embed.shape[1]))
 
278
  **factory
279
  )
280
  self.patch_size = patch_size