Hilbertmeng commited on
Commit
d42522e
·
1 Parent(s): 95d1f02

remove stack_hidden

Browse files
config.json CHANGED
@@ -26,7 +26,6 @@
26
  "rotary_pct": 0.25,
27
  "round64": true,
28
  "sepln": true,
29
- "stack_hidden": false,
30
  "tie_word_embeddings": false,
31
  "torch_dtype": "bfloat16",
32
  "transformers_version": "4.35.0",
 
26
  "rotary_pct": 0.25,
27
  "round64": true,
28
  "sepln": true,
 
29
  "tie_word_embeddings": false,
30
  "torch_dtype": "bfloat16",
31
  "transformers_version": "4.35.0",
configuration_muddpythia.py CHANGED
@@ -31,7 +31,6 @@ class MUDDPythiaConfig(PretrainedConfig):
31
  eos_token_id: int =2,
32
  tie_word_embeddings: bool =False,
33
  use_layer_cache: bool = True,
34
- stack_hidden: bool = False,
35
  dense: bool = True,
36
  dynamic_dense: bool = True,
37
  sepln: bool = True,
@@ -58,7 +57,6 @@ class MUDDPythiaConfig(PretrainedConfig):
58
  self.rotary_pct = rotary_pct
59
 
60
  self.use_layer_cache= use_layer_cache
61
- self.stack_hidden= stack_hidden
62
  self.dense= dense
63
  self.dynamic_dense= dynamic_dense
64
  self.sepln= sepln
 
31
  eos_token_id: int =2,
32
  tie_word_embeddings: bool =False,
33
  use_layer_cache: bool = True,
 
34
  dense: bool = True,
35
  dynamic_dense: bool = True,
36
  sepln: bool = True,
 
57
  self.rotary_pct = rotary_pct
58
 
59
  self.use_layer_cache= use_layer_cache
 
60
  self.dense= dense
61
  self.dynamic_dense= dynamic_dense
62
  self.sepln= sepln
modeling_muddpythia.py CHANGED
@@ -85,7 +85,6 @@ class MUDDPythia(PreTrainedModel):
85
 
86
  self.layer_cache = None
87
  self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
88
- self.stack_hidden = self.config.stack_hidden
89
  self.dynamic = self.config.dynamic_dense
90
  self.dense = self.config.dense
91
  if self.dynamic:
@@ -167,11 +166,11 @@ class MUDDPythia(PreTrainedModel):
167
  _hidden = self.layer_cache.update(x, i+1) # LBTD
168
  else:
169
  hiddens.append(x)
170
- _hidden = hiddens if not self.stack_hidden else hiddens
171
  if self.dynamic and self.dense:
172
  dw = self.dynamic_dense[i](x) # BTD -> CBTL
173
  dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
174
- if self.stack_hidden:
175
  x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
176
  else:
177
  x = self.dynamic_dense[i].layer_mix(_hidden, dw)
@@ -207,7 +206,7 @@ class TransformerBlock(nn.Module):
207
  normed_x = self.attention_norm(x)
208
  elif self.config.dense_type == 'qkvr':
209
  res = x[-1] # for mlp
210
- if self.config.stack_hidden or not self.config.sepln:
211
  normed_x = self.attention_norm(x[:3])
212
  else:
213
  normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
@@ -259,10 +258,7 @@ class Attention(nn.Module):
259
  if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
260
  bsz, seqlen, _ = x.shape
261
  else:
262
- if self.config.stack_hidden:
263
- C, bsz, seqlen, _ = x.shape
264
- else:
265
- C, (bsz, seqlen, _) = len(x), x[0].shape
266
  kv_size = self.n_local_heads * self.head_dim
267
 
268
  if self.config.dense_type == 'l' or not self.config.dense:
 
85
 
86
  self.layer_cache = None
87
  self.use_layer_cache = False if self.is_training else self.config.use_layer_cache
 
88
  self.dynamic = self.config.dynamic_dense
89
  self.dense = self.config.dense
90
  if self.dynamic:
 
166
  _hidden = self.layer_cache.update(x, i+1) # LBTD
167
  else:
168
  hiddens.append(x)
169
+ _hidden = torch.stack(hiddens)
170
  if self.dynamic and self.dense:
171
  dw = self.dynamic_dense[i](x) # BTD -> CBTL
172
  dw = dw + self.dense_bs[i][:,None,None,:] # CBTL
173
+ if seqlen > 1:
174
  x = torch.einsum('LBTD, CBTL -> CBTD', _hidden, dw)
175
  else:
176
  x = self.dynamic_dense[i].layer_mix(_hidden, dw)
 
206
  normed_x = self.attention_norm(x)
207
  elif self.config.dense_type == 'qkvr':
208
  res = x[-1] # for mlp
209
+ if not self.config.sepln:
210
  normed_x = self.attention_norm(x[:3])
211
  else:
212
  normed_x = tuple([norm_fn(_x) for norm_fn, _x in zip(self.attention_norms, x[:3])])
 
258
  if self.lidx == 0 or self.config.dense_type == 'l' or not self.config.dense:
259
  bsz, seqlen, _ = x.shape
260
  else:
261
+ C, (bsz, seqlen, _) = len(x), x[0].shape
 
 
 
262
  kv_size = self.n_local_heads * self.head_dim
263
 
264
  if self.config.dense_type == 'l' or not self.config.dense: