Upload model
#4
by
gheinrich
- opened
- adaptor_generic.py +19 -4
- adaptor_mlp.py +35 -17
- common.py +9 -0
- config.json +1 -1
- dual_hybrid_vit.py +213 -0
- enable_cpe_support.py +4 -1
- enable_spectral_reparam.py +1 -1
- extra_timm_models.py +100 -8
- forward_intermediates.py +5 -2
- radio_model.py +7 -7
- vit_patch_generator.py +4 -19
adaptor_generic.py
CHANGED
@@ -19,9 +19,23 @@ class GenericAdaptor(AdaptorBase):
|
|
19 |
def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
|
20 |
super().__init__()
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
if state is not None:
|
23 |
-
|
24 |
-
self.
|
|
|
25 |
else:
|
26 |
assert mlp_config is not None, "Config must not be None if state is None"
|
27 |
|
@@ -38,16 +52,17 @@ class GenericAdaptor(AdaptorBase):
|
|
38 |
mlp_config["feature"]["hidden_dim"],
|
39 |
mlp_config["feature"]["output_dim"],
|
40 |
mlp_config["feature"]["num_inner"],
|
|
|
41 |
)
|
42 |
|
43 |
def forward(self, input: AdaptorInput) -> RadioOutput:
|
44 |
# Convert input'd type to the type of the first parameter of the adaptor.
|
45 |
first_param = next(self.parameters())
|
46 |
summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
|
47 |
-
feat = self.feat_mlp(input.features.to(dtype=first_param.dtype)).to(dtype=input.features.dtype)
|
48 |
|
49 |
if input.feature_fmt == 'NCHW':
|
50 |
-
feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size, input.images.shape[-1] // input.patch_size, feat.shape[2])
|
51 |
.permute(0, 3, 1, 2)
|
52 |
)
|
53 |
|
|
|
19 |
def __init__(self, main_config: Namespace, adaptor_config, state, mlp_config=None):
|
20 |
super().__init__()
|
21 |
|
22 |
+
extra_args = dict()
|
23 |
+
ups = None
|
24 |
+
ups_rank = None
|
25 |
+
if adaptor_config is not None:
|
26 |
+
ups = adaptor_config.get('fd_upsample_factor', None)
|
27 |
+
ups_rank = adaptor_config.get('fd_upsample_rank', None)
|
28 |
+
elif mlp_config is not None:
|
29 |
+
ups = mlp_config["feature"].get('upsample_factor', None)
|
30 |
+
ups_rank = mlp_config["feature"].get('upsample_rank', None)
|
31 |
+
if ups is not None:
|
32 |
+
extra_args['upsample_factor'] = ups
|
33 |
+
extra_args['upsample_rank'] = ups_rank
|
34 |
+
|
35 |
if state is not None:
|
36 |
+
spectral_heads = getattr(main_config, 'spectral_heads', False)
|
37 |
+
self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.', spectral_weights=spectral_heads)
|
38 |
+
self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.', spectral_weights=spectral_heads, **extra_args)
|
39 |
else:
|
40 |
assert mlp_config is not None, "Config must not be None if state is None"
|
41 |
|
|
|
52 |
mlp_config["feature"]["hidden_dim"],
|
53 |
mlp_config["feature"]["output_dim"],
|
54 |
mlp_config["feature"]["num_inner"],
|
55 |
+
**extra_args
|
56 |
)
|
57 |
|
58 |
def forward(self, input: AdaptorInput) -> RadioOutput:
|
59 |
# Convert input'd type to the type of the first parameter of the adaptor.
|
60 |
first_param = next(self.parameters())
|
61 |
summary = self.head_mlp(input.summary.to(dtype=first_param.dtype)).to(dtype=input.summary.dtype)
|
62 |
+
feat = self.feat_mlp(input.features.to(dtype=first_param.dtype), images=input.images, patch_size=input.patch_size).to(dtype=input.features.dtype)
|
63 |
|
64 |
if input.feature_fmt == 'NCHW':
|
65 |
+
feat = (feat.reshape(feat.shape[0], input.images.shape[-2] // input.patch_size * self.feat_mlp.upsample_factor, input.images.shape[-1] // input.patch_size * self.feat_mlp.upsample_factor, feat.shape[2])
|
66 |
.permute(0, 3, 1, 2)
|
67 |
)
|
68 |
|
adaptor_mlp.py
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
import math
|
9 |
-
from typing import Dict
|
10 |
|
11 |
import torch
|
12 |
from torch import nn
|
@@ -14,6 +14,8 @@ from torch import nn
|
|
14 |
from einops import rearrange
|
15 |
from timm.models.vision_transformer import Block
|
16 |
|
|
|
|
|
17 |
|
18 |
class MLP(nn.Module):
|
19 |
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
@@ -51,6 +53,8 @@ class MLP2(nn.Module):
|
|
51 |
num_inner: int = 0,
|
52 |
pre_norm: bool = False, device: torch.device = None,
|
53 |
upsample_factor: int = 1,
|
|
|
|
|
54 |
**kwargs):
|
55 |
super().__init__()
|
56 |
|
@@ -60,10 +64,12 @@ class MLP2(nn.Module):
|
|
60 |
) if pre_norm else nn.Identity()
|
61 |
|
62 |
self.upsample_factor = upsample_factor
|
63 |
-
|
|
|
|
|
64 |
|
65 |
-
hidden_size *= upsample_factor
|
66 |
-
output_size *= (upsample_factor ** 2)
|
67 |
|
68 |
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
69 |
|
@@ -82,7 +88,7 @@ class MLP2(nn.Module):
|
|
82 |
nn.Linear(hidden_size, output_size, device=device),
|
83 |
)
|
84 |
|
85 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
x = self.pre_norm(x)
|
87 |
x = self.fc1(x)
|
88 |
for block in self.blocks:
|
@@ -90,8 +96,12 @@ class MLP2(nn.Module):
|
|
90 |
x = self.final(x)
|
91 |
|
92 |
if self.upsample_factor > 1:
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
95 |
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
|
96 |
c=self._real_output_dim)
|
97 |
|
@@ -113,20 +123,22 @@ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
|
|
113 |
return state
|
114 |
|
115 |
|
116 |
-
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
117 |
state = strip_prefix(state, prefix)
|
118 |
|
|
|
|
|
119 |
if version == 'v1':
|
120 |
-
hidden_dim, input_dim = state['fc1.
|
121 |
-
output_dim = state['fc2.
|
122 |
|
123 |
for num_inner in range(1000):
|
124 |
k = f'inner.{num_inner}.0.weight'
|
125 |
if k not in state:
|
126 |
break
|
127 |
elif version == 'v2':
|
128 |
-
hidden_dim, input_dim = state['fc1.
|
129 |
-
output_dim = state['final.2.
|
130 |
|
131 |
for num_inner in range(1000):
|
132 |
k = f'blocks.{num_inner}.0.weight'
|
@@ -138,19 +150,25 @@ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix
|
|
138 |
return input_dim, hidden_dim, output_dim, num_inner
|
139 |
|
140 |
|
141 |
-
def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int):
|
142 |
-
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
|
143 |
|
144 |
return ret
|
145 |
|
146 |
|
147 |
-
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
|
148 |
state = strip_prefix(state, prefix)
|
149 |
|
150 |
-
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
|
|
|
|
|
151 |
|
152 |
-
|
|
|
153 |
|
154 |
ret.load_state_dict(state)
|
155 |
|
|
|
|
|
|
|
156 |
return ret
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
import math
|
9 |
+
from typing import Dict, Optional
|
10 |
|
11 |
import torch
|
12 |
from torch import nn
|
|
|
14 |
from einops import rearrange
|
15 |
from timm.models.vision_transformer import Block
|
16 |
|
17 |
+
from .enable_spectral_reparam import disable_spectral_reparam, enable_spectral_reparam
|
18 |
+
|
19 |
|
20 |
class MLP(nn.Module):
|
21 |
def __init__(self, input_size: int, hidden_size: int, output_size: int,
|
|
|
53 |
num_inner: int = 0,
|
54 |
pre_norm: bool = False, device: torch.device = None,
|
55 |
upsample_factor: int = 1,
|
56 |
+
upsample_rank: int = None,
|
57 |
+
from_config: bool = False,
|
58 |
**kwargs):
|
59 |
super().__init__()
|
60 |
|
|
|
64 |
) if pre_norm else nn.Identity()
|
65 |
|
66 |
self.upsample_factor = upsample_factor
|
67 |
+
sq_ups = upsample_factor ** 2
|
68 |
+
|
69 |
+
self._real_output_dim = output_size // sq_ups
|
70 |
|
71 |
+
# hidden_size *= upsample_factor
|
72 |
+
# output_size *= (upsample_factor ** 2)
|
73 |
|
74 |
self.fc1 = nn.Linear(input_size, hidden_size, device=device)
|
75 |
|
|
|
88 |
nn.Linear(hidden_size, output_size, device=device),
|
89 |
)
|
90 |
|
91 |
+
def forward(self, x: torch.Tensor, images: Optional[torch.Tensor] = None, patch_size: Optional[int] = None) -> torch.Tensor:
|
92 |
x = self.pre_norm(x)
|
93 |
x = self.fc1(x)
|
94 |
for block in self.blocks:
|
|
|
96 |
x = self.final(x)
|
97 |
|
98 |
if self.upsample_factor > 1:
|
99 |
+
if images is None:
|
100 |
+
raise ValueError(f'`images` cannot be `None` when the head\'s `upsample_factor > 1`!')
|
101 |
+
if patch_size is None:
|
102 |
+
raise ValueError(f'`patch_size` cannot be `None` when the head\'s `upsample_factor > 1`!')
|
103 |
+
h, w = tuple(d // patch_size for d in images.shape[-2:])
|
104 |
+
x = rearrange(x, 'b (h w) (u1 u2 c) -> b (h u1 w u2) c',
|
105 |
h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
|
106 |
c=self._real_output_dim)
|
107 |
|
|
|
123 |
return state
|
124 |
|
125 |
|
126 |
+
def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False):
|
127 |
state = strip_prefix(state, prefix)
|
128 |
|
129 |
+
weight_suffix = 'weight' if not spectral_weights else 'parametrizations.weight.original'
|
130 |
+
|
131 |
if version == 'v1':
|
132 |
+
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
|
133 |
+
output_dim = state[f'fc2.{weight_suffix}'].shape[0]
|
134 |
|
135 |
for num_inner in range(1000):
|
136 |
k = f'inner.{num_inner}.0.weight'
|
137 |
if k not in state:
|
138 |
break
|
139 |
elif version == 'v2':
|
140 |
+
hidden_dim, input_dim = state[f'fc1.{weight_suffix}'].shape
|
141 |
+
output_dim = state[f'final.2.{weight_suffix}'].shape[0]
|
142 |
|
143 |
for num_inner in range(1000):
|
144 |
k = f'blocks.{num_inner}.0.weight'
|
|
|
150 |
return input_dim, hidden_dim, output_dim, num_inner
|
151 |
|
152 |
|
153 |
+
def create_mlp_from_config(version: str, input_dim: int, hidden_dim: int, output_dim: int, num_inner: int, **kwargs):
|
154 |
+
ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner, from_config=True, **kwargs)
|
155 |
|
156 |
return ret
|
157 |
|
158 |
|
159 |
+
def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = '', spectral_weights: bool = False, **kwargs):
|
160 |
state = strip_prefix(state, prefix)
|
161 |
|
162 |
+
input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state, spectral_weights=spectral_weights)
|
163 |
+
|
164 |
+
ret: nn.Module = create_mlp_from_config(version, input_dim, hidden_dim, output_dim, num_inner, **kwargs)
|
165 |
|
166 |
+
if spectral_weights:
|
167 |
+
enable_spectral_reparam(ret, init_norm_to_current=False, state_dict_guidance=state)
|
168 |
|
169 |
ret.load_state_dict(state)
|
170 |
|
171 |
+
if spectral_weights:
|
172 |
+
disable_spectral_reparam(ret)
|
173 |
+
|
174 |
return ret
|
common.py
CHANGED
@@ -94,6 +94,15 @@ RESOURCE_MAP = {
|
|
94 |
max_resolution=2048,
|
95 |
preferred_resolution=Resolution(512, 512),
|
96 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
}
|
98 |
|
99 |
DEFAULT_VERSION = "radio_v2.5-h"
|
|
|
94 |
max_resolution=2048,
|
95 |
preferred_resolution=Resolution(512, 512),
|
96 |
),
|
97 |
+
# C-RADIO
|
98 |
+
"c-radio_v3-l": RadioResource(
|
99 |
+
# NOTE: Currently, this model cannot be loaded via TorchHub. Instead, use the transformers API at https://huggingface.co/nvidia/C-RADIOv3-L
|
100 |
+
# and accept the license terms.
|
101 |
+
"https://huggingface.co/nvidia/C-RADIOv3-L/resolve/main/c-radio-v3_l_half.pth.tar?download=true",
|
102 |
+
patch_size=16,
|
103 |
+
max_resolution=2048,
|
104 |
+
preferred_resolution=Resolution(512, 512),
|
105 |
+
),
|
106 |
}
|
107 |
|
108 |
DEFAULT_VERSION = "radio_v2.5-h"
|
config.json
CHANGED
@@ -224,7 +224,7 @@
|
|
224 |
768
|
225 |
],
|
226 |
"torch_dtype": "float32",
|
227 |
-
"transformers_version": "4.
|
228 |
"version": "c-radio_v2.5-g",
|
229 |
"vitdet_window_size": null
|
230 |
}
|
|
|
224 |
768
|
225 |
],
|
226 |
"torch_dtype": "float32",
|
227 |
+
"transformers_version": "4.51.2",
|
228 |
"version": "c-radio_v2.5-g",
|
229 |
"vitdet_window_size": null
|
230 |
}
|
dual_hybrid_vit.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from logging import getLogger
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
from timm.models import register_model
|
9 |
+
from timm.models import vision_transformer as tvit
|
10 |
+
from timm.models import convnext as tconv
|
11 |
+
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from . import extra_timm_models as et
|
15 |
+
|
16 |
+
|
17 |
+
class Fuser(nn.Module):
|
18 |
+
def __init__(self, src_dim: int, tgt_dim: int, gated: bool = True):
|
19 |
+
super().__init__()
|
20 |
+
self.gated = gated
|
21 |
+
|
22 |
+
mid_dim = max(src_dim, tgt_dim) * 2
|
23 |
+
|
24 |
+
self.fwd = nn.Sequential(
|
25 |
+
nn.Conv2d(src_dim, mid_dim, kernel_size=3, stride=1, padding=1),
|
26 |
+
nn.GELU(),
|
27 |
+
nn.Conv2d(mid_dim, tgt_dim * (2 if gated else 1), kernel_size=3, stride=1, padding=1),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, src: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
|
31 |
+
if src.ndim == 3:
|
32 |
+
shape = tgt.shape[-2:]
|
33 |
+
else:
|
34 |
+
shape = src.shape[-2:]
|
35 |
+
|
36 |
+
nd = shape[0] * shape[1]
|
37 |
+
|
38 |
+
if src.ndim == 3:
|
39 |
+
src = src[:, -nd:].reshape(src.shape[0], src.shape[2], *shape)
|
40 |
+
|
41 |
+
if tgt.ndim == 3:
|
42 |
+
tgt_pre = tgt[:, :-nd]
|
43 |
+
tgt = tgt[:, -nd:].reshape(tgt.shape[0], tgt.shape[2], *shape)
|
44 |
+
else:
|
45 |
+
tgt_pre = None
|
46 |
+
|
47 |
+
pred = self.fwd(src)
|
48 |
+
|
49 |
+
if self.gated:
|
50 |
+
g, pred = torch.chunk(pred, 2, dim=1)
|
51 |
+
|
52 |
+
g = F.sigmoid(g)
|
53 |
+
|
54 |
+
pred = g * pred
|
55 |
+
|
56 |
+
tgt = tgt + pred
|
57 |
+
|
58 |
+
if tgt_pre is not None:
|
59 |
+
tgt = rearrange(tgt, 'b c h w -> b (h w) c')
|
60 |
+
tgt = torch.cat([tgt_pre, tgt], dim=1)
|
61 |
+
|
62 |
+
return tgt
|
63 |
+
|
64 |
+
|
65 |
+
class AttnDownsample(nn.Module):
|
66 |
+
def __init__(self, dim: int, window_size: int, num_heads: int = 16):
|
67 |
+
super().__init__()
|
68 |
+
self.q = nn.Parameter(torch.randn(1, num_heads, 1, dim // num_heads) * 0.01)
|
69 |
+
self.kv = nn.Linear(dim, dim * 2)
|
70 |
+
self.proj = nn.Linear(dim, dim)
|
71 |
+
self.window_size = window_size
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.head_dim = dim // num_heads
|
74 |
+
self.scale = self.head_dim ** -0.5
|
75 |
+
|
76 |
+
def forward(self, x: torch.Tensor, twod_shape: Tuple[int, int]) -> torch.Tensor:
|
77 |
+
ntok = twod_shape[0] * twod_shape[1]
|
78 |
+
x_pre = x[:, :-ntok]
|
79 |
+
|
80 |
+
B = x.shape[0]
|
81 |
+
ds_hw = tuple(s // self.window_size for s in twod_shape)
|
82 |
+
|
83 |
+
x_spat = rearrange(
|
84 |
+
x[:, -ntok:],
|
85 |
+
'b (h d1 w d2) c -> (b h w) (d1 d2) c',
|
86 |
+
h=ds_hw[0], w=ds_hw[1],
|
87 |
+
d1=self.window_size, d2=self.window_size,
|
88 |
+
)
|
89 |
+
|
90 |
+
B, N, C = x_spat.shape
|
91 |
+
|
92 |
+
k, v = self.kv(x_spat).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
93 |
+
|
94 |
+
q = (self.q * self.scale).expand(B, -1, -1, -1)
|
95 |
+
attn = q @ k.transpose(-2, -1)
|
96 |
+
attn = F.softmax(attn, dim=-1)
|
97 |
+
x = attn @ v
|
98 |
+
|
99 |
+
x = x.transpose(1, 2).reshape(B, C)
|
100 |
+
x = self.proj(x)
|
101 |
+
|
102 |
+
x = rearrange(x, '(b h w) c -> b (h w) c', b=x_pre.shape[0], h=ds_hw[0], w=ds_hw[1])
|
103 |
+
|
104 |
+
x = torch.cat([x_pre, x], dim=1)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class HybridModel(nn.Module):
|
109 |
+
def __init__(self, vit: tvit.VisionTransformer, conv: tconv.ConvNeXt, pretrained: bool = False,
|
110 |
+
concatenate: bool = False, **kwargs):
|
111 |
+
super().__init__()
|
112 |
+
self.conv = conv
|
113 |
+
self.vit = vit
|
114 |
+
self.concatenate = concatenate
|
115 |
+
|
116 |
+
conv.stages = nn.ModuleList(conv.stages)
|
117 |
+
vit.blocks = nn.ModuleList(vit.blocks)
|
118 |
+
|
119 |
+
self._half_vit_idx = len(vit.blocks) // 2 + 1
|
120 |
+
|
121 |
+
self._half_conv_idx = None
|
122 |
+
x = torch.empty(1, 3, 256, 256)
|
123 |
+
x = self.conv.stem(x)
|
124 |
+
for i in range(len(conv.stages)):
|
125 |
+
x = conv.stages[i](x)
|
126 |
+
if self._half_conv_idx is None and x.shape[-2:] == (16, 16):
|
127 |
+
self._half_conv_idx = i + 1
|
128 |
+
half_conv_dim = x.shape[1]
|
129 |
+
final_conv_dim = x.shape[1]
|
130 |
+
|
131 |
+
self.vit_to_conv_fusion = Fuser(vit.embed_dim, half_conv_dim)
|
132 |
+
self.conv_to_vit_fusion = Fuser(half_conv_dim, vit.embed_dim)
|
133 |
+
self.vit_ds = AttnDownsample(vit.embed_dim, window_size=2)
|
134 |
+
|
135 |
+
embed_dim = vit.embed_dim + (final_conv_dim if concatenate else 0)
|
136 |
+
if not concatenate:
|
137 |
+
self.final_fuse = Fuser(final_conv_dim, vit.embed_dim, gated=False)
|
138 |
+
self.final_block = tvit.Block(embed_dim, num_heads=16)
|
139 |
+
|
140 |
+
self.embed_dim = embed_dim
|
141 |
+
|
142 |
+
@property
|
143 |
+
def patch_size(self):
|
144 |
+
return 32
|
145 |
+
|
146 |
+
@property
|
147 |
+
def no_fsdp_wrap_types(self):
|
148 |
+
return {tvit.VisionTransformer, tconv.ConvNeXt}
|
149 |
+
|
150 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
151 |
+
return self.forward_features(x)
|
152 |
+
|
153 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
154 |
+
y_vit = self.vit.patch_generator(x)
|
155 |
+
|
156 |
+
for i in range(self._half_vit_idx):
|
157 |
+
y_vit = self.vit.blocks[i](y_vit)
|
158 |
+
|
159 |
+
y_conv = self.conv.stem(x)
|
160 |
+
for i in range(self._half_conv_idx):
|
161 |
+
y_conv = self.conv.stages[i](y_conv)
|
162 |
+
|
163 |
+
y_vit, y_conv = self.conv_to_vit_fusion(y_conv, y_vit), self.vit_to_conv_fusion(y_vit, y_conv)
|
164 |
+
|
165 |
+
y_vit = self.vit_ds(y_vit, y_conv.shape[-2:])
|
166 |
+
|
167 |
+
for i in range(self._half_vit_idx, len(self.vit.blocks)):
|
168 |
+
y_vit = self.vit.blocks[i](y_vit)
|
169 |
+
|
170 |
+
for i in range(self._half_conv_idx, len(self.conv.stages)):
|
171 |
+
y_conv = self.conv.stages[i](y_conv)
|
172 |
+
|
173 |
+
if self.concatenate:
|
174 |
+
y_conv = rearrange(y_conv, 'b c h w -> b (h w) c')
|
175 |
+
# Average pool across the board, and replicate for each cls/register token
|
176 |
+
conv_summary = y_conv.mean(dim=1, keepdim=True).expand(-1, self.vit.patch_generator.num_cls_patches, -1)
|
177 |
+
y_conv = torch.cat([conv_summary, y_conv], dim=1)
|
178 |
+
y = torch.cat([y_vit, y_conv], dim=2)
|
179 |
+
else:
|
180 |
+
y = self.final_fuse(y_conv, y_vit)
|
181 |
+
y = self.final_block(y)
|
182 |
+
|
183 |
+
summary = y[:, :self.vit.patch_generator.num_cls_tokens]
|
184 |
+
features = y[:, self.vit.patch_generator.num_cls_patches:]
|
185 |
+
|
186 |
+
return summary, features
|
187 |
+
|
188 |
+
|
189 |
+
@register_model
|
190 |
+
def hybrid_base(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
191 |
+
cfg = dict(num_classes=0, **kwargs)
|
192 |
+
conv = tconv.convnextv2_base(pretrained=pretrained, **cfg)
|
193 |
+
vit = tvit.vit_base_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
194 |
+
|
195 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
196 |
+
|
197 |
+
|
198 |
+
@register_model
|
199 |
+
def hybrid_large(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
200 |
+
cfg = dict(num_classes=0, **kwargs)
|
201 |
+
conv = tconv.convnextv2_large(pretrained=pretrained, **cfg)
|
202 |
+
vit = tvit.vit_large_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
203 |
+
|
204 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
205 |
+
|
206 |
+
|
207 |
+
@register_model
|
208 |
+
def hybrid_huge(pretrained=False, concatenate: bool = False, weight_init: str = 'skip', **kwargs):
|
209 |
+
cfg = dict(num_classes=0, **kwargs)
|
210 |
+
conv = tconv.convnextv2_huge(pretrained=pretrained, **cfg)
|
211 |
+
vit = et.vit_huge_patch16_224(pretrained=pretrained, weight_init=weight_init, **cfg)
|
212 |
+
|
213 |
+
return HybridModel(vit, conv, pretrained, concatenate=concatenate)
|
enable_cpe_support.py
CHANGED
@@ -19,6 +19,7 @@ from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermedi
|
|
19 |
from .extra_models import DinoWrapper
|
20 |
from .vit_patch_generator import ViTPatchGenerator
|
21 |
from .forward_intermediates import forward_intermediates
|
|
|
22 |
|
23 |
|
24 |
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
|
@@ -161,7 +162,9 @@ def enable_cpe(model: nn.Module,
|
|
161 |
):
|
162 |
if isinstance(model, VisionTransformer):
|
163 |
_enable_cpe_for_timm_vit(model, *args, **kwargs)
|
164 |
-
elif
|
165 |
_enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
|
|
|
|
|
166 |
else:
|
167 |
raise ValueError(f'CPE not supported for this model type: {type(model)}')
|
|
|
19 |
from .extra_models import DinoWrapper
|
20 |
from .vit_patch_generator import ViTPatchGenerator
|
21 |
from .forward_intermediates import forward_intermediates
|
22 |
+
from .dual_hybrid_vit import HybridModel
|
23 |
|
24 |
|
25 |
def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor:
|
|
|
162 |
):
|
163 |
if isinstance(model, VisionTransformer):
|
164 |
_enable_cpe_for_timm_vit(model, *args, **kwargs)
|
165 |
+
elif isinstance(model, DinoWrapper):
|
166 |
_enable_cpe_for_dv2_reg_vit(model, *args, **kwargs)
|
167 |
+
elif isinstance(model, HybridModel):
|
168 |
+
_enable_cpe_for_timm_vit(model.vit, *args, **kwargs)
|
169 |
else:
|
170 |
raise ValueError(f'CPE not supported for this model type: {type(model)}')
|
enable_spectral_reparam.py
CHANGED
@@ -155,7 +155,7 @@ def enable_spectral_reparam(model: Union[nn.Module, List[nn.Module]],
|
|
155 |
return True
|
156 |
|
157 |
p_name = f'{name}.parametrizations'
|
158 |
-
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name))
|
159 |
return is_prm
|
160 |
|
161 |
def parametrize_linear(linear: nn.Linear):
|
|
|
155 |
return True
|
156 |
|
157 |
p_name = f'{name}.parametrizations'
|
158 |
+
is_prm = any(k for k in state_dict_guidance if k.startswith(p_name) and k.endswith('_sn_version'))
|
159 |
return is_prm
|
160 |
|
161 |
def parametrize_linear(linear: nn.Linear):
|
extra_timm_models.py
CHANGED
@@ -6,7 +6,12 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
|
|
|
|
|
|
|
|
9 |
from torch import nn
|
|
|
10 |
|
11 |
from timm.models import register_model
|
12 |
from timm.models.vision_transformer import (
|
@@ -17,6 +22,7 @@ from timm.models.vision_transformer import (
|
|
17 |
LayerScale as TIMMLayerScale,
|
18 |
)
|
19 |
|
|
|
20 |
from . import dinov2_arch
|
21 |
|
22 |
|
@@ -24,7 +30,7 @@ from . import dinov2_arch
|
|
24 |
def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
25 |
""" ViT-Tiny (Vit-Ti/16)
|
26 |
"""
|
27 |
-
model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3
|
28 |
model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
29 |
return model
|
30 |
|
@@ -33,7 +39,7 @@ def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
33 |
def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
34 |
""" ViT-Small (ViT-S/16)
|
35 |
"""
|
36 |
-
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6
|
37 |
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
38 |
return model
|
39 |
|
@@ -43,16 +49,44 @@ def vit_base_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
43 |
""" ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
|
44 |
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
45 |
"""
|
46 |
-
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12
|
47 |
model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
48 |
return model
|
49 |
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
@register_model
|
52 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
53 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
54 |
"""
|
55 |
-
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16
|
56 |
if pretrained:
|
57 |
# There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
|
58 |
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
|
@@ -65,7 +99,7 @@ def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
65 |
def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
|
66 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
67 |
"""
|
68 |
-
model = vit_huge_patch16_224(pretrained=pretrained,
|
69 |
|
70 |
for m in model.modules():
|
71 |
if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
|
@@ -75,17 +109,19 @@ def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransforme
|
|
75 |
|
76 |
|
77 |
@register_model
|
78 |
-
def vit_giant_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
79 |
""" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
|
80 |
"""
|
81 |
-
model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24
|
82 |
model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
|
|
|
|
|
83 |
return model
|
84 |
|
85 |
|
86 |
@register_model
|
87 |
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
88 |
-
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6
|
89 |
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
|
90 |
return model
|
91 |
|
@@ -112,3 +148,59 @@ def _patch_layer_scale(model: VisionTransformer):
|
|
112 |
mod.ls2 = replace_ls(mod.ls2)
|
113 |
pass
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
import math
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import torch
|
13 |
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
|
16 |
from timm.models import register_model
|
17 |
from timm.models.vision_transformer import (
|
|
|
22 |
LayerScale as TIMMLayerScale,
|
23 |
)
|
24 |
|
25 |
+
# Import these to also register them
|
26 |
from . import dinov2_arch
|
27 |
|
28 |
|
|
|
30 |
def vit_tiny_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
31 |
""" ViT-Tiny (Vit-Ti/16)
|
32 |
"""
|
33 |
+
model_args = dict(patch_size=14, embed_dim=192, depth=12, num_heads=3)
|
34 |
model = _create_vision_transformer('vit_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
35 |
return model
|
36 |
|
|
|
39 |
def vit_small_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
40 |
""" ViT-Small (ViT-S/16)
|
41 |
"""
|
42 |
+
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6)
|
43 |
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
44 |
return model
|
45 |
|
|
|
49 |
""" ViT-Base (ViT-B/14) from original paper (https://arxiv.org/abs/2010.11929).
|
50 |
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
51 |
"""
|
52 |
+
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12)
|
53 |
model = _create_vision_transformer('vit_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
54 |
return model
|
55 |
|
56 |
|
57 |
+
@register_model
|
58 |
+
def vit_base_patch16_v2_224(pretrained=False, **kwargs) -> VisionTransformer:
|
59 |
+
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
60 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
61 |
+
"""
|
62 |
+
model_args = dict(
|
63 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
64 |
+
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
|
65 |
+
)
|
66 |
+
model = _create_vision_transformer(
|
67 |
+
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
68 |
+
return model
|
69 |
+
|
70 |
+
|
71 |
+
@register_model
|
72 |
+
def vit_large_patch16_v2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
73 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
74 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
75 |
+
"""
|
76 |
+
name = 'vit_large_patch14_reg4_dinov2'
|
77 |
+
model_args = dict(
|
78 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
|
79 |
+
reg_tokens=4, no_embed_class=True, img_size=518 * 16 // 14
|
80 |
+
)
|
81 |
+
model = _create_vision_transformer(name, pretrained=pretrained, **dict(model_args, **kwargs))
|
82 |
+
|
83 |
+
return model
|
84 |
+
|
85 |
@register_model
|
86 |
def vit_huge_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
87 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
88 |
"""
|
89 |
+
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16)
|
90 |
if pretrained:
|
91 |
# There is no pretrained version of ViT-H/16, but we can adapt a ViT-H/14 for this purpose
|
92 |
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=True, **dict(model_args, **kwargs))
|
|
|
99 |
def vit_huge_patch16_224_mlpnorm(pretrained=False, **kwargs) -> VisionTransformer:
|
100 |
""" ViT-Huge model (ViT-H/16) from original paper (https://arxiv.org/abs/2010.11929).
|
101 |
"""
|
102 |
+
model = vit_huge_patch16_224(pretrained=pretrained, **kwargs)
|
103 |
|
104 |
for m in model.modules():
|
105 |
if isinstance(m, Mlp) and not isinstance(m.norm, nn.LayerNorm):
|
|
|
109 |
|
110 |
|
111 |
@register_model
|
112 |
+
def vit_giant_patch16_224(pretrained=False, scaled_ln: bool = False, **kwargs) -> VisionTransformer:
|
113 |
""" ViT-giant model (ViT-g/16) from original paper (https://arxiv.org/abs/2010.11929).
|
114 |
"""
|
115 |
+
model_args = dict(patch_size=16, embed_dim=1536, depth=40, num_heads=24)
|
116 |
model = _create_vision_transformer('vit_giant_patch16_224', pretrained=False, **dict(model_args, **kwargs))
|
117 |
+
if scaled_ln:
|
118 |
+
_apply_scaled_ln(model)
|
119 |
return model
|
120 |
|
121 |
|
122 |
@register_model
|
123 |
def vit_bigG_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
124 |
+
model_args = dict(patch_size=14, embed_dim=1664, depth=48, num_heads=16, init_values=1e-6)
|
125 |
model = _create_vision_transformer('vit_bigG_patch14', pretrained=False, **dict(model_args, **kwargs))
|
126 |
return model
|
127 |
|
|
|
148 |
mod.ls2 = replace_ls(mod.ls2)
|
149 |
pass
|
150 |
|
151 |
+
|
152 |
+
class ScaledLayerNorm(nn.LayerNorm):
|
153 |
+
'''
|
154 |
+
https://arxiv.org/pdf/2502.05795v1
|
155 |
+
'''
|
156 |
+
def __init__(self, ln_base: nn.LayerNorm, depth: int = 0):
|
157 |
+
super().__init__(ln_base.normalized_shape, eps=ln_base.eps, elementwise_affine=ln_base.elementwise_affine)
|
158 |
+
self.load_state_dict(ln_base.state_dict())
|
159 |
+
self.register_buffer('ln_scale', torch.tensor(1.0 / math.sqrt(depth)), persistent=False)
|
160 |
+
|
161 |
+
def forward(self, x):
|
162 |
+
y = super().forward(x)
|
163 |
+
y = y * self.ln_scale
|
164 |
+
return y
|
165 |
+
|
166 |
+
|
167 |
+
class DyT(nn.Module):
|
168 |
+
def __init__(self, C: int, init_alpha: float):
|
169 |
+
super().__init__()
|
170 |
+
self.alpha = nn.Parameter(torch.full((1,), init_alpha))
|
171 |
+
self.gamma = nn.Parameter(torch.ones(C))
|
172 |
+
self.beta = nn.Parameter(torch.zeros(C))
|
173 |
+
|
174 |
+
def forward(self, x: torch.Tensor):
|
175 |
+
x = F.tanh(self.alpha * x)
|
176 |
+
return self.gamma * x + self.beta
|
177 |
+
|
178 |
+
@register_model
|
179 |
+
def vit_large_dyt_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
180 |
+
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
181 |
+
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
182 |
+
"""
|
183 |
+
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
|
184 |
+
model = _create_vision_transformer('vit_large_dyt_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
185 |
+
|
186 |
+
def _replace_ln_with_dyt(ln: nn.LayerNorm, depth: int):
|
187 |
+
return DyT(ln.normalized_shape[0], init_alpha=0.9)
|
188 |
+
_replace_ln(model, _replace_ln_with_dyt)
|
189 |
+
|
190 |
+
return model
|
191 |
+
|
192 |
+
|
193 |
+
def _apply_scaled_ln(model: VisionTransformer):
|
194 |
+
warnings.warn('Post-LayerNorm scaling activated!')
|
195 |
+
|
196 |
+
_replace_ln(model, lambda ln, depth: ScaledLayerNorm(ln, depth=depth))
|
197 |
+
|
198 |
+
def _replace_ln(model: VisionTransformer, fn):
|
199 |
+
def _inner_replace_ln(block: Block, depth: int, key: str):
|
200 |
+
prev = getattr(block, key)
|
201 |
+
if isinstance(prev, nn.LayerNorm):
|
202 |
+
setattr(block, key, fn(prev, depth=depth))
|
203 |
+
|
204 |
+
for i, block in enumerate(model.blocks):
|
205 |
+
_inner_replace_ln(block, i + 1, 'norm1')
|
206 |
+
_inner_replace_ln(block, i + 1, 'norm2')
|
forward_intermediates.py
CHANGED
@@ -6,7 +6,7 @@
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
-
from typing import Callable, List, Optional, Set, Tuple, Union, Any, Iterable
|
10 |
from types import MethodType
|
11 |
|
12 |
import torch
|
@@ -42,6 +42,7 @@ def forward_intermediates(
|
|
42 |
aggregation: Optional[str] = "sparse",
|
43 |
inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
|
44 |
norm_alpha_scheme = "post-alpha",
|
|
|
45 |
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
46 |
""" Forward features that returns intermediates.
|
47 |
|
@@ -65,6 +66,8 @@ def forward_intermediates(
|
|
65 |
reshape = output_fmt == 'NCHW'
|
66 |
intermediates = []
|
67 |
|
|
|
|
|
68 |
blocks = model.blocks
|
69 |
|
70 |
take_indices, max_index = _take_indices(len(blocks), indices)
|
@@ -90,7 +93,7 @@ def forward_intermediates(
|
|
90 |
take_off = 0
|
91 |
|
92 |
for i, blk in enumerate(blocks):
|
93 |
-
x = blk(x)
|
94 |
if aggregation == "dense":
|
95 |
# Arbitrarily use the rotation matrix from the final layer in the dense group
|
96 |
y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
|
|
|
6 |
# distribution of this software and related documentation without an express
|
7 |
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
|
9 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Union, Any, Iterable
|
10 |
from types import MethodType
|
11 |
|
12 |
import torch
|
|
|
42 |
aggregation: Optional[str] = "sparse",
|
43 |
inter_feature_normalizer: Optional[IntermediateFeatureNormalizerBase] = None,
|
44 |
norm_alpha_scheme = "post-alpha",
|
45 |
+
block_kwargs: Dict = None,
|
46 |
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
|
47 |
""" Forward features that returns intermediates.
|
48 |
|
|
|
66 |
reshape = output_fmt == 'NCHW'
|
67 |
intermediates = []
|
68 |
|
69 |
+
block_kwargs = block_kwargs or dict()
|
70 |
+
|
71 |
blocks = model.blocks
|
72 |
|
73 |
take_indices, max_index = _take_indices(len(blocks), indices)
|
|
|
93 |
take_off = 0
|
94 |
|
95 |
for i, blk in enumerate(blocks):
|
96 |
+
x = blk(x, **block_kwargs)
|
97 |
if aggregation == "dense":
|
98 |
# Arbitrarily use the rotation matrix from the final layer in the dense group
|
99 |
y, alpha = inter_feature_normalizer(x, i, rot_index=take_indices[take_off], skip=num_summary_tokens)
|
radio_model.py
CHANGED
@@ -18,6 +18,7 @@ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
|
|
18 |
from . import eradio_model
|
19 |
from .enable_spectral_reparam import configure_spectral_reparam_from_args
|
20 |
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
|
|
|
21 |
|
22 |
|
23 |
class Resolution(NamedTuple):
|
@@ -69,7 +70,7 @@ class RADIOModel(nn.Module):
|
|
69 |
patch_gen = getattr(self.model, "patch_generator", None)
|
70 |
if patch_gen is not None:
|
71 |
return patch_gen.num_skip
|
72 |
-
elif self.model
|
73 |
return 0
|
74 |
return 1
|
75 |
|
@@ -81,7 +82,7 @@ class RADIOModel(nn.Module):
|
|
81 |
patch_gen = getattr(self.model, 'patch_generator', None)
|
82 |
if patch_gen is not None:
|
83 |
return patch_gen.num_cls_tokens
|
84 |
-
elif self.model
|
85 |
return 0
|
86 |
return 1
|
87 |
|
@@ -218,7 +219,10 @@ class RADIOModel(nn.Module):
|
|
218 |
ret = dict(backbone=ret)
|
219 |
for name, adaptor in self.adaptors.items():
|
220 |
if all_summary.ndim == 3:
|
221 |
-
|
|
|
|
|
|
|
222 |
else:
|
223 |
summary = all_summary
|
224 |
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
|
@@ -326,10 +330,6 @@ def create_model_from_args(args) -> nn.Module:
|
|
326 |
|
327 |
model.head = nn.Identity()
|
328 |
|
329 |
-
assert (
|
330 |
-
not args.cls_token_per_teacher or args.cpe_max_size is not None
|
331 |
-
), "CPE must be enabled for multiple CLS tokens!"
|
332 |
-
|
333 |
if args.cpe_max_size is not None:
|
334 |
uq_teachers = set(t['name'] for t in args.teachers)
|
335 |
enable_cpe(
|
|
|
18 |
from . import eradio_model
|
19 |
from .enable_spectral_reparam import configure_spectral_reparam_from_args
|
20 |
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
|
21 |
+
from . import dual_hybrid_vit
|
22 |
|
23 |
|
24 |
class Resolution(NamedTuple):
|
|
|
70 |
patch_gen = getattr(self.model, "patch_generator", None)
|
71 |
if patch_gen is not None:
|
72 |
return patch_gen.num_skip
|
73 |
+
elif getattr(self.model, 'global_pool', None) == 'avg':
|
74 |
return 0
|
75 |
return 1
|
76 |
|
|
|
82 |
patch_gen = getattr(self.model, 'patch_generator', None)
|
83 |
if patch_gen is not None:
|
84 |
return patch_gen.num_cls_tokens
|
85 |
+
elif getattr(self.model, 'global_pool', None) == 'avg':
|
86 |
return 0
|
87 |
return 1
|
88 |
|
|
|
219 |
ret = dict(backbone=ret)
|
220 |
for name, adaptor in self.adaptors.items():
|
221 |
if all_summary.ndim == 3:
|
222 |
+
if all_summary.shape[1] == 1:
|
223 |
+
summary = all_summary[:, 0]
|
224 |
+
else:
|
225 |
+
summary = all_summary[:, adaptor.head_idx]
|
226 |
else:
|
227 |
summary = all_summary
|
228 |
ada_input = AdaptorInput(images=x, summary=summary.float(), features=all_feat, feature_fmt=feature_fmt, patch_size=self.patch_size)
|
|
|
330 |
|
331 |
model.head = nn.Identity()
|
332 |
|
|
|
|
|
|
|
|
|
333 |
if args.cpe_max_size is not None:
|
334 |
uq_teachers = set(t['name'] for t in args.teachers)
|
335 |
enable_cpe(
|
vit_patch_generator.py
CHANGED
@@ -106,6 +106,10 @@ class ViTPatchGenerator(nn.Module):
|
|
106 |
def num_cls_tokens(self):
|
107 |
return self.cls_token.num_tokens
|
108 |
|
|
|
|
|
|
|
|
|
109 |
@property
|
110 |
def num_registers(self):
|
111 |
return self.cls_token.num_registers
|
@@ -119,10 +123,6 @@ class ViTPatchGenerator(nn.Module):
|
|
119 |
'pos_embed',
|
120 |
]
|
121 |
|
122 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
123 |
-
if self.abs_pos:
|
124 |
-
self._load_embed(state_dict[f'{prefix}pos_embed'], self.pos_embed)
|
125 |
-
|
126 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
127 |
if src_embed.shape != targ_embed.shape:
|
128 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
@@ -285,18 +285,3 @@ class ViTPatchLinear(nn.Linear):
|
|
285 |
**factory
|
286 |
)
|
287 |
self.patch_size = patch_size
|
288 |
-
|
289 |
-
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
290 |
-
if self.bias is not None:
|
291 |
-
self.bias.data.copy_(state_dict[f'{prefix}bias'])
|
292 |
-
|
293 |
-
chk_weight = state_dict[f'{prefix}weight']
|
294 |
-
if chk_weight.shape != self.weight.shape:
|
295 |
-
src_patch_size = int(math.sqrt(chk_weight.shape[1] // 3))
|
296 |
-
|
297 |
-
assert (src_patch_size ** 2) * 3 == chk_weight.shape[1], 'Unable to interpolate non-square patch size'
|
298 |
-
|
299 |
-
chk_weight = rearrange(chk_weight, 'b (c h w) -> b c h w', c=3, h=src_patch_size, w=src_patch_size)
|
300 |
-
chk_weight = F.interpolate(chk_weight, size=(self.patch_size, self.patch_size), mode='bicubic', align_corners=True, antialias=False)
|
301 |
-
chk_weight = rearrange(chk_weight, 'b c h w -> b (c h w)')
|
302 |
-
self.weight.data.copy_(chk_weight)
|
|
|
106 |
def num_cls_tokens(self):
|
107 |
return self.cls_token.num_tokens
|
108 |
|
109 |
+
@property
|
110 |
+
def num_cls_patches(self):
|
111 |
+
return self.cls_token.num_patches
|
112 |
+
|
113 |
@property
|
114 |
def num_registers(self):
|
115 |
return self.cls_token.num_registers
|
|
|
123 |
'pos_embed',
|
124 |
]
|
125 |
|
|
|
|
|
|
|
|
|
126 |
def _load_embed(self, src_embed: torch.Tensor, targ_embed: nn.Parameter):
|
127 |
if src_embed.shape != targ_embed.shape:
|
128 |
src_size = int(math.sqrt(src_embed.shape[1]))
|
|
|
285 |
**factory
|
286 |
)
|
287 |
self.patch_size = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|