|
""" |
|
Based on: https://github.com/lucidrains/flamingo-pytorch |
|
""" |
|
|
|
import re |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from einops_exts import rearrange_many |
|
from torch import einsum, nn |
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
from typing import Optional |
|
from dataclasses import dataclass |
|
|
|
|
|
@dataclass |
|
class VLMOutputWithPast(CausalLMOutputWithPast): |
|
""" |
|
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes: |
|
past_media_locations: Optional[torch.Tensor] = None, |
|
past_vision_tokens: Optional[torch.Tensor] = None, |
|
""" |
|
|
|
past_media_locations: Optional[torch.Tensor] = None |
|
past_vision_tokens: Optional[torch.Tensor] = None |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def FeedForward(dim, mult=4): |
|
inner_dim = int(dim * mult) |
|
return nn.Sequential( |
|
nn.LayerNorm(dim), |
|
nn.Linear(dim, inner_dim, bias=False), |
|
nn.GELU(), |
|
nn.Linear(inner_dim, dim, bias=False), |
|
) |
|
|
|
|
|
class VisionTokenizer(nn.Module): |
|
def __init__(self, dim_media, num_tokens_per_media): |
|
super().__init__() |
|
self.dim_media = dim_media |
|
self.num_tokens_per_media = num_tokens_per_media |
|
|
|
|
|
|
|
class MLPVisionProjector(VisionTokenizer): |
|
def __init__(self, *, dim, dim_inner, num_latents): |
|
super().__init__(dim_media=dim, num_tokens_per_media=num_latents) |
|
self.projector = nn.Sequential( |
|
nn.Linear(dim, dim_inner), |
|
nn.GELU(), |
|
nn.Linear(dim_inner, dim_inner), |
|
) |
|
|
|
def forward(self, x): |
|
return self.projector(x) |
|
|
|
class PerceiverAttention(nn.Module): |
|
def __init__(self, *, dim, dim_head=64, heads=8): |
|
super().__init__() |
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.norm_media = nn.LayerNorm(dim) |
|
self.norm_latents = nn.LayerNorm(dim) |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
|
def forward(self, x, latents): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (b, T, n1, D) |
|
latent (torch.Tensor): latent features |
|
shape (b, T, n2, D) |
|
""" |
|
x = self.norm_media(x) |
|
latents = self.norm_latents(latents) |
|
|
|
h = self.heads |
|
|
|
q = self.to_q(latents) |
|
kv_input = torch.cat((x, latents), dim=-2) |
|
k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
|
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) |
|
q = q * self.scale |
|
|
|
|
|
sim = einsum("... i d, ... j d -> ... i j", q, k) |
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
attn = sim.softmax(dim=-1) |
|
|
|
out = einsum("... i j, ... j d -> ... i d", attn, v) |
|
out = rearrange(out, "b h t n d -> b t n (h d)", h=h) |
|
return self.to_out(out) |
|
|
|
|
|
class PerceiverResampler(VisionTokenizer): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
dim_inner=None, |
|
depth=6, |
|
dim_head=64, |
|
heads=8, |
|
num_latents=64, |
|
max_num_media=None, |
|
max_num_frames=None, |
|
ff_mult=4, |
|
): |
|
""" |
|
Perceiver module which takes in image features and outputs image tokens. |
|
Args: |
|
dim (int): dimension of the incoming image features |
|
dim_inner (int, optional): final dimension to project the incoming image features to; |
|
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim. |
|
depth (int, optional): number of layers. Defaults to 6. |
|
dim_head (int, optional): dimension of each head. Defaults to 64. |
|
heads (int, optional): number of heads. Defaults to 8. |
|
num_latents (int, optional): number of latent tokens to use in the Perceiver; |
|
also corresponds to number of tokens per sequence to output. Defaults to 64. |
|
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver |
|
and keep positional embeddings for. If None, no positional embeddings are used. |
|
max_num_frames (int, optional): maximum number of frames to input into the Perceiver |
|
and keep positional embeddings for. If None, no positional embeddings are used. |
|
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4. |
|
""" |
|
if dim_inner is not None: |
|
projection = nn.Linear(dim, dim_inner) |
|
else: |
|
projection = None |
|
dim_inner = dim |
|
super().__init__(dim_media=dim, num_tokens_per_media=num_latents) |
|
self.projection = projection |
|
self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
|
|
|
|
|
self.frame_embs = ( |
|
nn.Parameter(torch.randn(max_num_frames, dim)) |
|
if exists(max_num_frames) |
|
else None |
|
) |
|
self.media_time_embs = ( |
|
nn.Parameter(torch.randn(max_num_media, 1, dim)) |
|
if exists(max_num_media) |
|
else None |
|
) |
|
|
|
self.layers = nn.ModuleList([]) |
|
for _ in range(depth): |
|
self.layers.append( |
|
nn.ModuleList( |
|
[ |
|
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), |
|
FeedForward(dim=dim, mult=ff_mult), |
|
] |
|
) |
|
) |
|
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
def forward(self, x): |
|
""" |
|
Args: |
|
x (torch.Tensor): image features |
|
shape (b, T, F, v, D) |
|
Returns: |
|
shape (b, T, n, D) where n is self.num_latents |
|
""" |
|
b, T, F, v = x.shape[:4] |
|
|
|
|
|
if exists(self.frame_embs): |
|
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) |
|
x = x + frame_embs |
|
x = rearrange( |
|
x, "b T F v d -> b T (F v) d" |
|
) |
|
if exists(self.media_time_embs): |
|
x = x + self.media_time_embs[:T] |
|
|
|
|
|
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) |
|
for attn, ff in self.layers: |
|
latents = attn(x, latents) + latents |
|
latents = ff(latents) + latents |
|
|
|
if exists(self.projection): |
|
return self.projection(self.norm(latents)) |
|
else: |
|
return self.norm(latents) |
|
|
|
|
|
|
|
class MaskedCrossAttention(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
dim_visual, |
|
dim_head=64, |
|
heads=8, |
|
only_attend_immediate_media=True, |
|
): |
|
super().__init__() |
|
self.scale = dim_head**-0.5 |
|
self.heads = heads |
|
inner_dim = dim_head * heads |
|
|
|
self.norm = nn.LayerNorm(dim) |
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False) |
|
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) |
|
self.to_out = nn.Linear(inner_dim, dim, bias=False) |
|
|
|
|
|
self.only_attend_immediate_media = only_attend_immediate_media |
|
|
|
def forward(self, x, media, media_locations=None, use_cached_media=False): |
|
""" |
|
Args: |
|
x (torch.Tensor): text features |
|
shape (B, T_txt, D_txt) |
|
media (torch.Tensor): image features |
|
shape (B, T_img, n, D_img) where n is the dim of the latents |
|
media_locations: boolean mask identifying the media tokens in x |
|
shape (B, T_txt) |
|
use_cached_media: bool |
|
If true, treat all of x as if they occur after the last media |
|
registered in media_locations. T_txt does not need to exactly |
|
equal media_locations.shape[1] in this case |
|
""" |
|
|
|
if not use_cached_media: |
|
assert ( |
|
media_locations.shape[1] == x.shape[1] |
|
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" |
|
|
|
T_txt = x.shape[1] |
|
_, T_img, n = media.shape[:3] |
|
h = self.heads |
|
|
|
x = self.norm(x) |
|
|
|
q = self.to_q(x) |
|
media = rearrange(media, "b t n d -> b (t n) d") |
|
|
|
k, v = self.to_kv(media).chunk(2, dim=-1) |
|
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) |
|
|
|
q = q * self.scale |
|
|
|
sim = einsum("... i d, ... j d -> ... i j", q, k) |
|
|
|
if exists(media_locations): |
|
media_time = torch.arange(T_img, device=x.device) + 1 |
|
|
|
if use_cached_media: |
|
|
|
text_time = repeat( |
|
torch.count_nonzero(media_locations, dim=1), |
|
"b -> b i", |
|
i=T_txt, |
|
) |
|
else: |
|
|
|
text_time = media_locations.cumsum(dim=-1) |
|
|
|
|
|
|
|
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge |
|
|
|
text_to_media_mask = mask_op( |
|
rearrange(text_time, "b i -> b 1 i 1"), |
|
repeat(media_time, "j -> 1 1 1 (j n)", n=n), |
|
) |
|
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) |
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
attn = sim.softmax(dim=-1) |
|
|
|
if exists(media_locations) and self.only_attend_immediate_media: |
|
|
|
text_without_media_mask = text_time == 0 |
|
text_without_media_mask = rearrange( |
|
text_without_media_mask, "b i -> b 1 i 1" |
|
) |
|
attn = attn.masked_fill(text_without_media_mask, 0.0) |
|
|
|
out = einsum("... i j, ... j d -> ... i d", attn, v) |
|
out = rearrange(out, "b h n d -> b n (h d)") |
|
return self.to_out(out) |
|
|
|
|
|
class GatedCrossAttentionBlock(nn.Module): |
|
def __init__( |
|
self, |
|
*, |
|
dim, |
|
dim_visual, |
|
dim_head=64, |
|
heads=8, |
|
ff_mult=4, |
|
only_attend_immediate_media=True, |
|
): |
|
super().__init__() |
|
self.attn = MaskedCrossAttention( |
|
dim=dim, |
|
dim_visual=dim_visual, |
|
dim_head=dim_head, |
|
heads=heads, |
|
only_attend_immediate_media=only_attend_immediate_media, |
|
) |
|
self.attn_gate = nn.Parameter(torch.tensor([0.0])) |
|
|
|
self.ff = FeedForward(dim, mult=ff_mult) |
|
self.ff_gate = nn.Parameter(torch.tensor([0.0])) |
|
|
|
def forward( |
|
self, |
|
x, |
|
media, |
|
media_locations=None, |
|
use_cached_media=False, |
|
): |
|
x = ( |
|
self.attn( |
|
x, |
|
media, |
|
media_locations=media_locations, |
|
use_cached_media=use_cached_media, |
|
) |
|
* self.attn_gate.tanh() |
|
+ x |
|
) |
|
x = self.ff(x) * self.ff_gate.tanh() + x |
|
|
|
return x |
|
|
|
|
|
|
|
class DecoupledEmbedding(nn.Embedding): |
|
|
|
""" |
|
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the |
|
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0, |
|
then it will create `num_additional_embeddings` additional parameters that are always trained. If |
|
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_original_id: int, |
|
num_additional_embeddings: int = 0, |
|
_weight: torch.Tensor = None, |
|
num_original_embeddings: int = None, |
|
embedding_dim: int = None, |
|
partially_freeze=True, |
|
device=None, |
|
dtype=None, |
|
pad_token_id=None, |
|
) -> None: |
|
""" |
|
Args: |
|
max_original_id (`int`): |
|
The largest token id that should be embedded using the regular embedding (regular `weight`). |
|
This is usually len(tokenizer) - 1 before additional tokens are added. |
|
Note that this may not equal self.weight.shape[0] |
|
num_additional_embeddings (`int`): |
|
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`). |
|
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor. |
|
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters. |
|
num_original_embeddings (`int`): |
|
self.weight.shape[0] |
|
embedding_dim (`int`): |
|
The size of each embedding vector |
|
partially_freeze: (`bool`, *optional*, defaults to `True`): |
|
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen. |
|
padding_idx (`int`, *optional*): |
|
The padding index (needs to be less than num_embeddings) |
|
|
|
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`, |
|
`max_norm` or `norm_type`. We are not supporting these. |
|
""" |
|
|
|
if pad_token_id is not None and pad_token_id > max_original_id: |
|
raise ValueError( |
|
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}." |
|
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None." |
|
) |
|
if _weight is not None: |
|
assert (num_original_embeddings is None) or ( |
|
_weight.shape[0] == num_original_embeddings |
|
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}" |
|
assert (embedding_dim is None) or ( |
|
_weight.shape[1] == embedding_dim |
|
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}" |
|
num_original_embeddings = _weight.shape[0] |
|
embedding_dim = _weight.shape[1] |
|
else: |
|
assert ( |
|
num_original_embeddings is not None |
|
), "num_original_embeddings must be provided if _weight is not provided" |
|
assert ( |
|
embedding_dim is not None |
|
), "embedding_dim must be provided if _weight is not provided" |
|
|
|
super().__init__( |
|
num_embeddings=num_original_embeddings, |
|
embedding_dim=embedding_dim, |
|
device=device, |
|
dtype=dtype, |
|
padding_idx=pad_token_id, |
|
_weight=_weight, |
|
) |
|
self.max_original_id = max_original_id |
|
self.padding_idx = pad_token_id |
|
self.num_additional_embeddings = num_additional_embeddings |
|
if self.num_additional_embeddings > 0: |
|
self.additional_embedding = nn.Embedding( |
|
num_embeddings=self.num_additional_embeddings, |
|
embedding_dim=embedding_dim, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.set_requires_grad( |
|
require_regular_grad=not partially_freeze, require_additional_grad=True |
|
) |
|
|
|
def set_requires_grad(self, require_regular_grad, require_additional_grad): |
|
""" |
|
Helper function to separately set the requires_grad flag for the regular weight and the additional weight. |
|
""" |
|
self.weight.requires_grad_(require_regular_grad) |
|
self.additional_embedding.requires_grad_(require_additional_grad) |
|
|
|
def forward(self, input_ids): |
|
""" |
|
we have 2 embeddings, with different indices - one pretrained self.weight and another |
|
self.additional_embedding.weight that is being trained. |
|
|
|
in order to make a lookup of the input ids, we: |
|
1. find out the indices of the entries belonging to the 2nd embedding |
|
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd |
|
embedding starts from 0 and not num_embeddings |
|
3. perform the 2nd embedding lookup |
|
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index |
|
5. perform the 1st embedding lookup |
|
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup |
|
|
|
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but |
|
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices - |
|
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are |
|
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to |
|
measure. |
|
|
|
""" |
|
if self.num_additional_embeddings == 0: |
|
return F.embedding(input_ids, self.weight) |
|
|
|
|
|
input_ids = input_ids.clone() |
|
additional_vocab_indices = torch.where(input_ids > self.max_original_id) |
|
input_ids_additional_vocab = input_ids[additional_vocab_indices] |
|
additional_embeddings = self.additional_embedding( |
|
input_ids_additional_vocab - self.max_original_id - 1 |
|
) |
|
|
|
|
|
input_ids[additional_vocab_indices] = 0 |
|
full_vector = F.embedding(input_ids, self.weight) |
|
|
|
|
|
full_vector[additional_vocab_indices] = additional_embeddings |
|
|
|
return full_vector |
|
|
|
def extra_repr(self) -> str: |
|
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format( |
|
self.max_original_id + 1, |
|
self.num_additional_embeddings, |
|
self.embedding_dim, |
|
(not self.weight.requires_grad), |
|
) |
|
|
|
|
|
class DecoupledLinear(nn.Linear): |
|
|
|
""" |
|
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the |
|
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0, |
|
then it will create `additional_out_features * in_features` additional parameters that are always trained. If |
|
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
max_original_id: int, |
|
additional_out_features: int = 0, |
|
_weight: torch.Tensor = None, |
|
_bias: torch.Tensor = None, |
|
in_features: int = None, |
|
original_out_features: int = None, |
|
bias: bool = True, |
|
partially_freeze: bool = True, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
""" |
|
Args: |
|
max_original_id (`int`): The largest token id that should be extracted from the regular weight. |
|
This is usually len(tokenizer) - 1 before additional tokens are added. |
|
Note that this may not equal original_out_features - 1 |
|
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor. |
|
If provided, this sets the `in_features` and `original_out_features` parameters. |
|
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor. |
|
in_features: int. Input hidden size. |
|
original_out_features: int. Original out_features of the language model's get_output_embeddings() function. |
|
additional_out_features: int. Number of additional trainable dimensions. |
|
bias: bool. Whether to include a bias term. |
|
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen. |
|
""" |
|
|
|
if _weight is not None: |
|
assert (_weight.shape[0] == original_out_features) or ( |
|
original_out_features is None |
|
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}" |
|
assert (_weight.shape[1] == in_features) or ( |
|
in_features is None |
|
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}" |
|
in_features = _weight.shape[1] |
|
original_out_features = _weight.shape[0] |
|
else: |
|
assert ( |
|
in_features is not None |
|
), "in_features must be provided if _weight is not provided" |
|
assert ( |
|
original_out_features is not None |
|
), "original_out_features must be provided if _weight is not provided" |
|
|
|
if _bias is not None: |
|
assert bias is True, "bias must be True if _bias is provided" |
|
|
|
|
|
super().__init__( |
|
in_features, |
|
original_out_features, |
|
bias, |
|
device, |
|
dtype) |
|
|
|
|
|
if _weight is not None: |
|
self.weight = nn.Parameter(_weight) |
|
if _bias is not None: |
|
self.bias = nn.Parameter(_bias) |
|
|
|
self.in_features = in_features |
|
self.original_out_features = original_out_features |
|
self.max_original_id = max_original_id |
|
|
|
|
|
self.additional_out_features = additional_out_features |
|
self.has_bias = bias |
|
if additional_out_features > 0: |
|
self.additional_fc = nn.Linear( |
|
in_features=in_features, |
|
out_features=additional_out_features, |
|
bias=self.has_bias, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.set_requires_grad( |
|
require_regular_grad=not partially_freeze, require_additional_grad=True |
|
) |
|
|
|
def set_requires_grad(self, require_regular_grad, require_additional_grad): |
|
""" |
|
Helper function to separately set the requires_grad flag for the regular weight and the additional weight. |
|
""" |
|
self.weight.requires_grad_(require_regular_grad) |
|
if self.has_bias: |
|
self.bias.requires_grad_(require_regular_grad) |
|
self.additional_fc.requires_grad_(require_additional_grad) |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
output = F.linear(input, self.weight, self.bias) |
|
output = output[..., : self.max_original_id + 1] |
|
|
|
if self.additional_out_features > 0: |
|
additional_features = F.linear( |
|
input, self.additional_fc.weight, self.additional_fc.bias |
|
) |
|
output = torch.cat((output, additional_features), -1) |
|
return output |
|
|
|
def extra_repr(self) -> str: |
|
"""Overwriting `nn.Linear.extra_repr` to include new parameters.""" |
|
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format( |
|
self.in_features, |
|
self.max_original_id + 1, |
|
self.additional_out_features, |
|
self.bias is not None, |
|
(not self.weight.requires_grad or not self.bias.requires_grad), |
|
) |