Image Feature Extraction
Transformers
Safetensors
feature-extraction
custom_code
adaptor_generic.py CHANGED
@@ -19,9 +19,23 @@ class GenericAdaptor(AdaptorBase):
19
  def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
20
  super().__init__()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if state is not None:
23
- self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.')
24
- self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.')
 
25
  else:
26
  assert mlp_config is not None, "Config must not be None if state is None"
27
 
@@ -38,16 +52,17 @@ class GenericAdaptor(AdaptorBase):
38
  mlp_config["feature"]["hidden_dim"],
39
  mlp_config["feature"]["output_dim"],
40
  mlp_config["feature"]["num_inner"],
 
41
  )
42
 
43
  def forward(self, input: AdaptorInput) -> RadioOutput:
44
  # Convert input'd type to the type of the first parameter of the adaptor.
45
  first_param = next(self.parameters())
46
  summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
47
- feat = self.feat_mlp(input.features.to(dtype=first_param.dtype)).to(dtype=input.features.dtype)
48
 
49
  if input.feature_fmt == 'NCHW':
50
- feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size, input.images.shape[-1] // input.patch_size, feat.shape[2])
51
  .permute(0, 3, 1, 2)
52
  )
53
 
 
19
  def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
20
  super().__init__()
21
 
22
+ extra_args = dict()
23
+ ups = None
24
+ ups_rank = None
25
+ if adaptor_config is not None:
26
+ ups = adaptor_config.get('fd_upsample_factor', None)
27
+ ups_rank = adaptor_config.get('fd_upsample_rank', None)
28
+ elif mlp_config is not None:
29
+ ups = mlp_config["feature"].get('upsample_factor', None)
30
+ ups_rank = mlp_config["feature"].get('upsample_rank', None)
31
+ if ups is not None:
32
+ extra_args['upsample_factor'] = ups
33
+ extra_args['upsample_rank'] = ups_rank
34
+
35
  if state is not None:
36
+ spectral_heads = getattr(main_config, 'spectral_heads', False)
37
+ self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
38
+ self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
39
  else:
40
  assert mlp_config is not None, "Config must not be None if state is None"
41
 
 
52
  mlp_config["feature"]["hidden_dim"],
53
  mlp_config["feature"]["output_dim"],
54
  mlp_config["feature"]["num_inner"],
55
+ **extra_args
56
  )
57
 
58
  def forward(self, input: AdaptorInput) -> RadioOutput:
59
  # Convert input'd type to the type of the first parameter of the adaptor.
60
  first_param = next(self.parameters())
61
  summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
62
+ feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
63
 
64
  if input.feature_fmt == 'NCHW':
65
+ feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
66
  .permute(0, 3, 1, 2)
67
  )
68
 
adaptor_mlp.py CHANGED
@@ -6,7 +6,7 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
  import math
9
- from typing import Dict
10
 
11
  import torch
12
  from torch import nn
@@ -14,6 +14,8 @@ from torch import nn
14
  from einops import rearrange
15
  from timm.models.vision_transformer import Block
16
 
 
 
17
 
18
  class MLP(nn.Module):
19
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
@@ -51,6 +53,8 @@ class MLP2(nn.Module):
51
  num_inner: int = 0,
52
  pre_norm: bool = False, device: torch.device = None,
53
  upsample_factor: int = 1,
 
 
54
  **kwargs):
55
  super().__init__()
56
 
@@ -60,10 +64,12 @@ class MLP2(nn.Module):
60
  ) if pre_norm else nn.Identity()
61
 
62
  self.upsample_factor = upsample_factor
63
- self._real_output_dim = output_size
 
 
64
 
65
- hidden_size *= upsample_factor
66
- output_size *= (upsample_factor ** 2)
67
 
68
  self.fc1 = nn.Linear(input_size, hidden_size, device=device)
69
 
@@ -82,7 +88,7 @@ class MLP2(nn.Module):
82
  nn.Linear(hidden_size, output_size, device=device),
