""" here is the mian backbone for DLF """ import torch import torch.nn as nn import torch.nn.functional as F from ...subNets import BertTextEncoder from ...subNets.transformers_encoder.transformer import TransformerEncoder from huggingface_hub import PyTorchModelHubMixin class DLF(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/pwang322/DLF", paper_url="https://huggingface.co/papers/2412.12225", tags=["sentiment-analysis"], license="mit"): def __init__(self, args): super(DLF, self).__init__() if args.use_bert: self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers, pretrained=args.pretrained) self.use_bert = args.use_bert dst_feature_dims, nheads = args.dst_feature_dim_nheads if args.dataset_name == 'mosi': if args.need_data_aligned: self.len_l, self.len_v, self.len_a = 50, 50, 50 else: self.len_l, self.len_v, self.len_a = 50, 500, 375 if args.dataset_name == 'mosei': if args.need_data_aligned: self.len_l, self.len_v, self.len_a = 50, 50, 50 else: self.len_l, self.len_v, self.len_a = 50, 500, 500 self.orig_d_l, self.orig_d_a, self.orig_d_v = args.feature_dims self.d_l = self.d_a = self.d_v = dst_feature_dims self.num_heads = nheads self.layers = args.nlevels self.attn_dropout = args.attn_dropout self.attn_dropout_a = args.attn_dropout_a self.attn_dropout_v = args.attn_dropout_v self.relu_dropout = args.relu_dropout self.embed_dropout = args.embed_dropout self.res_dropout = args.res_dropout self.output_dropout = args.output_dropout self.text_dropout = args.text_dropout self.attn_mask = args.attn_mask combined_dim_low = self.d_a combined_dim_high = self.d_a combined_dim = (self.d_l + self.d_a + self.d_v ) + self.d_l * 3 output_dim = 1 # 1. Temporal convolutional layers for initial feature self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False) self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=args.conv1d_kernel_size_a, padding=0, bias=False) self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=args.conv1d_kernel_size_v, padding=0, bias=False) # 2. Modality-specific encoder self.encoder_s_l = self.get_network(self_type='l', layers = self.layers) self.encoder_s_v = self.get_network(self_type='v', layers = self.layers) self.encoder_s_a = self.get_network(self_type='a', layers = self.layers) # Modality-shared encoder self.encoder_c = self.get_network(self_type='l', layers = self.layers) # 3. Decoder for reconstruct three modalities self.decoder_l = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0, bias=False) self.decoder_v = nn.Conv1d(self.d_v * 2, self.d_v, kernel_size=1, padding=0, bias=False) self.decoder_a = nn.Conv1d(self.d_a * 2, self.d_a, kernel_size=1, padding=0, bias=False) # for calculate cosine sim between s_x self.proj_cosine_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low) self.proj_cosine_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low) self.proj_cosine_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low) # for align c_l, c_v, c_a self.align_c_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low) self.align_c_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low) self.align_c_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low) self.self_attentions_c_l = self.get_network(self_type='l') self.self_attentions_c_v = self.get_network(self_type='v') self.self_attentions_c_a = self.get_network(self_type='a') self.proj1_c = nn.Linear(self.d_l * 3, self.d_l * 3) self.proj2_c = nn.Linear(self.d_l * 3, self.d_l * 3) self.out_layer_c = nn.Linear(self.d_l * 3, output_dim) # 4 Multimodal Crossmodal Attentions self.trans_l_with_a = self.get_network(self_type='la', layers = self.layers) self.trans_l_with_v = self.get_network(self_type='lv', layers = self.layers) self.trans_a_with_l = self.get_network(self_type='al') self.trans_a_with_v = self.get_network(self_type='av') self.trans_v_with_l = self.get_network(self_type='vl') self.trans_v_with_a = self.get_network(self_type='va') self.trans_l_mem = self.get_network(self_type='l_mem', layers=self.layers) self.trans_a_mem = self.get_network(self_type='a_mem', layers=3) self.trans_v_mem = self.get_network(self_type='v_mem', layers=3) # 5. fc layers for shared features self.proj1_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low) self.proj2_l_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1)) self.out_layer_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), output_dim) self.proj1_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low) self.proj2_v_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1)) self.out_layer_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), output_dim) self.proj1_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low) self.proj2_a_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1)) self.out_layer_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), output_dim) # 6. fc layers for specific features self.proj1_l_high = nn.Linear(combined_dim_high, combined_dim_high) self.proj2_l_high = nn.Linear(combined_dim_high, combined_dim_high) self.out_layer_l_high = nn.Linear(combined_dim_high, output_dim) self.proj1_v_high = nn.Linear(combined_dim_high, combined_dim_high) self.proj2_v_high = nn.Linear(combined_dim_high, combined_dim_high) self.out_layer_v_high = nn.Linear(combined_dim_high, output_dim) self.proj1_a_high = nn.Linear(combined_dim_high, combined_dim_high) self.proj2_a_high = nn.Linear(combined_dim_high, combined_dim_high) self.out_layer_a_high = nn.Linear(combined_dim_high, output_dim) # 7. project for fusion self.projector_l = nn.Linear(self.d_l, self.d_l) self.projector_v = nn.Linear(self.d_v, self.d_v) self.projector_a = nn.Linear(self.d_a, self.d_a) self.projector_c = nn.Linear(3 * self.d_l, 3 * self.d_l) # 8. final project self.proj1 = nn.Linear(combined_dim, combined_dim) self.proj2 = nn.Linear(combined_dim, combined_dim) self.out_layer = nn.Linear(combined_dim, output_dim) def get_network(self, self_type='l', layers=-1): if self_type in ['l', 'al', 'vl']: embed_dim, attn_dropout = self.d_l, self.attn_dropout elif self_type in ['a', 'la', 'va']: embed_dim, attn_dropout = self.d_a, self.attn_dropout_a elif self_type in ['v', 'lv', 'av']: embed_dim, attn_dropout = self.d_v, self.attn_dropout_v elif self_type == 'l_mem': embed_dim, attn_dropout = self.d_l, self.attn_dropout elif self_type == 'a_mem': embed_dim, attn_dropout = self.d_a, self.attn_dropout elif self_type == 'v_mem': embed_dim, attn_dropout = self.d_v, self.attn_dropout else: raise ValueError("Unknown network type") return TransformerEncoder(embed_dim=embed_dim, num_heads=self.num_heads, layers=max(self.layers, layers), attn_dropout=attn_dropout, relu_dropout=self.relu_dropout, res_dropout=self.res_dropout, embed_dropout=self.embed_dropout, attn_mask=self.attn_mask) def forward(self, text, audio, video): #extraction if self.use_bert: text = self.text_model(text) x_l = F.dropout(text.transpose(1, 2), p=self.text_dropout, training=self.training) x_a = audio.transpose(1, 2) x_v = video.transpose(1, 2) proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l) proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a) proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v) proj_x_l = proj_x_l.permute(2, 0, 1) proj_x_v = proj_x_v .permute(2, 0, 1) proj_x_a = proj_x_a.permute(2, 0, 1) #disentanglement s_l = self.encoder_s_l(proj_x_l) s_v = self.encoder_s_v(proj_x_v) s_a = self.encoder_s_a(proj_x_a) c_l = self.encoder_c(proj_x_l) c_v = self.encoder_c(proj_x_v) c_a = self.encoder_c(proj_x_a) s_l = s_l.permute(1, 2, 0) s_v = s_v.permute(1, 2, 0) s_a = s_a.permute(1, 2, 0) c_l = c_l.permute(1, 2, 0) c_v = c_v.permute(1, 2, 0) c_a = c_a.permute(1, 2, 0) c_list = [c_l, c_v, c_a] c_l_sim = self.align_c_l(c_l.contiguous().view(x_l.size(0), -1)) c_v_sim = self.align_c_v(c_v.contiguous().view(x_l.size(0), -1)) c_a_sim = self.align_c_a(c_a.contiguous().view(x_l.size(0), -1)) recon_l = self.decoder_l(torch.cat([s_l, c_list[0]], dim=1)) recon_v = self.decoder_v(torch.cat([s_v, c_list[1]], dim=1)) recon_a = self.decoder_a(torch.cat([s_a, c_list[2]], dim=1)) recon_l = recon_l.permute(2, 0, 1) recon_v = recon_v.permute(2, 0, 1) recon_a = recon_a.permute(2, 0, 1) s_l_r = self.encoder_s_l(recon_l).permute(1, 2, 0) s_v_r = self.encoder_s_v(recon_v).permute(1, 2, 0) s_a_r = self.encoder_s_a(recon_a).permute(1, 2, 0) s_l = s_l.permute(2, 0, 1) s_v = s_v.permute(2, 0, 1) s_a = s_a.permute(2, 0, 1) c_l = c_l.permute(2, 0, 1) c_v = c_v.permute(2, 0, 1) c_a = c_a.permute(2, 0, 1) #enhancement hs_l_low = c_l.transpose(0, 1).contiguous().view(x_l.size(0), -1) repr_l_low = self.proj1_l_low(hs_l_low) hs_proj_l_low = self.proj2_l_low( F.dropout(F.relu(repr_l_low, inplace=True), p=self.output_dropout, training=self.training)) hs_proj_l_low += hs_l_low logits_l_low = self.out_layer_l_low(hs_proj_l_low) hs_v_low = c_v.transpose(0, 1).contiguous().view(x_v.size(0), -1) repr_v_low = self.proj1_v_low(hs_v_low) hs_proj_v_low = self.proj2_v_low( F.dropout(F.relu(repr_v_low, inplace=True), p=self.output_dropout, training=self.training)) hs_proj_v_low += hs_v_low logits_v_low = self.out_layer_v_low(hs_proj_v_low) hs_a_low = c_a.transpose(0, 1).contiguous().view(x_a.size(0), -1) repr_a_low = self.proj1_a_low(hs_a_low) hs_proj_a_low = self.proj2_a_low( F.dropout(F.relu(repr_a_low, inplace=True), p=self.output_dropout, training=self.training)) hs_proj_a_low += hs_a_low logits_a_low = self.out_layer_a_low(hs_proj_a_low) c_l_att = self.self_attentions_c_l(c_l) if type(c_l_att) == tuple: c_l_att = c_l_att[0] c_l_att = c_l_att[-1] c_v_att = self.self_attentions_c_v(c_v) if type(c_v_att) == tuple: c_v_att = c_v_att[0] c_v_att = c_v_att[-1] c_a_att = self.self_attentions_c_a(c_a) if type(c_a_att) == tuple: c_a_att = c_a_att[0] c_a_att = c_a_att[-1] c_fusion = torch.cat([c_l_att, c_v_att, c_a_att], dim=1) c_proj = self.proj2_c( F.dropout(F.relu(self.proj1_c(c_fusion), inplace=True), p=self.output_dropout, training=self.training)) c_proj += c_fusion logits_c = self.out_layer_c(c_proj) # LFA # L --> L h_ls = s_l h_ls = self.trans_l_mem(h_ls) if type(h_ls) == tuple: h_ls = h_ls[0] last_h_l = last_hs = h_ls[-1] # A --> L h_l_with_as = self.trans_l_with_a(s_l, s_a, s_a) h_as = h_l_with_as h_as = self.trans_a_mem(h_as) if type(h_as) == tuple: h_as = h_as[0] last_h_a = last_hs = h_as[-1] # V --> L h_l_with_vs = self.trans_l_with_v(s_l, s_v, s_v) h_vs = h_l_with_vs h_vs = self.trans_v_mem(h_vs) if type(h_vs) == tuple: h_vs = h_vs[0] last_h_v = last_hs = h_vs[-1] hs_proj_l_high = self.proj2_l_high( F.dropout(F.relu(self.proj1_l_high(last_h_l), inplace=True), p=self.output_dropout, training=self.training)) hs_proj_l_high += last_h_l logits_l_high = self.out_layer_l_high(hs_proj_l_high) hs_proj_v_high = self.proj2_v_high( F.dropout(F.relu(self.proj1_v_high(last_h_v), inplace=True), p=self.output_dropout, training=self.training)) hs_proj_v_high += last_h_v logits_v_high = self.out_layer_v_high(hs_proj_v_high) hs_proj_a_high = self.proj2_a_high( F.dropout(F.relu(self.proj1_a_high(last_h_a), inplace=True), p=self.output_dropout, training=self.training)) hs_proj_a_high += last_h_a logits_a_high = self.out_layer_a_high(hs_proj_a_high) #fusion last_h_l = torch.sigmoid(self.projector_l(hs_proj_l_high)) last_h_v = torch.sigmoid(self.projector_v(hs_proj_v_high)) last_h_a = torch.sigmoid(self.projector_a(hs_proj_a_high)) c_fusion = torch.sigmoid(self.projector_c(c_fusion)) last_hs = torch.cat([last_h_l, last_h_v, last_h_a, c_fusion], dim=1) #prediction last_hs_proj = self.proj2( F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training)) last_hs_proj += last_hs output = self.out_layer(last_hs_proj) res = { 'origin_l': proj_x_l, 'origin_v': proj_x_v, 'origin_a': proj_x_a, 's_l': s_l, 's_v': s_v, 's_a': s_a, 'c_l': c_l, 'c_v': c_v, 'c_a': c_a, 's_l_r': s_l_r, 's_v_r': s_v_r, 's_a_r': s_a_r, 'recon_l': recon_l, 'recon_v': recon_v, 'recon_a': recon_a, 'c_l_sim': c_l_sim, 'c_v_sim': c_v_sim, 'c_a_sim': c_a_sim, 'logits_l_hetero': logits_l_high, 'logits_v_hetero': logits_v_high, 'logits_a_hetero': logits_a_high, 'logits_c': logits_c, 'output_logit': output } return res