import logging import pickle import numpy as np import torch from torch.utils.data import DataLoader, Dataset __all__ = ['MMDataLoader'] logger = logging.getLogger('MMSA') class MMDataset(Dataset): def __init__(self, args, mode='train'): self.mode = mode self.args = args DATASET_MAP = { 'mosi': self.__init_mosi, 'mosei': self.__init_mosei, } DATASET_MAP[args['dataset_name']]() def __init_mosi(self): with open(self.args['featurePath'], 'rb') as f: data = pickle.load(f) if 'use_bert' in self.args and self.args['use_bert']: self.text = data[self.mode]['text_bert'].astype(np.float32) else: self.text = data[self.mode]['text'].astype(np.float32) self.vision = data[self.mode]['vision'].astype(np.float32) self.audio = data[self.mode]['audio'].astype(np.float32) self.raw_text = data[self.mode]['raw_text'] self.ids = data[self.mode]['id'] if self.args['feature_T'] != "": with open(self.args['feature_T'], 'rb') as f: data_T = pickle.load(f) if 'use_bert' in self.args and self.args['use_bert']: self.text = data_T[self.mode]['text_bert'].astype(np.float32) self.args['feature_dims'][0] = 768 else: self.text = data_T[self.mode]['text'].astype(np.float32) self.args['feature_dims'][0] = self.text.shape[2] if self.args['feature_A'] != "": with open(self.args['feature_A'], 'rb') as f: data_A = pickle.load(f) self.audio = data_A[self.mode]['audio'].astype(np.float32) self.args['feature_dims'][1] = self.audio.shape[2] if self.args['feature_V'] != "": with open(self.args['feature_V'], 'rb') as f: data_V = pickle.load(f) self.vision = data_V[self.mode]['vision'].astype(np.float32) self.args['feature_dims'][2] = self.vision.shape[2] self.labels = { 'M': np.array(data[self.mode]['regression_labels']).astype(np.float32) } logger.info(f"{self.mode} samples: {self.labels['M'].shape}") if not self.args['need_data_aligned']: if self.args['feature_A'] != "": self.audio_lengths = list(data_A[self.mode]['audio_lengths']) else: self.audio_lengths = data[self.mode]['audio_lengths'] if self.args['feature_V'] != "": self.vision_lengths = list(data_V[self.mode]['vision_lengths']) else: self.vision_lengths = data[self.mode]['vision_lengths'] self.audio[self.audio == -np.inf] = 0 if 'need_normalized' in self.args and self.args['need_normalized']: self.__normalize() def __init_mosei(self): return self.__init_mosi() def __init_sims(self): return self.__init_mosi() def __truncate(self): def do_truncate(modal_features, length): if length == modal_features.shape[1]: return modal_features truncated_feature = [] padding = np.array([0 for i in range(modal_features.shape[2])]) for instance in modal_features: for index in range(modal_features.shape[1]): if((instance[index] == padding).all()): if(index + length >= modal_features.shape[1]): truncated_feature.append(instance[index:index+20]) break else: truncated_feature.append(instance[index:index+20]) break truncated_feature = np.array(truncated_feature) return truncated_feature text_length, audio_length, video_length = self.args['seq_lens'] self.vision = do_truncate(self.vision, video_length) self.text = do_truncate(self.text, text_length) self.audio = do_truncate(self.audio, audio_length) def __normalize(self): self.vision = np.mean(self.vision, axis=1, keepdims=True) self.audio = np.mean(self.audio, axis=1, keepdims=True) self.vision[self.vision != self.vision] = 0 self.audio[self.audio != self.audio] = 0 def __len__(self): return len(self.labels['M']) def get_seq_len(self): if 'use_bert' in self.args and self.args['use_bert']: return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1]) else: return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1]) def get_feature_dim(self): return self.text.shape[2], self.audio.shape[2], self.vision.shape[2] def __getitem__(self, index): sample = { 'raw_text': self.raw_text[index], 'text': torch.Tensor(self.text[index]), 'audio': torch.Tensor(self.audio[index]), 'vision': torch.Tensor(self.vision[index]), 'index': index, 'id': self.ids[index], 'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()} } if not self.args['need_data_aligned']: sample['audio_lengths'] = self.audio_lengths[index] sample['vision_lengths'] = self.vision_lengths[index] return sample def MMDataLoader(args, num_workers): datasets = { 'train': MMDataset(args, mode='train'), 'valid': MMDataset(args, mode='valid'), 'test': MMDataset(args, mode='test') } if 'seq_lens' in args: args['seq_lens'] = datasets['train'].get_seq_len() dataLoader = { ds: DataLoader(datasets[ds], batch_size=args['batch_size'], num_workers=num_workers, shuffle=True) for ds in datasets.keys() } return dataLoader