qcw2333 commited on
Commit
abd23e5
·
verified ·
1 Parent(s): 97e7d77

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +174 -361
model.py CHANGED
@@ -1,112 +1,35 @@
1
- """Full definition of a GPT NeoX Language Model, all of it in this single file.
 
 
2
 
3
- Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
- https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
  """
 
 
6
  import math, random
7
  import numpy as np
8
  from typing import Any, List, Optional, Tuple
 
 
9
 
10
  import torch
11
  import torch.nn as nn
12
- from lightning_utilities.core.imports import RequirementCache
13
- from typing_extensions import Self
14
- from flash_attn import flash_attn_func
15
- # from lit_gpt.config import Config
16
- from xformers.ops import SwiGLU
17
-
18
  import torch.nn.functional as F
19
- # from .fused_rotary_embedding import apply_rotary_emb_func
20
- RoPECache = Tuple[torch.Tensor, torch.Tensor]
21
- KVCache = Tuple[torch.Tensor, torch.Tensor]
22
- PretokenCache = torch.Tensor
23
- # Tuple[torch.Tensor, torch.Tensor]
24
- FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
25
- from einops import rearrange
26
- from transformers import PreTrainedModel, Cache, DynamicCache
27
-
28
- from huggingface_hub import PyTorchModelHubMixin
29
- from .model_config import YingLongConfig
30
-
31
- # from torch.distributions import Normal, LowRankMultivariateNormal, kl_divergence,MultivariateNormal
32
 
33
- class quantitleLoss(torch.nn.Module):
34
- def __init__(self,
35
- qSize = 99,
36
- patch_size = 16,
37
- *args,**kwargs) -> None:
38
-
39
- super().__init__()
40
- self.qSize = qSize
41
- self.patch_size = patch_size
42
-
43
-
44
- q = np.array([i+1 for i in range(self.qSize)])
45
- q = q / (self.qSize + 1)
46
- q = q.reshape((1,1,-1))
47
-
48
- q_variance = q*(1-q)
49
-
50
- self.register_buffer('q', torch.tensor(q))
51
- self.register_buffer('q_variance', torch.tensor(q_variance))
52
 
53
-
54
- def forward(self, input: torch.Tensor, target: torch.Tensor,rel_loss = False) -> torch.Tensor:
55
-
56
-
57
-
58
- target = target.unsqueeze(-1)
59
- input = input[:,:target.shape[1],:,:]
60
-
61
-
62
- posPart = input - target
63
- negPart = -posPart
64
-
65
- raw_loss = torch.maximum(self.q * negPart, (1-self.q) * posPart)
66
 
67
- target_absmean = torch.mean(target.abs(),dim = (1,2),keepdims = True)
68
- raw_loss = raw_loss / torch.sqrt(self.q_variance) / (target_absmean + 1e-4)
69
-
70
- return torch.mean(raw_loss)
71
-
72
 
73
- def haarMatrix_unnormalized(n):
74
- # Allow only size n of power 2
75
- n = 2**np.ceil(np.log2(n))
76
- if n > 2:
77
- h = haarMatrix(n / 2)
78
- else:
79
- return np.array([[1, 1], [1, -1]])
80
-
81
- # calculate upper haar part
82
- h_n = np.kron(h, [1, 1])
83
- # calculate lower haar part
84
- # if normalized:
85
- # h_i = np.sqrt(n/2)*np.kron(np.eye(len(h)), [1, -1])
86
- # else:
87
- h_i = np.kron(np.eye(len(h)), [1, -1])
88
- # combine parts
89
- h = np.vstack((h_n, h_i))
90
- return h
91
 
 
 
92
 
93
- def haarMatrix(n,normalized = 'ortho'):
94
- h = haarMatrix_unnormalized(n)
95
- scaler = np.diag(1/np.sqrt(np.diag(h@h.transpose())))
96
- if normalized == 'ortho':
97
- return scaler @ h
98
- elif normalized == 'forward':
99
- return scaler @ h/ np.sqrt(n)
100
-
101
- else:
102
- return scaler @ h * np.sqrt(n)
103
- # else:
104
- # scaler = 1
105
 
106
 
107
 
108
-
109
-
110
  class Tokenizer(torch.nn.Module):
111
  def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
112
  super().__init__()
@@ -119,6 +42,7 @@ class Tokenizer(torch.nn.Module):
119
 
120
  self.register_buffer('mask_token', torch.zeros(1000))
121
  if self.config.haar_trans:
 
122
  self.register_buffer('haar_transform',torch.Tensor(haarMatrix(self.config.patch_size,normalized = self.config.haar_trans_norm)))
123
 
124
 
@@ -167,7 +91,6 @@ class Tokenizer(torch.nn.Module):
167
 
168
  else:
169
 
170
-
171
  factor = 1
172
  more_rows = future_token // self.patch_size + 1
173
  prev_more_rows = prev_token // self.patch_size + 1
