from transformers.cache_utils import Cache from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from transformers.utils import logging from transformers.configuration_utils import PretrainedConfig logger = logging.get_logger(__name__) class HybridCache(Cache): """ Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class. Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. batch_size (`int`): The batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. If you're using more than 1 computation device, you should pass the `layer_device_map` argument instead. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: ```python >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b") >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() ``` """ # TODO (joao): dive deeper into gemma2 and paligemma -- there are reports of speed loss with compilation. Revert # ALL changes from the PR that commented the line below when reactivating it. # is_compileable = True # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. def __init__( self, config: PretrainedConfig, batch_size: int = None, max_cache_len: int = None, device: Union[torch.device, str] = None, dtype: torch.dtype = torch.float32, max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() if batch_size is not None: logger.warning_once( f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'max_batch_size' argument instead." ) if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " "sliding window attention, please check if there is a `sliding_window` field in the model " "config and it's not set to None." ) self.max_cache_len = max_cache_len self.max_batch_size = batch_size or max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads ) self.dtype = dtype self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC self.is_sliding = torch.tensor( [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool ) self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] self.chunk_cache = {} global_cache_shape = (self.max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim) sliding_cache_shape = ( self.max_batch_size, self.num_key_value_heads, min(config.sliding_window, max_cache_len), self.head_dim, ) device = torch.device(device) if device is not None else None for i in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[i] else: layer_device = device # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # breaks when updating the cache. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) torch._dynamo.mark_static_address(new_layer_key_cache) torch._dynamo.mark_static_address(new_layer_value_cache) self.key_cache.append(new_layer_key_cache) self.value_cache.append(new_layer_value_cache) def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): if cache_position.shape[0] > max_cache_len: k_out = key_states[:, :, -max_cache_len:, :] v_out = value_states[:, :, -max_cache_len:, :] # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly self.key_cache[layer_idx] += k_out self.value_cache[layer_idx] += v_out # we should return the whole states instead of k_out, v_out to take the whole prompt # into consideration when building kv cache instead of just throwing away tokens outside of the window return key_states, value_states slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) cache_position = cache_position.clamp(0, max_cache_len - 1) to_shift = cache_position >= max_cache_len - 1 indices = (slicing + to_shift[-1].int() - 1) % max_cache_len k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] k_out[:, :, cache_position] = key_states v_out[:, :, cache_position] = value_states # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment) self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() self.key_cache[layer_idx] += k_out self.value_cache[layer_idx] += v_out return k_out, v_out def _static_update(self, layer_idx,cache): self.chunk_cache[layer_idx] = cache return def _get_chunk_cache(self,layer_idx): self.chunk_cache.setdefault(layer_idx,None) return self.chunk_cache[layer_idx] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor]: cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") # These two `if` blocks are only reached in multigpu and if `layer_device_map` is not passed. They are used # when the cache is initialized in the forward pass (e.g. Gemma2) if self.key_cache[layer_idx].device != key_states.device: self.key_cache[layer_idx] = self.key_cache[layer_idx].to(key_states.device) if self.value_cache[layer_idx].device != value_states.device: self.value_cache[layer_idx] = self.value_cache[layer_idx].to(value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] key_states = key_states.to(k_out.dtype) value_states = value_states.to(v_out.dtype) if sliding_window: update_fn = self._sliding_update else: update_fn = self._static_update return update_fn( cache_position, layer_idx, key_states, value_states, k_out, v_out, k_out.shape[2], ) def get_max_cache_shape(self) -> Optional[int]: return self.max_cache_len def get_seq_length(self, layer_idx: Optional[int] = 0): # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. # TODO: deprecate this function in favor of `cache_position` if layer_idx != 0: raise ValueError( "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " "Using the `layer_idx` argument is not supported." ) return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): """Resets the cache values while preserving the objects""" for layer_idx in range(len(self.key_cache)): # In-place ops prevent breaking the static address self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() @property def batch_size(self): logger.warning_once( f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." ) return self.max_batch_size