83
  )
84
 
85
- def forward(self, x: torch.Tensor) -> torch.Tensor:
86
  x = self.pre_norm(x)
87
  x = self.fc1(x)
88
  for block in self.blocks:
@@ -90,8 +96,12 @@ class MLP2(nn.Module):
90
  x = self.final(x)
91
 
92
  if self.upsample_factor > 1:
93
- h = w = int(math.sqrt(x.shape[1]))
94
- x = rearrange(x, 'b (h w) (u1 u2 c) -> b (u1 h u2 w) c',
 
 
 
 
95
  h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
96
  c=self._real_output_dim)
97
 
@@ -113,20 +123,22 @@ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
113
  return state
114
 
115
 
116
- def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
117
  state = strip_prefix(state, prefix)
118
 
 
 
119
  if version == 'v1':
120
- hidden_dim, input_dim = state['fc1.weight'].shape
121
- output_dim = state['fc2.weight'].shape[0]
122
 
123
  for num_inner in range(1000):
124
  k = f'inner.{num_inner}.0.weight'
125
  if k not in state:
126
  break
127
  elif version == 'v2':
128
- hidden_dim, input_dim = state['fc1.weight'].shape
129
- output_dim = state['final.2.weight'].shape[0]
130
 
131
  for num_inner in range(1000):
132
  k = f'blocks.{num_inner}.0.weight'
@@ -138,19 +150,25 @@ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix
138
  return input_dim, hidden_dim, output_dim, num_inner
139
 
140
 
141
- def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int):
142
- ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
143
 
144
  return ret
145
 
146
 
147
- def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
148
  state = strip_prefix(state, prefix)
149
 
150
- input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
 
 
151
 
152
- ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner)
 
153
 
154
  ret.load_state_dict(state)
155
 
 
 
 
156
  return ret
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
  import math
9
+ from typing import Dict, Optional
10
 
11
  import torch
12
  from torch import nn
 
14
  from einops import rearrange
15
  from timm.models.vision_transformer import Block
16
 
17
+ from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
18
+
19
 
20
  class MLP(nn.Module):
21
  def __init__(self, input_size: int, hidden_size: int, output_size: int,
 
53
  num_inner: int = 0,
54
  pre_norm: bool = False, device: torch.device = None,
55
  upsample_factor: int = 1,
56
+ upsample_rank: int = None,
57
+ from_config: bool = False,
58
  **kwargs):
59
  super().__init__()
60
 
 
64
  ) if pre_norm else nn.Identity()
65
 
66
  self.upsample_factor = upsample_factor
67
+ sq_ups = upsample_factor ** 2
68
+
69
+ self._real_output_dim = output_size // sq_ups
70
 
71
+ # hidden_size *= upsample_factor
72
+ # output_size *= (upsample_factor ** 2)
73
 
74
  self.fc1 = nn.Linear(input_size, hidden_size, device=device)
75
 
 
88
  nn.Linear(hidden_size, output_size, device=device),
89
  )
90
 
91
+ def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
92
  x = self.pre_norm(x)
93
  x = self.fc1(x)
94
  for block in self.blocks:
 
96
  x = self.final(x)
97
 
98
  if self.upsample_factor > 1:
99
+ if images is None:
100
+ raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
101
+ if patch_size is None:
102
+ raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
103
+ h, w = tuple(d // patch_size for d in images.shape[-2:])
104
+ x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
105
  h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
106
  c=self._real_output_dim)
107
 
 
123
  return state
124
 
125
 
126
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
127
  state = strip_prefix(state, prefix)
128
 
129
+ weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
130
+
131
  if version == 'v1':
132
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
133
+ output_dim = state[f'fc2.{weight_suffix}'].shape[0]
134
 
135
  for num_inner in range(1000):
136
  k = f'inner.{num_inner}.0.weight'
137
  if k not in state:
138
  break