@@ -187,7 +110,6 @@ class Tokenizer(torch.nn.Module):
187
  masks = [jj for jj in range(x_featured.shape[1])]
188
  masks = masks[-more_rows:]
189
 
190
- # if not mean_replace:
191
  x_featured[:,-more_rows:] = self.mask0(self.mask_token[:len(masks)].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
192
  x_featured[:,:prev_more_rows] = self.mask0(self.mask_token[:prev_more_rows].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
193
 
@@ -199,26 +121,30 @@ class Tokenizer(torch.nn.Module):
199
  class model_tmp(PreTrainedModel):
200
  config_class = YingLongConfig
201
  base_model_prefix = "model"
202
- # supports_gradient_checkpointing = True
203
- # _no_split_modules = ["TimeMoeDecoderLayer"]
204
- # _skip_keys_device_placement = "past_key_values"
205
- _supports_flash_attn_2 = True
206
- _supports_sdpa = False
207
- _supports_cache_class = True
208
 
209
- # class GPT(nn.Module,PreTrainedModel,PyTorchModelHubMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  class GPT(model_tmp):
211
  def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
212
 
213
 
214
- # config_class = YingLongConfig
215
- # base_model_prefix = "model"
216
- # # supports_gradient_checkpointing = True
217
- # # _no_split_modules = ["TimeMoeDecoderLayer"]
218
- # # _skip_keys_device_placement = "past_key_values"
219
- # _supports_flash_attn_2 = True
220
- # _supports_sdpa = False
221
- # _supports_cache_class = True
222
  super().__init__(config)
223
 
224
  self.config = config
@@ -227,49 +153,28 @@ class GPT(model_tmp):
227
 
228
 
229
  if self.config._norm_class == "RMSNorm":
230
- # from .model import RMSNorm
231
  self.config.norm_class = RMSNorm
232
  elif self.config._norm_class == "FusedRMSNorm":
233
- # from .model import FusedRMSNorm
234
  self.config.norm_class = FusedRMSNorm
235
  elif self.config._norm_class == 'BatchNorm':
236
- # from .model import iBatchNorm
237
  self.config.norm_class = iBatchNorm
238
 
239
-
240
 
241
  if self.config._mlp_class == "GptNeoxMLP":
242
- # from .model import GptNeoxMLP
243
  self.config.mlp_class = GptNeoxMLP
244
  elif self.config._mlp_class == "LLaMAMLP":
245
- # from .model import LLaMAMLP
246
  self.config.mlp_class = LLaMAMLP
247
 
248
-
249
- if config.stats_encoding:
250
- self.stat_tokens = 1
251
- else:
252
- self.stat_tokens = 0
253
-
254
-
255
-
256
-
257
 
258
  self.tokenizer = Tokenizer(config)
259
 
260
- # self.lm_head = nn.Sequential(config.norm_class(config.n_embd, eps=config.norm_eps),
261
- # nn.Linear(config.n_embd, config.n_embd*4),
262
- # nn.ReLU(),
263
- # nn.Linear(config.n_embd*4, 99*self.patch_size),
264
- # )
265
-
266
-
267
  self.lm_head = nn.Linear(config.n_embd, 99*self.patch_size)
268
-
269
-
270
- # self.gate = nn.Linear(config.n_embd, 1)
271
 
272
-
273
  self.quantitleLoss = quantitleLoss(99,patch_size = self.patch_size)
274
 
275
 
@@ -296,48 +201,16 @@ class GPT(model_tmp):
296
 
297
 
298
 
299
- self.rope_cache: Optional[RoPECache] = None
300
- self.mask_cache: Optional[torch.Tensor] = None
301
- self.kv_caches: List[KVCache] = []
302
 
303
 
304
- def _init_weights(self, module: nn.Module) -> None:
305
- """Meant to be used with `gpt.apply(gpt._init_weights)`."""
306
- # GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf
307
- if isinstance(module, nn.Embedding):
308
- torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
309
- # RWKV: set it to 1e-4
310
- # torch.nn.init.uniform_(module.weight, -1e-4, 1e-4)
311
- elif isinstance(module, nn.Linear):
312
- torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
313
- if module.bias is not None:
314
- torch.nn.init.zeros_(module.bias)
315
- # GPT-NeoX
316
- for name, p in module.named_parameters():
317
- if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, BidirectedlSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3
318
- nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / self.config.n_layer)
319
-
320
-
321
- def reset_cache(self) -> None:
322
- self.kv_caches.clear()
323
- if self.mask_cache is not None and self.mask_cache.device.type == "xla":
324
- # https://github.com/Lightning-AI/lit-gpt/pull/83#issuecomment-1558150179
325
- self.rope_cache = None
326
- self.mask_cache = None
327
 
328
  def forward(
329
  self, idx: torch.Tensor,
330
- max_seq_length: Optional[int] = None,
331
- input_pos: Optional[torch.Tensor] = None,
332
- next_token: torch.Tensor = None,
333
  future_token: int = 0,
334
  prev_token: int = 0,
335
- val: bool = False,
336
- print_intermediate: bool = False,
337
- cot_rounds: int = -1,
338
- sequential: bool = False,
339
  *args,**kwargs,
340
- ) -> torch.Tensor:
341
 
342
  if future_token > 0:
343
  more_rows = future_token // self.patch_size + 1
@@ -348,138 +221,53 @@ class GPT(model_tmp):
348
 
349
  B, T = idx.size()
350
 
351
- use_kv_cache = input_pos is not None
352
 
353
  block_size = self.config.block_size
354
- if max_seq_length is None:
355
- max_seq_length = block_size
356
-
357
 
358
- if use_kv_cache: # not relevant otherwise
359
- assert (
360
- max_seq_length >= T
361
- ), f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}"
362
  assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
