drama-base / modeling_drama_nested.py
ccsasuke's picture
efficient_drama (#2)
7ef3558 verified
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nested._internal.nested_tensor import nested_from_padded
from transformers import (
LlamaConfig,
LlamaModel,
LlamaPreTrainedModel,
PreTrainedTokenizer,
)
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
rotate_half,
)
from transformers.processing_utils import Unpack
class ModifiedLlamaAttention(LlamaAttention):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
warnings.warn(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
attn_output, attn_weights = sdpa_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0,
scaling=self.scaling,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
dropout: float = 0.0,
scaling: Optional[float] = None,
is_causal: Optional[bool] = None,
**kwargs: Any,
) -> Tuple[torch.Tensor, None]:
if hasattr(module, "num_key_value_groups"):
if key.is_nested:
key = repeat_jagged_kv(key, module.num_key_value_groups)
value = repeat_jagged_kv(value, module.num_key_value_groups)
else:
key = repeat_dense_kv(key, module.num_key_value_groups)
value = repeat_dense_kv(value, module.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None and causal_mask.ndim == 4:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
is_causal = query.shape[2] > 1 and causal_mask is None
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
is_causal = is_causal.item()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
def repeat_jagged_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
expand_shape = (batch, num_key_value_heads, -1, n_rep, head_dim)
if n_rep == 1:
return hidden_states
hidden_states = (
hidden_states.unsqueeze(3)
.expand(expand_shape)
.transpose(1, 2)
.flatten(2, 3)
.transpose(1, 2)
)
return hidden_states
def repeat_dense_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
unsqueeze_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
if q.is_nested and k.is_nested:
if q.layout != torch.jagged:
raise NotImplementedError(f"Unsupported layout: {q.layout}")
if k.layout != torch.jagged:
raise NotImplementedError(f"Unsupported layout: {k.layout}")
return _jagged_tensor_forward(q, k, cos, sin)
else:
return _padded_tensor_forward(q, k, cos, sin)
def _jagged_tensor_forward(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_dense = q.to_padded_tensor(0.0)
k_dense = k.to_padded_tensor(0.0)
q_dense_embed = (q_dense * cos) + (rotate_half(q_dense) * sin)
k_dense_embed = (k_dense * cos) + (rotate_half(k_dense) * sin)
q_jagged_embed = convert_dense_to_jagged(q, q_dense_embed)
k_jagged_embed = convert_dense_to_jagged(k, k_dense_embed)
return q_jagged_embed, k_jagged_embed
def _padded_tensor_forward(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def convert_dense_to_jagged(nested_q: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
padded_max_S = nested_q._get_max_seqlen()
total_L = nested_q._values.shape[nested_q._ragged_idx - 1]
if padded_max_S is None:
# use upper bound on max seqlen if it's not present
padded_max_S = total_L
# convert dense tensor -> jagged
q = q.expand(
[
x if i != nested_q._ragged_idx else padded_max_S
for i, x in enumerate(q.shape)
]
)
nested_result = nested_from_padded(
q,
offsets=nested_q._offsets,
ragged_idx=nested_q._ragged_idx,
sum_S=total_L,
min_seqlen=nested_q._get_min_seqlen(),
max_seqlen=padded_max_S,
)
return nested_result
class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LlamaConfig, layer_idx: int) -> None:
nn.Module.__init__(self)
self.hidden_size: int = config.hidden_size
self.self_attn = ModifiedLlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class LlamaBiModel(LlamaModel):
def __init__(self, config: LlamaConfig) -> None:
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx: int = config.pad_token_id
self.vocab_size: int = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
ModifiedLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens=None,
output_attentions=False,
):
"""
Updates the causal mask for attention computations.
"""
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and (attention_mask == 0.0).any():
return attention_mask
return None
if attention_mask is None or attention_mask.dim() == 4:
return attention_mask
return AttentionMaskConverter._expand_mask(
mask=attention_mask,
dtype=input_tensor.dtype,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = False
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if self.gradient_checkpointing and self.training and use_cache:
warnings.warn(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.",
DeprecationWarning,
stacklevel=2,
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
warnings.warn(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)",
DeprecationWarning,
stacklevel=2,
)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
if inputs_embeds.is_nested:
seq_len = inputs_embeds._get_max_seqlen()
else:
seq_len = inputs_embeds.shape[1]
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + seq_len,
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if not inputs_embeds.is_nested:
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
)
else:
causal_mask = None
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class DramaModel(LlamaBiModel):
"""
DramaModel is a modified version of the LlamaModel that supports bi-directional attention
and provides query and document encoding functionalities.
"""
def __init__(self, config: LlamaConfig):
"""
Initializes the DramaModel by disabling causal masking in self-attention layers.
"""
super().__init__(config)
for layer in self.layers:
layer.self_attn.is_causal = False
# query prefix
self.query_prefix = "Query: "
self.max_seq_len = 8192
self.hidden_size = config.hidden_size
def _average_pool(
self, last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Computes the average pooled representation of the last hidden states.
"""
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def _tokenize(
self,
tokenizer: PreTrainedTokenizer,
texts: list[str],
max_seq_len: int = None,
use_nested: bool = False,
):
"""
Tokenizes input text sequences with optional sequence length restriction.
"""
if max_seq_len is None:
max_seq_len = self.max_seq_len
if use_nested:
tokenized = tokenizer(
texts,
truncation=True,
max_length=max_seq_len,
return_length=True,
)
tokenized.input_ids = torch.nested.nested_tensor(
tokenized.input_ids, layout=torch.jagged
).to(self.device)
tokenized.attention_mask = None
else:
tokenized = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_seq_len,
return_tensors="pt",
).to(self.device)
tokenizer_ouput = {}
tokenizer_ouput["input_ids"] = tokenized.input_ids
tokenizer_ouput["attention_mask"] = tokenized.attention_mask
return tokenizer_ouput
def encode(self, input_ids, attention_mask, dim, *args, **kwargs):
"""
Pass through the model and compute normalized embeddings.
Args:
input_ids (torch.Tensor): Input token IDs.
attention_mask (torch.Tensor): Attention mask tensor.
dim (int): Dimensionality for output embeddings.
Returns:
torch.Tensor: Normalized output embeddings.
"""
outputs = self.forward(
input_ids, attention_mask, *args, **kwargs
).last_hidden_state
if not outputs.is_nested:
if dim is not None:
outputs = outputs[:, :, :dim]
embeddings = self._average_pool(outputs, attention_mask)
else:
if dim is not None:
outputs, _ = outputs.split_with_sizes(
split_sizes=[dim, outputs.shape[-1] - dim], dim=-1
)
embeddings = outputs.sum(dim=-2)
# normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def encode_queries(
self,
tokenizer: PreTrainedTokenizer,
queries: list[str],
max_seq_len: int = None,
dim: int = None,
use_nested: bool = False,
):
"""
Encodes a list of queries into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
queries (list[str]): List of query texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded query embeddings in shape (num_queries, dim).
"""
if not queries:
raise ValueError("queries must not be empty.")
if not isinstance(queries, list) or not all(
isinstance(q, str) for q in queries
):
raise ValueError("queries must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
queries = [self.query_prefix + query for query in queries]
tokenized_queries = self._tokenize(tokenizer, queries, max_seq_len, use_nested)
embeddings = self.encode(**tokenized_queries, dim=dim)
return embeddings
def encode_documents(
self,
tokenizer: PreTrainedTokenizer,
documents: list[str],
max_seq_len: int = None,
dim: int = None,
use_nested: bool = False,
):
"""
Encodes a list of documents into embeddings.
Args:
tokenizer (PreTrainedTokenizer): Tokenizer for text processing.
documents (list[str]): List of document texts.
max_seq_len (int, optional): Maximum sequence length.
dim (int, optional): Dimensionality for output embeddings.
Returns:
torch.Tensor: Encoded document embeddings in shape (num_documents, dim).
"""
if not documents:
raise ValueError("documents must not be empty.")
if not isinstance(documents, list) or not all(
isinstance(d, str) for d in documents
):
raise ValueError("documents must be a list of strings.")
if tokenizer is None:
raise ValueError("tokenizer must not be None.")
if dim is not None and (dim < 1 or dim > self.hidden_size):
raise ValueError(f"dim must be in range [1, {self.hidden_size}].")
tokenized_documents = self._tokenize(
tokenizer, documents, max_seq_len, use_nested
)
embeddings = self.encode(**tokenized_documents, dim=dim)
return embeddings