import torch from typing import Any, Dict, Optional, Union from transformers.cache_utils import DynamicCache class AttnState: def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): self.shift_state = shift_state self.wkv_state = wkv_state class FfnState: def __init__(self, shift_state: torch.Tensor): self.shift_state = shift_state class BlockState: def __init__( self, attn_state: AttnState, ffn_state: FfnState ): self.attn_state = attn_state self.ffn_state = ffn_state class HybridCache(DynamicCache): def __init__(self) -> None: super().__init__() self.rwkv_layers = set() self.key_cache_nums = 0 self.v_first_cache = None def update( self, key_states: Union[int, torch.Tensor], value_states: Union[torch.Tensor, BlockState], layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None ): if isinstance(key_states, int) and isinstance(value_states, BlockState): self.rwkv_layers.add(layer_idx) if layer_idx >= self.key_cache_nums: self.key_cache.append([]) self.value_cache.append([]) self.key_cache[layer_idx].append(key_states) self.value_cache[layer_idx].append(value_states) self.key_cache_nums += 1 else: self.key_cache[layer_idx][0] += key_states self.value_cache[layer_idx][0] = value_states return key_states, value_states return super().update(key_states, value_states, layer_idx, cache_kwargs) def update_v_first(self, v_first: torch.Tensor): self.v_first_cache = v_first def get_v_first(self): return self.v_first_cache def get_seq_length(self, layer_idx: Optional[int] = 0): if layer_idx in self.rwkv_layers: return self.key_cache[layer_idx][0] return super().get_seq_length(layer_idx) def reorder_cache(self, beam_idx): return super().reorder_cache(beam_idx) def __getitem__(self, item): if item in self.rwkv_layers: return self.value_cache[item] return super().__getitem__(item)