|
import torch |
|
import torch.nn as nn |
|
|
|
__all__ = ['AlignSubNet'] |
|
|
|
class CTCModule(nn.Module): |
|
def __init__(self, in_dim, out_seq_len): |
|
''' |
|
This module is performing alignment from A (e.g., audio) to B (e.g., text). |
|
:param in_dim: Dimension for input modality A |
|
:param out_seq_len: Sequence length for output modality B |
|
From: https://github.com/yaohungt/Multimodal-Transformer |
|
''' |
|
super(CTCModule, self).__init__() |
|
|
|
self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) |
|
|
|
self.out_seq_len = out_seq_len |
|
|
|
self.softmax = nn.Softmax(dim=2) |
|
|
|
def forward(self, x): |
|
''' |
|
:input x: Input with shape [batch_size x in_seq_len x in_dim] |
|
''' |
|
|
|
pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x) |
|
|
|
prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) |
|
prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] |
|
prob_pred_output_position = prob_pred_output_position.transpose(1,2) |
|
pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) |
|
|
|
|
|
|
|
return pseudo_aligned_out |
|
|
|
class AlignSubNet(nn.Module): |
|
def __init__(self, args, mode): |
|
""" |
|
mode: the way of aligning |
|
avg_pool, ctc, conv1d |
|
""" |
|
super(AlignSubNet, self).__init__() |
|
assert mode in ['avg_pool', 'ctc', 'conv1d'] |
|
|
|
in_dim_t, in_dim_a, in_dim_v = args.feature_dims |
|
seq_len_t, seq_len_a, seq_len_v = args.seq_lens |
|
self.dst_len = seq_len_t |
|
self.mode = mode |
|
|
|
self.ALIGN_WAY = { |
|
'avg_pool': self.__avg_pool, |
|
'ctc': self.__ctc, |
|
'conv1d': self.__conv1d |
|
} |
|
|
|
if mode == 'conv1d': |
|
self.conv1d_T = nn.Conv1d(seq_len_t, self.dst_len, kernel_size=1, bias=False) |
|
self.conv1d_A = nn.Conv1d(seq_len_a, self.dst_len, kernel_size=1, bias=False) |
|
self.conv1d_V = nn.Conv1d(seq_len_v, self.dst_len, kernel_size=1, bias=False) |
|
elif mode == 'ctc': |
|
self.ctc_t = CTCModule(in_dim_t, self.dst_len) |
|
self.ctc_a = CTCModule(in_dim_a, self.dst_len) |
|
self.ctc_v = CTCModule(in_dim_v, self.dst_len) |
|
|
|
def get_seq_len(self): |
|
return self.dst_len |
|
|
|
def __ctc(self, text_x, audio_x, video_x): |
|
text_x = self.ctc_t(text_x) if text_x.size(1) != self.dst_len else text_x |
|
audio_x = self.ctc_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x |
|
video_x = self.ctc_v(video_x) if video_x.size(1) != self.dst_len else video_x |
|
return text_x, audio_x, video_x |
|
|
|
def __avg_pool(self, text_x, audio_x, video_x): |
|
def align(x): |
|
raw_seq_len = x.size(1) |
|
if raw_seq_len == self.dst_len: |
|
return x |
|
if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len: |
|
pad_len = 0 |
|
pool_size = raw_seq_len // self.dst_len |
|
else: |
|
pad_len = self.dst_len - raw_seq_len % self.dst_len |
|
pool_size = raw_seq_len // self.dst_len + 1 |
|
pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)]) |
|
x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1) |
|
x = x.mean(dim=1) |
|
return x |
|
text_x = align(text_x) |
|
audio_x = align(audio_x) |
|
video_x = align(video_x) |
|
return text_x, audio_x, video_x |
|
|
|
def __conv1d(self, text_x, audio_x, video_x): |
|
text_x = self.conv1d_T(text_x) if text_x.size(1) != self.dst_len else text_x |
|
audio_x = self.conv1d_A(text_x) if audio_x.size(1) != self.dst_len else audio_x |
|
video_x = self.conv1d_V(text_x) if video_x.size(1) != self.dst_len else video_x |
|
return text_x, audio_x, video_x |
|
|
|
def forward(self, text_x, audio_x, video_x): |
|
|
|
if text_x.size(1) == audio_x.size(1) == video_x.size(1): |
|
return text_x, audio_x, video_x |
|
return self.ALIGN_WAY[self.mode](text_x, audio_x, video_x) |