Sony
/

Image-Text-to-Text
Safetensors
English
conversational
SwyWang's picture
Upload 8 files
7eb0198 verified
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
import re
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum, nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional
from dataclasses import dataclass
@dataclass
class VLMOutputWithPast(CausalLMOutputWithPast):
"""
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
"""
past_media_locations: Optional[torch.Tensor] = None
past_vision_tokens: Optional[torch.Tensor] = None
def exists(val):
return val is not None
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
class VisionTokenizer(nn.Module):
def __init__(self, dim_media, num_tokens_per_media):
super().__init__()
self.dim_media = dim_media
self.num_tokens_per_media = num_tokens_per_media
# MLP (not used in the current implementation)
class MLPVisionProjector(VisionTokenizer):
def __init__(self, *, dim, dim_inner, num_latents):
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
self.projector = nn.Sequential(
nn.Linear(dim, dim_inner),
nn.GELU(),
nn.Linear(dim_inner, dim_inner),
)
def forward(self, x):
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, T, n1, D)
latent (torch.Tensor): latent features
shape (b, T, n2, D)
"""
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(VisionTokenizer):
def __init__(
self,
*,
dim,
dim_inner=None,
depth=6,
dim_head=64,
heads=8,
num_latents=64,
max_num_media=None,
max_num_frames=None,
ff_mult=4,
):
"""
Perceiver module which takes in image features and outputs image tokens.
Args:
dim (int): dimension of the incoming image features
dim_inner (int, optional): final dimension to project the incoming image features to;
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
depth (int, optional): number of layers. Defaults to 6.
dim_head (int, optional): dimension of each head. Defaults to 64.
heads (int, optional): number of heads. Defaults to 8.
num_latents (int, optional): number of latent tokens to use in the Perceiver;
also corresponds to number of tokens per sequence to output. Defaults to 64.
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
max_num_frames (int, optional): maximum number of frames to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
"""
if dim_inner is not None:
projection = nn.Linear(dim, dim_inner)
else:
projection = None
dim_inner = dim
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
self.projection = projection
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# positional embeddings
self.frame_embs = (
nn.Parameter(torch.randn(max_num_frames, dim))
if exists(max_num_frames)
else None
)
self.media_time_embs = (
nn.Parameter(torch.randn(max_num_media, 1, dim))
if exists(max_num_media)
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (b, T, F, v, D)
Returns:
shape (b, T, n, D) where n is self.num_latents
"""
b, T, F, v = x.shape[:4]
# frame and media time embeddings
if exists(self.frame_embs):
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
x = x + frame_embs
x = rearrange(
x, "b T F v d -> b T (F v) d"
) # flatten the frame and spatial dimensions
if exists(self.media_time_embs):
x = x + self.media_time_embs[:T]
# blocks
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
if exists(self.projection):
return self.projection(self.norm(latents))
else:
return self.norm(latents)
# gated cross attention
class MaskedCrossAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_visual,
dim_head=64,
heads=8,
only_attend_immediate_media=True,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether for text to only attend to immediate preceding image, or all previous images
self.only_attend_immediate_media = only_attend_immediate_media
def forward(self, x, media, media_locations=None, use_cached_media=False):
"""
Args:
x (torch.Tensor): text features
shape (B, T_txt, D_txt)
media (torch.Tensor): image features
shape (B, T_img, n, D_img) where n is the dim of the latents
media_locations: boolean mask identifying the media tokens in x
shape (B, T_txt)
use_cached_media: bool
If true, treat all of x as if they occur after the last media
registered in media_locations. T_txt does not need to exactly
equal media_locations.shape[1] in this case
"""
if not use_cached_media:
assert (
media_locations.shape[1] == x.shape[1]
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
T_txt = x.shape[1]
_, T_img, n = media.shape[:3]
h = self.heads
x = self.norm(x)
q = self.to_q(x)
media = rearrange(media, "b t n d -> b (t n) d")
k, v = self.to_kv(media).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("... i d, ... j d -> ... i j", q, k)
if exists(media_locations):
media_time = torch.arange(T_img, device=x.device) + 1
if use_cached_media:
# text time is set to the last cached media location
text_time = repeat(
torch.count_nonzero(media_locations, dim=1),
"b -> b i",
i=T_txt,
)
else:
# at each boolean of True, increment the time counter (relative to media time)
text_time = media_locations.cumsum(dim=-1)
# text time must equal media time if only attending to most immediate image
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
text_to_media_mask = mask_op(
rearrange(text_time, "b i -> b 1 i 1"),
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
if exists(media_locations) and self.only_attend_immediate_media:
# any text without a preceding media needs to have attention zeroed out
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(
text_without_media_mask, "b i -> b 1 i 1"
)
attn = attn.masked_fill(text_without_media_mask, 0.0)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class GatedCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_visual,
dim_head=64,
heads=8,
ff_mult=4,
only_attend_immediate_media=True,
):
super().__init__()
self.attn = MaskedCrossAttention(
dim=dim,
dim_visual=dim_visual,
dim_head=dim_head,
heads=heads,
only_attend_immediate_media=only_attend_immediate_media,
)
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
self.ff = FeedForward(dim, mult=ff_mult)
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
def forward(
self,
x,
media,
media_locations=None,
use_cached_media=False,
):
x = (
self.attn(
x,
media,
media_locations=media_locations,
use_cached_media=use_cached_media,
)
* self.attn_gate.tanh()
+ x
)
x = self.ff(x) * self.ff_gate.tanh() + x
return x
# Both DecoupledEmbedding and DecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity
class DecoupledEmbedding(nn.Embedding):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
then it will create `num_additional_embeddings` additional parameters that are always trained. If
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
"""
def __init__(
self,
max_original_id: int,
num_additional_embeddings: int = 0,
_weight: torch.Tensor = None,
num_original_embeddings: int = None,
embedding_dim: int = None,
partially_freeze=True,
device=None,
dtype=None,
pad_token_id=None,
) -> None:
"""
Args:
max_original_id (`int`):
The largest token id that should be embedded using the regular embedding (regular `weight`).
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal self.weight.shape[0]
num_additional_embeddings (`int`):
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
num_original_embeddings (`int`):
self.weight.shape[0]
embedding_dim (`int`):
The size of each embedding vector
partially_freeze: (`bool`, *optional*, defaults to `True`):
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
padding_idx (`int`, *optional*):
The padding index (needs to be less than num_embeddings)
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
`max_norm` or `norm_type`. We are not supporting these.
"""
# validate args
if pad_token_id is not None and pad_token_id > max_original_id:
raise ValueError(
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
)
if _weight is not None:
assert (num_original_embeddings is None) or (
_weight.shape[0] == num_original_embeddings
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
assert (embedding_dim is None) or (
_weight.shape[1] == embedding_dim
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
num_original_embeddings = _weight.shape[0]
embedding_dim = _weight.shape[1]
else:
assert (
num_original_embeddings is not None
), "num_original_embeddings must be provided if _weight is not provided"
assert (
embedding_dim is not None
), "embedding_dim must be provided if _weight is not provided"
super().__init__(
num_embeddings=num_original_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
padding_idx=pad_token_id,
_weight=_weight,
)
self.max_original_id = max_original_id
self.padding_idx = pad_token_id
self.num_additional_embeddings = num_additional_embeddings
if self.num_additional_embeddings > 0:
self.additional_embedding = nn.Embedding(
num_embeddings=self.num_additional_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
self.additional_embedding.requires_grad_(require_additional_grad)
def forward(self, input_ids):
"""
we have 2 embeddings, with different indices - one pretrained self.weight and another
self.additional_embedding.weight that is being trained.
in order to make a lookup of the input ids, we:
1. find out the indices of the entries belonging to the 2nd embedding
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
embedding starts from 0 and not num_embeddings
3. perform the 2nd embedding lookup
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
5. perform the 1st embedding lookup
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
measure.
"""
if self.num_additional_embeddings == 0:
return F.embedding(input_ids, self.weight)
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids.clone()
additional_vocab_indices = torch.where(input_ids > self.max_original_id)
input_ids_additional_vocab = input_ids[additional_vocab_indices]
additional_embeddings = self.additional_embedding(
input_ids_additional_vocab - self.max_original_id - 1
)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids[additional_vocab_indices] = 0
full_vector = F.embedding(input_ids, self.weight)
# overwrite the records with high indices
full_vector[additional_vocab_indices] = additional_embeddings
return full_vector
def extra_repr(self) -> str:
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
self.max_original_id + 1,
self.num_additional_embeddings,
self.embedding_dim,
(not self.weight.requires_grad),
)
class DecoupledLinear(nn.Linear):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
then it will create `additional_out_features * in_features` additional parameters that are always trained. If
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
"""
def __init__(
self,
max_original_id: int,
additional_out_features: int = 0,
_weight: torch.Tensor = None,
_bias: torch.Tensor = None,
in_features: int = None,
original_out_features: int = None,
bias: bool = True,
partially_freeze: bool = True,
device=None,
dtype=None,
) -> None:
"""
Args:
max_original_id (`int`): The largest token id that should be extracted from the regular weight.
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal original_out_features - 1
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
If provided, this sets the `in_features` and `original_out_features` parameters.
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
in_features: int. Input hidden size.
original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
additional_out_features: int. Number of additional trainable dimensions.
bias: bool. Whether to include a bias term.
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
"""
# argument validation
if _weight is not None:
assert (_weight.shape[0] == original_out_features) or (
original_out_features is None
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
assert (_weight.shape[1] == in_features) or (
in_features is None
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
in_features = _weight.shape[1]
original_out_features = _weight.shape[0]
else:
assert (
in_features is not None
), "in_features must be provided if _weight is not provided"
assert (
original_out_features is not None
), "original_out_features must be provided if _weight is not provided"
if _bias is not None:
assert bias is True, "bias must be True if _bias is provided"
# initialize original linear
super().__init__(
in_features,
original_out_features,
bias,
device,
dtype)
# set weight and bias manually
if _weight is not None:
self.weight = nn.Parameter(_weight)
if _bias is not None:
self.bias = nn.Parameter(_bias)
self.in_features = in_features
self.original_out_features = original_out_features
self.max_original_id = max_original_id
# initialize additional linear
self.additional_out_features = additional_out_features
self.has_bias = bias
if additional_out_features > 0:
self.additional_fc = nn.Linear(
in_features=in_features,
out_features=additional_out_features,
bias=self.has_bias,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
if self.has_bias:
self.bias.requires_grad_(require_regular_grad)
self.additional_fc.requires_grad_(require_additional_grad)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.max_original_id + 1]
if self.additional_out_features > 0:
additional_features = F.linear(
input, self.additional_fc.weight, self.additional_fc.bias
)
output = torch.cat((output, additional_features), -1)
return output
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
self.in_features,
self.max_original_id + 1,
self.additional_out_features,
self.bias is not None,
(not self.weight.requires_grad or not self.bias.requires_grad),
)