# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Dict, Optional, Tuple import torch from fairseq import utils from fairseq.modules.quant_noise import quant_noise from torch import Tensor, nn from torch.nn import Parameter from fairseq.modules.multihead_attention import MultiheadAttention from ..modules.multihead_functional import multi_head_attention_forward class MultiheadAttentionSelection(MultiheadAttention): def __init__( self, embed_dim, total_num_heads, num_heads, kdim=None, vdim=None, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, self_attention=False, encoder_decoder_attention=False, q_noise=0.0, qn_block_size=8, layer_idx=0, attn_head_selector=None ): super().__init__( embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention, encoder_decoder_attention=encoder_decoder_attention, q_noise=q_noise, qn_block_size=qn_block_size, ) self.layer_idx = layer_idx self.attn_head_selector = attn_head_selector self.total_num_heads = total_num_heads self.total_embed_dim = self.head_dim * total_num_heads self.k_proj = quant_noise( nn.Linear(self.kdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size ) self.v_proj = quant_noise( nn.Linear(self.vdim, self.total_embed_dim, bias=bias), q_noise, qn_block_size ) self.q_proj = quant_noise( nn.Linear(embed_dim, self.total_embed_dim, bias=bias), q_noise, qn_block_size ) if add_bias_kv: self.bias_k = Parameter(torch.Tensor(1, 1, self.total_embed_dim)) self.bias_v = Parameter(torch.Tensor(1, 1, self.total_embed_dim)) else: self.bias_k = self.bias_v = None self.reset_parameters() def forward( self, query, key: Optional[Tensor], value: Optional[Tensor], key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, need_weights: bool = True, static_kv: bool = False, attn_mask: Optional[Tensor] = None, before_softmax: bool = False, need_head_weights: bool = False, # subset_heads: Optional[Tensor] = None, # subset_weights: Optional[Tensor] = None ) -> Tuple[Tensor, Optional[Tensor]]: if need_head_weights: need_weights = True is_tpu = query.device.type == "xla" subset_heads, subset_weights = self.attn_head_selector(self.layer_idx) tgt_len, bsz, embed_dim = query.size() src_len = tgt_len assert list(query.size()) == [tgt_len, bsz, self.embed_dim] if key is not None: src_len, key_bsz, _ = key.size() if not torch.jit.is_scripting(): assert key_bsz == bsz assert value is not None assert src_len, bsz == value.shape[:2] if ( not self.onnx_trace and not is_tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation # treats bias in linear module as method. and not torch.jit.is_scripting() ): assert key is not None and value is not None return multi_head_attention_forward( query, key, value, self.embed_dim, self.total_num_heads, self.num_heads, torch.empty([0]), torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), self.bias_k, self.bias_v, self.add_zero_attn, self.dropout_module.p, self.out_proj.weight, self.out_proj.bias, self.training or self.dropout_module.apply_during_inference, key_padding_mask, need_weights, attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, subset_heads=subset_heads, subset_weights=subset_weights ) if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert self.encoder_decoder_attention and not self.self_attention key = value = None else: saved_state = None if self.self_attention: q = self.q_proj(query) k = self.k_proj(query) v = self.v_proj(query) elif self.encoder_decoder_attention: # encoder-decoder attention q = self.q_proj(query) if key is None: assert value is None k = v = None else: k = self.k_proj(key) v = self.v_proj(key) else: assert key is not None and value is not None q = self.q_proj(query) k = self.k_proj(key) v = self.v_proj(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1), ], dim=1, ) q = ( q.contiguous() .view(tgt_len, bsz * self.total_num_heads, self.head_dim) .transpose(0, 1) ) if k is not None: k = ( k.contiguous() .view(-1, bsz * self.total_num_heads, self.head_dim) .transpose(0, 1) ) if v is not None: v = ( v.contiguous() .view(-1, bsz * self.total_num_heads, self.head_dim) .transpose(0, 1) ) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if "prev_key" in saved_state: _prev_key = saved_state["prev_key"] assert _prev_key is not None prev_key = _prev_key.view(bsz * self.total_num_heads, -1, self.head_dim) if static_kv: k = prev_key else: assert k is not None k = torch.cat([prev_key, k], dim=1) src_len = k.size(1) if "prev_value" in saved_state: _prev_value = saved_state["prev_value"] assert _prev_value is not None prev_value = _prev_value.view(bsz * self.total_num_heads, -1, self.head_dim) if static_kv: v = prev_value else: assert v is not None v = torch.cat([prev_value, v], dim=1) prev_key_padding_mask: Optional[Tensor] = None if "prev_key_padding_mask" in saved_state: prev_key_padding_mask = saved_state["prev_key_padding_mask"] assert k is not None and v is not None key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( key_padding_mask=key_padding_mask, prev_key_padding_mask=prev_key_padding_mask, batch_size=bsz, src_len=k.size(1), static_kv=static_kv, ) saved_state["prev_key"] = k.view(bsz, self.total_num_heads, -1, self.head_dim) saved_state["prev_value"] = v.view(bsz, self.total_num_heads, -1, self.head_dim) saved_state["prev_key_padding_mask"] = key_padding_mask # In this branch incremental_state is never None assert incremental_state is not None incremental_state = self._set_input_buffer(incremental_state, saved_state) assert k is not None assert k.size(1) == src_len # This is part of a workaround to get around fork/join parallelism # not supporting Optional types. if key_padding_mask is not None and key_padding_mask.dim() == 0: key_padding_mask = None if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: assert v is not None src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 ) if key_padding_mask is not None: key_padding_mask = torch.cat( [ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as( key_padding_mask ), ], dim=1, ) attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) assert list(attn_weights.size()) == [bsz * self.total_num_heads, tgt_len, src_len] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.total_num_heads, tgt_len, src_len) if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), ) else: attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) attn_weights = attn_weights.transpose(0, 2) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if before_softmax: return attn_weights, v attn_weights_float = utils.softmax( attn_weights, dim=-1, onnx_trace=self.onnx_trace ) attn_weights = attn_weights_float.type_as(attn_weights) attn_probs = self.dropout_module(attn_weights) assert v is not None # evaluation if subset_heads is not None and subset_heads.numel() == 1: subset_heads = subset_heads.repeat(bsz) subset_weights = subset_weights.repeat(bsz) if subset_heads is None: attn = torch.bmm(attn_probs, v) else: # training with head selection mixed_attn = torch.bmm(attn_probs, v).contiguous().view(bsz, self.total_num_heads, tgt_len, self.head_dim) attn = torch.stack( [mixed_attn[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1 ) attn = attn * subset_weights.unsqueeze(2).unsqueeze(3) attn = attn.contiguous().view(bsz * self.num_heads, tgt_len, self.head_dim) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if self.onnx_trace and attn.size(1) == 1: # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) attn_weights: Optional[Tensor] = None if need_weights: if subset_heads is None: attn_weights = attn_weights_float.view( bsz, self.num_heads, tgt_len, src_len ).transpose(1, 0) else: mixed_attn_weights = attn_weights_float.view( bsz, self.total_num_heads, tgt_len, src_len ) attn_weights = torch.stack( [mixed_attn_weights[torch.arange(bsz), subset_heads[:, col], :, :] for col in range(subset_heads.size(1))], dim=1 ).transpose(1, 0) if not need_head_weights: # average attention weights over heads attn_weights = attn_weights.mean(dim=0) return attn, attn_weights