139
  elif version == 'v2':
140
+ hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
141
+ output_dim = state[f'final.2.{weight_suffix}'].shape[0]
142
 
143
  for num_inner in range(1000):
144
  k = f'blocks.{num_inner}.0.weight'
 
150
  return input_dim, hidden_dim, output_dim, num_inner
151
 
152
 
153
+ def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
154
+ ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
155
 
156
  return ret
157
 
158
 
159
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
160
  state = strip_prefix(state, prefix)
161
 
162
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
163
+
164
+ ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
165
 
166
+ if spectral_weights:
167
+ enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
168
 
169
  ret.load_state_dict(state)
170
 
171
+ if spectral_weights:
172
+ disable_spectral_reparam(ret)
173
+
174
  return ret
common.py CHANGED
@@ -94,6 +94,15 @@ RESOURCE_MAP = {
94
  max_resolution=2048,
95
  preferred_resolution=Resolution(512, 512),
96
  ),
 
 
 
 
 
 
 
 
 
97
  }
98
 
99
  DEFAULT_VERSION = "radio_v2.5-h"
 
94
  max_resolution=2048,
95
  preferred_resolution=Resolution(512, 512),
96
  ),
97
+ # C-RADIO
98
+ "c-radio_v3-l": RadioResource(
99
+ # NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
100
+ # and accept the license terms.
101
+ "https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
102
+ patch_size=16,
103
+ max_resolution=2048,
104
+ preferred_resolution=Resolution(512, 512),
105
+ ),
106
  }
107
 
108
  DEFAULT_VERSION = "radio_v2.5-h"
config.json CHANGED
@@ -224,7 +224,7 @@
224
  768
225
  ],
226
  "torch_dtype": "float32",
227
- "transformers_version": "4.47.0.dev0",
228
  "version": "c-radio_v2.5-g",
229
  "vitdet_window_size": null
230
  }
 
224
  768
225
  ],
226
  "torch_dtype": "float32",
227
+ "transformers_version": "4.51.2",
228
  "version": "c-radio_v2.5-g",
229
  "vitdet_window_size": null
230
  }
