peter-wang321 commited on
Commit
9157432
·
1 Parent(s): ab35d3d

Initial DLF commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ log/*.log
2
+ pt/*.pth
3
+ __pycache__/
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from easydict import EasyDict as edict
5
+
6
+
7
+ def get_config_regression(model_name, dataset_name, config_file=""):
8
+ """
9
+ Get the regression config of given dataset and model from config file.
10
+
11
+ Parameters:
12
+ config_file (str): Path to config file, if given an empty string, will use default config file.
13
+ model_name (str): Name of model.
14
+ dataset_name (str): Name of dataset.
15
+
16
+ Returns:
17
+ config (dict): config of the given dataset and model
18
+ """
19
+ if config_file == "":
20
+ config_file = Path(__file__).parent / "config" / "config_regression.json"
21
+ with open(config_file, 'r') as f:
22
+ config_all = json.load(f)
23
+ model_common_args = config_all[model_name]['commonParams']
24
+ model_dataset_args = config_all[model_name]['datasetParams'][dataset_name]
25
+ dataset_args = config_all['datasetCommonParams'][dataset_name]
26
+ # use aligned feature if the model requires it, otherwise use unaligned feature
27
+ dataset_args = dataset_args['aligned'] if (model_common_args['need_data_aligned'] and 'aligned' in dataset_args) else dataset_args['unaligned']
28
+
29
+ config = {}
30
+ config['model_name'] = model_name
31
+ config['dataset_name'] = dataset_name
32
+ config.update(dataset_args)
33
+ config.update(model_common_args)
34
+ config.update(model_dataset_args)
35
+ config['featurePath'] = os.path.join(config_all['datasetCommonParams']['dataset_root_dir'], config['featurePath'])
36
+ config = edict(config)
37
+
38
+ return config
39
+
40
+
config/config.json ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "datasetCommonParams": {
3
+ "dataset_root_dir": "../dataset",
4
+ "mosi": {
5
+ "aligned": {
6
+ "featurePath": "MOSI/Processed/aligned_50.pkl",
7
+ "feature_dims": [768, 5, 20],
8
+ "train_samples": 1284,
9
+ "num_classes": 3,
10
+ "language": "en",
11
+ "KeyEval": "Loss"
12
+ },
13
+ "unaligned": {
14
+ "featurePath": "MOSI/Processed/unaligned_50.pkl",
15
+ "feature_dims": [768, 5, 20],
16
+ "train_samples": 1284,
17
+ "num_classes": 3,
18
+ "language": "en",
19
+ "KeyEval": "Loss"
20
+ }
21
+ },
22
+ "mosei": {
23
+ "aligned": {
24
+ "featurePath": "MOSEI/Processed/aligned_50.pkl",
25
+ "feature_dims": [768, 74, 35],
26
+ "train_samples": 16326,
27
+ "num_classes": 3,
28
+ "language": "en",
29
+ "KeyEval": "Loss"
30
+ },
31
+ "unaligned": {
32
+ "featurePath": "MOSEI/Processed/unaligned_50.pkl",
33
+ "feature_dims": [768, 74, 35],
34
+ "train_samples": 16326,
35
+ "num_classes": 3,
36
+ "language": "en",
37
+ "KeyEval": "Loss"
38
+ }
39
+ }
40
+ },
41
+ "DLF": {
42
+ "commonParams": {
43
+ "need_data_aligned": true,
44
+ "need_model_aligned": true,
45
+ "early_stop": 10,
46
+ "use_bert": true,
47
+ "use_finetune": true,
48
+ "attn_mask": true,
49
+ "update_epochs": 10
50
+ },
51
+ "datasetParams": {
52
+ "mosi": {
53
+ "attn_dropout_a": 0.2,
54
+ "attn_dropout_v": 0.0,
55
+ "relu_dropout": 0.0,
56
+ "embed_dropout": 0.2,
57
+ "res_dropout": 0.0,
58
+ "dst_feature_dim_nheads": [50, 10],
59
+ "batch_size": 16,
60
+ "learning_rate": 0.0001,
61
+ "nlevels": 2,
62
+ "conv1d_kernel_size_l": 5,
63
+ "conv1d_kernel_size_a": 5,
64
+ "conv1d_kernel_size_v": 5,
65
+ "text_dropout": 0.5,
66
+ "attn_dropout": 0.3,
67
+ "output_dropout": 0.5,
68
+ "grad_clip": 0.6,
69
+ "patience": 5,
70
+ "weight_decay": 0.005,
71
+ "transformers": "bert",
72
+ "pretrained": "bert-base-uncased"
73
+ },
74
+ "mosei": {
75
+ "attn_dropout_a": 0.0,
76
+ "attn_dropout_v": 0.0,
77
+ "relu_dropout": 0.0,
78
+ "embed_dropout": 0.0,
79
+ "res_dropout": 0.0,
80
+ "dst_feature_dim_nheads": [50, 10],
81
+ "batch_size": 16,
82
+ "learning_rate": 0.0001,
83
+ "nlevels": 2,
84
+ "conv1d_kernel_size_l": 3,
85
+ "conv1d_kernel_size_a": 3,
86
+ "conv1d_kernel_size_v": 3,
87
+ "text_dropout": 0.1,
88
+ "attn_dropout": 0.5,
89
+ "output_dropout": 0.5,
90
+ "grad_clip": 0.6,
91
+ "patience": 5,
92
+ "weight_decay": 0.001,
93
+ "transformers": "bert",
94
+ "pretrained": "bert-base-uncased"
95
+ }
96
+ }
97
+ }
98
+ }
config/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ config file
data_loader.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import pickle
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader, Dataset
6
+ __all__ = ['MMDataLoader']
7
+ logger = logging.getLogger('MMSA')
8
+
9
+ class MMDataset(Dataset):
10
+ def __init__(self, args, mode='train'):
11
+ self.mode = mode
12
+ self.args = args
13
+ DATASET_MAP = {
14
+ 'mosi': self.__init_mosi,
15
+ 'mosei': self.__init_mosei,
16
+ }
17
+ DATASET_MAP[args['dataset_name']]()
18
+
19
+ def __init_mosi(self):
20
+ with open(self.args['featurePath'], 'rb') as f:
21
+ data = pickle.load(f)
22
+ if 'use_bert' in self.args and self.args['use_bert']:
23
+ self.text = data[self.mode]['text_bert'].astype(np.float32)
24
+ else:
25
+ self.text = data[self.mode]['text'].astype(np.float32)
26
+ self.vision = data[self.mode]['vision'].astype(np.float32)
27
+ self.audio = data[self.mode]['audio'].astype(np.float32)
28
+ self.raw_text = data[self.mode]['raw_text']
29
+ self.ids = data[self.mode]['id']
30
+
31
+
32
+ if self.args['feature_T'] != "":
33
+ with open(self.args['feature_T'], 'rb') as f:
34
+ data_T = pickle.load(f)
35
+ if 'use_bert' in self.args and self.args['use_bert']:
36
+ self.text = data_T[self.mode]['text_bert'].astype(np.float32)
37
+ self.args['feature_dims'][0] = 768
38
+ else:
39
+ self.text = data_T[self.mode]['text'].astype(np.float32)
40
+ self.args['feature_dims'][0] = self.text.shape[2]
41
+ if self.args['feature_A'] != "":
42
+ with open(self.args['feature_A'], 'rb') as f:
43
+ data_A = pickle.load(f)
44
+ self.audio = data_A[self.mode]['audio'].astype(np.float32)
45
+ self.args['feature_dims'][1] = self.audio.shape[2]
46
+ if self.args['feature_V'] != "":
47
+ with open(self.args['feature_V'], 'rb') as f:
48
+ data_V = pickle.load(f)
49
+ self.vision = data_V[self.mode]['vision'].astype(np.float32)
50
+ self.args['feature_dims'][2] = self.vision.shape[2]
51
+
52
+ self.labels = {
53
+ 'M': np.array(data[self.mode]['regression_labels']).astype(np.float32)
54
+ }
55
+
56
+ logger.info(f"{self.mode} samples: {self.labels['M'].shape}")
57
+
58
+
59
+ if not self.args['need_data_aligned']:
60
+ if self.args['feature_A'] != "":
61
+ self.audio_lengths = list(data_A[self.mode]['audio_lengths'])
62
+ else:
63
+ self.audio_lengths = data[self.mode]['audio_lengths']
64
+ if self.args['feature_V'] != "":
65
+ self.vision_lengths = list(data_V[self.mode]['vision_lengths'])
66
+ else:
67
+ self.vision_lengths = data[self.mode]['vision_lengths']
68
+ self.audio[self.audio == -np.inf] = 0
69
+
70
+ if 'need_normalized' in self.args and self.args['need_normalized']:
71
+ self.__normalize()
72
+
73
+ def __init_mosei(self):
74
+ return self.__init_mosi()
75
+
76
+ def __init_sims(self):
77
+ return self.__init_mosi()
78
+
79
+ def __truncate(self):
80
+ def do_truncate(modal_features, length):
81
+ if length == modal_features.shape[1]:
82
+ return modal_features
83
+ truncated_feature = []
84
+ padding = np.array([0 for i in range(modal_features.shape[2])])
85
+ for instance in modal_features:
86
+ for index in range(modal_features.shape[1]):
87
+ if((instance[index] == padding).all()):
88
+ if(index + length >= modal_features.shape[1]):
89
+ truncated_feature.append(instance[index:index+20])
90
+ break
91
+ else:
92
+ truncated_feature.append(instance[index:index+20])
93
+ break
94
+ truncated_feature = np.array(truncated_feature)
95
+ return truncated_feature
96
+
97
+ text_length, audio_length, video_length = self.args['seq_lens']
98
+ self.vision = do_truncate(self.vision, video_length)
99
+ self.text = do_truncate(self.text, text_length)
100
+ self.audio = do_truncate(self.audio, audio_length)
101
+
102
+ def __normalize(self):
103
+
104
+ self.vision = np.mean(self.vision, axis=1, keepdims=True)
105
+ self.audio = np.mean(self.audio, axis=1, keepdims=True)
106
+
107
+ self.vision[self.vision != self.vision] = 0
108
+ self.audio[self.audio != self.audio] = 0
109
+
110
+ def __len__(self):
111
+ return len(self.labels['M'])
112
+
113
+ def get_seq_len(self):
114
+ if 'use_bert' in self.args and self.args['use_bert']:
115
+ return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1])
116
+ else:
117
+ return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1])
118
+
119
+ def get_feature_dim(self):
120
+ return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
121
+
122
+ def __getitem__(self, index):
123
+ sample = {
124
+ 'raw_text': self.raw_text[index],
125
+ 'text': torch.Tensor(self.text[index]),
126
+ 'audio': torch.Tensor(self.audio[index]),
127
+ 'vision': torch.Tensor(self.vision[index]),
128
+ 'index': index,
129
+ 'id': self.ids[index],
130
+ 'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()}
131
+ }
132
+ if not self.args['need_data_aligned']:
133
+ sample['audio_lengths'] = self.audio_lengths[index]
134
+ sample['vision_lengths'] = self.vision_lengths[index]
135
+ return sample
136
+
137
+ def MMDataLoader(args, num_workers):
138
+
139
+ datasets = {
140
+ 'train': MMDataset(args, mode='train'),
141
+ 'valid': MMDataset(args, mode='valid'),
142
+ 'test': MMDataset(args, mode='test')
143
+ }
144
+
145
+ if 'seq_lens' in args:
146
+ args['seq_lens'] = datasets['train'].get_seq_len()
147
+
148
+ dataLoader = {
149
+ ds: DataLoader(datasets[ds],
150
+ batch_size=args['batch_size'],
151
+ num_workers=num_workers,
152
+ shuffle=True)
153
+ for ds in datasets.keys()
154
+ }
155
+
156
+ return dataLoader
log/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ training log
pt/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Here the trained models are saved.
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.33.1
2
+ huggingface-hub==0.17.1
3
+ numpy==1.21.5
4
+ scipy==1.9.1
5
+ scikit-learn==1.0.2
6
+ pandas==1.4.4
result/normal/mosei.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Time,Model,acc_7,acc_5,acc_2,F1_score,Corr,MAE,Loss
2
+ 2024/12/08 13:54:11 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
3
+ 2024/12/08 16:00:34 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
4
+ 2024/12/08 21:00:18 - ,DLF,"(53.87, 0.0)","(55.66, 0.0)","(84.31, 0.0)","(84.36, 0.0)","(75.75, 0.0)","(53.86, 0.0)","(53.82, 0.0)"
5
+ 2024/12/08 22:10:45 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
6
+ 2024/12/09 00:43:05 - ,DLF,"(53.14, 0.0)","(55.03, 0.0)","(84.92, 0.0)","(84.9, 0.0)","(76.46, 0.0)","(54.05, 0.0)","(54.03, 0.0)"
7
+ 2024/12/09 02:40:19 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
8
+ 2024/12/09 04:45:03 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
9
+ 2024/12/09 16:10:24 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
10
+ 2024/12/13 02:59:46 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
11
+ 2024/12/15 23:53:26 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
result/normal/mosi.csv ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Time,Model,acc_7,acc_5,acc_2,F1_score,Corr,MAE,Loss
2
+ 2024/12/08 14:55:08 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
3
+ 2024/12/08 15:20:12 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
4
+ 2024/12/08 15:45:53 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
5
+ 2024/12/08 19:28:22 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
6
+ 2024/12/08 19:46:53 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
7
+ 2024/12/08 20:05:59 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
8
+ 2024/12/08 20:26:47 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
9
+ 2024/12/08 20:42:22 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
10
+ 2024/12/09 11:34:14 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
11
+ 2024/12/09 12:00:35 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
12
+ 2024/12/09 15:54:01 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
13
+ 2024/12/09 18:33:19 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
14
+ 2024/12/13 02:44:39 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
15
+ 2024/12/15 20:13:12 - ,DLF,"(46.36, 0.0)","(53.35, 0.0)","(83.38, 0.0)","(83.4, 0.0)","(78.83, 0.0)","(72.89, 0.0)","(72.94, 0.0)"
16
+ 2024/12/15 20:38:24 - ,DLF,"(44.75, 0.0)","(51.9, 0.0)","(83.84, 0.0)","(83.85, 0.0)","(78.2, 0.0)","(72.78, 0.0)","(72.78, 0.0)"
17
+ 2024/12/15 20:50:42 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
18
+ 2024/12/15 23:25:21 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
19
+ 2024/12/16 03:44:10 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
20
+ 2024/12/16 04:06:00 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
21
+ 2024/12/16 04:59:28 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
result/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The results will be saved here as a csv file
run.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from config import get_config_regression
10
+ from data_loader import MMDataLoader
11
+ from trains import ATIO
12
+ from utils import assign_gpu, setup_seed
13
+ from trains.singleTask.model import DLF
14
+ from trains.singleTask.distillnets import get_distillation_kernel, get_distillation_kernel_homo
15
+ from trains.singleTask.misc import softmax
16
+ import sys
17
+
18
+ from datetime import datetime
19
+ now = datetime.now()
20
+ format = "%Y/%m/%d %H:%M:%S"
21
+ formatted_now = now.strftime(format)
22
+ formatted_now = str(formatted_now)+" - "
23
+
24
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
25
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
26
+ logger = logging.getLogger('MMSA')
27
+
28
+ def _set_logger(log_dir, model_name, dataset_name, verbose_level):
29
+
30
+ # base logger
31
+ log_file_path = Path(log_dir) / f"{model_name}-{dataset_name}.log"
32
+ logger = logging.getLogger('MMSA')
33
+ logger.setLevel(logging.DEBUG)
34
+
35
+ # file handler
36
+ fh = logging.FileHandler(log_file_path)
37
+ fh_formatter = logging.Formatter('%(asctime)s - %(name)s [%(levelname)s] - %(message)s')
38
+ fh.setLevel(logging.DEBUG)
39
+ fh.setFormatter(fh_formatter)
40
+ logger.addHandler(fh)
41
+
42
+ # stream handler
43
+ stream_level = {0: logging.ERROR, 1: logging.INFO, 2: logging.DEBUG}
44
+ ch = logging.StreamHandler()
45
+ ch.setLevel(stream_level[verbose_level])
46
+ ch_formatter = logging.Formatter('%(name)s - %(message)s')
47
+ ch.setFormatter(ch_formatter)
48
+ logger.addHandler(ch)
49
+
50
+ return logger
51
+
52
+
53
+ def DLF_run(
54
+ model_name, dataset_name, config=None, config_file="", seeds=[], is_tune=False,
55
+ tune_times=500, feature_T="", feature_A="", feature_V="",
56
+ model_save_dir="", res_save_dir="", log_dir="",
57
+ gpu_ids=[0], num_workers=1, verbose_level=1, mode = '', is_training = False
58
+ ):
59
+ # Initialization
60
+ model_name = model_name.upper()
61
+ dataset_name = dataset_name.lower()
62
+
63
+ if config_file != "":
64
+ config_file = Path(config_file)
65
+ else: # use default config files
66
+ config_file = Path(__file__).parent / "config" / "config.json"
67
+ if not config_file.is_file():
68
+ raise ValueError(f"Config file {str(config_file)} not found.")
69
+ if model_save_dir == "":
70
+ model_save_dir = Path.home() / "MMSA" / "saved_models"
71
+ Path(model_save_dir).mkdir(parents=True, exist_ok=True)
72
+ if res_save_dir == "":
73
+ res_save_dir = Path.home() / "MMSA" / "results"
74
+ Path(res_save_dir).mkdir(parents=True, exist_ok=True)
75
+ if log_dir == "":
76
+ log_dir = Path.home() / "MMSA" / "logs"
77
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
78
+ seeds = seeds if seeds != [] else [1111, 1112, 1113, 1114, 1115]
79
+ logger = _set_logger(log_dir, model_name, dataset_name, verbose_level)
80
+
81
+
82
+ args = get_config_regression(model_name, dataset_name, config_file)
83
+ args.is_training = is_training
84
+ args.mode = mode # train or test
85
+ args['model_save_path'] = Path(model_save_dir) / f"{args['model_name']}-{args['dataset_name']}.pth"
86
+ args['device'] = assign_gpu(gpu_ids)
87
+ args['train_mode'] = 'regression'
88
+ args['feature_T'] = feature_T
89
+ args['feature_A'] = feature_A
90
+ args['feature_V'] = feature_V
91
+ if config:
92
+ args.update(config)
93
+
94
+
95
+ res_save_dir = Path(res_save_dir) / "normal"
96
+ res_save_dir.mkdir(parents=True, exist_ok=True)
97
+ model_results = []
98
+ for i, seed in enumerate(seeds):
99
+ setup_seed(seed)
100
+ args['cur_seed'] = i + 1
101
+ result = _run(args, num_workers, is_tune)
102
+ model_results.append(result)
103
+ if args.is_training:
104
+ criterions = list(model_results[0].keys())
105
+ # save result to csv
106
+ csv_file = res_save_dir / f"{dataset_name}.csv"
107
+ if csv_file.is_file():
108
+ df = pd.read_csv(csv_file)
109
+ else:
110
+ df = pd.DataFrame(columns=["Time"]+["Model"] + criterions)
111
+ # save results
112
+ res = [model_name]
113
+ for c in criterions:
114
+ values = [r[c] for r in model_results]
115
+ mean = round(np.mean(values)*100, 2)
116
+ std = round(np.std(values)*100, 2)
117
+ res.append((mean, std))
118
+
119
+ res = [formatted_now]+res
120
+ df.loc[len(df)] = res
121
+ df.to_csv(csv_file, index=None)
122
+ logger.info(f"Results saved to {csv_file}.")
123
+
124
+
125
+ def _run(args, num_workers=4, is_tune=False, from_sena=False):
126
+
127
+ dataloader = MMDataLoader(args, num_workers)
128
+
129
+ if args.is_training:
130
+ print("training for DLF")
131
+
132
+
133
+ args.gd_size_low = 64
134
+ args.w_losses_low = [1, 10]
135
+ args.metric_low = 'l1'
136
+
137
+
138
+ args.gd_size_high = 32
139
+ args.w_losses_high = [1, 10]
140
+ args.metric_high = 'l1'
141
+
142
+ to_idx = [0, 1, 2]
143
+ from_idx = [0, 1, 2]
144
+ assert len(from_idx) >= 1
145
+
146
+ model = []
147
+ model_DLF = getattr(DLF, 'DLF')(args)
148
+
149
+ model_distill_homo = getattr(get_distillation_kernel_homo, 'DistillationKernel')(n_classes=1,
150
+ hidden_size=
151
+ args.dst_feature_dim_nheads[0],
152
+ gd_size=args.gd_size_low,
153
+ to_idx=to_idx, from_idx=from_idx,
154
+ gd_prior=softmax([0, 0, 1, 0, 1, 0], 0.25),
155
+ gd_reg=10,
156
+ w_losses=args.w_losses_low,
157
+ metric=args.metric_low,
158
+ alpha=1 / 8,
159
+ hyp_params=args)
160
+
161
+ model_distill_hetero = getattr(get_distillation_kernel, 'DistillationKernel')(n_classes=1,
162
+ hidden_size=
163
+ args.dst_feature_dim_nheads[0] * 2,
164
+ gd_size=args.gd_size_high,
165
+ to_idx=to_idx, from_idx=from_idx,
166
+ gd_prior=softmax([0, 0, 1, 0, 1, 1], 0.25),
167
+ gd_reg=10,
168
+ w_losses=args.w_losses_high,
169
+ metric=args.metric_high,
170
+ alpha=1 / 8,
171
+ hyp_params=args)
172
+
173
+ model_DLF = model_DLF.cuda()
174
+
175
+ model = [model_DLF]
176
+ else:
177
+ print("testing phase for DLF")
178
+ model = getattr(DLF, 'DLF')(args)
179
+ model = model.cuda()
180
+
181
+ trainer = ATIO().getTrain(args)
182
+
183
+
184
+ #test
185
+ if args.mode == 'test':
186
+ model.load_state_dict(torch.load('./pt/DLF'+str(args.dataset_name)+'.pth'),strict=False)
187
+ results = trainer.do_test(model, dataloader['test'], mode="TEST")
188
+ sys.stdout.flush()
189
+ input('[Press Any Key to start another run]')
190
+ #train
191
+ else:
192
+ epoch_results = trainer.do_train(model, dataloader, return_epoch_results=from_sena)
193
+ model[0].load_state_dict(torch.load('./pt/DLF'+str(args.dataset_name)+'.pth'))
194
+
195
+ results = trainer.do_test(model[0], dataloader['test'], mode="TEST")
196
+
197
+ del model
198
+ torch.cuda.empty_cache()
199
+ gc.collect()
200
+ time.sleep(1)
201
+ return results
test.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Testing script for DLF
3
+ """
4
+ from run import DLF_run
5
+
6
+ DLF_run(model_name='DLF', dataset_name='mosei', is_tune=False, seeds=[1111], model_save_dir="./pt",
7
+ res_save_dir="./result", log_dir="./log", mode='test', is_training=False)
train.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for DLF
3
+ """
4
+ from run import DLF_run
5
+
6
+ DLF_run(model_name='DLF', dataset_name='mosi', is_tune=False, seeds=[1111], model_save_dir="./pt",
7
+ res_save_dir="./result", log_dir="./log", mode='train', is_training=True)
trains/ATIO.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ATIO -- All Trains in One
3
+ """
4
+ from .singleTask import *
5
+
6
+ __all__ = ['ATIO']
7
+
8
+ class ATIO():
9
+ def __init__(self):
10
+ self.TRAIN_MAP = {
11
+ 'DLF': DLF,
12
+ }
13
+
14
+ def getTrain(self, args):
15
+ return self.TRAIN_MAP[args['model_name']](args)
trains/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ATIO import ATIO
trains/singleTask/DLF.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch import optim
6
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
7
+ from tqdm import tqdm
8
+ from ..utils import MetricsTop, dict_to_str
9
+ from .HingeLoss import HingeLoss
10
+
11
+
12
+ logger = logging.getLogger('MMSA')
13
+
14
+ class MSE(nn.Module):
15
+ def __init__(self):
16
+ super(MSE, self).__init__()
17
+
18
+ def forward(self, pred, real):
19
+ diffs = torch.add(real, -pred)
20
+ n = torch.numel(diffs.data)
21
+ mse = torch.sum(diffs.pow(2)) / n
22
+ return mse
23
+
24
+ class DLF():
25
+ def __init__(self, args):
26
+ self.args = args
27
+ self.criterion = nn.L1Loss()
28
+ self.cosine = nn.CosineEmbeddingLoss()
29
+ self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name)
30
+ self.MSE = MSE()
31
+ self.sim_loss = HingeLoss()
32
+
33
+ def do_train(self, model, dataloader, return_epoch_results=False):
34
+
35
+ # 0: DLF model
36
+ params = model[0].parameters()
37
+
38
+ optimizer = optim.Adam(params, lr=self.args.learning_rate)
39
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True, patience=self.args.patience)
40
+
41
+ epochs, best_epoch = 0, 0
42
+ if return_epoch_results:
43
+ epoch_results = {
44
+ 'train': [],
45
+ 'valid': [],
46
+ 'test': []
47
+ }
48
+ min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max'
49
+ best_valid = 1e8 if min_or_max == 'min' else 0
50
+
51
+ net = []
52
+ net_DLF = model[0]
53
+ net.append(net_DLF)
54
+ model = net
55
+
56
+ while True:
57
+ epochs += 1
58
+ y_pred, y_true = [], []
59
+ for mod in model:
60
+ mod.train()
61
+
62
+
63
+ train_loss = 0.0
64
+ left_epochs = self.args.update_epochs
65
+ with tqdm(dataloader['train']) as td:
66
+ for batch_data in td:
67
+
68
+ if left_epochs == self.args.update_epochs:
69
+ optimizer.zero_grad()
70
+ left_epochs -= 1
71
+ vision = batch_data['vision'].to(self.args.device)
72
+ audio = batch_data['audio'].to(self.args.device)
73
+ text = batch_data['text'].to(self.args.device)
74
+ labels = batch_data['labels']['M'].to(self.args.device)
75
+ labels = labels.view(-1, 1)
76
+
77
+
78
+
79
+ output = model[0](text, audio, vision)
80
+
81
+ # task loss
82
+ loss_task_all = self.criterion(output['output_logit'], labels)
83
+
84
+ loss_task_l_hetero = self.criterion(output['logits_l_hetero'], labels)
85
+ loss_task_v_hetero = self.criterion(output['logits_v_hetero'], labels)
86
+ loss_task_a_hetero = self.criterion(output['logits_a_hetero'], labels)
87
+ loss_task_c = self.criterion(output['logits_c'], labels)
88
+
89
+ # total MSA loss L_msa
90
+ loss_task = 1* (1 * loss_task_all + 1*loss_task_c + 3 * loss_task_l_hetero + 1*loss_task_v_hetero + 1*loss_task_a_hetero)
91
+
92
+ # reconstruction loss L_r
93
+ loss_recon_l = self.MSE(output['recon_l'], output['origin_l'])
94
+ loss_recon_v = self.MSE(output['recon_v'], output['origin_v'])
95
+ loss_recon_a = self.MSE(output['recon_a'], output['origin_a'])
96
+ loss_recon = loss_recon_l + loss_recon_v + loss_recon_a
97
+
98
+ # specific loss L_s
99
+ loss_sl_slr = self.MSE(output['s_l'].permute(1, 2, 0), output['s_l_r'])
100
+ loss_sv_slv = self.MSE(output['s_v'].permute(1, 2, 0), output['s_v_r'])
101
+ loss_sa_sla = self.MSE(output['s_a'].permute(1, 2, 0), output['s_a_r'])
102
+ loss_s_sr = loss_sl_slr + loss_sv_slv + loss_sa_sla
103
+
104
+ # ort loss L_o
105
+ if self.args.dataset_name == 'mosi':
106
+ num = 50
107
+ elif self.args.dataset_name == 'mosei':
108
+ num = 10
109
+
110
+ cosine_similarity_s_c_l = self.cosine(output['s_l'].reshape(-1, num), output['c_l'].reshape(-1, num), torch.tensor([-1]).cuda())
111
+ cosine_similarity_s_c_v = self.cosine(output['s_v'].reshape(-1, num), output['c_v'].reshape(-1, num), torch.tensor([-1]).cuda())
112
+ cosine_similarity_s_c_a = self.cosine(output['s_a'].reshape(-1, num), output['c_a'].reshape(-1, num), torch.tensor([-1]).cuda())
113
+
114
+ loss_ort = cosine_similarity_s_c_l + cosine_similarity_s_c_v + cosine_similarity_s_c_a
115
+
116
+ # triplet margin loss L_m
117
+ c_l, c_v, c_a = output['c_l_sim'], output['c_v_sim'], output['c_a_sim']
118
+ ids, feats = [], []
119
+ for i in range(labels.size(0)):
120
+ feats.append(c_l[i].view(1, -1))
121
+ feats.append(c_v[i].view(1, -1))
122
+ feats.append(c_a[i].view(1, -1))
123
+ ids.append(labels[i].view(1, -1))
124
+ ids.append(labels[i].view(1, -1))
125
+ ids.append(labels[i].view(1, -1))
126
+ feats = torch.cat(feats, dim=0)
127
+ ids = torch.cat(ids, dim=0)
128
+ loss_sim = self.sim_loss(ids, feats)
129
+
130
+ #overall loss L_DLF
131
+ combined_loss = loss_task + (loss_s_sr + loss_recon + (loss_sim+loss_ort) * 0.1) * 0.1
132
+
133
+ combined_loss.backward()
134
+
135
+
136
+ if self.args.grad_clip != -1.0:
137
+ params = list(model[0].parameters())
138
+
139
+ nn.utils.clip_grad_value_(params, self.args.grad_clip)
140
+
141
+ train_loss += combined_loss.item()
142
+
143
+
144
+ y_pred.append(output['output_logit'].cpu())
145
+ y_true.append(labels.cpu())
146
+ if not left_epochs:
147
+ optimizer.step()
148
+ left_epochs = self.args.update_epochs
149
+ if not left_epochs:
150
+ # update
151
+ optimizer.step()
152
+
153
+
154
+ train_loss = train_loss / len(dataloader['train'])
155
+ pred, true = torch.cat(y_pred), torch.cat(y_true)
156
+ train_results = self.metrics(pred, true)
157
+ logger.info(
158
+ f">> Epoch: {epochs} "
159
+ f"TRAIN -({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] "
160
+ f">> total_loss: {round(train_loss, 4)} "
161
+ f"{dict_to_str(train_results)}"
162
+ )
163
+ # validation
164
+ val_results = self.do_test(model[0], dataloader['valid'], mode="VAL")
165
+ test_results = self.do_test(model[0], dataloader['test'], mode="TEST")
166
+ cur_valid = val_results[self.args.KeyEval]
167
+ scheduler.step(val_results['Loss'])
168
+ # save each epoch model
169
+ torch.save(model[0].state_dict(), './pt/' + str(self.args.dataset_name) + '_' + str(epochs) + '.pth')
170
+ # save best model
171
+ isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6)
172
+ if isBetter:
173
+ best_valid, best_epoch = cur_valid, epochs
174
+ # save model
175
+ model_save_path = './pt/DLF' + str(self.args.dataset_name)+'.pth'
176
+ torch.save(model[0].state_dict(), model_save_path)
177
+
178
+ if return_epoch_results:
179
+ train_results["Loss"] = train_loss
180
+ epoch_results['train'].append(train_results)
181
+ epoch_results['valid'].append(val_results)
182
+ test_results = self.do_test(model, dataloader['test'], mode="TEST")
183
+ epoch_results['test'].append(test_results)
184
+ # early stop
185
+ if epochs - best_epoch >= self.args.early_stop:
186
+ return epoch_results if return_epoch_results else None
187
+
188
+ def do_test(self, model, dataloader, mode="VAL", return_sample_results=False):
189
+
190
+ model.eval()
191
+ y_pred, y_true = [], []
192
+
193
+ eval_loss = 0.0
194
+ if return_sample_results:
195
+ ids, sample_results = [], []
196
+ all_labels = []
197
+ features = {
198
+ "Feature_t": [],
199
+ "Feature_a": [],
200
+ "Feature_v": [],
201
+ "Feature_f": [],
202
+ }
203
+
204
+ with torch.no_grad():
205
+ with tqdm(dataloader) as td:
206
+ for batch_data in td:
207
+ vision = batch_data['vision'].to(self.args.device)
208
+ audio = batch_data['audio'].to(self.args.device)
209
+ text = batch_data['text'].to(self.args.device)
210
+ labels = batch_data['labels']['M'].to(self.args.device)
211
+ labels = labels.view(-1, 1)
212
+ output = model(text, audio, vision)
213
+ loss = self.criterion(output['output_logit'], labels)
214
+ eval_loss += loss.item()
215
+ y_pred.append(output['output_logit'].cpu())
216
+ y_true.append(labels.cpu())
217
+
218
+ eval_loss = eval_loss / len(dataloader)
219
+ pred, true = torch.cat(y_pred), torch.cat(y_true)
220
+
221
+ eval_results = self.metrics(pred, true)
222
+ eval_results["Loss"] = round(eval_loss, 4)
223
+ logger.info(f"{mode}-({self.args.model_name}) >> {dict_to_str(eval_results)}")
224
+
225
+ if return_sample_results:
226
+ eval_results["Ids"] = ids
227
+ eval_results["SResults"] = sample_results
228
+ for k in features.keys():
229
+ features[k] = np.concatenate(features[k], axis=0)
230
+ eval_results['Features'] = features
231
+ eval_results['Labels'] = all_labels
232
+
233
+ return eval_results
trains/singleTask/HingeLoss.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class HingeLoss(nn.Module):
6
+ def __init__(self):
7
+ super(HingeLoss, self).__init__()
8
+
9
+ def compute_cosine(self, x, y):
10
+ # x = self.compute_compact_s(x)
11
+ # y = self.compute_compact_s(y)
12
+ x_norm = torch.sqrt(torch.sum(torch.pow(x, 2), 1)+1e-8)
13
+ x_norm = torch.max(x_norm, 1e-8*torch.ones_like(x_norm))
14
+ y_norm = torch.sqrt(torch.sum(torch.pow(y, 2), 1)+1e-8)
15
+ y_norm = torch.max(y_norm, 1e-8*torch.ones_like(y_norm))
16
+ cosine = torch.sum(x * y, 1) / (x_norm * y_norm)
17
+ return cosine
18
+
19
+ def forward(self, ids, feats, margin=0.1):
20
+ B, F = feats.shape
21
+
22
+ s = feats.repeat(1, B).view(-1, F) # B**2 X F
23
+ s_ids = ids.view(B, 1).repeat(1, B) # B X B
24
+
25
+ t = feats.repeat(B, 1) # B**2 X F
26
+ t_ids = ids.view(1, B).repeat(B, 1) # B X B
27
+
28
+ cosine = self.compute_cosine(s, t) # B**2
29
+ equal_mask = torch.eye(B, dtype=torch.bool) # B X B
30
+ s_ids = s_ids[~equal_mask].view(B, B-1) # B X (B-1)
31
+ t_ids = t_ids[~equal_mask].view(B, B-1) # B X (B-1)
32
+ cosine = cosine.view(B, B)[~equal_mask].view(B, B-1) # B X (B-1)
33
+
34
+ sim_mask = (s_ids == t_ids) # B X (B-1)
35
+ margin = 0.15 * abs(s_ids - t_ids)#[~sim_mask].view(B, B - 3)
36
+
37
+ loss = 0
38
+ loss_num = 0
39
+
40
+ for i in range(B):
41
+ sim_num = sum(sim_mask[i])
42
+ dif_num = B - 1 - sim_num
43
+ if not sim_num or not dif_num:
44
+ continue
45
+ sim_cos = cosine[i, sim_mask[i]].reshape(-1, 1).repeat(1, dif_num)
46
+ dif_cos = cosine[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1)
47
+ t_margin = margin[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1)
48
+
49
+ loss_i = torch.max(torch.zeros_like(sim_cos), t_margin - sim_cos + dif_cos).mean()
50
+ loss += loss_i
51
+ loss_num += 1
52
+
53
+ if loss_num == 0:
54
+ loss_num = 1
55
+
56
+ loss = loss / loss_num
57
+ return loss
trains/singleTask/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .DLF import DLF
trains/singleTask/distillnets/get_distillation_kernel.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph distillation for hetero GD"""
2
+
3
+ import torch
4
+ from torch.autograd import Variable
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from ..utils import distance_metric, min_cosine
8
+
9
+ class DistillationKernel(nn.Module):
10
+ """Graph Distillation kernel.
11
+
12
+ Calculate the edge weights e_{j->k} for each j. Modality k is specified by
13
+ to_idx, and the other modalities are specified by from_idx.
14
+ """
15
+
16
+ def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
17
+ gd_prior, gd_reg, w_losses, metric, alpha, hyp_params):
18
+ super(DistillationKernel, self).__init__()
19
+ self.W_logit = nn.Linear(n_classes, gd_size)
20
+ self.W_repr = nn.Linear(hidden_size, gd_size)
21
+ self.W_edge = nn.Linear(gd_size * 4, 1)
22
+
23
+ self.gd_size = gd_size
24
+ self.to_idx = to_idx
25
+ self.from_idx = from_idx
26
+ self.alpha = alpha
27
+ self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
28
+ self.gd_reg = gd_reg
29
+ self.w_losses = w_losses
30
+ self.metric = metric
31
+ self.hyp_params = hyp_params
32
+
33
+
34
+ def forward(self, logits, reprs):
35
+ """
36
+ Args:
37
+ logits: (n_modalities, batch_size, n_classes)
38
+ reprs: (n_modalities, batch_siz`, hidden_size)
39
+ Return:
40
+ edges: weights e_{j->k} (n_modalities_from, batch_size)
41
+ """
42
+ n_modalities, batch_size = logits.size()[:2]
43
+ z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
44
+ z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
45
+ z = torch.cat(
46
+ (z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
47
+ self.gd_size * 2)
48
+ edges = []
49
+ for j in self.to_idx:
50
+ for i in self.from_idx:
51
+ if i == j:
52
+ continue
53
+ else:
54
+ # To calculate e_{j->k}, concatenate z^j, z^k
55
+ e = self.W_edge(torch.cat((z[j], z[i]), dim=1))
56
+ edges.append(e)
57
+ edges = torch.cat(edges, dim=1)
58
+ edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1)
59
+ edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1)
60
+ return edges, edges_origin
61
+
62
+
63
+ def distillation_loss(self, logits, reprs, edges):
64
+ """Calculate graph distillation losses, which include:
65
+ regularization loss, loss for logits, and loss for representation.
66
+ """
67
+ loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
68
+ loss_logit, loss_repr = 0, 0
69
+ x = 0
70
+ for j in self.to_idx:
71
+ for i, idx in enumerate(self.from_idx):
72
+ if i == j:
73
+ continue
74
+ else:
75
+ w_distill = edges[x] + self.gd_prior[x]
76
+ # print(edges.sum(1), w_distill.sum(0))
77
+ loss_logit += self.w_losses[0] * distance_metric(
78
+ logits[j], logits[idx], self.metric, w_distill)
79
+ loss_repr += self.w_losses[1] * min_cosine(
80
+ reprs[j], reprs[idx], self.metric, w_distill)
81
+ x = x + 1
82
+ return loss_reg, loss_logit, loss_repr
83
+
84
+
85
+ def get_distillation_kernel(n_classes,
86
+ hidden_size,
87
+ gd_size,
88
+ to_idx,
89
+ from_idx,
90
+ gd_prior,
91
+ gd_reg,
92
+ w_losses,
93
+ metric,
94
+ alpha=1 / 8):
95
+ return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
96
+ gd_prior, gd_reg, w_losses, metric, alpha)
trains/singleTask/distillnets/get_distillation_kernel_homo.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph distillation for homo GD"""
2
+
3
+ import torch
4
+ from torch.autograd import Variable
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from ..utils import distance_metric, min_cosine
8
+
9
+ class DistillationKernel(nn.Module):
10
+ """Graph Distillation kernel.
11
+
12
+ Calculate the edge weights e_{j->k} for each j. Modality k is specified by
13
+ to_idx, and the other modalities are specified by from_idx.
14
+ """
15
+
16
+ def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
17
+ gd_prior, gd_reg, w_losses, metric, alpha, hyp_params):
18
+ super(DistillationKernel, self).__init__()
19
+ self.W_logit = nn.Linear(n_classes, gd_size)
20
+ self.W_repr = nn.Linear(hidden_size, gd_size)
21
+ self.W_edge = nn.Linear(gd_size * 4, 1)
22
+
23
+ self.gd_size = gd_size
24
+ self.to_idx = to_idx
25
+ self.from_idx = from_idx
26
+ self.alpha = alpha
27
+ self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
28
+ self.gd_reg = gd_reg
29
+ self.w_losses = w_losses
30
+ self.metric = metric
31
+ self.hyp_params = hyp_params
32
+
33
+
34
+ def forward(self, logits, reprs):
35
+ """
36
+ Args:
37
+ logits: (n_modalities, batch_size, n_classes)
38
+ reprs: (n_modalities, batch_siz`, hidden_size)
39
+ Return:
40
+ edges: weights e_{j->k} (n_modalities_from, batch_size)
41
+ """
42
+ n_modalities, batch_size = logits.size()[:2]
43
+ z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
44
+ z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
45
+ z = torch.cat(
46
+ (z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
47
+ self.gd_size * 2)
48
+
49
+
50
+ edges = []
51
+ for j in self.to_idx:
52
+ for i in self.from_idx:
53
+ if i == j:
54
+ continue
55
+ else:
56
+ # To calculate e_{j->k}, concatenate z^j, z^k
57
+ e = self.W_edge(torch.cat((z[j], z[i]), dim=1))
58
+ edges.append(e)
59
+ edges = torch.cat(edges, dim=1)
60
+ edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1) # original value of edges
61
+ edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1) # normalized value of edges
62
+ return edges, edges_origin
63
+
64
+
65
+ def distillation_loss(self, logits, reprs, edges):
66
+ """Calculate graph distillation losses, which include:
67
+ regularization loss, loss for logits, and loss for representation.
68
+ """
69
+ loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
70
+
71
+
72
+ loss_logit, loss_repr = 0, 0
73
+ x = 0
74
+ for j in self.to_idx:
75
+ for i, idx in enumerate(self.from_idx):
76
+ if i == j:
77
+ continue
78
+ else:
79
+ w_distill = edges[x] + self.gd_prior[x]
80
+ # print(edges.sum(1), w_distill.sum(0))
81
+ loss_logit += self.w_losses[0] * distance_metric(
82
+ logits[j], logits[idx], self.metric, w_distill)
83
+ loss_repr += self.w_losses[1] * distance_metric(
84
+ reprs[j], reprs[idx], self.metric, w_distill)
85
+ x = x + 1
86
+ return loss_reg, loss_logit, loss_repr
87
+
88
+
89
+ def get_distillation_kernel(n_classes,
90
+ hidden_size,
91
+ gd_size,
92
+ to_idx,
93
+ from_idx,
94
+ gd_prior,
95
+ gd_reg,
96
+ w_losses,
97
+ metric,
98
+ alpha=1 / 8):
99
+ return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
100
+ gd_prior, gd_reg, w_losses, metric, alpha)
trains/singleTask/misc.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import average_precision_score
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.data
6
+
7
+ def to_numpy(array):
8
+ if isinstance(array, np.ndarray):
9
+ return array
10
+ if isinstance(array, torch.autograd.Variable):
11
+ array = array.data
12
+ if array.is_cuda:
13
+ array = array.cpu()
14
+
15
+ return array.numpy()
16
+
17
+
18
+ def squeeze(array):
19
+ if not isinstance(array, list) or len(array) > 1:
20
+ return array
21
+ else: # len(array) == 1:
22
+ return array[0]
23
+
24
+
25
+ def unsqueeze(array):
26
+ if isinstance(array, list):
27
+ return array
28
+ else:
29
+ return [array]
30
+
31
+
32
+ def is_due(*args):
33
+ """Determines whether to perform an action or not, depending on the epoch.
34
+ Used for logging, saving, learning rate decay, etc.
35
+
36
+ Args:
37
+ *args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
38
+ due_every (due every due_every epochs)
39
+ step, due_every (due every due_every steps)
40
+ Returns:
41
+ due: boolean: perform action or not
42
+ """
43
+ if len(args) == 2 and isinstance(args[1], list):
44
+ epoch, due_at = args
45
+ due = epoch in due_at
46
+ elif len(args) == 3:
47
+ epoch, num_epochs, due_every = args
48
+ due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
49
+ else:
50
+ step, due_every = args
51
+ due = (due_every > 0) and (step % due_every == 0)
52
+
53
+ return due
54
+
55
+
56
+ def softmax(w, t=1.0, axis=None):
57
+ w = np.array(w) / t
58
+ e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
59
+ dist = e / np.sum(e, axis=axis, keepdims=True)
60
+ return dist
61
+
62
+
63
+ def min_cosine(student, teacher, option, weights=None):
64
+ cosine = torch.nn.CosineEmbeddingLoss()
65
+ dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda())
66
+ if weights is None:
67
+ dist = dists.mean()
68
+ else:
69
+ dist = (dists * weights).mean()
70
+
71
+ return dist
72
+
73
+
74
+
75
+ def distance_metric(student, teacher, option, weights=None):
76
+ """Distance metric to calculate the imitation loss.
77
+
78
+ Args:
79
+ student: batch_size x n_classes
80
+ teacher: batch_size x n_classes
81
+ option: one of [cosine, l2, l2, kl]
82
+ weights: batch_size or float
83
+
84
+ Returns:
85
+ The computed distance metric.
86
+ """
87
+ if option == 'cosine':
88
+ dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
89
+ # dists = 1 - F.cosine_similarity(student, teacher, dim=1)
90
+ elif option == 'l2':
91
+ dists = (student-teacher.detach()).pow(2).sum(1)
92
+ elif option == 'l1':
93
+ dists = torch.abs(student-teacher.detach()).sum(1)
94
+ elif option == 'kl':
95
+ assert weights is None
96
+ T = 8
97
+ # averaged for each minibatch
98
+ dist = F.kl_div(
99
+ F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
100
+ T * T)
101
+ return dist
102
+ else:
103
+ raise NotImplementedError
104
+
105
+ if weights is None:
106
+ dist = dists.mean()
107
+ else:
108
+ dist = (dists * weights).mean()
109
+
110
+ return dist
111
+
112
+
113
+ def get_segments(input, timestep):
114
+ """Split entire input into segments of length timestep.
115
+
116
+ Args:
117
+ input: 1 x total_length x n_frames x ...
118
+ timestep: the timestamp.
119
+
120
+ Returns:
121
+ input: concatenated video segments
122
+ start_indices: indices of the segments
123
+ """
124
+ assert input.size(0) == 1, 'Test time, batch_size must be 1'
125
+
126
+ input.squeeze_(dim=0)
127
+ # Find overlapping segments
128
+ length = input.size()[0]
129
+ step = timestep // 2
130
+ num_segments = (length - timestep) // step + 1
131
+ start_indices = (np.arange(num_segments) * step).tolist()
132
+ if length % step > 0:
133
+ start_indices.append(length - timestep)
134
+
135
+ # Get the segments
136
+ segments = []
137
+ for s in start_indices:
138
+ segment = input[s: (s + timestep)].unsqueeze(0)
139
+ segments.append(segment)
140
+ input = torch.cat(segments, dim=0)
141
+ return input, start_indices
142
+
143
+ def get_stats(logit, label):
144
+ '''
145
+ Calculate the accuracy.
146
+ '''
147
+ logit = to_numpy(logit)
148
+ label = to_numpy(label)
149
+
150
+ pred = np.argmax(logit, 1)
151
+ acc = np.sum(pred == label)/label.shape[0]
152
+
153
+ return acc, pred, label
154
+
155
+
156
+ def get_stats_detection(logit, label, n_classes=52):
157
+ '''
158
+ Calculate the accuracy and average precisions.
159
+ '''
160
+ logit = to_numpy(logit)
161
+ label = to_numpy(label)
162
+ scores = softmax(logit, axis=1)
163
+
164
+ pred = np.argmax(logit, 1)
165
+ length = label.shape[0]
166
+ acc = np.sum(pred == label)/length
167
+
168
+ keep_bg = label == 0
169
+ acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
170
+ ratio_bg = np.sum(keep_bg)/length
171
+
172
+ keep_action = label != 0
173
+ acc_action = np.sum(
174
+ pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
175
+
176
+ # Average precision
177
+ y_true = np.zeros((len(label), n_classes))
178
+ y_true[np.arange(len(label)), label] = 1
179
+ acc = np.sum(pred == label)/label.shape[0]
180
+ aps = average_precision_score(y_true, scores, average=None)
181
+ aps = list(filter(lambda x: not np.isnan(x), aps))
182
+ ap = np.mean(aps)
183
+
184
+ return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
185
+
186
+
187
+ def info(text):
188
+ print('\033[94m' + text + '\033[0m')
189
+
190
+
191
+ def warn(text):
192
+ print('\033[93m' + text + '\033[0m')
193
+
194
+
195
+ def err(text):
196
+ print('\033[91m' + text + '\033[0m')
trains/singleTask/model/DLF.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ here is the mian backbone for DLF
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from ...subNets import BertTextEncoder
8
+ from ...subNets.transformers_encoder.transformer import TransformerEncoder
9
+
10
+ class DLF(nn.Module):
11
+ def __init__(self, args):
12
+ super(DLF, self).__init__()
13
+ if args.use_bert:
14
+ self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers,
15
+ pretrained=args.pretrained)
16
+ self.use_bert = args.use_bert
17
+ dst_feature_dims, nheads = args.dst_feature_dim_nheads
18
+ if args.dataset_name == 'mosi':
19
+ if args.need_data_aligned:
20
+ self.len_l, self.len_v, self.len_a = 50, 50, 50
21
+ else:
22
+ self.len_l, self.len_v, self.len_a = 50, 500, 375
23
+ if args.dataset_name == 'mosei':
24
+ if args.need_data_aligned:
25
+ self.len_l, self.len_v, self.len_a = 50, 50, 50
26
+ else:
27
+ self.len_l, self.len_v, self.len_a = 50, 500, 500
28
+ self.orig_d_l, self.orig_d_a, self.orig_d_v = args.feature_dims
29
+ self.d_l = self.d_a = self.d_v = dst_feature_dims
30
+ self.num_heads = nheads
31
+ self.layers = args.nlevels
32
+ self.attn_dropout = args.attn_dropout
33
+ self.attn_dropout_a = args.attn_dropout_a
34
+ self.attn_dropout_v = args.attn_dropout_v
35
+ self.relu_dropout = args.relu_dropout
36
+ self.embed_dropout = args.embed_dropout
37
+ self.res_dropout = args.res_dropout
38
+ self.output_dropout = args.output_dropout
39
+ self.text_dropout = args.text_dropout
40
+ self.attn_mask = args.attn_mask
41
+ combined_dim_low = self.d_a
42
+ combined_dim_high = self.d_a
43
+ combined_dim = (self.d_l + self.d_a + self.d_v ) + self.d_l * 3
44
+
45
+ output_dim = 1
46
+
47
+ # 1. Temporal convolutional layers for initial feature
48
+ self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False)
49
+ self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=args.conv1d_kernel_size_a, padding=0, bias=False)
50
+ self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=args.conv1d_kernel_size_v, padding=0, bias=False)
51
+
52
+ # 2. Modality-specific encoder
53
+ self.encoder_s_l = self.get_network(self_type='l', layers = self.layers)
54
+ self.encoder_s_v = self.get_network(self_type='v', layers = self.layers)
55
+ self.encoder_s_a = self.get_network(self_type='a', layers = self.layers)
56
+
57
+ # Modality-shared encoder
58
+ self.encoder_c = self.get_network(self_type='l', layers = self.layers)
59
+
60
+
61
+ # 3. Decoder for reconstruct three modalities
62
+ self.decoder_l = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0, bias=False)
63
+ self.decoder_v = nn.Conv1d(self.d_v * 2, self.d_v, kernel_size=1, padding=0, bias=False)
64
+ self.decoder_a = nn.Conv1d(self.d_a * 2, self.d_a, kernel_size=1, padding=0, bias=False)
65
+
66
+ # for calculate cosine sim between s_x
67
+ self.proj_cosine_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
68
+ self.proj_cosine_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
69
+ self.proj_cosine_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
70
+
71
+ # for align c_l, c_v, c_a
72
+ self.align_c_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
73
+ self.align_c_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
74
+ self.align_c_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
75
+
76
+ self.self_attentions_c_l = self.get_network(self_type='l')
77
+ self.self_attentions_c_v = self.get_network(self_type='v')
78
+ self.self_attentions_c_a = self.get_network(self_type='a')
79
+
80
+ self.proj1_c = nn.Linear(self.d_l * 3, self.d_l * 3)
81
+ self.proj2_c = nn.Linear(self.d_l * 3, self.d_l * 3)
82
+ self.out_layer_c = nn.Linear(self.d_l * 3, output_dim)
83
+
84
+
85
+ # 4 Multimodal Crossmodal Attentions
86
+ self.trans_l_with_a = self.get_network(self_type='la', layers = self.layers)
87
+ self.trans_l_with_v = self.get_network(self_type='lv', layers = self.layers)
88
+ self.trans_a_with_l = self.get_network(self_type='al')
89
+ self.trans_a_with_v = self.get_network(self_type='av')
90
+ self.trans_v_with_l = self.get_network(self_type='vl')
91
+ self.trans_v_with_a = self.get_network(self_type='va')
92
+ self.trans_l_mem = self.get_network(self_type='l_mem', layers=self.layers)
93
+ self.trans_a_mem = self.get_network(self_type='a_mem', layers=3)
94
+ self.trans_v_mem = self.get_network(self_type='v_mem', layers=3)
95
+
96
+
97
+ # 5. fc layers for shared features
98
+ self.proj1_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
99
+ self.proj2_l_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1))
100
+ self.out_layer_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), output_dim)
101
+ self.proj1_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
102
+ self.proj2_v_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1))
103
+ self.out_layer_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), output_dim)
104
+ self.proj1_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
105
+ self.proj2_a_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1))
106
+ self.out_layer_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), output_dim)
107
+
108
+
109
+ # 6. fc layers for specific features
110
+ self.proj1_l_high = nn.Linear(combined_dim_high, combined_dim_high)
111
+ self.proj2_l_high = nn.Linear(combined_dim_high, combined_dim_high)
112
+ self.out_layer_l_high = nn.Linear(combined_dim_high, output_dim)
113
+ self.proj1_v_high = nn.Linear(combined_dim_high, combined_dim_high)
114
+ self.proj2_v_high = nn.Linear(combined_dim_high, combined_dim_high)
115
+ self.out_layer_v_high = nn.Linear(combined_dim_high, output_dim)
116
+ self.proj1_a_high = nn.Linear(combined_dim_high, combined_dim_high)
117
+ self.proj2_a_high = nn.Linear(combined_dim_high, combined_dim_high)
118
+ self.out_layer_a_high = nn.Linear(combined_dim_high, output_dim)
119
+
120
+ # 7. project for fusion
121
+ self.projector_l = nn.Linear(self.d_l, self.d_l)
122
+ self.projector_v = nn.Linear(self.d_v, self.d_v)
123
+ self.projector_a = nn.Linear(self.d_a, self.d_a)
124
+ self.projector_c = nn.Linear(3 * self.d_l, 3 * self.d_l)
125
+
126
+ # 8. final project
127
+ self.proj1 = nn.Linear(combined_dim, combined_dim)
128
+ self.proj2 = nn.Linear(combined_dim, combined_dim)
129
+ self.out_layer = nn.Linear(combined_dim, output_dim)
130
+
131
+ def get_network(self, self_type='l', layers=-1):
132
+ if self_type in ['l', 'al', 'vl']:
133
+ embed_dim, attn_dropout = self.d_l, self.attn_dropout
134
+ elif self_type in ['a', 'la', 'va']:
135
+ embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
136
+ elif self_type in ['v', 'lv', 'av']:
137
+ embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
138
+ elif self_type == 'l_mem':
139
+ embed_dim, attn_dropout = self.d_l, self.attn_dropout
140
+ elif self_type == 'a_mem':
141
+ embed_dim, attn_dropout = self.d_a, self.attn_dropout
142
+ elif self_type == 'v_mem':
143
+ embed_dim, attn_dropout = self.d_v, self.attn_dropout
144
+ else:
145
+ raise ValueError("Unknown network type")
146
+
147
+ return TransformerEncoder(embed_dim=embed_dim,
148
+ num_heads=self.num_heads,
149
+ layers=max(self.layers, layers),
150
+ attn_dropout=attn_dropout,
151
+ relu_dropout=self.relu_dropout,
152
+ res_dropout=self.res_dropout,
153
+ embed_dropout=self.embed_dropout,
154
+ attn_mask=self.attn_mask)
155
+
156
+
157
+ def forward(self, text, audio, video):
158
+ #extraction
159
+ if self.use_bert:
160
+ text = self.text_model(text)
161
+ x_l = F.dropout(text.transpose(1, 2), p=self.text_dropout, training=self.training)
162
+ x_a = audio.transpose(1, 2)
163
+ x_v = video.transpose(1, 2)
164
+
165
+
166
+ proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
167
+ proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
168
+ proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
169
+
170
+ proj_x_l = proj_x_l.permute(2, 0, 1)
171
+ proj_x_v = proj_x_v .permute(2, 0, 1)
172
+ proj_x_a = proj_x_a.permute(2, 0, 1)
173
+
174
+ #disentanglement
175
+ s_l = self.encoder_s_l(proj_x_l)
176
+ s_v = self.encoder_s_v(proj_x_v)
177
+ s_a = self.encoder_s_a(proj_x_a)
178
+
179
+ c_l = self.encoder_c(proj_x_l)
180
+ c_v = self.encoder_c(proj_x_v)
181
+ c_a = self.encoder_c(proj_x_a)
182
+
183
+
184
+ s_l = s_l.permute(1, 2, 0)
185
+ s_v = s_v.permute(1, 2, 0)
186
+ s_a = s_a.permute(1, 2, 0)
187
+
188
+ c_l = c_l.permute(1, 2, 0)
189
+ c_v = c_v.permute(1, 2, 0)
190
+ c_a = c_a.permute(1, 2, 0)
191
+ c_list = [c_l, c_v, c_a]
192
+
193
+
194
+ c_l_sim = self.align_c_l(c_l.contiguous().view(x_l.size(0), -1))
195
+ c_v_sim = self.align_c_v(c_v.contiguous().view(x_l.size(0), -1))
196
+ c_a_sim = self.align_c_a(c_a.contiguous().view(x_l.size(0), -1))
197
+
198
+ recon_l = self.decoder_l(torch.cat([s_l, c_list[0]], dim=1))
199
+ recon_v = self.decoder_v(torch.cat([s_v, c_list[1]], dim=1))
200
+ recon_a = self.decoder_a(torch.cat([s_a, c_list[2]], dim=1))
201
+
202
+ recon_l = recon_l.permute(2, 0, 1)
203
+ recon_v = recon_v.permute(2, 0, 1)
204
+ recon_a = recon_a.permute(2, 0, 1)
205
+
206
+ s_l_r = self.encoder_s_l(recon_l).permute(1, 2, 0)
207
+ s_v_r = self.encoder_s_v(recon_v).permute(1, 2, 0)
208
+ s_a_r = self.encoder_s_a(recon_a).permute(1, 2, 0)
209
+
210
+ s_l = s_l.permute(2, 0, 1)
211
+ s_v = s_v.permute(2, 0, 1)
212
+ s_a = s_a.permute(2, 0, 1)
213
+
214
+ c_l = c_l.permute(2, 0, 1)
215
+ c_v = c_v.permute(2, 0, 1)
216
+ c_a = c_a.permute(2, 0, 1)
217
+
218
+ #enhancement
219
+ hs_l_low = c_l.transpose(0, 1).contiguous().view(x_l.size(0), -1)
220
+ repr_l_low = self.proj1_l_low(hs_l_low)
221
+ hs_proj_l_low = self.proj2_l_low(
222
+ F.dropout(F.relu(repr_l_low, inplace=True), p=self.output_dropout, training=self.training))
223
+ hs_proj_l_low += hs_l_low
224
+ logits_l_low = self.out_layer_l_low(hs_proj_l_low)
225
+
226
+ hs_v_low = c_v.transpose(0, 1).contiguous().view(x_v.size(0), -1)
227
+ repr_v_low = self.proj1_v_low(hs_v_low)
228
+ hs_proj_v_low = self.proj2_v_low(
229
+ F.dropout(F.relu(repr_v_low, inplace=True), p=self.output_dropout, training=self.training))
230
+ hs_proj_v_low += hs_v_low
231
+ logits_v_low = self.out_layer_v_low(hs_proj_v_low)
232
+
233
+ hs_a_low = c_a.transpose(0, 1).contiguous().view(x_a.size(0), -1)
234
+ repr_a_low = self.proj1_a_low(hs_a_low)
235
+ hs_proj_a_low = self.proj2_a_low(
236
+ F.dropout(F.relu(repr_a_low, inplace=True), p=self.output_dropout, training=self.training))
237
+ hs_proj_a_low += hs_a_low
238
+ logits_a_low = self.out_layer_a_low(hs_proj_a_low)
239
+
240
+
241
+ c_l_att = self.self_attentions_c_l(c_l)
242
+ if type(c_l_att) == tuple:
243
+ c_l_att = c_l_att[0]
244
+ c_l_att = c_l_att[-1]
245
+
246
+ c_v_att = self.self_attentions_c_v(c_v)
247
+ if type(c_v_att) == tuple:
248
+ c_v_att = c_v_att[0]
249
+ c_v_att = c_v_att[-1]
250
+
251
+ c_a_att = self.self_attentions_c_a(c_a)
252
+ if type(c_a_att) == tuple:
253
+ c_a_att = c_a_att[0]
254
+ c_a_att = c_a_att[-1]
255
+
256
+ c_fusion = torch.cat([c_l_att, c_v_att, c_a_att], dim=1)
257
+
258
+ c_proj = self.proj2_c(
259
+ F.dropout(F.relu(self.proj1_c(c_fusion), inplace=True), p=self.output_dropout,
260
+ training=self.training))
261
+ c_proj += c_fusion
262
+ logits_c = self.out_layer_c(c_proj)
263
+
264
+ # LFA
265
+ # L --> L
266
+ h_ls = s_l
267
+ h_ls = self.trans_l_mem(h_ls)
268
+ if type(h_ls) == tuple:
269
+ h_ls = h_ls[0]
270
+ last_h_l = last_hs = h_ls[-1]
271
+
272
+ # A --> L
273
+ h_l_with_as = self.trans_l_with_a(s_l, s_a, s_a)
274
+ h_as = h_l_with_as
275
+ h_as = self.trans_a_mem(h_as)
276
+ if type(h_as) == tuple:
277
+ h_as = h_as[0]
278
+ last_h_a = last_hs = h_as[-1]
279
+
280
+ # V --> L
281
+ h_l_with_vs = self.trans_l_with_v(s_l, s_v, s_v)
282
+ h_vs = h_l_with_vs
283
+ h_vs = self.trans_v_mem(h_vs)
284
+ if type(h_vs) == tuple:
285
+ h_vs = h_vs[0]
286
+ last_h_v = last_hs = h_vs[-1]
287
+
288
+
289
+ hs_proj_l_high = self.proj2_l_high(
290
+ F.dropout(F.relu(self.proj1_l_high(last_h_l), inplace=True), p=self.output_dropout, training=self.training))
291
+ hs_proj_l_high += last_h_l
292
+ logits_l_high = self.out_layer_l_high(hs_proj_l_high)
293
+
294
+ hs_proj_v_high = self.proj2_v_high(
295
+ F.dropout(F.relu(self.proj1_v_high(last_h_v), inplace=True), p=self.output_dropout, training=self.training))
296
+ hs_proj_v_high += last_h_v
297
+ logits_v_high = self.out_layer_v_high(hs_proj_v_high)
298
+
299
+ hs_proj_a_high = self.proj2_a_high(
300
+ F.dropout(F.relu(self.proj1_a_high(last_h_a), inplace=True), p=self.output_dropout,
301
+ training=self.training))
302
+ hs_proj_a_high += last_h_a
303
+ logits_a_high = self.out_layer_a_high(hs_proj_a_high)
304
+
305
+ #fusion
306
+ last_h_l = torch.sigmoid(self.projector_l(hs_proj_l_high))
307
+ last_h_v = torch.sigmoid(self.projector_v(hs_proj_v_high))
308
+ last_h_a = torch.sigmoid(self.projector_a(hs_proj_a_high))
309
+ c_fusion = torch.sigmoid(self.projector_c(c_fusion))
310
+
311
+ last_hs = torch.cat([last_h_l, last_h_v, last_h_a, c_fusion], dim=1)
312
+
313
+ #prediction
314
+ last_hs_proj = self.proj2(
315
+ F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training))
316
+ last_hs_proj += last_hs
317
+
318
+ output = self.out_layer(last_hs_proj)
319
+
320
+ res = {
321
+ 'origin_l': proj_x_l,
322
+ 'origin_v': proj_x_v,
323
+ 'origin_a': proj_x_a,
324
+ 's_l': s_l,
325
+ 's_v': s_v,
326
+ 's_a': s_a,
327
+ 'c_l': c_l,
328
+ 'c_v': c_v,
329
+ 'c_a': c_a,
330
+ 's_l_r': s_l_r,
331
+ 's_v_r': s_v_r,
332
+ 's_a_r': s_a_r,
333
+ 'recon_l': recon_l,
334
+ 'recon_v': recon_v,
335
+ 'recon_a': recon_a,
336
+ 'c_l_sim': c_l_sim,
337
+ 'c_v_sim': c_v_sim,
338
+ 'c_a_sim': c_a_sim,
339
+ 'logits_l_hetero': logits_l_high,
340
+ 'logits_v_hetero': logits_v_high,
341
+ 'logits_a_hetero': logits_a_high,
342
+ 'logits_c': logits_c,
343
+ 'output_logit': output
344
+ }
345
+ return res
trains/singleTask/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .misc import *
trains/singleTask/utils/misc.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import average_precision_score
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.utils.data
6
+
7
+ def to_numpy(array):
8
+ if isinstance(array, np.ndarray):
9
+ return array
10
+ if isinstance(array, torch.autograd.Variable):
11
+ array = array.data
12
+ if array.is_cuda:
13
+ array = array.cpu()
14
+
15
+ return array.numpy()
16
+
17
+
18
+ def squeeze(array):
19
+ if not isinstance(array, list) or len(array) > 1:
20
+ return array
21
+ else: # len(array) == 1:
22
+ return array[0]
23
+
24
+
25
+ def unsqueeze(array):
26
+ if isinstance(array, list):
27
+ return array
28
+ else:
29
+ return [array]
30
+
31
+
32
+ def is_due(*args):
33
+ """Determines whether to perform an action or not, depending on the epoch.
34
+ Used for logging, saving, learning rate decay, etc.
35
+
36
+ Args:
37
+ *args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
38
+ due_every (due every due_every epochs)
39
+ step, due_every (due every due_every steps)
40
+ Returns:
41
+ due: boolean: perform action or not
42
+ """
43
+ if len(args) == 2 and isinstance(args[1], list):
44
+ epoch, due_at = args
45
+ due = epoch in due_at
46
+ elif len(args) == 3:
47
+ epoch, num_epochs, due_every = args
48
+ due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
49
+ else:
50
+ step, due_every = args
51
+ due = (due_every > 0) and (step % due_every == 0)
52
+
53
+ return due
54
+
55
+
56
+ def softmax(w, t=1.0, axis=None):
57
+ w = np.array(w) / t
58
+ e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
59
+ dist = e / np.sum(e, axis=axis, keepdims=True)
60
+ return dist
61
+
62
+ def min_cosine(student, teacher, option, weights=None):
63
+ cosine = torch.nn.CosineEmbeddingLoss()
64
+ dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda())
65
+ if weights is None:
66
+ dist = dists.mean()
67
+ else:
68
+ dist = (dists * weights).mean()
69
+
70
+ return dist
71
+
72
+
73
+ def distance_metric(student, teacher, option, weights=None):
74
+ """Distance metric to calculate the imitation loss.
75
+
76
+ Args:
77
+ student: batch_size x n_classes
78
+ teacher: batch_size x n_classes
79
+ option: one of [cosine, l2, l2, kl]
80
+ weights: batch_size or float
81
+
82
+ Returns:
83
+ The computed distance metric.
84
+ """
85
+ if option == 'cosine':
86
+ dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
87
+ # dists = 1 - F.cosine_similarity(student, teacher, dim=1)
88
+ elif option == 'l2':
89
+ dists = (student-teacher.detach()).pow(2).sum(1)
90
+ elif option == 'l1':
91
+ dists = torch.abs(student-teacher.detach()).sum(1)
92
+ elif option == 'kl':
93
+ # assert weights is None
94
+ T = 8
95
+ # averaged for each minibatch
96
+ dist = F.kl_div(
97
+ F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
98
+ T * T)
99
+ return dist
100
+ else:
101
+ raise NotImplementedError
102
+
103
+ if weights is None:
104
+ dist = dists.mean()
105
+ else:
106
+ dist = (dists * weights).mean()
107
+
108
+ return dist
109
+
110
+
111
+ def get_segments(input, timestep):
112
+ """Split entire input into segments of length timestep.
113
+
114
+ Args:
115
+ input: 1 x total_length x n_frames x ...
116
+ timestep: the timestamp.
117
+
118
+ Returns:
119
+ input: concatenated video segments
120
+ start_indices: indices of the segments
121
+ """
122
+ assert input.size(0) == 1, 'Test time, batch_size must be 1'
123
+
124
+ input.squeeze_(dim=0)
125
+ # Find overlapping segments
126
+ length = input.size()[0]
127
+ step = timestep // 2
128
+ num_segments = (length - timestep) // step + 1
129
+ start_indices = (np.arange(num_segments) * step).tolist()
130
+ if length % step > 0:
131
+ start_indices.append(length - timestep)
132
+
133
+ # Get the segments
134
+ segments = []
135
+ for s in start_indices:
136
+ segment = input[s: (s + timestep)].unsqueeze(0)
137
+ segments.append(segment)
138
+ input = torch.cat(segments, dim=0)
139
+ return input, start_indices
140
+
141
+ def get_stats(logit, label):
142
+ '''
143
+ Calculate the accuracy.
144
+ '''
145
+ logit = to_numpy(logit)
146
+ label = to_numpy(label)
147
+
148
+ pred = np.argmax(logit, 1)
149
+ acc = np.sum(pred == label)/label.shape[0]
150
+
151
+ return acc, pred, label
152
+
153
+
154
+ def get_stats_detection(logit, label, n_classes=52):
155
+ '''
156
+ Calculate the accuracy and average precisions.
157
+ '''
158
+ logit = to_numpy(logit)
159
+ label = to_numpy(label)
160
+ scores = softmax(logit, axis=1)
161
+
162
+ pred = np.argmax(logit, 1)
163
+ length = label.shape[0]
164
+ acc = np.sum(pred == label)/length
165
+
166
+ keep_bg = label == 0
167
+ acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
168
+ ratio_bg = np.sum(keep_bg)/length
169
+
170
+ keep_action = label != 0
171
+ acc_action = np.sum(
172
+ pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
173
+
174
+ # Average precision
175
+ y_true = np.zeros((len(label), n_classes))
176
+ y_true[np.arange(len(label)), label] = 1
177
+ acc = np.sum(pred == label)/label.shape[0]
178
+ aps = average_precision_score(y_true, scores, average=None)
179
+ aps = list(filter(lambda x: not np.isnan(x), aps))
180
+ ap = np.mean(aps)
181
+
182
+ return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
183
+
184
+
185
+ def info(text):
186
+ print('\033[94m' + text + '\033[0m')
187
+
188
+
189
+ def warn(text):
190
+ print('\033[93m' + text + '\033[0m')
191
+
192
+
193
+ def err(text):
194
+ print('\033[91m' + text + '\033[0m')
195
+
196
+
trains/subNets/AlignNets.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ __all__ = ['AlignSubNet']
5
+
6
+ class CTCModule(nn.Module):
7
+ def __init__(self, in_dim, out_seq_len):
8
+ '''
9
+ This module is performing alignment from A (e.g., audio) to B (e.g., text).
10
+ :param in_dim: Dimension for input modality A
11
+ :param out_seq_len: Sequence length for output modality B
12
+ From: https://github.com/yaohungt/Multimodal-Transformer
13
+ '''
14
+ super(CTCModule, self).__init__()
15
+ # Use LSTM for predicting the position from A to B
16
+ self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank
17
+
18
+ self.out_seq_len = out_seq_len
19
+
20
+ self.softmax = nn.Softmax(dim=2)
21
+
22
+ def forward(self, x):
23
+ '''
24
+ :input x: Input with shape [batch_size x in_seq_len x in_dim]
25
+ '''
26
+ # NOTE that the index 0 refers to blank.
27
+ pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x)
28
+
29
+ prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1
30
+ prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len
31
+ prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len
32
+ pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim
33
+
34
+ # pseudo_aligned_out is regarded as the aligned A (w.r.t B)
35
+ # return pseudo_aligned_out, (pred_output_position_inclu_blank)
36
+ return pseudo_aligned_out
37
+
38
+ class AlignSubNet(nn.Module):
39
+ def __init__(self, args, mode):
40
+ """
41
+ mode: the way of aligning
42
+ avg_pool, ctc, conv1d
43
+ """
44
+ super(AlignSubNet, self).__init__()
45
+ assert mode in ['avg_pool', 'ctc', 'conv1d']
46
+
47
+ in_dim_t, in_dim_a, in_dim_v = args.feature_dims
48
+ seq_len_t, seq_len_a, seq_len_v = args.seq_lens
49
+ self.dst_len = seq_len_t
50
+ self.mode = mode
51
+
52
+ self.ALIGN_WAY = {
53
+ 'avg_pool': self.__avg_pool,
54
+ 'ctc': self.__ctc,
55
+ 'conv1d': self.__conv1d
56
+ }
57
+
58
+ if mode == 'conv1d':
59
+ self.conv1d_T = nn.Conv1d(seq_len_t, self.dst_len, kernel_size=1, bias=False)
60
+ self.conv1d_A = nn.Conv1d(seq_len_a, self.dst_len, kernel_size=1, bias=False)
61
+ self.conv1d_V = nn.Conv1d(seq_len_v, self.dst_len, kernel_size=1, bias=False)
62
+ elif mode == 'ctc':
63
+ self.ctc_t = CTCModule(in_dim_t, self.dst_len)
64
+ self.ctc_a = CTCModule(in_dim_a, self.dst_len)
65
+ self.ctc_v = CTCModule(in_dim_v, self.dst_len)
66
+
67
+ def get_seq_len(self):
68
+ return self.dst_len
69
+
70
+ def __ctc(self, text_x, audio_x, video_x):
71
+ text_x = self.ctc_t(text_x) if text_x.size(1) != self.dst_len else text_x
72
+ audio_x = self.ctc_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x
73
+ video_x = self.ctc_v(video_x) if video_x.size(1) != self.dst_len else video_x
74
+ return text_x, audio_x, video_x
75
+
76
+ def __avg_pool(self, text_x, audio_x, video_x):
77
+ def align(x):
78
+ raw_seq_len = x.size(1)
79
+ if raw_seq_len == self.dst_len:
80
+ return x
81
+ if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len:
82
+ pad_len = 0
83
+ pool_size = raw_seq_len // self.dst_len
84
+ else:
85
+ pad_len = self.dst_len - raw_seq_len % self.dst_len
86
+ pool_size = raw_seq_len // self.dst_len + 1
87
+ pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)])
88
+ x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1)
89
+ x = x.mean(dim=1)
90
+ return x
91
+ text_x = align(text_x)
92
+ audio_x = align(audio_x)
93
+ video_x = align(video_x)
94
+ return text_x, audio_x, video_x
95
+
96
+ def __conv1d(self, text_x, audio_x, video_x):
97
+ text_x = self.conv1d_T(text_x) if text_x.size(1) != self.dst_len else text_x
98
+ audio_x = self.conv1d_A(text_x) if audio_x.size(1) != self.dst_len else audio_x
99
+ video_x = self.conv1d_V(text_x) if video_x.size(1) != self.dst_len else video_x
100
+ return text_x, audio_x, video_x
101
+
102
+ def forward(self, text_x, audio_x, video_x):
103
+ # already aligned
104
+ if text_x.size(1) == audio_x.size(1) == video_x.size(1):
105
+ return text_x, audio_x, video_x
106
+ return self.ALIGN_WAY[self.mode](text_x, audio_x, video_x)
trains/subNets/BertTextEncoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
4
+
5
+ __all__ = ['BertTextEncoder']
6
+
7
+ TRANSFORMERS_MAP = {
8
+ 'bert': (BertModel, BertTokenizer),
9
+ 'roberta': (RobertaModel, RobertaTokenizer),
10
+ }
11
+
12
+ class BertTextEncoder(nn.Module):
13
+ def __init__(self, use_finetune=False, transformers='bert', pretrained='bert-base-uncased'):
14
+ super().__init__()
15
+
16
+ tokenizer_class = TRANSFORMERS_MAP[transformers][1]
17
+ model_class = TRANSFORMERS_MAP[transformers][0]
18
+ self.tokenizer = tokenizer_class.from_pretrained(pretrained)
19
+ self.model = model_class.from_pretrained(pretrained)
20
+ self.use_finetune = use_finetune
21
+
22
+ def get_tokenizer(self):
23
+ return self.tokenizer
24
+
25
+ # def from_text(self, text):
26
+ # """
27
+ # text: raw data
28
+ # """
29
+ # input_ids = self.get_id(text)
30
+ # with torch.no_grad():
31
+ # last_hidden_states = self.model(input_ids)[0] # Models outputs are now tuples
32
+ # return last_hidden_states.squeeze()
33
+
34
+ def forward(self, text):
35
+ """
36
+ text: (batch_size, 3, seq_len)
37
+ 3: input_ids, input_mask, segment_ids
38
+ input_ids: input_ids,
39
+ input_mask: attention_mask,
40
+ segment_ids: token_type_ids
41
+ """
42
+ input_ids, input_mask, segment_ids = text[:,0,:].long(), text[:,1,:].float(), text[:,2,:].long()
43
+ if self.use_finetune:
44
+ last_hidden_states = self.model(input_ids=input_ids,
45
+ attention_mask=input_mask,
46
+ token_type_ids=segment_ids)[0] # Models outputs are now tuples
47
+ else:
48
+ with torch.no_grad():
49
+ last_hidden_states = self.model(input_ids=input_ids,
50
+ attention_mask=input_mask,
51
+ token_type_ids=segment_ids)[0] # Models outputs are now tuples
52
+ return last_hidden_states
trains/subNets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .BertTextEncoder import BertTextEncoder
2
+ from .AlignNets import AlignSubNet
trains/subNets/transformers_encoder/multihead_attention.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ from torch.nn import Parameter
5
+
6
+ class MultiheadAttention(nn.Module):
7
+ """Multi-headed attention.
8
+ See "Attention Is All You Need" for more details.
9
+ """
10
+
11
+ def __init__(self, embed_dim, num_heads, attn_dropout=0.,
12
+ bias=True, add_bias_kv=False, add_zero_attn=False):
13
+ super().__init__()
14
+ self.embed_dim = embed_dim
15
+ self.num_heads = num_heads
16
+ self.attn_dropout = attn_dropout
17
+ self.head_dim = embed_dim // num_heads
18
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
19
+ self.scaling = self.head_dim ** -0.5
20
+
21
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
22
+ self.register_parameter('in_proj_bias', None)
23
+ if bias:
24
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
25
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
26
+
27
+ if add_bias_kv:
28
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
29
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
30
+ else:
31
+ self.bias_k = self.bias_v = None
32
+
33
+ self.add_zero_attn = add_zero_attn
34
+
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self):
38
+ nn.init.xavier_uniform_(self.in_proj_weight)
39
+ nn.init.xavier_uniform_(self.out_proj.weight)
40
+ if self.in_proj_bias is not None:
41
+ nn.init.constant_(self.in_proj_bias, 0.)
42
+ nn.init.constant_(self.out_proj.bias, 0.)
43
+ if self.bias_k is not None:
44
+ nn.init.xavier_normal_(self.bias_k)
45
+ if self.bias_v is not None:
46
+ nn.init.xavier_normal_(self.bias_v)
47
+
48
+ def forward(self, query, key, value, attn_mask=None):
49
+ """Input shape: Time x Batch x Channel
50
+ Self-attention can be implemented by passing in the same arguments for
51
+ query, key and value. Timesteps can be masked by supplying a T x T mask in the
52
+ `attn_mask` argument. Padding elements can be excluded from
53
+ the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
54
+ batch x src_len, where padding elements are indicated by 1s.
55
+ """
56
+ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
57
+ kv_same = key.data_ptr() == value.data_ptr()
58
+
59
+ tgt_len, bsz, embed_dim = query.size()
60
+ assert embed_dim == self.embed_dim
61
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
62
+ assert key.size() == value.size()
63
+
64
+ aved_state = None
65
+
66
+ if qkv_same:
67
+ # self-attention
68
+ q, k, v = self.in_proj_qkv(query)
69
+ elif kv_same:
70
+ # encoder-decoder attention
71
+ q = self.in_proj_q(query)
72
+
73
+ if key is None:
74
+ assert value is None
75
+ k = v = None
76
+ else:
77
+ k, v = self.in_proj_kv(key)
78
+ else:
79
+ q = self.in_proj_q(query)
80
+ k = self.in_proj_k(key)
81
+ v = self.in_proj_v(value)
82
+ q = q * self.scaling
83
+
84
+ if self.bias_k is not None:
85
+ assert self.bias_v is not None
86
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
87
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
88
+ if attn_mask is not None:
89
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
90
+
91
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
92
+ if k is not None:
93
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
94
+ if v is not None:
95
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
96
+
97
+ src_len = k.size(1)
98
+
99
+ if self.add_zero_attn:
100
+ src_len += 1
101
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
102
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
103
+ if attn_mask is not None:
104
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
105
+
106
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
107
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
108
+
109
+ if attn_mask is not None:
110
+ try:
111
+ attn_weights += attn_mask.unsqueeze(0)
112
+ except:
113
+ print(attn_weights.shape)
114
+ print(attn_mask.unsqueeze(0).shape)
115
+ assert False
116
+
117
+ attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
118
+ # attn_weights = F.relu(attn_weights)
119
+ # attn_weights = attn_weights / torch.max(attn_weights)
120
+ attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)
121
+
122
+ attn = torch.bmm(attn_weights, v)
123
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
124
+
125
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
126
+ attn = self.out_proj(attn)
127
+
128
+ # average attention weights over heads
129
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
130
+ attn_weights = attn_weights.sum(dim=1) / self.num_heads
131
+ return attn, attn_weights
132
+
133
+ def in_proj_qkv(self, query):
134
+ return self._in_proj(query).chunk(3, dim=-1)
135
+
136
+ def in_proj_kv(self, key):
137
+ return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
138
+
139
+ def in_proj_q(self, query, **kwargs):
140
+ return self._in_proj(query, end=self.embed_dim, **kwargs)
141
+
142
+ def in_proj_k(self, key):
143
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
144
+
145
+ def in_proj_v(self, value):
146
+ return self._in_proj(value, start=2 * self.embed_dim)
147
+
148
+ def _in_proj(self, input, start=0, end=None, **kwargs):
149
+ weight = kwargs.get('weight', self.in_proj_weight)
150
+ bias = kwargs.get('bias', self.in_proj_bias)
151
+ weight = weight[start:end, :]
152
+ if bias is not None:
153
+ bias = bias[start:end]
154
+ return F.linear(input, weight, bias)
trains/subNets/transformers_encoder/position_embedding.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ def make_positions(tensor, padding_idx, left_pad):
6
+ """Replace non-padding symbols with their position numbers.
7
+ Position numbers begin at padding_idx+1.
8
+ Padding symbols are ignored, but it is necessary to specify whether padding
9
+ is added on the left side (left_pad=True) or right side (left_pad=False).
10
+ """
11
+ max_pos = padding_idx + 1 + tensor.size(1)
12
+ device = tensor.get_device()
13
+ buf_name = f'range_buf_{device}'
14
+ if not hasattr(make_positions, buf_name):
15
+ setattr(make_positions, buf_name, tensor.new())
16
+ setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
17
+ if getattr(make_positions, buf_name).numel() < max_pos:
18
+ torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
19
+ mask = tensor.ne(padding_idx)
20
+ positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
21
+ if left_pad:
22
+ positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
23
+ new_tensor = tensor.clone()
24
+ return new_tensor.masked_scatter_(mask, positions[mask]).long()
25
+
26
+
27
+ class SinusoidalPositionalEmbedding(nn.Module):
28
+ """This module produces sinusoidal positional embeddings of any length.
29
+ Padding symbols are ignored, but it is necessary to specify whether padding
30
+ is added on the left side (left_pad=True) or right side (left_pad=False).
31
+ """
32
+
33
+ def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
34
+ super().__init__()
35
+ self.embedding_dim = embedding_dim
36
+ self.padding_idx = padding_idx
37
+ self.left_pad = left_pad
38
+ self.weights = dict() # device --> actual weight; due to nn.DataParallel :-(
39
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
40
+
41
+ @staticmethod
42
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
43
+ """Build sinusoidal embeddings.
44
+ This matches the implementation in tensor2tensor, but differs slightly
45
+ from the description in Section 3.5 of "Attention Is All You Need".
46
+ """
47
+ half_dim = embedding_dim // 2
48
+ emb = math.log(10000) / (half_dim - 1)
49
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
50
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
51
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
52
+ if embedding_dim % 2 == 1:
53
+ # zero pad
54
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
55
+ if padding_idx is not None:
56
+ emb[padding_idx, :] = 0
57
+ return emb
58
+
59
+ def forward(self, input):
60
+ """Input is expected to be of size [bsz x seqlen]."""
61
+ bsz, seq_len = input.size()
62
+ max_pos = self.padding_idx + 1 + seq_len
63
+ device = input.get_device()
64
+ if device not in self.weights or max_pos > self.weights[device].size(0):
65
+ # recompute/expand embeddings if needed
66
+ self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
67
+ max_pos,
68
+ self.embedding_dim,
69
+ self.padding_idx,
70
+ )
71
+ self.weights[device] = self.weights[device].type_as(self._float_tensor).to(input.device)
72
+ positions = make_positions(input, self.padding_idx, self.left_pad)
73
+ return self.weights[device].index_select(0, positions.contiguous().view(-1)).view(bsz, seq_len, -1).detach()
74
+
75
+ def max_positions(self):
76
+ """Maximum number of supported positions."""
77
+ return int(1e5) # an arbitrary large number
trains/subNets/transformers_encoder/transformer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from .multihead_attention import MultiheadAttention
6
+ from .position_embedding import SinusoidalPositionalEmbedding
7
+
8
+ class TransformerEncoder(nn.Module):
9
+ """
10
+ Transformer encoder consisting of *args.encoder_layers* layers. Each layer
11
+ is a :class:`TransformerEncoderLayer`.
12
+ Args:
13
+ embed_tokens (torch.nn.Embedding): input embedding
14
+ num_heads (int): number of heads
15
+ layers (int): number of layers
16
+ attn_dropout (float): dropout applied on the attention weights
17
+ relu_dropout (float): dropout applied on the first layer of the residual block
18
+ res_dropout (float): dropout applied on the residual block
19
+ attn_mask (bool): whether to apply mask on the attention weights
20
+ """
21
+
22
+ def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0,
23
+ embed_dropout=0.0, attn_mask=False):
24
+ super().__init__()
25
+ self.dropout = embed_dropout # Embedding dropout
26
+ self.attn_dropout = attn_dropout
27
+ self.embed_dim = embed_dim
28
+ self.embed_scale = math.sqrt(embed_dim)
29
+ self.embed_positions = SinusoidalPositionalEmbedding(embed_dim)
30
+
31
+ self.attn_mask = attn_mask
32
+
33
+ self.layers = nn.ModuleList([]) #define multiple transformer layers
34
+ for layer in range(layers):
35
+ new_layer = TransformerEncoderLayer(embed_dim,
36
+ num_heads=num_heads,
37
+ attn_dropout=attn_dropout,
38
+ relu_dropout=relu_dropout,
39
+ res_dropout=res_dropout,
40
+ attn_mask=attn_mask)
41
+ self.layers.append(new_layer)
42
+
43
+ self.register_buffer('version', torch.Tensor([2]))
44
+ self.normalize = True
45
+ if self.normalize:
46
+ self.layer_norm = LayerNorm(embed_dim)
47
+
48
+ def forward(self, x_in, x_in_k = None, x_in_v = None):
49
+ """
50
+ Args:
51
+ x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
52
+ x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
53
+ x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
54
+ Returns:
55
+ dict:
56
+ - **encoder_out** (Tensor): the last encoder layer's output of
57
+ shape `(src_len, batch, embed_dim)`
58
+ - **encoder_padding_mask** (ByteTensor): the positions of
59
+ padding elements of shape `(batch, src_len)`
60
+ """
61
+ # embed tokens and positions
62
+ x = self.embed_scale * x_in
63
+ #breakpoint()
64
+ if self.embed_positions is not None:
65
+ x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
66
+ x = F.dropout(x, p=self.dropout, training=self.training)
67
+
68
+ if x_in_k is not None and x_in_v is not None:
69
+ # embed tokens and positions
70
+ x_k = self.embed_scale * x_in_k
71
+ x_v = self.embed_scale * x_in_v
72
+ if self.embed_positions is not None:
73
+ x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
74
+ x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
75
+ x_k = F.dropout(x_k, p=self.dropout, training=self.training)
76
+ x_v = F.dropout(x_v, p=self.dropout, training=self.training)
77
+
78
+ # encoder layers
79
+ intermediates = [x]
80
+ for layer in self.layers:
81
+ if x_in_k is not None and x_in_v is not None:
82
+ x = layer(x, x_k, x_v)
83
+ else:
84
+ x = layer(x)
85
+ intermediates.append(x)
86
+
87
+ if self.normalize:
88
+ x = self.layer_norm(x)
89
+
90
+ return x
91
+
92
+ def max_positions(self):
93
+ """Maximum input length supported by the encoder."""
94
+ if self.embed_positions is None:
95
+ return self.max_source_positions
96
+ return min(self.max_source_positions, self.embed_positions.max_positions())
97
+
98
+
99
+ class TransformerEncoderLayer(nn.Module):
100
+ """Encoder layer block.
101
+ In the original paper each operation (multi-head attention or FFN) is
102
+ postprocessed with: `dropout -> add residual -> layernorm`. In the
103
+ tensor2tensor code they suggest that learning is more robust when
104
+ preprocessing each layer with layernorm and postprocessing with:
105
+ `dropout -> add residual`. We default to the approach in the paper, but the
106
+ tensor2tensor approach can be enabled by setting
107
+ *args.encoder_normalize_before* to ``True``.
108
+ Args:
109
+ embed_dim: Embedding dimension
110
+ """
111
+
112
+ def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
113
+ attn_mask=False):
114
+ super().__init__()
115
+ self.embed_dim = embed_dim
116
+ self.num_heads = num_heads
117
+
118
+ self.self_attn = MultiheadAttention(
119
+ embed_dim=self.embed_dim,
120
+ num_heads=self.num_heads,
121
+ attn_dropout=attn_dropout
122
+ )
123
+ self.attn_mask = attn_mask
124
+
125
+ self.relu_dropout = relu_dropout
126
+ self.res_dropout = res_dropout
127
+ self.normalize_before = True #true means using tensor2tensor approach
128
+
129
+ self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper
130
+ self.fc2 = Linear(4*self.embed_dim, self.embed_dim)
131
+ self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) #Define two layer_norms layers
132
+
133
+ def forward(self, x, x_k=None, x_v=None): #Two Transformer layers
134
+ """
135
+ Args:
136
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
137
+ encoder_padding_mask (ByteTensor): binary ByteTensor of shape
138
+ `(batch, src_len)` where padding elements are indicated by ``1``.
139
+ x_k (Tensor): same as x
140
+ x_v (Tensor): same as x
141
+ Returns:
142
+ encoded output of shape `(batch, src_len, embed_dim)`
143
+ """
144
+ residual = x
145
+ x = self.maybe_layer_norm(0, x, before=True)
146
+ mask = buffered_future_mask(x, x_k) if self.attn_mask else None
147
+ if x_k is None and x_v is None:
148
+ x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
149
+ else:
150
+ x_k = self.maybe_layer_norm(0, x_k, before=True)
151
+ x_v = self.maybe_layer_norm(0, x_v, before=True)
152
+ x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
153
+ x = F.dropout(x, p=self.res_dropout, training=self.training)
154
+ x = residual + x
155
+ x = self.maybe_layer_norm(0, x, after=True) #First Transformer layer
156
+
157
+ residual = x
158
+ x = self.maybe_layer_norm(1, x, before=True)
159
+ x = F.relu(self.fc1(x))
160
+ x = F.dropout(x, p=self.relu_dropout, training=self.training)
161
+ x = self.fc2(x)
162
+ x = F.dropout(x, p=self.res_dropout, training=self.training)
163
+ x = residual + x
164
+ x = self.maybe_layer_norm(1, x, after=True)
165
+ return x #The second one
166
+
167
+ def maybe_layer_norm(self, i, x, before=False, after=False):
168
+ assert before ^ after #before XOR after, allow only one is true
169
+ if after ^ self.normalize_before:
170
+ return self.layer_norms[i](x)
171
+ else:
172
+ return x
173
+
174
+ def fill_with_neg_inf(t):
175
+ """FP16-compatible function that fills a tensor with -inf."""
176
+ return t.float().fill_(float('-inf')).type_as(t)
177
+
178
+
179
+ def buffered_future_mask(tensor, tensor2=None):
180
+ dim1 = dim2 = tensor.size(0)
181
+ if tensor2 is not None:
182
+ dim2 = tensor2.size(0)
183
+ future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1))
184
+ if tensor.is_cuda:
185
+ future_mask = future_mask.to(tensor.device)
186
+ return future_mask[:dim1, :dim2]
187
+
188
+
189
+ def Linear(in_features, out_features, bias=True):
190
+ m = nn.Linear(in_features, out_features, bias)
191
+ nn.init.xavier_uniform_(m.weight)
192
+ if bias:
193
+ nn.init.constant_(m.bias, 0.)
194
+ return m
195
+
196
+
197
+ def LayerNorm(embedding_dim):
198
+ m = nn.LayerNorm(embedding_dim)
199
+ return m
200
+
201
+
202
+ if __name__ == '__main__':
203
+ encoder = TransformerEncoder(300, 4, 2) #embed_dim, num_heads, layers
204
+ x = torch.tensor(torch.rand(20, 2, 300))
205
+ print(encoder(x).shape)
trains/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import dict_to_str, setup_seed, assign_gpu, count_parameters
2
+ from .metricsTop import MetricsTop
trains/utils/functions.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import pynvml
5
+ import logging
6
+
7
+
8
+ logger = logging.getLogger('MMSA')
9
+
10
+
11
+ def dict_to_str(src_dict):
12
+ dst_str = ""
13
+ for key in src_dict.keys():
14
+ dst_str += " %s: %.4f " %(key, src_dict[key])
15
+ return dst_str
16
+
17
+ def setup_seed(seed):
18
+ torch.manual_seed(seed)
19
+ np.random.seed(seed)
20
+ random.seed(seed)
21
+ torch.backends.cudnn.benchmark = False
22
+ torch.backends.cudnn.deterministic = True
23
+
24
+ def assign_gpu(gpu_ids, memory_limit=1e16):
25
+ if len(gpu_ids) == 0 and torch.cuda.is_available():
26
+ # find most free gpu
27
+ pynvml.nvmlInit()
28
+ n_gpus = pynvml.nvmlDeviceGetCount()
29
+ dst_gpu_id, min_mem_used = 0, memory_limit
30
+ for g_id in range(n_gpus):
31
+ handle = pynvml.nvmlDeviceGetHandleByIndex(g_id)
32
+ meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
33
+ mem_used = meminfo.used
34
+ if mem_used < min_mem_used:
35
+ min_mem_used = mem_used
36
+ dst_gpu_id = g_id
37
+ logger.info(f'Found gpu {dst_gpu_id}, used memory {min_mem_used}.')
38
+ gpu_ids.append(dst_gpu_id)
39
+ # device
40
+ using_cuda = len(gpu_ids) > 0 and torch.cuda.is_available()
41
+ # logger.info("Let's use %d GPUs!" % len(gpu_ids))
42
+ device = torch.device('cuda:%d' % int(gpu_ids[0]) if using_cuda else 'cpu')
43
+ return device
44
+
45
+ def count_parameters(model):
46
+ res = 0
47
+ for p in model.parameters():
48
+ if p.requires_grad:
49
+ res += p.numel()
50
+ # print(p)
51
+ return res
trains/utils/metricsTop.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sklearn.metrics import accuracy_score, f1_score
4
+ from sklearn.metrics import mutual_info_score
5
+
6
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
7
+
8
+ __all__ = ['MetricsTop']
9
+
10
+ class MetricsTop():
11
+ def __init__(self, train_mode):
12
+ if train_mode == "regression":
13
+ self.metrics_dict = {
14
+ 'MOSI': self.__eval_mosi_regression,
15
+ 'MOSEI': self.__eval_mosei_regression,
16
+ }
17
+ else:
18
+ self.metrics_dict = {
19
+ 'MOSI': self.__eval_mosi_classification,
20
+ 'MOSEI': self.__eval_mosei_classification,
21
+ }
22
+
23
+ def __eval_mosi_classification(self, y_pred, y_true):
24
+ """
25
+ {
26
+ "Negative": 0,
27
+ "Neutral": 1,
28
+ "Positive": 2
29
+ }
30
+ """
31
+ y_pred = y_pred.cpu().detach().numpy()
32
+ y_true = y_true.cpu().detach().numpy()
33
+ # three classes
34
+ y_pred_3 = np.argmax(y_pred, axis=1)
35
+ Mult_acc_3 = accuracy_score(y_pred_3, y_true)
36
+ F1_score_3 = f1_score(y_true, y_pred_3, average='weighted')
37
+ # two classes
38
+ y_pred = np.array([[v[0], v[2]] for v in y_pred])
39
+ # with 0 (<= 0 or > 0)
40
+ y_pred_2 = np.argmax(y_pred, axis=1)
41
+ y_true_2 = []
42
+ for v in y_true:
43
+ y_true_2.append(0 if v <= 1 else 1)
44
+ y_true_2 = np.array(y_true_2)
45
+ Has0_acc_2 = accuracy_score(y_pred_2, y_true_2)
46
+ Has0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
47
+ # without 0 (< 0 or > 0)
48
+ non_zeros = np.array([i for i, e in enumerate(y_true) if e != 1])
49
+ y_pred_2 = y_pred[non_zeros]
50
+ y_pred_2 = np.argmax(y_pred_2, axis=1)
51
+ y_true_2 = y_true[non_zeros]
52
+ Non0_acc_2 = accuracy_score(y_pred_2, y_true_2)
53
+ Non0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
54
+
55
+ eval_results = {
56
+ "Has0_acc_2": round(Has0_acc_2, 4),
57
+ "Has0_F1_score": round(Has0_F1_score, 4),
58
+ "Non0_acc_2": round(Non0_acc_2, 4),
59
+ "Non0_F1_score": round(Non0_F1_score, 4),
60
+ "Acc_3": round(Mult_acc_3, 4),
61
+ "F1_score_3": round(F1_score_3, 4)
62
+ }
63
+ return eval_results
64
+
65
+ def __eval_mosei_classification(self, y_pred, y_true):
66
+ return self.__eval_mosi_classification(y_pred, y_true)
67
+
68
+ def __multiclass_acc(self, y_pred, y_true):
69
+ """
70
+ Compute the multiclass accuracy w.r.t. groundtruth
71
+
72
+ :param preds: Float array representing the predictions, dimension (N,)
73
+ :param truths: Float/int array representing the groundtruth classes, dimension (N,)
74
+ :return: Classification accuracy
75
+ """
76
+ return np.sum(np.round(y_pred) == np.round(y_true)) / float(len(y_true))
77
+
78
+ def __eval_mosei_regression(self, y_pred, y_true, exclude_zero=False):
79
+ test_preds = y_pred.view(-1).cpu().detach().numpy()
80
+ test_truth = y_true.view(-1).cpu().detach().numpy()
81
+
82
+ test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
83
+ test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
84
+ test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
85
+ test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)
86
+ test_preds_a3 = np.clip(test_preds, a_min=-1., a_max=1.)
87
+ test_truth_a3 = np.clip(test_truth, a_min=-1., a_max=1.)
88
+
89
+
90
+ mae = np.mean(np.absolute(test_preds - test_truth)).astype(np.float64)
91
+ corr = np.corrcoef(test_preds, test_truth)[0][1]
92
+ mult_a7 = self.__multiclass_acc(test_preds_a7, test_truth_a7)
93
+ mult_a5 = self.__multiclass_acc(test_preds_a5, test_truth_a5)
94
+ mult_a3 = self.__multiclass_acc(test_preds_a3, test_truth_a3)
95
+
96
+ non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0])
97
+ non_zeros_binary_truth = (test_truth[non_zeros] > 0)
98
+ non_zeros_binary_preds = (test_preds[non_zeros] > 0)
99
+
100
+ non_zeros_acc2 = accuracy_score(non_zeros_binary_preds, non_zeros_binary_truth)
101
+ non_zeros_f1_score = f1_score(non_zeros_binary_truth, non_zeros_binary_preds, average='weighted')
102
+
103
+ binary_truth = (test_truth >= 0)
104
+ binary_preds = (test_preds >= 0)
105
+ acc2 = accuracy_score(binary_preds, binary_truth)
106
+ f_score = f1_score(binary_truth, binary_preds, average='weighted')
107
+
108
+ eval_results = {
109
+ "acc_7": round(mult_a7, 4),
110
+ "acc_5": round(mult_a5, 4),
111
+ "acc_2": round(non_zeros_acc2, 4),
112
+ "F1_score": round(non_zeros_f1_score, 4),
113
+ "Corr": round(corr, 4),
114
+ "MAE": round(mae, 4)
115
+ }
116
+
117
+
118
+
119
+ return eval_results
120
+
121
+ def __eval_mosi_regression(self, y_pred, y_true):
122
+ return self.__eval_mosei_regression(y_pred, y_true)
123
+
124
+ def getMetics(self, datasetName):
125
+ return self.metrics_dict[datasetName.upper()]
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functions import dict_to_str, setup_seed, assign_gpu, count_parameters
2
+ from .metricsTop import MetricsTop
utils/functions.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import random
4
+ import pynvml
5
+ import logging
6
+
7
+
8
+ logger = logging.getLogger('MMSA')
9
+
10
+
11
+ def dict_to_str(src_dict):
12
+ dst_str = ""
13
+ for key in src_dict.keys():
14
+ dst_str += " %s: %.4f " %(key, src_dict[key])
15
+ return dst_str
16
+
17
+ def setup_seed(seed):
18
+ torch.manual_seed(seed)
19
+ np.random.seed(seed)
20
+ random.seed(seed)
21
+ torch.backends.cudnn.benchmark = False
22
+ torch.backends.cudnn.deterministic = True
23
+
24
+ def assign_gpu(gpu_ids, memory_limit=1e16):
25
+ if len(gpu_ids) == 0 and torch.cuda.is_available():
26
+ # find most free gpu
27
+ pynvml.nvmlInit()
28
+ n_gpus = pynvml.nvmlDeviceGetCount()
29
+ dst_gpu_id, min_mem_used = 0, memory_limit
30
+ for g_id in range(n_gpus):
31
+ handle = pynvml.nvmlDeviceGetHandleByIndex(g_id)
32
+ meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
33
+ mem_used = meminfo.used
34
+ if mem_used < min_mem_used:
35
+ min_mem_used = mem_used
36
+ dst_gpu_id = g_id
37
+ logger.info(f'Found gpu {dst_gpu_id}, used memory {min_mem_used}.')
38
+ gpu_ids.append(dst_gpu_id)
39
+ # device
40
+ using_cuda = len(gpu_ids) > 0 and torch.cuda.is_available()
41
+ # logger.info("Let's use %d GPUs!" % len(gpu_ids))
42
+ device = torch.device('cuda:%d' % int(gpu_ids[0]) if using_cuda else 'cpu')
43
+ return device
44
+
45
+ def count_parameters(model):
46
+ res = 0
47
+ for p in model.parameters():
48
+ if p.requires_grad:
49
+ res += p.numel()
50
+ # print(p)
51
+ return res
utils/metricsTop.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import accuracy_score, f1_score
3
+
4
+ __all__ = ['MetricsTop']
5
+
6
+ class MetricsTop():
7
+ def __init__(self, train_mode):
8
+ if train_mode == "regression":
9
+ self.metrics_dict = {
10
+ 'MOSI': self.__eval_mosi_regression,
11
+ 'MOSEI': self.__eval_mosei_regression,
12
+ }
13
+ else:
14
+ self.metrics_dict = {
15
+ 'MOSI': self.__eval_mosi_classification,
16
+ 'MOSEI': self.__eval_mosei_classification,
17
+ }
18
+
19
+ def __eval_mosi_classification(self, y_pred, y_true):
20
+ y_pred = y_pred.cpu().detach().numpy()
21
+ y_true = y_true.cpu().detach().numpy()
22
+ # three classes
23
+ y_pred_3 = np.argmax(y_pred, axis=1)
24
+ Mult_acc_3 = accuracy_score(y_pred_3, y_true)
25
+ F1_score_3 = f1_score(y_true, y_pred_3, average='weighted')
26
+ # two classes
27
+ y_pred = np.array([[v[0], v[2]] for v in y_pred])
28
+ # with 0 (<= 0 or > 0)
29
+ y_pred_2 = np.argmax(y_pred, axis=1)
30
+ y_true_2 = []
31
+ for v in y_true:
32
+ y_true_2.append(0 if v <= 1 else 1)
33
+ y_true_2 = np.array(y_true_2)
34
+ Has0_acc_2 = accuracy_score(y_pred_2, y_true_2)
35
+ Has0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
36
+ # without 0 (< 0 or > 0)
37
+ non_zeros = np.array([i for i, e in enumerate(y_true) if e != 1])
38
+ y_pred_2 = y_pred[non_zeros]
39
+ y_pred_2 = np.argmax(y_pred_2, axis=1)
40
+ y_true_2 = y_true[non_zeros]
41
+ Non0_acc_2 = accuracy_score(y_pred_2, y_true_2)
42
+ Non0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
43
+
44
+ eval_results = {
45
+ "Has0_acc_2": round(Has0_acc_2, 4),
46
+ "Has0_F1_score": round(Has0_F1_score, 4),
47
+ "Non0_acc_2": round(Non0_acc_2, 4),
48
+ "Non0_F1_score": round(Non0_F1_score, 4),
49
+ "Acc_3": round(Mult_acc_3, 4),
50
+ "F1_score_3": round(F1_score_3, 4)
51
+ }
52
+ return eval_results
53
+
54
+ def __eval_mosei_classification(self, y_pred, y_true):
55
+ return self.__eval_mosi_classification(y_pred, y_true)
56
+
57
+
58
+ def __multiclass_acc(self, y_pred, y_true):
59
+ """
60
+ Compute the multiclass accuracy w.r.t. groundtruth
61
+
62
+ :param preds: Float array representing the predictions, dimension (N,)
63
+ :param truths: Float/int array representing the groundtruth classes, dimension (N,)
64
+ :return: Classification accuracy
65
+ """
66
+ return np.sum(np.round(y_pred) == np.round(y_true)) / float(len(y_true))
67
+
68
+ def __eval_mosei_regression(self, y_pred, y_true, exclude_zero=False):
69
+ test_preds = y_pred.view(-1).cpu().detach().numpy()
70
+ test_truth = y_true.view(-1).cpu().detach().numpy()
71
+
72
+ test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
73
+ test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
74
+ test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
75
+ test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)
76
+ test_preds_a3 = np.clip(test_preds, a_min=-1., a_max=1.)
77
+ test_truth_a3 = np.clip(test_truth, a_min=-1., a_max=1.)
78
+
79
+
80
+ mae = np.mean(np.absolute(test_preds - test_truth)).astype(np.float64) # Average L1 distance between preds and truths
81
+ corr = np.corrcoef(test_preds, test_truth)[0][1]
82
+ mult_a7 = self.__multiclass_acc(test_preds_a7, test_truth_a7)
83
+ mult_a5 = self.__multiclass_acc(test_preds_a5, test_truth_a5)
84
+ mult_a3 = self.__multiclass_acc(test_preds_a3, test_truth_a3)
85
+
86
+ non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0])
87
+ non_zeros_binary_truth = (test_truth[non_zeros] > 0)
88
+ non_zeros_binary_preds = (test_preds[non_zeros] > 0)
89
+
90
+ non_zeros_acc2 = accuracy_score(non_zeros_binary_preds, non_zeros_binary_truth)
91
+ non_zeros_f1_score = f1_score(non_zeros_binary_truth, non_zeros_binary_preds, average='weighted')
92
+
93
+ binary_truth = (test_truth >= 0)
94
+ binary_preds = (test_preds >= 0)
95
+ acc2 = accuracy_score(binary_preds, binary_truth)
96
+ f_score = f1_score(binary_truth, binary_preds, average='weighted')
97
+
98
+ eval_results = {
99
+ "Acc_2": round(non_zeros_acc2, 4),
100
+ "F1_score": round(non_zeros_f1_score, 4),
101
+ "Acc_7": round(mult_a7, 4),
102
+ "MAE": round(mae, 4),
103
+ }
104
+ return eval_results
105
+
106
+
107
+ def __eval_mosi_regression(self, y_pred, y_true):
108
+ return self.__eval_mosei_regression(y_pred, y_true)
109
+
110
+ def getMetics(self, datasetName):
111
+ return self.metrics_dict[datasetName.upper()]