363
- if self.rope_cache is None:
364
- self.rope_cache = self.build_rope_cache(idx)
365
- if use_kv_cache and self.mask_cache is None:
366
- self.mask_cache = self.build_mask_cache(idx)
367
  cos, sin = self.rope_cache
368
- if use_kv_cache:
369
- if self.stat_tokens:
370
- if len(input_pos) == 1:
371
- idx = idx[:,input_pos]
372
- input_pos = input_pos.add_(1)
373
- else:
374
- input_pos = torch.arange(0, input_pos[-1]+2, device=idx.device)
375
-
376
- cos = cos.index_select(0, input_pos)
377
- sin = sin.index_select(0, input_pos)
378
- mask = self.mask_cache.index_select(2, input_pos)
379
- mask = mask[:, :, :, :max_seq_length]
380
-
381
- else:
382
- cos = cos.index_select(0, input_pos)
383
- sin = sin.index_select(0, input_pos)
384
- idx = idx[:,input_pos]
385
- else:
386
- cos = cos[:max(T,1024) + self.stat_tokens]
387
- sin = sin[:max(T,1024) + self.stat_tokens]
388
- mask = None
389
 
390
- idx_ori = idx
 
391
 
 
392
 
393
 
394
- if use_kv_cache:
395
- pass
396
- else:
397
- x,x_raw,masks,mean,std,x_0 = self.tokenizer(idx,
398
- future_token =future_token,
399
- prev_token = prev_token,
400
- sequential = sequential,
401
- )
402
 
403
 
404
 
405
-
406
  if self.unet:
407
  skips = []
408
 
409
 
410
 
411
- res_intermediate = []
412
- target_intermediate = []
413
- if not use_kv_cache:
414
-
415
- if cot_rounds <0:
416
- cot_rounds = self.config.n_cot
417
-
418
- res_list = []
419
- gate_list = []
420
- for rep in range(cot_rounds):
421
- for block_idx in range(len( self.transformer.h)):
422
-
423
 
 
424
 
425
- block = self.transformer.h[block_idx]
426
 