dual_hybrid_vit.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import getLogger
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from timm.models import register_model
9
+ from timm.models import vision_transformer as tvit
10
+ from timm.models import convnext as tconv
11
+
12
+ from einops import rearrange
13
+
14
+ from . import extra_timm_models as et
15
+
16
+
17
+ class Fuser(nn.Module):
18
+ def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
19
+ super().__init__()
20
+ self.gated = gated
21
+
22
+ mid_dim = max(src_dim, tgt_dim) * 2
23
+
24
+ self.fwd = nn.Sequential(
25
+ nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
26
+ nn.GELU(),
27
+ nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
28
+ )
29
+
30
+ def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
31
+ if src.ndim == 3:
32
+ shape = tgt.shape[-2:]
33
+ else:
34
+ shape = src.shape[-2:]
35
+
36
+ nd = shape[0] * shape[1]
37
+
38
+ if src.ndim == 3:
39
+ src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
40
+
41
+ if tgt.ndim == 3:
42
+ tgt_pre = tgt[:, :-nd]
43
+ tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
44
+ else:
45
+ tgt_pre = None
46
+
47
+ pred = self.fwd(src)
48
+
49
+ if self.gated:
50
+ g, pred = torch.chunk(pred, 2, dim=1)
51
+
52
+ g = F.sigmoid(g)
53
+
54
+ pred = g * pred
55
+
56
+ tgt = tgt + pred
57
+
58
+ if tgt_pre is not None:
59
+ tgt = rearrange(tgt, 'b c h w -> b (h w) c')
60
+ tgt = torch.cat([tgt_pre, tgt], dim=1)
61
+
62
+ return tgt
63
+
64
+
65
+ class AttnDownsample(nn.Module):
66
+ def __init__(self, dim: int, window_size: int, num_heads: int = 16):
67
+ super().__init__()
68
+ self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
69
+ self.kv = nn.Linear(dim, dim * 2)
70
+ self.proj = nn.Linear(dim, dim)
71
+ self.window_size = window_size
72
+ self.num_heads = num_heads
73
+ self.head_dim = dim // num_heads
74
+ self.scale = self.head_dim ** -0.5
75
+
76
+ def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
77
+ ntok = twod_shape[0] * twod_shape[1]
78
+ x_pre = x[:, :-ntok]
79
+
80
+ B = x.shape[0]
81
+ ds_hw = tuple(s // self.window_size for s in twod_shape)
82
+
83
+ x_spat = rearrange(
84
+ x[:, -ntok:],
85
+ 'b (h d1 w d2) c -> (b h w) (d1 d2) c',
86
+ h=ds_hw[0], w=ds_hw[1],
87
+ d1=self.window_size, d2=self.window_size,
88
+ )
89
+
90
+ B, N, C = x_spat.shape
91
+
92
+ k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
93
+
94
+ q = (self.q * self.scale).expand(B, -1, -1, -1)
95
+ attn = q @ k.transpose(-2, -1)
96
+ attn = F.softmax(attn, dim=-1)
97
+ x = attn @ v
98
+
99
+ x = x.transpose(1, 2).reshape(B, C)
100
+ x = self.proj(x)
101
+
102
+ x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
103
+
104
+ x = torch.cat([x_pre, x], dim=1)
105
+ return x
106
+
107
+
108
+ class HybridModel(nn.Module):
109
+ def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
110
+ concatenate: bool = False, **kwargs):
111
+ super().__init__()
112
+ self.conv = conv
113
+ self.vit = vit
114
+ self.concatenate = concatenate
115
+
116
+ conv.stages = nn.ModuleList(conv.stages)
117
+ vit.blocks = nn.ModuleList(vit.blocks)
118
+
119
+ self._half_vit_idx = len(vit.blocks) // 2 + 1
120
+
121
+ self._half_conv_idx = None
122
+ x = torch.empty(1, 3, 256, 256)
123
+ x = self.conv.stem(x)
124
+ for i in range(len(conv.stages)):
125
+ x = conv.stages[i](x)
126
+ if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
127
+ self._half_conv_idx = i + 1
128
+ half_conv_dim = x.shape[1]
129
+ final_conv_dim = x.shape[1]
130
+
131
+ self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
132
+ self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
133
+ self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
134
+
135
+ embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
136
+ if not concatenate:
137
+ self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
138
+ self.final_block = tvit.Block(embed_dim, num_heads=16)
139
+
140
+ self.embed_dim = embed_dim
141
+
142
+ @property
143
+ def patch_size(self):
144
+ return 32
145
+
146
+ @property
147
+ def no_fsdp_wrap_types(self):
148
+ return {tvit.VisionTransformer, tconv.ConvNeXt}
149
+
150
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
151
+ return self.forward_features(x)
152
+
153
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
154
+ y_vit = self.vit.patch_generator(x)
155
+
156
+ for i in range(self._half_vit_idx):
157
+ y_vit = self.vit.blocks[i](y_vit)
158
+
159
+ y_conv = self.conv.stem(x)
160
+ for i in range(self._half_conv_idx):
161
+ y_conv = self.conv.stages[i](y_conv)
162
+
163
+ y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
164
+
165
+ y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
166
+
167
+ for i in range(self._half_vit_idx, len(self.vit.blocks)):
168
+ y_vit = self.vit.blocks[i](y_vit)
169
+
170
+ for i in range(self._half_conv_idx, len(self.conv.stages)):
171
+ y_conv = self.conv.stages[i](y_conv)
172
+
173
+ if self.concatenate:
174
+ y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
175
+ # Average pool across the board, and replicate for each cls/register token
176
+ conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
177
+ y_conv = torch.cat([conv_summary, y_conv], dim=1)
178
+ y = torch.cat([y_vit, y_conv], dim=2)
179
+ else:
180
+ y = self.final_fuse(y_conv, y_vit)
181
+ y = self.final_block(y)
182
+
183
+ summary = y[:, :self.vit.patch_generator.num_cls_tokens]
184
+ features = y[:, self.vit.patch_generator.num_cls_patches:]
185
+
186
+ return summary, features
187
+
188
+
189
+ @register_model
190
+ def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
191
+ cfg = dict(num_classes=0, **kwargs)
192
+ conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
193
+ vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
194
+
195
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
196
+
197
+
198
+ @register_model
199
+ def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
200
+ cfg = dict(num_classes=0, **kwargs)
201
+ conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
202
+ vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
203
+
204
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
205
+
206
+
207
+ @register_model
208
+ def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
209
+ cfg = dict(num_classes=0, **kwargs)
210
+ conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
211
+ vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
212
+
213
+ return HybridModel(vit, conv, pretrained, concatenate=concatenate)
enable_cpe_support.py CHANGED
@@ -19,6 +19,7 @@ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermedi
19
  from .extra_models import DinoWrapper
