Upload hybrid_cache.py
Browse files- hybrid_cache.py +31 -110
hybrid_cache.py
CHANGED
@@ -3,109 +3,69 @@ from typing import Any, Dict, Optional, Union
|
|
3 |
from transformers.cache_utils import DynamicCache
|
4 |
|
5 |
|
6 |
-
class
|
7 |
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
8 |
self.shift_state = shift_state
|
9 |
self.wkv_state = wkv_state
|
10 |
|
11 |
|
12 |
-
class
|
13 |
def __init__(self, shift_state: torch.Tensor):
|
14 |
self.shift_state = shift_state
|
15 |
|
16 |
|
17 |
class BlockState:
|
18 |
-
def __init__(
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
def __init__(self, shift_states, wkv_states):
|
26 |
-
self.wkv_states = wkv_states
|
27 |
-
self.shift_states = shift_states
|
28 |
-
|
29 |
-
@staticmethod
|
30 |
-
def create(N, B, C, H, device, dtype):
|
31 |
-
result = BlockStateList.empty(N, B, C, H, device, dtype)
|
32 |
-
result.wkv_states[:] = 0
|
33 |
-
result.wkv_states[:] = 0
|
34 |
-
result.shift_states[:] = 0
|
35 |
-
return result
|
36 |
-
|
37 |
-
@staticmethod
|
38 |
-
def empty(N, B, C, H, device, dtype):
|
39 |
-
wkv_states = torch.empty((N, B, H, C//H, C//H),
|
40 |
-
device=device,
|
41 |
-
dtype=torch.bfloat16)
|
42 |
-
shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
|
43 |
-
return BlockStateList(shift_states, wkv_states)
|
44 |
-
|
45 |
-
def __getitem__(self, layer: int):
|
46 |
-
return BlockState(
|
47 |
-
TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
|
48 |
-
ChannelMixState(self.shift_states[layer, 1]))
|
49 |
-
|
50 |
-
def __setitem__(self, layer: int, state: BlockState):
|
51 |
-
self.shift_states[layer, 0] = state.time_mix_state.shift_state
|
52 |
-
self.wkv_states[layer] = state.time_mix_state.wkv_state
|
53 |
-
self.shift_states[layer, 1] = state.channel_mix_state.shift_state
|
54 |
-
|
55 |
|
56 |
class HybridCache(DynamicCache):
|
57 |
def __init__(self) -> None:
|
58 |
super().__init__()
|
59 |
self.rwkv_layers = set()
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
for data in cache:
|
72 |
-
if not isinstance(data, torch.Tensor):
|
73 |
-
memories += data.time_mix_state.wkv_state.numel()
|
74 |
-
else:
|
75 |
-
memories += data.numel()
|
76 |
-
count_info += f", memories={memories / 1024/1024}MB, seq_length={seq_length}"
|
77 |
-
return count_info
|
78 |
-
|
79 |
-
def update(self,
|
80 |
-
key_states: Union[int, torch.Tensor],
|
81 |
-
value_states: Union[torch.Tensor, BlockState],
|
82 |
-
layer_idx: int,
|
83 |
-
cache_kwargs: Optional[Dict[str, Any]] = None):
|
84 |
-
if isinstance(key_states, int) and not isinstance(value_states, torch.Tensor):
|
85 |
self.rwkv_layers.add(layer_idx)
|
86 |
-
|
|
|
87 |
self.key_cache.append([])
|
88 |
self.value_cache.append([])
|
89 |
-
|
90 |
-
if len(self.key_cache[layer_idx]) == 0:
|
91 |
self.key_cache[layer_idx].append(key_states)
|
92 |
self.value_cache[layer_idx].append(value_states)
|
|
|
|
|
93 |
else:
|
94 |
-
self.key_cache[layer_idx][0]
|
95 |
self.value_cache[layer_idx][0] = value_states
|
96 |
|
97 |
return key_states, value_states
|
98 |
|
99 |
return super().update(key_states, value_states, layer_idx, cache_kwargs)
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
102 |
if layer_idx in self.rwkv_layers:
|
103 |
return self.key_cache[layer_idx][0]
|
104 |
return super().get_seq_length(layer_idx)
|
105 |
|
106 |
-
def get_max_length(self):
|
107 |
-
return super().get_max_length()
|
108 |
-
|
109 |
def reorder_cache(self, beam_idx):
|
110 |
return super().reorder_cache(beam_idx)
|
111 |
|
@@ -113,42 +73,3 @@ class HybridCache(DynamicCache):
|
|
113 |
if item in self.rwkv_layers:
|
114 |
return self.value_cache[item]
|
115 |
return super().__getitem__(item)
|
116 |
-
|
117 |
-
def offload_to_cpu(self):
|
118 |
-
for cache in self.value_cache:
|
119 |
-
for data in cache:
|
120 |
-
if isinstance(data, torch.Tensor):
|
121 |
-
data.cpu()
|
122 |
-
else:
|
123 |
-
data.time_mix_state.wkv_state.cpu()
|
124 |
-
data.time_mix_state.shift_state.cpu()
|
125 |
-
|
126 |
-
def offload_to_cuda(self, device: str):
|
127 |
-
for cache in self.value_cache:
|
128 |
-
for data in cache:
|
129 |
-
if isinstance(data, torch.Tensor):
|
130 |
-
data.cuda(device)
|
131 |
-
else:
|
132 |
-
data.time_mix_state.wkv_state.cuda(device)
|
133 |
-
data.time_mix_state.shift_state.cuda(device)
|
134 |
-
|
135 |
-
def offload_to_device(self, device_type: str, device_id: int = 0):
|
136 |
-
for cache in self.value_cache:
|
137 |
-
for data in cache:
|
138 |
-
if isinstance(data, torch.Tensor):
|
139 |
-
method = getattr(data, device_type)
|
140 |
-
if device_type == 'cpu':
|
141 |
-
method()
|
142 |
-
else:
|
143 |
-
method(device_id)
|
144 |
-
else:
|
145 |
-
wkv_state_method = getattr(
|
146 |
-
data.time_mix_state.wkv_state, device_type)
|
147 |
-
shift_state_method = getattr(
|
148 |
-
data.time_mix_state.shift_state, device_type)
|
149 |
-
if device_type == 'cpu':
|
150 |
-
wkv_state_method()
|
151 |
-
shift_state_method()
|
152 |
-
else:
|
153 |
-
wkv_state_method(device_id)
|
154 |
-
shift_state_method(device_id)
|
|
|
3 |
from transformers.cache_utils import DynamicCache
|
4 |
|
5 |
|
6 |
+
class AttnState:
|
7 |
def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
|
8 |
self.shift_state = shift_state
|
9 |
self.wkv_state = wkv_state
|
10 |
|
11 |
|
12 |
+
class FfnState:
|
13 |
def __init__(self, shift_state: torch.Tensor):
|
14 |
self.shift_state = shift_state
|
15 |
|
16 |
|
17 |
class BlockState:
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
attn_state: AttnState,
|
21 |
+
ffn_state: FfnState
|
22 |
+
):
|
23 |
+
self.attn_state = attn_state
|
24 |
+
self.ffn_state = ffn_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
class HybridCache(DynamicCache):
|
27 |
def __init__(self) -> None:
|
28 |
super().__init__()
|
29 |
self.rwkv_layers = set()
|
30 |
+
self.key_cache_nums = 0
|
31 |
+
self.v_first_cache = None
|
32 |
+
|
33 |
+
def update(
|
34 |
+
self,
|
35 |
+
key_states: Union[int, torch.Tensor],
|
36 |
+
value_states: Union[torch.Tensor, BlockState],
|
37 |
+
layer_idx: int,
|
38 |
+
cache_kwargs: Optional[Dict[str, Any]] = None
|
39 |
+
):
|
40 |
+
if isinstance(key_states, int) and isinstance(value_states, BlockState):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
self.rwkv_layers.add(layer_idx)
|
42 |
+
|
43 |
+
if layer_idx >= self.key_cache_nums:
|
44 |
self.key_cache.append([])
|
45 |
self.value_cache.append([])
|
|
|
|
|
46 |
self.key_cache[layer_idx].append(key_states)
|
47 |
self.value_cache[layer_idx].append(value_states)
|
48 |
+
self.key_cache_nums += 1
|
49 |
+
|
50 |
else:
|
51 |
+
self.key_cache[layer_idx][0] += key_states
|
52 |
self.value_cache[layer_idx][0] = value_states
|
53 |
|
54 |
return key_states, value_states
|
55 |
|
56 |
return super().update(key_states, value_states, layer_idx, cache_kwargs)
|
57 |
|
58 |
+
def update_v_first(self, v_first: torch.Tensor):
|
59 |
+
self.v_first_cache = v_first
|
60 |
+
|
61 |
+
def get_v_first(self):
|
62 |
+
return self.v_first_cache
|
63 |
+
|
64 |
def get_seq_length(self, layer_idx: Optional[int] = 0):
|
65 |
if layer_idx in self.rwkv_layers:
|
66 |
return self.key_cache[layer_idx][0]
|
67 |
return super().get_seq_length(layer_idx)
|
68 |
|
|
|
|
|
|
|
69 |
def reorder_cache(self, beam_idx):
|
70 |
return super().reorder_cache(beam_idx)
|
71 |
|
|
|
73 |
if item in self.rwkv_layers:
|
74 |
return self.value_cache[item]
|
75 |
return super().__getitem__(item)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|