from transformers import DynamicCache import torch import os class FinchCache(DynamicCache): def __init__(self) -> None: super().__init__() self.key_cache = [] self.value_cache = [] @staticmethod def _rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _apply_key_rotary_pos_emb(self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: return (key_states * cos) + (self._rotate_half(key_states) * sin) @staticmethod def _rerotate_cos_sin(x, inv_freq, important_pos_batch): B, H, L = important_pos_batch.shape device = important_pos_batch.device device_type = x.device.type dtype = x.dtype idx = torch.arange(0, L, device=device) idx = idx.unsqueeze(0) inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1) # (B, H, M, 1) idx = idx[:, None, :].float().expand(B, H, L) # (B, H, L) delta_pos = idx - important_pos_batch delta_pos = delta_pos.unsqueeze(2) # (B, H, 1, L) device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = delta_pos.float() * inv_freq.float() freqs = freqs.transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos().contiguous() sin = emb.sin().contiguous() return cos.to(dtype=dtype), sin.to(dtype=dtype) @staticmethod def gather_important_tokens(states, indices): return torch.gather(states, 2, indices.unsqueeze(-1).expand(-1, -1, -1, states.size(3))).contiguous() def compress_cache(self, layer_index, important_pos, inv_freq): new_length = important_pos.size(2) new_cos, new_sin = self._rerotate_cos_sin(self.key_cache[layer_index], inv_freq, important_pos) gathered_keys = self.gather_important_tokens(self.key_cache[layer_index], important_pos).clone() self.key_cache[layer_index] = self._apply_key_rotary_pos_emb(gathered_keys, new_cos, new_sin) gathered_values = self.gather_important_tokens(self.value_cache[layer_index], important_pos).clone() self.value_cache[layer_index] = gathered_values self._seen_tokens = new_length def save(self, path: str): """Save the cache to disk, moving tensors to CPU.""" try: os.makedirs(os.path.dirname(path), exist_ok=True) torch.save( {"key_cache": [k.cpu() for k in self.key_cache], "value_cache": [v.cpu() for v in self.value_cache]}, path, ) except Exception as e: print(f"Error occurred while saving: {e}") @classmethod def load(cls, path: str, device: str = "cpu") -> "FinchCache": """Load the cache from disk and move tensors to the specified device.""" data = torch.load(path, map_location=device) cache = cls() cache.key_cache = [k.to(device) for k in data["key_cache"]] cache.value_cache = [v.to(device) for v in data["value_cache"]] cache._seen_tokens = cache.value_cache[0].size(2) if cache.value_cache else 0 return cache