20
  from .vit_patch_generator import ViTPatchGenerator
21
  from .forward_intermediates import forward_intermediates
 
22
 
23
 
24
  def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
@@ -161,7 +162,9 @@ def enable_cpe(model: nn.Module,
161
  ):
162
  if isinstance(model, VisionTransformer):
163
  _enable_cpe_for_timm_vit(model, *args, **kwargs)
164
- elif True: # isinstance(model, DinoWrapper):
165
  _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
 
 
166
  else:
167
  raise ValueError(f'CPE not supported for this model type: {type(model)}')
 
19
  from .extra_models import DinoWrapper
20
  from .vit_patch_generator import ViTPatchGenerator
21
  from .forward_intermediates import forward_intermediates
22
+ from .dual_hybrid_vit import HybridModel
23
 
24
 
25
  def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
 
162
  ):
163
  if isinstance(model, VisionTransformer):
164
  _enable_cpe_for_timm_vit(model, *args, **kwargs)
165
+ elif isinstance(model, DinoWrapper):
166
  _enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
167
+ elif isinstance(model, HybridModel):
168
+ _enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
169
  else:
170
  raise ValueError(f'CPE not supported for this model type: {type(model)}')
enable_spectral_reparam.py CHANGED
@@ -155,7 +155,7 @@ def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
155
  return True
156
 
157
  p_name = f'{name}.parametrizations'
158
- is_prm = any(k for k in state_dict_guidance if k.startswith(p_name))
159
  return is_prm
160
 
161
  def parametrize_linear(linear: nn.Linear):
 
155
  return True
156
 
157
  p_name = f'{name}.parametrizations'
158
+ is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
159
  return is_prm
160
 
161
  def parametrize_linear(linear: nn.Linear):
extra_timm_models.py CHANGED
@@ -6,7 +6,12 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
 
 
 
 
9
  from torch import nn
 
10
 
11
  from timm.models import register_model
12
  from timm.models.vision_transformer import (
@@ -17,6 +22,7 @@ from timm.models.vision_transformer import (
17
  LayerScale as TIMMLayerScale,
18
  )
19
 
 
20
  from . import dinov2_arch
21
 
22
 
@@ -24,7 +30,7 @@ from . import dinov2_arch
24
  def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
25
  """ ViT-Tiny (Vit-Ti/16)
26
  """
27
- model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3, weight_init='skip')
28
  model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
29
  return model
30
 
@@ -33,7 +39,7 @@ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
33
  def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
34
  """ ViT-Small (ViT-S/16)
35
  """
36
- model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, weight_init='skip')
37
  model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
38
  return model
39
 
@@ -43,16 +49,44 @@ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
43
  """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
44
  ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
45
  """
46
- model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, weight_init='skip')
47
  model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
48
  return model
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @register_model
52
  def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
