import math import torch from cache import FinchCache from utils import repeat_kv from transformers.models.llama.modeling_llama import rotate_half import spaces @spaces.GPU def get_compressed_kv_cache(model, sink_tokens, step_size, target_token_size, context_ids, context_attention_mask, question_ids, question_attention_mask): device = model.device dtype = model.dtype sink_tokens = sink_tokens num_chunks = step_size context_ids = context_ids.to(device) context_attention_mask = context_attention_mask.to(device) question_ids = question_ids.to(device) question_attention_mask = question_attention_mask.to(device) question_len = question_ids.size(1) total_len = context_ids.size(1) max_context_tokens_allowed = model.config.max_position_embeddings - question_len if total_len > max_context_tokens_allowed: num_chunks = max(step_size, math.ceil(total_len / max_context_tokens_allowed)) if total_len <= sink_tokens or num_chunks == 1: # If the context is too short or only one chunk is desired, use the entire context. context_ids_list = [context_ids] context_attention_mask_list = [context_attention_mask] else: # Calculate how many tokens remain after the sink tokens. remainder_len = total_len - sink_tokens # Compute the base tokens per chunk and any leftover. base = remainder_len // num_chunks leftover = remainder_len % num_chunks # Build a list of chunk sizes. # First chunk gets the sink tokens plus base tokens. chunk_sizes = [sink_tokens + base] # Chunks 2 to num_chunks-1 get base tokens each. for _ in range(num_chunks - 2): chunk_sizes.append(base) # The last chunk gets the remaining tokens (base + leftover). if num_chunks > 1: chunk_sizes.append(base + leftover) # Now slice the context using the calculated sizes. context_ids_list = [] context_attention_mask_list = [] offset = 0 for size in chunk_sizes: end = offset + size context_ids_list.append(context_ids[:, offset:end]) context_attention_mask_list.append(context_attention_mask[:, offset:end]) offset = end # (Optional) Continue with the rest of your processing… len_rest = max(total_len - sink_tokens, 1) compression_factor = len_rest // target_token_size if compression_factor < 1: compression_factor = 1 tokenized_doc_chunks = [] for ids_chunk, mask_chunk in zip(context_ids_list, context_attention_mask_list): tokenized_doc_chunks.append({"input_ids": ids_chunk, "attention_mask": mask_chunk}) print("Number of chunks: ", len(tokenized_doc_chunks)) rotary_emb = model.model.rotary_emb.to(device) inv_freq = rotary_emb.inv_freq batch_size = question_ids.size(0) ones_mask = torch.ones(batch_size, 1, dtype=question_attention_mask.dtype, device=device) cache = FinchCache() past_cache_len = 0 past_attention_mask = torch.zeros(batch_size, 0, dtype=question_attention_mask.dtype, device=device) num_chunks = len(tokenized_doc_chunks) # Prepare a shared dictionary for hook outputs. query_context_matrices = {} # Define a hook function that uses a per-chunk offset stored on self. def query_hook_fn(module, input, output): layer_idx = getattr(module, "layer_idx", None) if layer_idx is not None: query_states = output.detach() bsz, seq_len, hidden_dim = query_states.size() num_query_heads = module.num_query_heads head_dim = hidden_dim // num_query_heads query_states = ( query_states.view(bsz, seq_len, num_query_heads, head_dim) .transpose(1, 2) .contiguous() ) # Use self._current_chunk_offset to select only the new tokens. query_context_matrices[layer_idx] = query_states[:, :, _current_chunk_offset:, :].clone() # Pre-register hooks for all layers only once. hooks = [] for i, layer in enumerate(model.model.layers): layer.self_attn.q_proj.layer_idx = i # For tracking. layer.self_attn.q_proj.num_query_heads = layer.self_attn.config.num_attention_heads hook = layer.self_attn.q_proj.register_forward_hook(query_hook_fn) hooks.append(hook) # Process each document chunk sequentially. for j, tokenized_doc_chunk in enumerate(tokenized_doc_chunks): current_seq_length = tokenized_doc_chunk["input_ids"].size(1) # Save the offset in an attribute the hook can access. _current_chunk_offset = current_seq_length # Clear the dictionary from any previous chunk. query_context_matrices.clear() # These chunks are already on the device. chunk_input_ids = tokenized_doc_chunk["input_ids"].contiguous() chunk_attention_mask = tokenized_doc_chunk["attention_mask"].contiguous() segment_attention_mask = torch.cat( [past_attention_mask, chunk_attention_mask, ones_mask], dim=-1 ).contiguous() current_input_ids = torch.cat([chunk_input_ids, question_ids], dim=-1).contiguous() current_attention_mask = torch.cat([segment_attention_mask, question_attention_mask], dim=-1).contiguous() past_seen_tokens = cache.get_seq_length() if cache is not None else 0 cache_position = torch.arange( past_seen_tokens + chunk_input_ids.shape[1], past_seen_tokens + current_input_ids.shape[1], device=device ) causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position( current_attention_mask, sequence_length=question_ids.size(1), target_length=current_attention_mask.size(-1), dtype=dtype, device=device, cache_position=cache_position, batch_size=current_input_ids.size(0), ).contiguous() with torch.no_grad(): outputs = model.model( input_ids=current_input_ids, use_cache=True, past_key_values=cache, ) cache = outputs.past_key_values len_question = question_ids.size(1) # Now, for each transformer layer, update the cache using the query/key attention. for layer_idx in range(len(model.model.layers)): key_matrix = cache.key_cache[layer_idx] query_matrix = query_context_matrices[layer_idx] layer_cache_pos = torch.arange( past_cache_len + current_seq_length, past_cache_len + current_seq_length + len_question, device=device ) position_ids = layer_cache_pos.unsqueeze(0) cos, sin = rotary_emb(query_matrix, position_ids) cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) query_matrix = (query_matrix * cos) + (rotate_half(query_matrix) * sin) num_repeats = model.config.num_attention_heads // model.config.num_key_value_heads key_matrix = repeat_kv(key_matrix, num_repeats) scaling = math.sqrt(model.config.head_dim) attention_matrix = torch.matmul(query_matrix, key_matrix.transpose(2, 3)) / scaling causal_mask_sliced = causal_mask[:, :, :, : key_matrix.shape[-2]] attention_matrix = attention_matrix + causal_mask_sliced attention_matrix = torch.nn.functional.softmax(attention_matrix, dim=-1, dtype=torch.float32).to(query_matrix.dtype) # Normalization tol = 1e-8 binary_mask = (torch.abs(causal_mask_sliced.to(torch.float32)) < tol).to(torch.float32) non_zero_counts = binary_mask.sum(dim=3, keepdim=True) non_zero_counts = torch.clamp_min(non_zero_counts, 1.0).to(attention_matrix.dtype) attention_matrix = attention_matrix / non_zero_counts if j != num_chunks - 1: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length].clone().contiguous() else: attention_matrix = attention_matrix[:, :, :, : past_cache_len + current_seq_length + len_question].clone().contiguous() attention_matrix = torch.sum(attention_matrix, dim=-2) attention_matrix = attention_matrix.view( attention_matrix.size(0), model.config.num_key_value_heads, num_repeats, -1 ).sum(dim=2) full_context_size = attention_matrix.size(-1) attention_matrix[..., :sink_tokens] = float("inf") if j == num_chunks - 1: attention_matrix[..., -len_question:] = float("inf") if j == 0: k = int(sink_tokens + (max(0, current_seq_length - sink_tokens) // compression_factor)) k = min(k + past_cache_len, full_context_size) elif j < num_chunks - 1: to_keep_new = int(current_seq_length // compression_factor) k = min(past_cache_len + to_keep_new, full_context_size) else: desired_final = sink_tokens + target_token_size + len_question# TODO remember to include the question tokens k = desired_final if full_context_size >= desired_final else full_context_size k = max(k, sink_tokens) selected_indices = torch.topk(attention_matrix, k, dim=-1).indices selected_indices, _ = torch.sort(selected_indices, dim=-1) cache.compress_cache(layer_idx, selected_indices, inv_freq) past_cache_len = cache._seen_tokens past_attention_mask = torch.ones(1, past_cache_len, device=device) # Remove the hooks once after all chunks are processed. for hook in hooks: hook.remove() return cache