427
- if self.unet and block_idx >=len(self.transformer.h) //2:
428
- x = self.unet_projection[block_idx - len(self.transformer.h) //2](torch.cat((skips.pop(),x),dim = -1))
429
 
430
- x, *_ = block(x, (cos, sin), max_seq_length)
431
-
432
- if self.unet and block_idx <len(self.transformer.h) //2:
433
- skips.append(x)
434
- x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
435
- x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
436
- # if block_idx <len(self.transformer.h) //2:
437
- # x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
438
- # x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
439
-
440
-
441
-
442
 
443
- # res_list.append(self.lm_head(x).unsqueeze(-1))
444
- # gate_list.append(self.gate(x).unsqueeze(-1))
445
- # gate_list.append(self.gate(x))
446
- # if print_intermediate:
447
- # res_intermediate.append(res_list[-1])
448
- # if print_intermediate:
449
- # res_tmp = self.lm_head(x[:,self.stat_tokens:])
450
- # res_tmp = rearrange(res_tmp,'b c (l1 l2) -> b c l1 l2', l2 = 99)
451
- # if self.config.haar_trans_inv:
452
-
453
- # res_tmp = torch.einsum('bcal,ad->bcdl',res_tmp,self.tokenizer.haar_transform)
454
- # if self.config.haar_trans_norm == "backward":
455
- # res_tmp = res_tmp / np.sqrt(res_tmp.shape[-2])
456
- # elif self.config.haar_trans_norm == "forward":
457
- # res_tmp = res_tmp * np.sqrt(res_tmp.shape[-2])
458
- # res_tmp = res_tmp * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
459
- # res_intermediate.append(res_tmp[:,masks,:,:])
460
 
461
-
 
 
 
462
 
463
-
464
-
465
- else:
466
- self.kv_caches = self.kv_caches or self.build_kv_caches(x, max_seq_length, cos.size(-1) * 2)
467
- for block_idx in range(len( self.transformer.h)):
468
- block = self.transformer.h[block_idx]
469
- if self.unet and block_idx >=len(self.transformer.h) //2:
470
- x = F.silu(skips.pop()) * x
471
- x, self.kv_caches[block_idx] = block(x, (cos, sin), max_seq_length, mask, input_pos, self.kv_caches[block_idx])
472
- if self.unet and block_idx <len(self.transformer.h) //2:
473
- skips.append(x)
474
-
475
-
476
 
477
 
478
  res = self.lm_head(x)
479
- # gate = torch.cat(gate_list,dim = -1)
480
- # gate = F.softmax(gate,dim = -1)
481
- # res = torch.cat(res_list,dim = -1) * gate
482
- # res = res.sum(dim = -1)
483
 
484
 
485
  res = rearrange(res,'b c (l1 l2) -> b c l1 l2', l2 = 99)
@@ -487,7 +275,6 @@ class GPT(model_tmp):
487
 
488
 
489
  if self.config.haar_trans_inv:
490
- # print('res',res.shape,self.tokenizer.haar_transform.shape)
491
  res = torch.einsum('bcal,ad->bcdl',res,self.tokenizer.haar_transform)
492
  if self.config.haar_trans_norm == "backward":
493
  res = res / np.sqrt(res.shape[-2])
@@ -495,32 +282,33 @@ class GPT(model_tmp):
495
  res = res * np.sqrt(res.shape[-2])
496
 
497
 
 
 
 
498
  res = res * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
499
 
500
 
501
 
502
-
503
  if future_token == 0:
504
- return res[:,masks,:,:], x_raw[:,masks,:],res_intermediate,target_intermediate
505
  else:
506
- return res[:,masks,:,:],res_intermediate
507
 
508
  def generate(self,*args,**kwargs):
509
-
510
- res, _ = self.forward(*args,**kwargs)
511
- # logits_all,res_intermediate = model(idx = x_train, future_token = (pred_len//32 + 1)* 32, prev_token = 0,print_intermediate = False,cot_rounds = 1)
512
-
513
  res = rearrange(res, 'b l c d -> b (l c) d')
514
  return res[:,:kwargs['future_token'],:]
515
 
516
 
 
517
  @classmethod
518
  def from_name(cls, name: str, **kwargs: Any) -> Self:
519
  return cls(Config.from_name(name, **kwargs))
520
 
521
- def build_rope_cache(self, idx: torch.Tensor) -> RoPECache:
522
  return build_rope_cache(
523
- seq_len=self.config.block_size + self.stat_tokens,
524
  n_elem=int(self.config.rotary_percentage * self.config.head_size),
525
  dtype=torch.bfloat16,
526
  device=idx.device,
@@ -528,27 +316,6 @@ class GPT(model_tmp):
528
  condense_ratio=self.config.condense_ratio,
529
  )
530
 
531
- def build_mask_cache(self, idx: torch.Tensor) -> torch.Tensor:
532
- ones = torch.ones((self.config.block_size+self.stat_tokens, self.config.block_size+self.stat_tokens), device=idx.device, dtype=torch.bool)
533
- return torch.tril(ones).unsqueeze(0).unsqueeze(0)
534
-
535
- def build_kv_caches(self, idx: torch.Tensor, max_seq_length: int, rope_cache_length: int) -> List[KVCache]:
536
- B = idx.size(0)
537
- heads = 1 if self.config.n_query_groups == 1 else self.config.n_query_groups
538
-
539
- k_cache_shape = (
540
- B,
541
- max_seq_length,
542
- heads,
543
- rope_cache_length + self.config.head_size - int(self.config.rotary_percentage * self.config.head_size),
544
- )
545
- v_cache_shape = (B, max_seq_length, heads, self.config.head_size)
546
- device = idx.device
547
- return [
548
- (torch.zeros(k_cache_shape, device=device), torch.zeros(v_cache_shape, device=device))
549
- for _ in range(self.config.n_layer)
550
- ]
551
-
552
 
553
  class Block(nn.Module):
554
  def __init__(self, config:YingLongConfig) -> None:
@@ -562,15 +329,14 @@ class Block(nn.Module):
562
  def forward(
563
  self,
564
  x: torch.Tensor,
565
- rope: RoPECache,
566
  max_seq_length: int,
567
  mask: Optional[torch.Tensor] = None,
568
  input_pos: Optional[torch.Tensor] = None,
569
- kv_cache: Optional[KVCache] = None,
570
- ) -> Tuple[torch.Tensor, Optional[KVCache]]:
571
 
572
  n_1 = self.norm_1(x)
573
- h, new_kv_cache = self.attn(n_1, rope, max_seq_length, mask, input_pos, kv_cache)
574
  if self.config.parallel_residual:
575
  n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
576
  x = x + h + self.mlp(n_2)
@@ -583,29 +349,25 @@ class Block(nn.Module):
583
 
584
  x = x + h
585
  x = x + self.mlp(self.norm_2(x))
586
- return x, new_kv_cache
587
 
588
 
589
  class BidirectedlSelfAttention(nn.Module):
590
  def __init__(self, config:YingLongConfig) -> None:
591
  super().__init__()
592
  shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
593
- # key, query, value projections for all heads, but in a batch
594
  self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
595
- # output projection
596
  self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
597
-
598
  self.config = config
599
 
600
  def forward(
601
  self,
602
  x: torch.Tensor,
603
- rope: RoPECache,
604
  max_seq_length: int,
605
  mask: Optional[torch.Tensor] = None,
606
  input_pos: Optional[torch.Tensor] = None,
607
- kv_cache: Optional[KVCache] = None,
608
- ) -> Tuple[torch.Tensor, Optional[KVCache]]:
609
 
610
 
611
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
@@ -616,52 +378,20 @@ class BidirectedlSelfAttention(nn.Module):
616
  q_per_kv = self.config.n_head // self.config.n_query_groups
617
  total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
618
  qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
619
- # qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
620
 
621
  # split batched computation into three
622
  q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
623
 
624
- # repeat k and v if necessary
625
- # Peiyuan: we do not need to do this as flash attention 2 already support GQA
626
- # if self.config.n_query_groups != 1: # doing this would require a full kv cache with MQA (inefficient!)
627
- # # for MHA this is a no-op
628
- # k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
629
- # v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
630
-
631
  q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs)
