DLF / data_loader.py
peter-wang321
Initial DLF commit
9157432
raw
history blame
6.03 kB
import logging
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
__all__ = ['MMDataLoader']
logger = logging.getLogger('MMSA')
class MMDataset(Dataset):
def __init__(self, args, mode='train'):
self.mode = mode
self.args = args
DATASET_MAP = {
'mosi': self.__init_mosi,
'mosei': self.__init_mosei,
}
DATASET_MAP[args['dataset_name']]()
def __init_mosi(self):
with open(self.args['featurePath'], 'rb') as f:
data = pickle.load(f)
if 'use_bert' in self.args and self.args['use_bert']:
self.text = data[self.mode]['text_bert'].astype(np.float32)
else:
self.text = data[self.mode]['text'].astype(np.float32)
self.vision = data[self.mode]['vision'].astype(np.float32)
self.audio = data[self.mode]['audio'].astype(np.float32)
self.raw_text = data[self.mode]['raw_text']
self.ids = data[self.mode]['id']
if self.args['feature_T'] != "":
with open(self.args['feature_T'], 'rb') as f:
data_T = pickle.load(f)
if 'use_bert' in self.args and self.args['use_bert']:
self.text = data_T[self.mode]['text_bert'].astype(np.float32)
self.args['feature_dims'][0] = 768
else:
self.text = data_T[self.mode]['text'].astype(np.float32)
self.args['feature_dims'][0] = self.text.shape[2]
if self.args['feature_A'] != "":
with open(self.args['feature_A'], 'rb') as f:
data_A = pickle.load(f)
self.audio = data_A[self.mode]['audio'].astype(np.float32)
self.args['feature_dims'][1] = self.audio.shape[2]
if self.args['feature_V'] != "":
with open(self.args['feature_V'], 'rb') as f:
data_V = pickle.load(f)
self.vision = data_V[self.mode]['vision'].astype(np.float32)
self.args['feature_dims'][2] = self.vision.shape[2]
self.labels = {
'M': np.array(data[self.mode]['regression_labels']).astype(np.float32)
}
logger.info(f"{self.mode} samples: {self.labels['M'].shape}")
if not self.args['need_data_aligned']:
if self.args['feature_A'] != "":
self.audio_lengths = list(data_A[self.mode]['audio_lengths'])
else:
self.audio_lengths = data[self.mode]['audio_lengths']
if self.args['feature_V'] != "":
self.vision_lengths = list(data_V[self.mode]['vision_lengths'])
else:
self.vision_lengths = data[self.mode]['vision_lengths']
self.audio[self.audio == -np.inf] = 0
if 'need_normalized' in self.args and self.args['need_normalized']:
self.__normalize()
def __init_mosei(self):
return self.__init_mosi()
def __init_sims(self):
return self.__init_mosi()
def __truncate(self):
def do_truncate(modal_features, length):
if length == modal_features.shape[1]:
return modal_features
truncated_feature = []
padding = np.array([0 for i in range(modal_features.shape[2])])
for instance in modal_features:
for index in range(modal_features.shape[1]):
if((instance[index] == padding).all()):
if(index + length >= modal_features.shape[1]):
truncated_feature.append(instance[index:index+20])
break
else:
truncated_feature.append(instance[index:index+20])
break
truncated_feature = np.array(truncated_feature)
return truncated_feature
text_length, audio_length, video_length = self.args['seq_lens']
self.vision = do_truncate(self.vision, video_length)
self.text = do_truncate(self.text, text_length)
self.audio = do_truncate(self.audio, audio_length)
def __normalize(self):
self.vision = np.mean(self.vision, axis=1, keepdims=True)
self.audio = np.mean(self.audio, axis=1, keepdims=True)
self.vision[self.vision != self.vision] = 0
self.audio[self.audio != self.audio] = 0
def __len__(self):
return len(self.labels['M'])
def get_seq_len(self):
if 'use_bert' in self.args and self.args['use_bert']:
return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1])
else:
return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1])
def get_feature_dim(self):
return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
def __getitem__(self, index):
sample = {
'raw_text': self.raw_text[index],
'text': torch.Tensor(self.text[index]),
'audio': torch.Tensor(self.audio[index]),
'vision': torch.Tensor(self.vision[index]),
'index': index,
'id': self.ids[index],
'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()}
}
if not self.args['need_data_aligned']:
sample['audio_lengths'] = self.audio_lengths[index]
sample['vision_lengths'] = self.vision_lengths[index]
return sample
def MMDataLoader(args, num_workers):
datasets = {
'train': MMDataset(args, mode='train'),
'valid': MMDataset(args, mode='valid'),
'test': MMDataset(args, mode='test')
}
if 'seq_lens' in args:
args['seq_lens'] = datasets['train'].get_seq_len()
dataLoader = {
ds: DataLoader(datasets[ds],
batch_size=args['batch_size'],
num_workers=num_workers,
shuffle=True)
for ds in datasets.keys()
}
return dataLoader