53
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
54
  """
55
- model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, weight_init='skip')
56
  if pretrained:
57
  # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
58
  model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
@@ -65,7 +99,7 @@ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
65
  def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
66
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
67
  """
68
- model = vit_huge_patch16_224(pretrained=pretrained, weight_init='skip', **kwargs)
69
 
70
  for m in model.modules():
71
  if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
@@ -75,17 +109,19 @@ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransforme
75
 
76
 
77
  @register_model
78
- def vit_giant_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
79
  """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
80
  """
81
- model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24, weight_init='skip')
82
  model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
 
 
83
  return model
84
 
85
 
86
  @register_model
87
  def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
88
- model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6, weight_init='skip')
89
  model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
90
  return model
91
 
@@ -112,3 +148,59 @@ def _patch_layer_scale(model: VisionTransformer):
112
  mod.ls2 = replace_ls(mod.ls2)
113
  pass
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ import math
10
+ import warnings
11
+
12
+ import torch
13
  from torch import nn
14
+ from torch.nn import functional as F
15
 
16
  from timm.models import register_model
17
  from timm.models.vision_transformer import (
 
22
  LayerScale as TIMMLayerScale,
23
  )
24
 
25
+ # Import these to also register them
26
  from . import dinov2_arch
27
 
28
 
 
30
  def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
31
  """ ViT-Tiny (Vit-Ti/16)
32
  """
33
+ model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
34
  model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
35
  return model
36
 
 
39
  def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
40
  """ ViT-Small (ViT-S/16)
41
  """
42
+ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
43
  model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
44
  return model
45
 
 
49
  """ ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
50
  ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
51
  """
52
+ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
53
  model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
54
  return model
55
 
56
 
57
+ @register_model
58
+ def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
59
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
60
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
61
+ """
62
+ model_args = dict(
63
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
64
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
65
+ )
66
+ model = _create_vision_transformer(
67
+ 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
68
+ return model
69
+
70
+
71
+ @register_model
72
+ def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
73
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
74
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
75
+ """
76
+ name = 'vit_large_patch14_reg4_dinov2'
77
+ model_args = dict(
78
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
79
+ reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
80
+ )
81
+ model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
82
+
83
+ return model
84
+
85
  @register_model
86
  def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
87
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
88
  """
89
+ model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
90
  if pretrained:
91
  # There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
92
  model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
 
99
  def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
100
  """ ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
101
  """
102
+ model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
103
 
104
  for m in model.modules():
105
  if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
 
109
 
110
 
111
  @register_model
112
+ def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
113
  """ ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
114
  """
115
+ model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
116
  model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
117
+ if scaled_ln:
118
+ _apply_scaled_ln(model)
119
  return model
120
 
121
 
122
  @register_model
123
  def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
124
+ model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
125
  model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
126
  return model
127
 
 
148
  mod.ls2 = replace_ls(mod.ls2)
149
  pass
150
 
151
+
152
+ class ScaledLayerNorm(nn.LayerNorm):
153
+ '''
154
+ https://arxiv.org/pdf/2502.05795v1
155
+ '''
156
+ def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
157
+ super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
158
+ self.load_state_dict(ln_base.state_dict())
159
+ self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
160
+
161
+ def forward(self, x):
162
+ y = super().forward(x)
163
+ y = y * self.ln_scale
164
+ return y
165
+
166
+
167
+ class DyT(nn.Module):
168
+ def __init__(self, C: int, init_alpha: float):
169
+ super().__init__()
170
+ self.alpha = nn.Parameter(torch.full((1,), init_alpha))
171
+ self.gamma = nn.Parameter(torch.ones(C))
172
+ self.beta = nn.Parameter(torch.zeros(C))
173
+
174
+ def forward(self, x: torch.Tensor):
175
+ x = F.tanh(self.alpha * x)
176
+ return self.gamma * x + self.beta
177
+
178
+ @register_model
179
+ def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
180
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
181
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
182
+ """
183
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
184
+ model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
185
+
186
+ def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
187
+ return DyT(ln.normalized_shape[0], init_alpha=0.9)
188
+ _replace_ln(model, _replace_ln_with_dyt)
189
+
190
+ return model
191
+
192
+
193
+ def _apply_scaled_ln(model: VisionTransformer):
194
+ warnings.warn('Post-LayerNorm scaling activated!')
195
+
196
+ _replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
197
+
198
+ def _replace_ln(model: VisionTransformer, fn):
199
+ def _inner_replace_ln(block: Block, depth: int, key: str):
200
+ prev = getattr(block, key)
201
+ if isinstance(prev, nn.LayerNorm):
202
+ setattr(block, key, fn(prev, depth=depth))
203
+
204
+ for i, block in enumerate(model.blocks):
205
+ _inner_replace_ln(block, i + 1, 'norm1')
206
+ _inner_replace_ln(block, i + 1, 'norm2')
forward_intermediates.py CHANGED
@@ -6,7 +6,7 @@
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
- from typing import Callable, List, Optional, Set, Tuple, Union, Any, Iterable
10
  from types import MethodType
