zhiyuan8 commited on
Commit
b099d00
·
verified ·
1 Parent(s): ba4f8ad

Upload hybrid_cache.py

Browse files
Files changed (1) hide show
  1. hybrid_cache.py +31 -110
hybrid_cache.py CHANGED
@@ -3,109 +3,69 @@ from typing import Any, Dict, Optional, Union
3
  from transformers.cache_utils import DynamicCache
4
 
5
 
6
- class TimeMixState:
7
  def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
8
  self.shift_state = shift_state
9
  self.wkv_state = wkv_state
10
 
11
 
12
- class ChannelMixState:
13
  def __init__(self, shift_state: torch.Tensor):
14
  self.shift_state = shift_state
15
 
16
 
17
  class BlockState:
18
- def __init__(self, time_mix_state: TimeMixState,
19
- channel_mix_state: ChannelMixState):
20
- self.time_mix_state = time_mix_state
21
- self.channel_mix_state = channel_mix_state
22
-
23
-
24
- class BlockStateList:
25
- def __init__(self, shift_states, wkv_states):
26
- self.wkv_states = wkv_states
27
- self.shift_states = shift_states
28
-
29
- @staticmethod
30
- def create(N, B, C, H, device, dtype):
31
- result = BlockStateList.empty(N, B, C, H, device, dtype)
32
- result.wkv_states[:] = 0
33
- result.wkv_states[:] = 0
34
- result.shift_states[:] = 0
35
- return result
36
-
37
- @staticmethod
38
- def empty(N, B, C, H, device, dtype):
39
- wkv_states = torch.empty((N, B, H, C//H, C//H),
40
- device=device,
41
- dtype=torch.bfloat16)
42
- shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
43
- return BlockStateList(shift_states, wkv_states)
44
-
45
- def __getitem__(self, layer: int):
46
- return BlockState(
47
- TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
48
- ChannelMixState(self.shift_states[layer, 1]))
49
-
50
- def __setitem__(self, layer: int, state: BlockState):
51
- self.shift_states[layer, 0] = state.time_mix_state.shift_state
52
- self.wkv_states[layer] = state.time_mix_state.wkv_state
53
- self.shift_states[layer, 1] = state.channel_mix_state.shift_state
54
-
55
 
56
  class HybridCache(DynamicCache):
57
  def __init__(self) -> None:
58
  super().__init__()
59
  self.rwkv_layers = set()
60
-
61
- def __repr__(self) -> str:
62
- rwkv_layers = f"HybridCache(rwkv_layers={self.rwkv_layers})"
63
- # count the number of key_cache and value_cache
64
- key_cache_count = sum(len(cache) for cache in self.key_cache)
65
- value_cache_count = sum(len(cache) for cache in self.value_cache)
66
- count_info = rwkv_layers + \
67
- f", key_cache_count={key_cache_count}, value_cache_count={value_cache_count}"
68
- memories = 0
69
- seq_length = self.get_seq_length()
70
- for cache in self.value_cache:
71
- for data in cache:
72
- if not isinstance(data, torch.Tensor):
73
- memories += data.time_mix_state.wkv_state.numel()
74
- else:
75
- memories += data.numel()
76
- count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
77
- return count_info
78
-
79
- def update(self,
80
- key_states: Union[int, torch.Tensor],
81
- value_states: Union[torch.Tensor, BlockState],
82
- layer_idx: int,
83
- cache_kwargs: Optional[Dict[str, Any]] = None):
84
- if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
85
  self.rwkv_layers.add(layer_idx)
86
- if layer_idx >= len(self.key_cache):
 
87
  self.key_cache.append([])
88
  self.value_cache.append([])
89
-
90
- if len(self.key_cache[layer_idx]) == 0:
91
  self.key_cache[layer_idx].append(key_states)
92
  self.value_cache[layer_idx].append(value_states)
 
 
93
  else:
94
- self.key_cache[layer_idx][0] = self.key_cache[layer_idx][0]+key_states
95
  self.value_cache[layer_idx][0] = value_states
96
 
97
  return key_states, value_states
98
 
99
  return super().update(key_states, value_states, layer_idx, cache_kwargs)
100
 
 
 
 
 
 
 
101
  def get_seq_length(self, layer_idx: Optional[int] = 0):
102
  if layer_idx in self.rwkv_layers:
103
  return self.key_cache[layer_idx][0]
104
  return super().get_seq_length(layer_idx)
105
 
106
- def get_max_length(self):
107
- return super().get_max_length()
108
-
109
  def reorder_cache(self, beam_idx):
110
  return super().reorder_cache(beam_idx)
111
 
@@ -113,42 +73,3 @@ class HybridCache(DynamicCache):
113
  if item in self.rwkv_layers:
114
  return self.value_cache[item]
115
  return super().__getitem__(item)
116
-
117
- def offload_to_cpu(self):
118
- for cache in self.value_cache:
119
- for data in cache:
120
- if isinstance(data, torch.Tensor):
121
- data.cpu()
122
- else:
123
- data.time_mix_state.wkv_state.cpu()
124
- data.time_mix_state.shift_state.cpu()
125
-
126
- def offload_to_cuda(self, device: str):
127
- for cache in self.value_cache:
128
- for data in cache:
129
- if isinstance(data, torch.Tensor):
130
- data.cuda(device)
131
- else:
132
- data.time_mix_state.wkv_state.cuda(device)
133
- data.time_mix_state.shift_state.cuda(device)
134
-
135
- def offload_to_device(self, device_type: str, device_id: int = 0):
136
- for cache in self.value_cache:
137
- for data in cache:
138
- if isinstance(data, torch.Tensor):
139
- method = getattr(data, device_type)
140
- if device_type == 'cpu':
141
- method()
142
- else:
143
- method(device_id)
144
- else:
145
- wkv_state_method = getattr(
146
- data.time_mix_state.wkv_state, device_type)
147
- shift_state_method = getattr(
148
- data.time_mix_state.shift_state, device_type)
149
- if device_type == 'cpu':
150
- wkv_state_method()
151
- shift_state_method()
152
- else:
153
- wkv_state_method(device_id)
154
- shift_state_method(device_id)
 
3
  from transformers.cache_utils import DynamicCache
4
 
5
 
6
+ class AttnState:
7
  def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
8
  self.shift_state = shift_state
9
  self.wkv_state = wkv_state
10
 
11
 
12
+ class FfnState:
13
  def __init__(self, shift_state: torch.Tensor):
14
  self.shift_state = shift_state
15
 
16
 
17
  class BlockState:
18
+ def __init__(
19
+ self,
20
+ attn_state: AttnState,
21
+ ffn_state: FfnState
22
+ ):
23
+ self.attn_state = attn_state
24
+ self.ffn_state = ffn_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class HybridCache(DynamicCache):
27
  def __init__(self) -> None:
28
  super().__init__()
29
  self.rwkv_layers = set()
30
+ self.key_cache_nums = 0
31
+ self.v_first_cache = None
32
+
33
+ def update(
34
+ self,
35
+ key_states: Union[int, torch.Tensor],
36
+ value_states: Union[torch.Tensor, BlockState],
37
+ layer_idx: int,
38
+ cache_kwargs: Optional[Dict[str, Any]] = None
39
+ ):
40
+ if isinstance(key_states, int) and isinstance(value_states, BlockState):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  self.rwkv_layers.add(layer_idx)
42
+
43
+ if layer_idx >= self.key_cache_nums:
44
  self.key_cache.append([])
45
  self.value_cache.append([])
 
 
46
  self.key_cache[layer_idx].append(key_states)
47
  self.value_cache[layer_idx].append(value_states)
48
+ self.key_cache_nums += 1
49
+
50
  else:
51
+ self.key_cache[layer_idx][0] += key_states
52
  self.value_cache[layer_idx][0] = value_states
53
 
54
  return key_states, value_states
55
 
56
  return super().update(key_states, value_states, layer_idx, cache_kwargs)
57
 
58
+ def update_v_first(self, v_first: torch.Tensor):
59
+ self.v_first_cache = v_first
60
+
61
+ def get_v_first(self):
62
+ return self.v_first_cache
63
+
64
  def get_seq_length(self, layer_idx: Optional[int] = 0):
65
  if layer_idx in self.rwkv_layers:
66
  return self.key_cache[layer_idx][0]
67
  return super().get_seq_length(layer_idx)
68
 
 
 
 
69
  def reorder_cache(self, beam_idx):
70
  return super().reorder_cache(beam_idx)
71
 
 
73
  if item in self.rwkv_layers:
74
  return self.value_cache[item]
75
  return super().__getitem__(item)