import math import torch def get_alibi( max_positions: int, attention_heads: int, ): def get_slopes(n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio ** i for i in range(n)] # In the paper, we only train models that have 2^a heads for some # a. This function has some good properties that only occur when # the input is a power of 2. To maintain that even when the number # of heads is not a power of 2, we use this workaround. if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] ) maxpos = max_positions attn_heads = attention_heads slopes = torch.Tensor(get_slopes(attn_heads)) # prepare alibi position linear bias. Note that wav2vec2 is non # autoregressive model so we want a symmetric mask with 0 on the # diagonal and other wise linear decreasing valuees pos_bias = ( torch.abs( torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) ) * -1 ) alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( attn_heads, -1, -1 ) return alibi_bias def masked_alibi(alibi_bias, mask_indices, orig_B, orig_T): alibi_bias = alibi_bias.view(orig_B, -1, orig_T, orig_T) H = alibi_bias.size(1) alibi_mask = mask_indices.unsqueeze(1) alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-1)) alibi_bias = alibi_bias.view(orig_B, H, -1, orig_T) M = alibi_bias.size(-2) alibi_bias = alibi_bias.masked_select(alibi_mask.unsqueeze(-2)) alibi_bias = alibi_bias.view(-1, M, M) return alibi_bias