peter-wang321
Initial DLF commit
9157432
import logging
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from ..utils import MetricsTop, dict_to_str
from .HingeLoss import HingeLoss
logger = logging.getLogger('MMSA')
class MSE(nn.Module):
def __init__(self):
super(MSE, self).__init__()
def forward(self, pred, real):
diffs = torch.add(real, -pred)
n = torch.numel(diffs.data)
mse = torch.sum(diffs.pow(2)) / n
return mse
class DLF():
def __init__(self, args):
self.args = args
self.criterion = nn.L1Loss()
self.cosine = nn.CosineEmbeddingLoss()
self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name)
self.MSE = MSE()
self.sim_loss = HingeLoss()
def do_train(self, model, dataloader, return_epoch_results=False):
# 0: DLF model
params = model[0].parameters()
optimizer = optim.Adam(params, lr=self.args.learning_rate)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True, patience=self.args.patience)
epochs, best_epoch = 0, 0
if return_epoch_results:
epoch_results = {
'train': [],
'valid': [],
'test': []
}
min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max'
best_valid = 1e8 if min_or_max == 'min' else 0
net = []
net_DLF = model[0]
net.append(net_DLF)
model = net
while True:
epochs += 1
y_pred, y_true = [], []
for mod in model:
mod.train()
train_loss = 0.0
left_epochs = self.args.update_epochs
with tqdm(dataloader['train']) as td:
for batch_data in td:
if left_epochs == self.args.update_epochs:
optimizer.zero_grad()
left_epochs -= 1
vision = batch_data['vision'].to(self.args.device)
audio = batch_data['audio'].to(self.args.device)
text = batch_data['text'].to(self.args.device)
labels = batch_data['labels']['M'].to(self.args.device)
labels = labels.view(-1, 1)
output = model[0](text, audio, vision)
# task loss
loss_task_all = self.criterion(output['output_logit'], labels)
loss_task_l_hetero = self.criterion(output['logits_l_hetero'], labels)
loss_task_v_hetero = self.criterion(output['logits_v_hetero'], labels)
loss_task_a_hetero = self.criterion(output['logits_a_hetero'], labels)
loss_task_c = self.criterion(output['logits_c'], labels)
# total MSA loss L_msa
loss_task = 1* (1 * loss_task_all + 1*loss_task_c + 3 * loss_task_l_hetero + 1*loss_task_v_hetero + 1*loss_task_a_hetero)
# reconstruction loss L_r
loss_recon_l = self.MSE(output['recon_l'], output['origin_l'])
loss_recon_v = self.MSE(output['recon_v'], output['origin_v'])
loss_recon_a = self.MSE(output['recon_a'], output['origin_a'])
loss_recon = loss_recon_l + loss_recon_v + loss_recon_a
# specific loss L_s
loss_sl_slr = self.MSE(output['s_l'].permute(1, 2, 0), output['s_l_r'])
loss_sv_slv = self.MSE(output['s_v'].permute(1, 2, 0), output['s_v_r'])
loss_sa_sla = self.MSE(output['s_a'].permute(1, 2, 0), output['s_a_r'])
loss_s_sr = loss_sl_slr + loss_sv_slv + loss_sa_sla
# ort loss L_o
if self.args.dataset_name == 'mosi':
num = 50
elif self.args.dataset_name == 'mosei':
num = 10
cosine_similarity_s_c_l = self.cosine(output['s_l'].reshape(-1, num), output['c_l'].reshape(-1, num), torch.tensor([-1]).cuda())
cosine_similarity_s_c_v = self.cosine(output['s_v'].reshape(-1, num), output['c_v'].reshape(-1, num), torch.tensor([-1]).cuda())
cosine_similarity_s_c_a = self.cosine(output['s_a'].reshape(-1, num), output['c_a'].reshape(-1, num), torch.tensor([-1]).cuda())
loss_ort = cosine_similarity_s_c_l + cosine_similarity_s_c_v + cosine_similarity_s_c_a
# triplet margin loss L_m
c_l, c_v, c_a = output['c_l_sim'], output['c_v_sim'], output['c_a_sim']
ids, feats = [], []
for i in range(labels.size(0)):
feats.append(c_l[i].view(1, -1))
feats.append(c_v[i].view(1, -1))
feats.append(c_a[i].view(1, -1))
ids.append(labels[i].view(1, -1))
ids.append(labels[i].view(1, -1))
ids.append(labels[i].view(1, -1))
feats = torch.cat(feats, dim=0)
ids = torch.cat(ids, dim=0)
loss_sim = self.sim_loss(ids, feats)
#overall loss L_DLF
combined_loss = loss_task + (loss_s_sr + loss_recon + (loss_sim+loss_ort) * 0.1) * 0.1
combined_loss.backward()
if self.args.grad_clip != -1.0:
params = list(model[0].parameters())
nn.utils.clip_grad_value_(params, self.args.grad_clip)
train_loss += combined_loss.item()
y_pred.append(output['output_logit'].cpu())
y_true.append(labels.cpu())
if not left_epochs:
optimizer.step()
left_epochs = self.args.update_epochs
if not left_epochs:
# update
optimizer.step()
train_loss = train_loss / len(dataloader['train'])
pred, true = torch.cat(y_pred), torch.cat(y_true)
train_results = self.metrics(pred, true)
logger.info(
f">> Epoch: {epochs} "
f"TRAIN -({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] "
f">> total_loss: {round(train_loss, 4)} "
f"{dict_to_str(train_results)}"
)
# validation
val_results = self.do_test(model[0], dataloader['valid'], mode="VAL")
test_results = self.do_test(model[0], dataloader['test'], mode="TEST")
cur_valid = val_results[self.args.KeyEval]
scheduler.step(val_results['Loss'])
# save each epoch model
torch.save(model[0].state_dict(), './pt/' + str(self.args.dataset_name) + '_' + str(epochs) + '.pth')
# save best model
isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6)
if isBetter:
best_valid, best_epoch = cur_valid, epochs
# save model
model_save_path = './pt/DLF' + str(self.args.dataset_name)+'.pth'
torch.save(model[0].state_dict(), model_save_path)
if return_epoch_results:
train_results["Loss"] = train_loss
epoch_results['train'].append(train_results)
epoch_results['valid'].append(val_results)
test_results = self.do_test(model, dataloader['test'], mode="TEST")
epoch_results['test'].append(test_results)
# early stop
if epochs - best_epoch >= self.args.early_stop:
return epoch_results if return_epoch_results else None
def do_test(self, model, dataloader, mode="VAL", return_sample_results=False):
model.eval()
y_pred, y_true = [], []
eval_loss = 0.0
if return_sample_results:
ids, sample_results = [], []
all_labels = []
features = {
"Feature_t": [],
"Feature_a": [],
"Feature_v": [],
"Feature_f": [],
}
with torch.no_grad():
with tqdm(dataloader) as td:
for batch_data in td:
vision = batch_data['vision'].to(self.args.device)
audio = batch_data['audio'].to(self.args.device)
text = batch_data['text'].to(self.args.device)
labels = batch_data['labels']['M'].to(self.args.device)
labels = labels.view(-1, 1)
output = model(text, audio, vision)
loss = self.criterion(output['output_logit'], labels)
eval_loss += loss.item()
y_pred.append(output['output_logit'].cpu())
y_true.append(labels.cpu())
eval_loss = eval_loss / len(dataloader)
pred, true = torch.cat(y_pred), torch.cat(y_true)
eval_results = self.metrics(pred, true)
eval_results["Loss"] = round(eval_loss, 4)
logger.info(f"{mode}-({self.args.model_name}) >> {dict_to_str(eval_results)}")
if return_sample_results:
eval_results["Ids"] = ids
eval_results["SResults"] = sample_results
for k in features.keys():
features[k] = np.concatenate(features[k], axis=0)
eval_results['Features'] = features
eval_results['Labels'] = all_labels
return eval_results