Commit
·
d42522e
1
Parent(s):
95d1f02
remove stack_hidden
Browse files- config.json +0 -1
- configuration_muddpythia.py +0 -2
- modeling_muddpythia.py +4 -8
config.json
CHANGED
@@ -26,7 +26,6 @@
|
|
26 |
"rotary_pct": 0.25,
|
27 |
"round64": true,
|
28 |
"sepln": true,
|
29 |
-
"stack_hidden": false,
|
30 |
"tie_word_embeddings": false,
|
31 |
"torch_dtype": "bfloat16",
|
32 |
"transformers_version": "4.35.0",
|
|
|
26 |
"rotary_pct": 0.25,
|
27 |
"round64": true,
|
28 |
"sepln": true,
|
|
|
29 |
"tie_word_embeddings": false,
|
30 |
"torch_dtype": "bfloat16",
|
31 |
"transformers_version": "4.35.0",
|
configuration_muddpythia.py
CHANGED
@@ -31,7 +31,6 @@ class MUDDPythiaConfig(PretrainedConfig):
|
|
31 |
eos_token_id: int =2,
|
32 |
tie_word_embeddings: bool =False,
|
33 |
use_layer_cache: bool = True,
|
34 |
-
stack_hidden: bool = False,
|
35 |
dense: bool = True,
|
36 |
dynamic_dense: bool = True,
|
37 |
sepln: bool = True,
|
@@ -58,7 +57,6 @@ class MUDDPythiaConfig(PretrainedConfig):
|
|
58 |
self.rotary_pct = rotary_pct
|
59 |
|
60 |
self.use_layer_cache= use_layer_cache
|
61 |
-
self.stack_hidden= stack_hidden
|
62 |
self.dense= dense
|
63 |
self.dynamic_dense= dynamic_dense
|
64 |
self.sepln= sepln
|
|
|
31 |
eos_token_id: int =2,
|
32 |
tie_word_embeddings: bool =False,
|
33 |
use_layer_cache: bool = True,
|
|
|
34 |
dense: bool = True,
|
35 |
dynamic_dense: bool = True,
|
36 |
sepln: bool = True,
|
|
|
57 |
self.rotary_pct = rotary_pct
|
58 |
|
59 |
self.use_layer_cache= use_layer_cache
|
|
|
60 |
self.dense= dense
|
61 |
self.dynamic_dense= dynamic_dense
|
62 |
self.sepln= sepln
|
modeling_muddpythia.py
CHANGED
@@ -85,7 +85,6 @@ class MUDDPythia(PreTrainedModel):
|
|
85 |
|
86 |
self.layer_cache = None
|
87 |
self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
|
88 |
-
self.stack_hidden = self.config.stack_hidden
|
89 |
self.dynamic = self.config.dynamic_dense
|
90 |
self.dense = self.config.dense
|
91 |
if self.dynamic:
|
@@ -167,11 +166,11 @@ class MUDDPythia(PreTrainedModel):
|
|
167 |
_hidden = self.layer_cache.update(x, i+1) # LBTD
|
168 |
else:
|
169 |
hiddens.append(x)
|
170 |
-
_hidden =
|
171 |
if self.dynamic and self.dense:
|
172 |
dw = self.dynamic_dense[i](x) # BTD -> CBTL
|
173 |
dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
|
174 |
-
if
|
175 |
x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
|
176 |
else:
|
177 |
x = self.dynamic_dense[i].layer_mix(_hidden, dw)
|
@@ -207,7 +206,7 @@ class TransformerBlock(nn.Module):
|
|
207 |
normed_x = self.attention_norm(x)
|
208 |
elif self.config.dense_type == 'qkvr':
|
209 |
res = x[-1] # for mlp
|
210 |
-
if
|
211 |
normed_x = self.attention_norm(x[:3])
|
212 |
else:
|
213 |
normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
|
@@ -259,10 +258,7 @@ class Attention(nn.Module):
|
|
259 |
if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
|
260 |
bsz, seqlen, _ = x.shape
|
261 |
else:
|
262 |
-
|
263 |
-
C, bsz, seqlen, _ = x.shape
|
264 |
-
else:
|
265 |
-
C, (bsz, seqlen, _) = len(x), x[0].shape
|
266 |
kv_size = self.n_local_heads * self.head_dim
|
267 |
|
268 |
if self.config.dense_type == 'l' or not self.config.dense:
|
|
|
85 |
|
86 |
self.layer_cache = None
|
87 |
self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
|
|
|
88 |
self.dynamic = self.config.dynamic_dense
|
89 |
self.dense = self.config.dense
|
90 |
if self.dynamic:
|
|
|
166 |
_hidden = self.layer_cache.update(x, i+1) # LBTD
|
167 |
else:
|
168 |
hiddens.append(x)
|
169 |
+
_hidden = torch.stack(hiddens)
|
170 |
if self.dynamic and self.dense:
|
171 |
dw = self.dynamic_dense[i](x) # BTD -> CBTL
|
172 |
dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
|
173 |
+
if seqlen > 1:
|
174 |
x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
|
175 |
else:
|
176 |
x = self.dynamic_dense[i].layer_mix(_hidden, dw)
|
|
|
206 |
normed_x = self.attention_norm(x)
|
207 |
elif self.config.dense_type == 'qkvr':
|
208 |
res = x[-1] # for mlp
|
209 |
+
if not self.config.sepln:
|
210 |
normed_x = self.attention_norm(x[:3])
|
211 |
else:
|
212 |
normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
|
|
|
258 |
if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
|
259 |
bsz, seqlen, _ = x.shape
|
260 |
else:
|
261 |
+
C, (bsz, seqlen, _) = len(x), x[0].shape
|
|
|
|
|
|
|
262 |
kv_size = self.n_local_heads * self.head_dim
|
263 |
|
264 |
if self.config.dense_type == 'l' or not self.config.dense:
|