632
  k = k.reshape(B, T, -1, self.config.head_size)
633
  v = v.reshape(B, T, -1, self.config.head_size)
634
 
635
  cos, sin = rope
636
 
637
- # apply rope in fp32 significanly stabalize training
638
- # fused rope expect (batch_size, seqlen, nheads, headdim)
639
  q = apply_rotary_emb_func(q, cos, sin, False, True)
640
  k = apply_rotary_emb_func(k, cos, sin, False, True)
641
-
642
- # n_elem = int(self.config.rotary_percentage * self.config.head_size)
643
-
644
- # q_roped = apply_rope(q[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
645
- # k_roped = apply_rope(k[..., :n_elem], cos.repeat(1,2), sin.repeat(1,2))
646
- # print( (q_roped - q).sum())
647
- # q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
648
- # k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
649
-
650
- if kv_cache is not None:
651
- cache_k, cache_v = kv_cache
652
- cache_k, cache_v = cache_k.to(dtype=k.dtype), cache_v.to(dtype=v.dtype)
653
- # check if reached token limit
654
- if input_pos[-1] >= max_seq_length:
655
- input_pos = torch.tensor(max_seq_length - 1, device=input_pos.device)
656
- # shift 1 position to the left
657
- cache_k = torch.roll(cache_k, -1, dims=1)
658
- cache_v = torch.roll(cache_v, -1, dims=1)
659
-
660
- k = cache_k.index_copy_(1, input_pos, k)
661
- v = cache_v.index_copy_(1, input_pos, v)
662
- kv_cache = k, v
663
 
664
-
665
 
666
  y = self.scaled_dot_product_attention(q, k, v, mask=mask)
667
 
@@ -670,7 +400,9 @@ class BidirectedlSelfAttention(nn.Module):
670
  # output projection
671
  y = self.proj(y)
672
 
673
- return y, kv_cache
 
 
674
 
