Commit
·
51d254e
1
Parent(s):
9ee8b6b
fix k_mask
Browse files- modeling_dcformer.py +8 -6
modeling_dcformer.py
CHANGED
@@ -123,7 +123,7 @@ class DCFormer(PreTrainedModel):
|
|
123 |
def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
|
124 |
batch_size, seq_length = input_ids.shape
|
125 |
input_pos = torch.arange(seq_length, device=self.device)
|
126 |
-
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate
|
127 |
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
|
128 |
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
|
129 |
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
@@ -162,12 +162,14 @@ class DCFormer(PreTrainedModel):
|
|
162 |
for i, layer in enumerate(self.layers):
|
163 |
if self.is_training or self.window_size is None :
|
164 |
layer_mask = mask
|
|
|
165 |
elif self.window_size is not None:
|
166 |
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
|
|
|
167 |
if self.use_gradient_checkpointing:
|
168 |
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
|
169 |
else:
|
170 |
-
x = layer(x, input_pos, freqs_cis, layer_mask)
|
171 |
x = self.norm(x)
|
172 |
logits = self.output(x)
|
173 |
if return_tensor:
|
@@ -185,8 +187,8 @@ class DCFormerBlock(nn.Module):
|
|
185 |
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
186 |
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
187 |
|
188 |
-
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
|
189 |
-
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True)
|
190 |
out = h + self.feed_forward(self.ffn_norm(h))
|
191 |
return out
|
192 |
|
@@ -416,7 +418,7 @@ class DCMHAttention(nn.Module):
|
|
416 |
y = probs @ v
|
417 |
return y
|
418 |
|
419 |
-
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True) -> Tensor:
|
420 |
bsz, seqlen, _ = x.shape
|
421 |
|
422 |
kv_size = self.n_local_heads * self.head_dim
|
@@ -483,7 +485,7 @@ class DCMHAttention(nn.Module):
|
|
483 |
y[:,:,start:stop] = _o
|
484 |
else: # inference
|
485 |
if seqlen == 1: # one-token generation
|
486 |
-
k_mask = mask if self.window_size is None else
|
487 |
if fast_infer:
|
488 |
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
|
489 |
else:
|
|
|
123 |
def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None):
|
124 |
batch_size, seq_length = input_ids.shape
|
125 |
input_pos = torch.arange(seq_length, device=self.device)
|
126 |
+
generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device)
|
127 |
generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int)
|
128 |
logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True)
|
129 |
_next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
|
|
|
162 |
for i, layer in enumerate(self.layers):
|
163 |
if self.is_training or self.window_size is None :
|
164 |
layer_mask = mask
|
165 |
+
gen_mask = None
|
166 |
elif self.window_size is not None:
|
167 |
layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0]
|
168 |
+
gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None
|
169 |
if self.use_gradient_checkpointing:
|
170 |
x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask)
|
171 |
else:
|
172 |
+
x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask)
|
173 |
x = self.norm(x)
|
174 |
logits = self.output(x)
|
175 |
if return_tensor:
|
|
|
187 |
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
188 |
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
189 |
|
190 |
+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor:
|
191 |
+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, gen_mask=gen_mask, fast_infer=True)
|
192 |
out = h + self.feed_forward(self.ffn_norm(h))
|
193 |
return out
|
194 |
|
|
|
418 |
y = probs @ v
|
419 |
return y
|
420 |
|
421 |
+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor:
|
422 |
bsz, seqlen, _ = x.shape
|
423 |
|
424 |
kv_size = self.n_local_heads * self.head_dim
|
|
|
485 |
y[:,:,start:stop] = _o
|
486 |
else: # inference
|
487 |
if seqlen == 1: # one-token generation
|
488 |
+
k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length]
|
489 |
if fast_infer:
|
490 |
y = self._generate_fast(x, input_pos, q, k, v, k_mask)
|
491 |
else:
|