DLF / trains /subNets /transformers_encoder /position_embedding.py
peter-wang321
Initial DLF commit
9157432
import math
import torch
import torch.nn as nn
def make_positions(tensor, padding_idx, left_pad):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
max_pos = padding_idx + 1 + tensor.size(1)
device = tensor.get_device()
buf_name = f'range_buf_{device}'
if not hasattr(make_positions, buf_name):
setattr(make_positions, buf_name, tensor.new())
setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
if getattr(make_positions, buf_name).numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
mask = tensor.ne(padding_idx)
positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
new_tensor = tensor.clone()
return new_tensor.masked_scatter_(mask, positions[mask]).long()
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.weights = dict() # device --> actual weight; due to nn.DataParallel :-(
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.size()
max_pos = self.padding_idx + 1 + seq_len
device = input.get_device()
if device not in self.weights or max_pos > self.weights[device].size(0):
# recompute/expand embeddings if needed
self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights[device] = self.weights[device].type_as(self._float_tensor).to(input.device)
positions = make_positions(input, self.padding_idx, self.left_pad)
return self.weights[device].index_select(0, positions.contiguous().view(-1)).view(bsz, seq_len, -1).detach()
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number