675
  def scaled_dot_product_attention(
676
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
@@ -690,14 +422,83 @@ class BidirectedlSelfAttention(nn.Module):
690
  k = k.transpose(1, 2)
691
  v = v.transpose(1, 2)
692
  if q.size() != k.size():
693
- k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
694
- v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
695
  y = torch.nn.functional.scaled_dot_product_attention(
696
  q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=False
697
  )
698
  return y.transpose(1, 2)
699
 
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  class GptNeoxMLP(nn.Module):
702
  def __init__(self, config:YingLongConfig) -> None:
703
  super().__init__()
@@ -713,21 +514,15 @@ class GptNeoxMLP(nn.Module):
713
  class LLaMAMLP(nn.Module):
714
  def __init__(self, config:YingLongConfig) -> None:
715
  super().__init__()
716
- # self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
717
- # self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
718
- # self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
719
  self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
720
  def forward(self, x: torch.Tensor) -> torch.Tensor:
721
- # x_fc_1 = self.fc_1(x)
722
- # x_fc_2 = self.fc_2(x)
723
- # x = torch.nn.functional.silu(x_fc_1) * x_fc_2
724
- # return self.proj(x)
725
  return self.swiglu(x)
726
 
727
 
728
  def build_rope_cache(
729
  seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
730
- ) -> RoPECache:
731
  """Enhanced Transformer with Rotary Position Embedding.
732
 
733
  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
@@ -745,7 +540,6 @@ def build_rope_cache(
745
 
746
  cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
747
 
748
- # print(' print(seq_idx.shape,theta.shape,sin.shape,cos.shape,idx_theta.shape)',seq_idx.shape,theta.shape,sin.shape,cos.shape,idx_theta.shape)
749
  # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
750
  if dtype == torch.bfloat16:
751
  return cos.bfloat16(), sin.bfloat16()
@@ -766,6 +560,14 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T
766
 
767
 
768
 
 
 
 
 
 
 
 
 
769
  import torch
770
  # Copyright (c) 2022, Tri Dao.
771
  # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
@@ -1611,7 +1413,18 @@ class RMSNorm(torch.nn.Module):
1611
 
1612
 
1613
 
1614
-
 
 
 
 
 
 
 
 
 
 
 
1615
 
1616
 
1617
  # Copyright (c) 2023, Tri Dao.
 
1
+ """
2
+
3
+ Based on the tinyllama implementation: https://github.com/jzhang38/TinyLlama
4
 
 
 
5
  """
6
+
7
+
8
  import math, random
9
  import numpy as np
10
  from typing import Any, List, Optional, Tuple
11
+ from typing_extensions import Self
12
+
13
 
14
  import torch
15
  import torch.nn as nn
 
 
 
 
 
 
16
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ from lightning_utilities.core.imports import RequirementCache
20
+ FlashAttention2Available = RequirementCache("flash-attn>=2.0.0.post1")
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ from flash_attn import flash_attn_func
23
+ from xformers.ops import SwiGLU
24
+ from einops import rearrange
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ from transformers import PreTrainedModel
28
+ from .model_config import YingLongConfig
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
 
 
 
33
  class Tokenizer(torch.nn.Module):
34
  def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
35
  super().__init__()
 
42
 
43
  self.register_buffer('mask_token', torch.zeros(1000))
44
  if self.config.haar_trans:
45
+
46
  self.register_buffer('haar_transform',torch.Tensor(haarMatrix(self.config.patch_size,normalized = self.config.haar_trans_norm)))
47
 
48
 
 
91
 
92
  else:
93
 
 
94
  factor = 1
95
  more_rows = future_token // self.patch_size + 1
96
  prev_more_rows = prev_token // self.patch_size + 1
 
110
  masks = [jj for jj in range(x_featured.shape[1])]
111
  masks = masks[-more_rows:]
112
 
 
113
  x_featured[:,-more_rows:] = self.mask0(self.mask_token[:len(masks)].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
114
  x_featured[:,:prev_more_rows] = self.mask0(self.mask_token[:prev_more_rows].unsqueeze(-1)).repeat(x_featured.shape[0],1,1)
115
 
 
121
  class model_tmp(PreTrainedModel):
122
  config_class = YingLongConfig
123
  base_model_prefix = "model"
 
 
 
 
 
 
124
 
125
+
126
+
127
+ def _init_weights(self, module: nn.Module) -> None:
128
+ if isinstance(module, nn.Embedding):
129
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
130
+ elif isinstance(module, nn.Linear):
131
+ torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
132
+ if module.bias is not None:
133
+ torch.nn.init.zeros_(module.bias)
134
+ for name, p in module.named_parameters():
135
+ if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, BidirectedlSelfAttention))):
136
+ nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / self.config.n_layer)
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
  class GPT(model_tmp):
145
  def __init__(self, config: YingLongConfig, *args,**kwargs) -> None:
146
 
147
 
 
 
 
 
 
 
 
 
148
  super().__init__(config)
149
 
150
  self.config = config
 
153
 
154
 
155
  if self.config._norm_class == "RMSNorm":
156
+
157
  self.config.norm_class = RMSNorm
158
  elif self.config._norm_class == "FusedRMSNorm":
 
159
  self.config.norm_class = FusedRMSNorm
160
  elif self.config._norm_class == 'BatchNorm':
 
161
  self.config.norm_class = iBatchNorm
162
 
 
163
 
164
  if self.config._mlp_class == "GptNeoxMLP":
 
165
  self.config.mlp_class = GptNeoxMLP
166
  elif self.config._mlp_class == "LLaMAMLP":
 
167
  self.config.mlp_class = LLaMAMLP
168
 
169
+
170
+
 
 
 
 
 
 
 
171
 
172
  self.tokenizer = Tokenizer(config)
173
 
174
+
 
 
 
 
 
 
175
  self.lm_head = nn.Linear(config.n_embd, 99*self.patch_size)
176
+
 
 
177
 
 
178
  self.quantitleLoss = quantitleLoss(99,patch_size = self.patch_size)
179
 
180
 
 
201
 
202
 
203
 
204
+ self.rope_cache = None
 
 
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  def forward(
209
  self, idx: torch.Tensor,
 
 
 
210
  future_token: int = 0,
211
  prev_token: int = 0,
 
 
 
 
212
  *args,**kwargs,
213
+ ) -> torch.Tensor:
214
 
215
  if future_token > 0:
216
  more_rows = future_token // self.patch_size + 1
 
221
 
222
  B, T = idx.size()
223
 
224
+
225
 
226
  block_size = self.config.block_size
227
+ max_seq_length = T
 
 
228
 
 
 
 
 
229
  assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}"
230
+
231
+
232
+ self.rope_cache = self.build_rope_cache(idx)
 
233
  cos, sin = self.rope_cache
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ cos = cos[:max(T,1024)]
236
+ sin = sin[:max(T,1024)]
237
 
238
+
239
 
240
 
241
+ x,x_raw,masks,mean,std,_ = self.tokenizer(idx, future_token =future_token,prev_token = prev_token)
 
 
 
 
 
 
 
242
 
243
 
244
 
 
245
  if self.unet:
246
  skips = []
247
 
248
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ for block_idx in range(len( self.transformer.h)):
252
 
 
253
 
254
+ block = self.transformer.h[block_idx]
 
255
 
256
+ if self.unet and block_idx >=len(self.transformer.h) //2:
257
+ x = self.unet_projection[block_idx - len(self.transformer.h) //2](torch.cat((skips.pop(),x),dim = -1))
 
 
 
 
 
 
 
 
 
 
258
 
259
+ x = block(x, (cos, sin), max_seq_length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
+ if self.unet and block_idx <len(self.transformer.h) //2:
262
+ skips.append(x)
263
+ x_delay = torch.cat((x[:,0,:].unsqueeze(1),x[:,:-1,:]),dim = 1)
264
+ x = self.unet_merge[block_idx](torch.cat((x_delay,x),dim = -1))
265
 
266
+
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
 
269
  res = self.lm_head(x)
270
+
 
 
 
271
 
272
 
273
  res = rearrange(res,'b c (l1 l2) -> b c l1 l2', l2 = 99)
 
275
 
276
 
277
  if self.config.haar_trans_inv:
 
278
  res = torch.einsum('bcal,ad->bcdl',res,self.tokenizer.haar_transform)
279
  if self.config.haar_trans_norm == "backward":
280
  res = res / np.sqrt(res.shape[-2])
 
282
  res = res * np.sqrt(res.shape[-2])
283
 
284
 
285
+
286
+
287
+
288
  res = res * (std.unsqueeze(-1) + 1e-4) + mean.unsqueeze(-1)
289
 
290
 
291
 
292
+
293
  if future_token == 0:
294
+ return res[:,masks,:,:], x_raw[:,masks,:]
295
  else:
296
+ return res[:,masks,:,:]
297
 
298
  def generate(self,*args,**kwargs):
299
+ res = self.forward(*args,**kwargs)
 
 
 
300
  res = rearrange(res, 'b l c d -> b (l c) d')
301
  return res[:,:kwargs['future_token'],:]
302
 
303
 
304
+
305
  @classmethod
306
  def from_name(cls, name: str, **kwargs: Any) -> Self:
307
  return cls(Config.from_name(name, **kwargs))
308
 
309
+ def build_rope_cache(self, idx: torch.Tensor) :
310
  return build_rope_cache(
311
+ seq_len=self.config.block_size,
312
  n_elem=int(self.config.rotary_percentage * self.config.head_size),
313
  dtype=torch.bfloat16,
314
  device=idx.device,
 
316
  condense_ratio=self.config.condense_ratio,
317
  )
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  class Block(nn.Module):
321
  def __init__(self, config:YingLongConfig) -> None:
 
329
  def forward(
330
  self,
331
  x: torch.Tensor,
332
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]],
333
  max_seq_length: int,
334
  mask: Optional[torch.Tensor] = None,
335
  input_pos: Optional[torch.Tensor] = None,
336
+ ) -> torch.Tensor:
 
337
 
338
  n_1 = self.norm_1(x)
339
+ h = self.attn(n_1, rope, max_seq_length, mask, input_pos)
340
  if self.config.parallel_residual:
341
  n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
342
  x = x + h + self.mlp(n_2)
 
349
 
350
  x = x + h
351
  x = x + self.mlp(self.norm_2(x))
352
+ return x
353
 
354
 
355
  class BidirectedlSelfAttention(nn.Module):
356
  def __init__(self, config:YingLongConfig) -> None:
357
  super().__init__()
358
  shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
 
359
  self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
 
360
  self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
 
361
  self.config = config
362
 
363
  def forward(
364
  self,
365
  x: torch.Tensor,
366
+ rope: Tuple[torch.Tensor, torch.Tensor],
367
  max_seq_length: int,
368
  mask: Optional[torch.Tensor] = None,
369
  input_pos: Optional[torch.Tensor] = None,
370
+ ) -> torch.Tensor:
 
371
 
372
 
373
  B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
 
378
  q_per_kv = self.config.n_head // self.config.n_query_groups
379
  total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
380
  qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) # (B, T, n_query_groups, total_qkv, hs)
381
+
382
 
383
  # split batched computation into three
384
  q, k, v = qkv.split((q_per_kv, 1, 1), dim=-2)
385
 
 
 
 
 
 
 
 
386
  q = q.reshape(B, T, -1, self.config.head_size) # (B, T, nh_q, hs)
387
  k = k.reshape(B, T, -1, self.config.head_size)
388
  v = v.reshape(B, T, -1, self.config.head_size)
389
 
390
  cos, sin = rope
391
 
 
 
392
  q = apply_rotary_emb_func(q, cos, sin, False, True)
393
  k = apply_rotary_emb_func(k, cos, sin, False, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
 
395
 
396
  y = self.scaled_dot_product_attention(q, k, v, mask=mask)
397
 
 
400
  # output projection
401
  y = self.proj(y)
402
 
403
+ return y
404
+
405
+
406
 
407
  def scaled_dot_product_attention(
408
  self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
 
422
  k = k.transpose(1, 2)
423
  v = v.transpose(1, 2)
424
  if q.size() != k.size():
425
+ k = k.repeat_interleave(q.shape[1]//k.shape[1], dim=1)
426
+ v = v.repeat_interleave(q.shape[1]//v.shape[1], dim=1)
427
  y = torch.nn.functional.scaled_dot_product_attention(
428
  q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=False
429
  )
430
  return y.transpose(1, 2)
431
 
432
 
433
+
434
+
435
+
436
+
437
+ class quantitleLoss(torch.nn.Module):
438
+ def __init__(self,
439
+ qSize = 99,
440
+ patch_size = 16,
441
+ *args,**kwargs):
442
+
443
+ super().__init__()
444
+ self.qSize = qSize
445
+ self.patch_size = patch_size
446
+
447
+
448
+ q = np.array([i+1 for i in range(self.qSize)])
449
+ q = q / (self.qSize + 1)
450
+ q = q.reshape((1,1,-1))
451
+
452
+ q_variance = q*(1-q)
453
+
454
+ self.register_buffer('q', torch.tensor(q))
455
+ self.register_buffer('q_variance', torch.tensor(q_variance))
456
+
457
+
458
+ def forward(self, input: torch.Tensor, target: torch.Tensor,rel_loss = False):
459
+
460
+
461
+
462
+ target = target.unsqueeze(-1)
463
+ input = input[:,:target.shape[1],:,:]
464
+
465
+
466
+ posPart = input - target
467
+ negPart = -posPart
468
+
469
+ raw_loss = torch.maximum(self.q * negPart, (1-self.q) * posPart)
470
+
471
+ target_absmean = torch.mean(target.abs(),dim = (1,2),keepdims = True)
472
+ raw_loss = raw_loss / torch.sqrt(self.q_variance) / (target_absmean + 1e-4)
473
+
474
+ return torch.mean(raw_loss)
475
+
476
+
477
+ def haarMatrix_unnormalized(n):
478
+
479
+ n = 2**np.ceil(np.log2(n))
480
+ if n > 2:
481
+ h = haarMatrix(n / 2)
482
+ else:
483
+ return np.array([[1, 1], [1, -1]])
484
+ h_n = np.kron(h, [1, 1])
485
+ h_i = np.kron(np.eye(len(h)), [1, -1])
486
+ h = np.vstack((h_n, h_i))
487
+ return h
488
+
489
+ def haarMatrix(n,normalized = 'ortho'):
490
+ h = haarMatrix_unnormalized(n)
491
+ scaler = np.diag(1/np.sqrt(np.diag(h@h.transpose())))
492
+ if normalized == 'ortho':
493
+ return scaler @ h
494
+ elif normalized == 'forward':
495
+ return scaler @ h/ np.sqrt(n)
496
+
497
+ else:
498
+ return scaler @ h * np.sqrt(n)
499
+
500
+
501
+
502
  class GptNeoxMLP(nn.Module):
503
  def __init__(self, config:YingLongConfig) -> None:
504
  super().__init__()
 
514
  class LLaMAMLP(nn.Module):
515
  def __init__(self, config:YingLongConfig) -> None:
516
  super().__init__()
517
+
 
 
518
  self.swiglu = SwiGLU(config.n_embd,config.intermediate_size, bias=False, _pack_weights=False)
519
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
520
  return self.swiglu(x)
521
 
522
 
523
  def build_rope_cache(
524
  seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000, condense_ratio: int = 1
525
+ ) -> Tuple[torch.Tensor,torch.Tensor]:
526
  """Enhanced Transformer with Rotary Position Embedding.
527
 
528
  Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
 
540
 
541
  cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)
542
 
 
543
  # added by peiyuan to ensure same data type with q, k, to use fused rotary embedding
544
  if dtype == torch.bfloat16:
545
  return cos.bfloat16(), sin.bfloat16()
 
560
 
561
 
562
 
563
+
564
+
565
+
566
+ ######################################
567
+ #layernorm
568
+ ######################################
569
+
570
+
571
  import torch
572
  # Copyright (c) 2022, Tri Dao.
573
  # Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py AND https://github.com/Dao-AILab/flash-attention/blob/7a983df74215e035e566e37125b0a71e3618f39d/flash_attn/ops/layer_norm.py#L16
 
1413
 
1414
 
1415
 
1416
+
1417
+
1418
+
1419
+
1420
+ ######################################
1421
+ #rope_emb
1422
+ ######################################
1423
+
1424
+
1425
+
1426
+
1427
+
1428
 
1429
 
1430
  # Copyright (c) 2023, Tri Dao.