|
import logging |
|
import pickle |
|
import numpy as np |
|
import torch |
|
from torch.utils.data import DataLoader, Dataset |
|
__all__ = ['MMDataLoader'] |
|
logger = logging.getLogger('MMSA') |
|
|
|
class MMDataset(Dataset): |
|
def __init__(self, args, mode='train'): |
|
self.mode = mode |
|
self.args = args |
|
DATASET_MAP = { |
|
'mosi': self.__init_mosi, |
|
'mosei': self.__init_mosei, |
|
} |
|
DATASET_MAP[args['dataset_name']]() |
|
|
|
def __init_mosi(self): |
|
with open(self.args['featurePath'], 'rb') as f: |
|
data = pickle.load(f) |
|
if 'use_bert' in self.args and self.args['use_bert']: |
|
self.text = data[self.mode]['text_bert'].astype(np.float32) |
|
else: |
|
self.text = data[self.mode]['text'].astype(np.float32) |
|
self.vision = data[self.mode]['vision'].astype(np.float32) |
|
self.audio = data[self.mode]['audio'].astype(np.float32) |
|
self.raw_text = data[self.mode]['raw_text'] |
|
self.ids = data[self.mode]['id'] |
|
|
|
|
|
if self.args['feature_T'] != "": |
|
with open(self.args['feature_T'], 'rb') as f: |
|
data_T = pickle.load(f) |
|
if 'use_bert' in self.args and self.args['use_bert']: |
|
self.text = data_T[self.mode]['text_bert'].astype(np.float32) |
|
self.args['feature_dims'][0] = 768 |
|
else: |
|
self.text = data_T[self.mode]['text'].astype(np.float32) |
|
self.args['feature_dims'][0] = self.text.shape[2] |
|
if self.args['feature_A'] != "": |
|
with open(self.args['feature_A'], 'rb') as f: |
|
data_A = pickle.load(f) |
|
self.audio = data_A[self.mode]['audio'].astype(np.float32) |
|
self.args['feature_dims'][1] = self.audio.shape[2] |
|
if self.args['feature_V'] != "": |
|
with open(self.args['feature_V'], 'rb') as f: |
|
data_V = pickle.load(f) |
|
self.vision = data_V[self.mode]['vision'].astype(np.float32) |
|
self.args['feature_dims'][2] = self.vision.shape[2] |
|
|
|
self.labels = { |
|
'M': np.array(data[self.mode]['regression_labels']).astype(np.float32) |
|
} |
|
|
|
logger.info(f"{self.mode} samples: {self.labels['M'].shape}") |
|
|
|
|
|
if not self.args['need_data_aligned']: |
|
if self.args['feature_A'] != "": |
|
self.audio_lengths = list(data_A[self.mode]['audio_lengths']) |
|
else: |
|
self.audio_lengths = data[self.mode]['audio_lengths'] |
|
if self.args['feature_V'] != "": |
|
self.vision_lengths = list(data_V[self.mode]['vision_lengths']) |
|
else: |
|
self.vision_lengths = data[self.mode]['vision_lengths'] |
|
self.audio[self.audio == -np.inf] = 0 |
|
|
|
if 'need_normalized' in self.args and self.args['need_normalized']: |
|
self.__normalize() |
|
|
|
def __init_mosei(self): |
|
return self.__init_mosi() |
|
|
|
def __init_sims(self): |
|
return self.__init_mosi() |
|
|
|
def __truncate(self): |
|
def do_truncate(modal_features, length): |
|
if length == modal_features.shape[1]: |
|
return modal_features |
|
truncated_feature = [] |
|
padding = np.array([0 for i in range(modal_features.shape[2])]) |
|
for instance in modal_features: |
|
for index in range(modal_features.shape[1]): |
|
if((instance[index] == padding).all()): |
|
if(index + length >= modal_features.shape[1]): |
|
truncated_feature.append(instance[index:index+20]) |
|
break |
|
else: |
|
truncated_feature.append(instance[index:index+20]) |
|
break |
|
truncated_feature = np.array(truncated_feature) |
|
return truncated_feature |
|
|
|
text_length, audio_length, video_length = self.args['seq_lens'] |
|
self.vision = do_truncate(self.vision, video_length) |
|
self.text = do_truncate(self.text, text_length) |
|
self.audio = do_truncate(self.audio, audio_length) |
|
|
|
def __normalize(self): |
|
|
|
self.vision = np.mean(self.vision, axis=1, keepdims=True) |
|
self.audio = np.mean(self.audio, axis=1, keepdims=True) |
|
|
|
self.vision[self.vision != self.vision] = 0 |
|
self.audio[self.audio != self.audio] = 0 |
|
|
|
def __len__(self): |
|
return len(self.labels['M']) |
|
|
|
def get_seq_len(self): |
|
if 'use_bert' in self.args and self.args['use_bert']: |
|
return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1]) |
|
else: |
|
return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1]) |
|
|
|
def get_feature_dim(self): |
|
return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] |
|
|
|
def __getitem__(self, index): |
|
sample = { |
|
'raw_text': self.raw_text[index], |
|
'text': torch.Tensor(self.text[index]), |
|
'audio': torch.Tensor(self.audio[index]), |
|
'vision': torch.Tensor(self.vision[index]), |
|
'index': index, |
|
'id': self.ids[index], |
|
'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()} |
|
} |
|
if not self.args['need_data_aligned']: |
|
sample['audio_lengths'] = self.audio_lengths[index] |
|
sample['vision_lengths'] = self.vision_lengths[index] |
|
return sample |
|
|
|
def MMDataLoader(args, num_workers): |
|
|
|
datasets = { |
|
'train': MMDataset(args, mode='train'), |
|
'valid': MMDataset(args, mode='valid'), |
|
'test': MMDataset(args, mode='test') |
|
} |
|
|
|
if 'seq_lens' in args: |
|
args['seq_lens'] = datasets['train'].get_seq_len() |
|
|
|
dataLoader = { |
|
ds: DataLoader(datasets[ds], |
|
batch_size=args['batch_size'], |
|
num_workers=num_workers, |
|
shuffle=True) |
|
for ds in datasets.keys() |
|
} |
|
|
|
return dataLoader |
|
|