11
 
12
  import torch
@@ -42,6 +42,7 @@ def forward_intermediates(
42
  aggregation: Optional[str] = "sparse",
43
  inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
44
  norm_alpha_scheme = "post-alpha",
 
45
  ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
46
  """ Forward features that returns intermediates.
47
 
@@ -65,6 +66,8 @@ def forward_intermediates(
65
  reshape = output_fmt == 'NCHW'
66
  intermediates = []
67
 
 
 
68
  blocks = model.blocks
69
 
70
  take_indices, max_index = _take_indices(len(blocks), indices)
@@ -90,7 +93,7 @@ def forward_intermediates(
90
  take_off = 0
91
 
92
  for i, blk in enumerate(blocks):
93
- x = blk(x)
94
  if aggregation == "dense":
95
  # Arbitrarily use the rotation matrix from the final layer in the dense group
96
  y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
 
6
  # distribution of this software and related documentation without an express
7
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
 
9
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
10
  from types import MethodType
11
 
12
  import torch
 
42
  aggregation: Optional[str] = "sparse",
43
  inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
44
  norm_alpha_scheme = "post-alpha",
45
+ block_kwargs: Dict = None,
46
  ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
47
  """ Forward features that returns intermediates.
48
 
 
66
  reshape = output_fmt == 'NCHW'
67
  intermediates = []
68
 
69
+ block_kwargs = block_kwargs or dict()
70
+
71
  blocks = model.blocks
72
 
73
  take_indices, max_index = _take_indices(len(blocks), indices)
 
93
  take_off = 0
94
 
95
  for i, blk in enumerate(blocks):
96
+ x = blk(x, **block_kwargs)
97
  if aggregation == "dense":
98
  # Arbitrarily use the rotation matrix from the final layer in the dense group
99
  y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
radio_model.py CHANGED
@@ -18,6 +18,7 @@ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
18
  from . import eradio_model
19
  from .enable_spectral_reparam import configure_spectral_reparam_from_args
20
  from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
 
21
 
22
 
23
  class Resolution(NamedTuple):
@@ -69,7 +70,7 @@ class RADIOModel(nn.Module):
69
  patch_gen = getattr(self.model, "patch_generator", None)
70
  if patch_gen is not None:
71
  return patch_gen.num_skip
72
- elif self.model.global_pool == 'avg':
73
  return 0
74
  return 1
75
 
@@ -81,7 +82,7 @@ class RADIOModel(nn.Module):
81
  patch_gen = getattr(self.model, 'patch_generator', None)
82
  if patch_gen is not None:
83
  return patch_gen.num_cls_tokens
84
- elif self.model.global_pool == 'avg':
85
  return 0
86
  return 1
87
 
@@ -218,7 +219,10 @@ class RADIOModel(nn.Module):
218
  ret = dict(backbone=ret)
219
  for name, adaptor in self.adaptors.items():
220
  if all_summary.ndim == 3:
221
- summary = all_summary[:, adaptor.head_idx]
 
 
 
222
  else:
223
  summary = all_summary
224
  ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
@@ -326,10 +330,6 @@ def create_model_from_args(args) -> nn.Module:
326
 
327
  model.head = nn.Identity()
328
 
329
- assert (
330
- not args.cls_token_per_teacher or args.cpe_max_size is not None
331
- ), "CPE must be enabled for multiple CLS tokens!"
332
-
333
  if args.cpe_max_size is not None:
334
  uq_teachers = set(t['name'] for t in args.teachers)
335
  enable_cpe(
 
18
  from . import eradio_model
19
  from .enable_spectral_reparam import configure_spectral_reparam_from_args
20
  from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
21
+ from . import dual_hybrid_vit
22
 
23
 
24
  class Resolution(NamedTuple):
 
70
  patch_gen = getattr(self.model, "patch_generator", None)
71
  if patch_gen is not None:
72
  return patch_gen.num_skip
73
+ elif getattr(self.model, 'global_pool', None) == 'avg':
74
  return 0
75
  return 1
76
 
 
82
  patch_gen = getattr(self.model, 'patch_generator', None)
83
  if patch_gen is not None:
84
  return patch_gen.num_cls_tokens
85
+ elif getattr(self.model, 'global_pool', None) == 'avg':
86
  return 0
87
  return 1
88
 
 
219
  ret = dict(backbone=ret)
220
  for name, adaptor in self.adaptors.items():
221
  if all_summary.ndim == 3:
222
+ if all_summary.shape[1] == 1:
223
+ summary = all_summary[:, 0]
224
+ else:
225
+ summary = all_summary[:, adaptor.head_idx]
226
  else:
227
  summary = all_summary
228
  ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
 
330
 
331
  model.head = nn.Identity()
332
 
 
 
 
 
333
  if args.cpe_max_size is not None:
334
  uq_teachers = set(t['name'] for t in args.teachers)
335
  enable_cpe(
vit_patch_generator.py CHANGED
@@ -106,6 +106,10 @@ class ViTPatchGenerator(nn.Module):
106
  def num_cls_tokens(self):
107
  return self.cls_token.num_tokens
108
 
 
 
 
 
109
  @property
110
  def num_registers(self):
111
  return self.cls_token.num_registers
@@ -119,10 +123,6 @@ class ViTPatchGenerator(nn.Module):
119
  'pos_embed',
120
  ]
121
 
122
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
123
- if self.abs_pos:
124
- self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
125
-
126
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
127
  if src_embed.shape != targ_embed.shape:
128
  src_size = int(math.sqrt(src_embed.shape[1]))
@@ -285,18 +285,3 @@ class ViTPatchLinear(nn.Linear):
285
  **factory
286
  )
287
  self.patch_size = patch_size
288
-
289
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
290
- if self.bias is not None:
291
- self.bias.data.copy_(state_dict[f'{prefix}bias'])
292
-
293
- chk_weight = state_dict[f'{prefix}weight']
294
- if chk_weight.shape != self.weight.shape:
295
- src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
296
-
297
- assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
298
-
299
- chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
300
- chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
301
- chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
302
- self.weight.data.copy_(chk_weight)
 
106
  def num_cls_tokens(self):
107
  return self.cls_token.num_tokens
108
 
109
+ @property
110
+ def num_cls_patches(self):
111
+ return self.cls_token.num_patches
112
+
113
  @property
114
  def num_registers(self):
115
  return self.cls_token.num_registers
 
123
  'pos_embed',
124
  ]
125
 
 
 
 
 
126
  def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
127
  if src_embed.shape != targ_embed.shape:
128
  src_size = int(math.sqrt(src_embed.shape[1]))
 
285
  **factory
286
  )
287
  self.patch_size = patch_size