peter-wang321
commited on
Commit
·
9157432
1
Parent(s):
ab35d3d
Initial DLF commit
Browse files- .gitignore +3 -0
- config.py +40 -0
- config/config.json +98 -0
- config/readme.md +1 -0
- data_loader.py +156 -0
- log/readme.md +1 -0
- pt/readme.md +1 -0
- requirements.txt +6 -0
- result/normal/mosei.csv +11 -0
- result/normal/mosi.csv +21 -0
- result/readme.md +1 -0
- run.py +201 -0
- test.py +7 -0
- train.py +7 -0
- trains/ATIO.py +15 -0
- trains/__init__.py +1 -0
- trains/singleTask/DLF.py +233 -0
- trains/singleTask/HingeLoss.py +57 -0
- trains/singleTask/__init__.py +1 -0
- trains/singleTask/distillnets/get_distillation_kernel.py +96 -0
- trains/singleTask/distillnets/get_distillation_kernel_homo.py +100 -0
- trains/singleTask/misc.py +196 -0
- trains/singleTask/model/DLF.py +345 -0
- trains/singleTask/utils/__init__.py +1 -0
- trains/singleTask/utils/misc.py +196 -0
- trains/subNets/AlignNets.py +106 -0
- trains/subNets/BertTextEncoder.py +52 -0
- trains/subNets/__init__.py +2 -0
- trains/subNets/transformers_encoder/multihead_attention.py +154 -0
- trains/subNets/transformers_encoder/position_embedding.py +77 -0
- trains/subNets/transformers_encoder/transformer.py +205 -0
- trains/utils/__init__.py +2 -0
- trains/utils/functions.py +51 -0
- trains/utils/metricsTop.py +125 -0
- utils/__init__.py +2 -0
- utils/functions.py +51 -0
- utils/metricsTop.py +111 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
log/*.log
|
2 |
+
pt/*.pth
|
3 |
+
__pycache__/
|
config.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from easydict import EasyDict as edict
|
5 |
+
|
6 |
+
|
7 |
+
def get_config_regression(model_name, dataset_name, config_file=""):
|
8 |
+
"""
|
9 |
+
Get the regression config of given dataset and model from config file.
|
10 |
+
|
11 |
+
Parameters:
|
12 |
+
config_file (str): Path to config file, if given an empty string, will use default config file.
|
13 |
+
model_name (str): Name of model.
|
14 |
+
dataset_name (str): Name of dataset.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
config (dict): config of the given dataset and model
|
18 |
+
"""
|
19 |
+
if config_file == "":
|
20 |
+
config_file = Path(__file__).parent / "config" / "config_regression.json"
|
21 |
+
with open(config_file, 'r') as f:
|
22 |
+
config_all = json.load(f)
|
23 |
+
model_common_args = config_all[model_name]['commonParams']
|
24 |
+
model_dataset_args = config_all[model_name]['datasetParams'][dataset_name]
|
25 |
+
dataset_args = config_all['datasetCommonParams'][dataset_name]
|
26 |
+
# use aligned feature if the model requires it, otherwise use unaligned feature
|
27 |
+
dataset_args = dataset_args['aligned'] if (model_common_args['need_data_aligned'] and 'aligned' in dataset_args) else dataset_args['unaligned']
|
28 |
+
|
29 |
+
config = {}
|
30 |
+
config['model_name'] = model_name
|
31 |
+
config['dataset_name'] = dataset_name
|
32 |
+
config.update(dataset_args)
|
33 |
+
config.update(model_common_args)
|
34 |
+
config.update(model_dataset_args)
|
35 |
+
config['featurePath'] = os.path.join(config_all['datasetCommonParams']['dataset_root_dir'], config['featurePath'])
|
36 |
+
config = edict(config)
|
37 |
+
|
38 |
+
return config
|
39 |
+
|
40 |
+
|
config/config.json
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"datasetCommonParams": {
|
3 |
+
"dataset_root_dir": "../dataset",
|
4 |
+
"mosi": {
|
5 |
+
"aligned": {
|
6 |
+
"featurePath": "MOSI/Processed/aligned_50.pkl",
|
7 |
+
"feature_dims": [768, 5, 20],
|
8 |
+
"train_samples": 1284,
|
9 |
+
"num_classes": 3,
|
10 |
+
"language": "en",
|
11 |
+
"KeyEval": "Loss"
|
12 |
+
},
|
13 |
+
"unaligned": {
|
14 |
+
"featurePath": "MOSI/Processed/unaligned_50.pkl",
|
15 |
+
"feature_dims": [768, 5, 20],
|
16 |
+
"train_samples": 1284,
|
17 |
+
"num_classes": 3,
|
18 |
+
"language": "en",
|
19 |
+
"KeyEval": "Loss"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
"mosei": {
|
23 |
+
"aligned": {
|
24 |
+
"featurePath": "MOSEI/Processed/aligned_50.pkl",
|
25 |
+
"feature_dims": [768, 74, 35],
|
26 |
+
"train_samples": 16326,
|
27 |
+
"num_classes": 3,
|
28 |
+
"language": "en",
|
29 |
+
"KeyEval": "Loss"
|
30 |
+
},
|
31 |
+
"unaligned": {
|
32 |
+
"featurePath": "MOSEI/Processed/unaligned_50.pkl",
|
33 |
+
"feature_dims": [768, 74, 35],
|
34 |
+
"train_samples": 16326,
|
35 |
+
"num_classes": 3,
|
36 |
+
"language": "en",
|
37 |
+
"KeyEval": "Loss"
|
38 |
+
}
|
39 |
+
}
|
40 |
+
},
|
41 |
+
"DLF": {
|
42 |
+
"commonParams": {
|
43 |
+
"need_data_aligned": true,
|
44 |
+
"need_model_aligned": true,
|
45 |
+
"early_stop": 10,
|
46 |
+
"use_bert": true,
|
47 |
+
"use_finetune": true,
|
48 |
+
"attn_mask": true,
|
49 |
+
"update_epochs": 10
|
50 |
+
},
|
51 |
+
"datasetParams": {
|
52 |
+
"mosi": {
|
53 |
+
"attn_dropout_a": 0.2,
|
54 |
+
"attn_dropout_v": 0.0,
|
55 |
+
"relu_dropout": 0.0,
|
56 |
+
"embed_dropout": 0.2,
|
57 |
+
"res_dropout": 0.0,
|
58 |
+
"dst_feature_dim_nheads": [50, 10],
|
59 |
+
"batch_size": 16,
|
60 |
+
"learning_rate": 0.0001,
|
61 |
+
"nlevels": 2,
|
62 |
+
"conv1d_kernel_size_l": 5,
|
63 |
+
"conv1d_kernel_size_a": 5,
|
64 |
+
"conv1d_kernel_size_v": 5,
|
65 |
+
"text_dropout": 0.5,
|
66 |
+
"attn_dropout": 0.3,
|
67 |
+
"output_dropout": 0.5,
|
68 |
+
"grad_clip": 0.6,
|
69 |
+
"patience": 5,
|
70 |
+
"weight_decay": 0.005,
|
71 |
+
"transformers": "bert",
|
72 |
+
"pretrained": "bert-base-uncased"
|
73 |
+
},
|
74 |
+
"mosei": {
|
75 |
+
"attn_dropout_a": 0.0,
|
76 |
+
"attn_dropout_v": 0.0,
|
77 |
+
"relu_dropout": 0.0,
|
78 |
+
"embed_dropout": 0.0,
|
79 |
+
"res_dropout": 0.0,
|
80 |
+
"dst_feature_dim_nheads": [50, 10],
|
81 |
+
"batch_size": 16,
|
82 |
+
"learning_rate": 0.0001,
|
83 |
+
"nlevels": 2,
|
84 |
+
"conv1d_kernel_size_l": 3,
|
85 |
+
"conv1d_kernel_size_a": 3,
|
86 |
+
"conv1d_kernel_size_v": 3,
|
87 |
+
"text_dropout": 0.1,
|
88 |
+
"attn_dropout": 0.5,
|
89 |
+
"output_dropout": 0.5,
|
90 |
+
"grad_clip": 0.6,
|
91 |
+
"patience": 5,
|
92 |
+
"weight_decay": 0.001,
|
93 |
+
"transformers": "bert",
|
94 |
+
"pretrained": "bert-base-uncased"
|
95 |
+
}
|
96 |
+
}
|
97 |
+
}
|
98 |
+
}
|
config/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
config file
|
data_loader.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import pickle
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import DataLoader, Dataset
|
6 |
+
__all__ = ['MMDataLoader']
|
7 |
+
logger = logging.getLogger('MMSA')
|
8 |
+
|
9 |
+
class MMDataset(Dataset):
|
10 |
+
def __init__(self, args, mode='train'):
|
11 |
+
self.mode = mode
|
12 |
+
self.args = args
|
13 |
+
DATASET_MAP = {
|
14 |
+
'mosi': self.__init_mosi,
|
15 |
+
'mosei': self.__init_mosei,
|
16 |
+
}
|
17 |
+
DATASET_MAP[args['dataset_name']]()
|
18 |
+
|
19 |
+
def __init_mosi(self):
|
20 |
+
with open(self.args['featurePath'], 'rb') as f:
|
21 |
+
data = pickle.load(f)
|
22 |
+
if 'use_bert' in self.args and self.args['use_bert']:
|
23 |
+
self.text = data[self.mode]['text_bert'].astype(np.float32)
|
24 |
+
else:
|
25 |
+
self.text = data[self.mode]['text'].astype(np.float32)
|
26 |
+
self.vision = data[self.mode]['vision'].astype(np.float32)
|
27 |
+
self.audio = data[self.mode]['audio'].astype(np.float32)
|
28 |
+
self.raw_text = data[self.mode]['raw_text']
|
29 |
+
self.ids = data[self.mode]['id']
|
30 |
+
|
31 |
+
|
32 |
+
if self.args['feature_T'] != "":
|
33 |
+
with open(self.args['feature_T'], 'rb') as f:
|
34 |
+
data_T = pickle.load(f)
|
35 |
+
if 'use_bert' in self.args and self.args['use_bert']:
|
36 |
+
self.text = data_T[self.mode]['text_bert'].astype(np.float32)
|
37 |
+
self.args['feature_dims'][0] = 768
|
38 |
+
else:
|
39 |
+
self.text = data_T[self.mode]['text'].astype(np.float32)
|
40 |
+
self.args['feature_dims'][0] = self.text.shape[2]
|
41 |
+
if self.args['feature_A'] != "":
|
42 |
+
with open(self.args['feature_A'], 'rb') as f:
|
43 |
+
data_A = pickle.load(f)
|
44 |
+
self.audio = data_A[self.mode]['audio'].astype(np.float32)
|
45 |
+
self.args['feature_dims'][1] = self.audio.shape[2]
|
46 |
+
if self.args['feature_V'] != "":
|
47 |
+
with open(self.args['feature_V'], 'rb') as f:
|
48 |
+
data_V = pickle.load(f)
|
49 |
+
self.vision = data_V[self.mode]['vision'].astype(np.float32)
|
50 |
+
self.args['feature_dims'][2] = self.vision.shape[2]
|
51 |
+
|
52 |
+
self.labels = {
|
53 |
+
'M': np.array(data[self.mode]['regression_labels']).astype(np.float32)
|
54 |
+
}
|
55 |
+
|
56 |
+
logger.info(f"{self.mode} samples: {self.labels['M'].shape}")
|
57 |
+
|
58 |
+
|
59 |
+
if not self.args['need_data_aligned']:
|
60 |
+
if self.args['feature_A'] != "":
|
61 |
+
self.audio_lengths = list(data_A[self.mode]['audio_lengths'])
|
62 |
+
else:
|
63 |
+
self.audio_lengths = data[self.mode]['audio_lengths']
|
64 |
+
if self.args['feature_V'] != "":
|
65 |
+
self.vision_lengths = list(data_V[self.mode]['vision_lengths'])
|
66 |
+
else:
|
67 |
+
self.vision_lengths = data[self.mode]['vision_lengths']
|
68 |
+
self.audio[self.audio == -np.inf] = 0
|
69 |
+
|
70 |
+
if 'need_normalized' in self.args and self.args['need_normalized']:
|
71 |
+
self.__normalize()
|
72 |
+
|
73 |
+
def __init_mosei(self):
|
74 |
+
return self.__init_mosi()
|
75 |
+
|
76 |
+
def __init_sims(self):
|
77 |
+
return self.__init_mosi()
|
78 |
+
|
79 |
+
def __truncate(self):
|
80 |
+
def do_truncate(modal_features, length):
|
81 |
+
if length == modal_features.shape[1]:
|
82 |
+
return modal_features
|
83 |
+
truncated_feature = []
|
84 |
+
padding = np.array([0 for i in range(modal_features.shape[2])])
|
85 |
+
for instance in modal_features:
|
86 |
+
for index in range(modal_features.shape[1]):
|
87 |
+
if((instance[index] == padding).all()):
|
88 |
+
if(index + length >= modal_features.shape[1]):
|
89 |
+
truncated_feature.append(instance[index:index+20])
|
90 |
+
break
|
91 |
+
else:
|
92 |
+
truncated_feature.append(instance[index:index+20])
|
93 |
+
break
|
94 |
+
truncated_feature = np.array(truncated_feature)
|
95 |
+
return truncated_feature
|
96 |
+
|
97 |
+
text_length, audio_length, video_length = self.args['seq_lens']
|
98 |
+
self.vision = do_truncate(self.vision, video_length)
|
99 |
+
self.text = do_truncate(self.text, text_length)
|
100 |
+
self.audio = do_truncate(self.audio, audio_length)
|
101 |
+
|
102 |
+
def __normalize(self):
|
103 |
+
|
104 |
+
self.vision = np.mean(self.vision, axis=1, keepdims=True)
|
105 |
+
self.audio = np.mean(self.audio, axis=1, keepdims=True)
|
106 |
+
|
107 |
+
self.vision[self.vision != self.vision] = 0
|
108 |
+
self.audio[self.audio != self.audio] = 0
|
109 |
+
|
110 |
+
def __len__(self):
|
111 |
+
return len(self.labels['M'])
|
112 |
+
|
113 |
+
def get_seq_len(self):
|
114 |
+
if 'use_bert' in self.args and self.args['use_bert']:
|
115 |
+
return (self.text.shape[2], self.audio.shape[1], self.vision.shape[1])
|
116 |
+
else:
|
117 |
+
return (self.text.shape[1], self.audio.shape[1], self.vision.shape[1])
|
118 |
+
|
119 |
+
def get_feature_dim(self):
|
120 |
+
return self.text.shape[2], self.audio.shape[2], self.vision.shape[2]
|
121 |
+
|
122 |
+
def __getitem__(self, index):
|
123 |
+
sample = {
|
124 |
+
'raw_text': self.raw_text[index],
|
125 |
+
'text': torch.Tensor(self.text[index]),
|
126 |
+
'audio': torch.Tensor(self.audio[index]),
|
127 |
+
'vision': torch.Tensor(self.vision[index]),
|
128 |
+
'index': index,
|
129 |
+
'id': self.ids[index],
|
130 |
+
'labels': {k: torch.Tensor(v[index].reshape(-1)) for k, v in self.labels.items()}
|
131 |
+
}
|
132 |
+
if not self.args['need_data_aligned']:
|
133 |
+
sample['audio_lengths'] = self.audio_lengths[index]
|
134 |
+
sample['vision_lengths'] = self.vision_lengths[index]
|
135 |
+
return sample
|
136 |
+
|
137 |
+
def MMDataLoader(args, num_workers):
|
138 |
+
|
139 |
+
datasets = {
|
140 |
+
'train': MMDataset(args, mode='train'),
|
141 |
+
'valid': MMDataset(args, mode='valid'),
|
142 |
+
'test': MMDataset(args, mode='test')
|
143 |
+
}
|
144 |
+
|
145 |
+
if 'seq_lens' in args:
|
146 |
+
args['seq_lens'] = datasets['train'].get_seq_len()
|
147 |
+
|
148 |
+
dataLoader = {
|
149 |
+
ds: DataLoader(datasets[ds],
|
150 |
+
batch_size=args['batch_size'],
|
151 |
+
num_workers=num_workers,
|
152 |
+
shuffle=True)
|
153 |
+
for ds in datasets.keys()
|
154 |
+
}
|
155 |
+
|
156 |
+
return dataLoader
|
log/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
training log
|
pt/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Here the trained models are saved.
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.33.1
|
2 |
+
huggingface-hub==0.17.1
|
3 |
+
numpy==1.21.5
|
4 |
+
scipy==1.9.1
|
5 |
+
scikit-learn==1.0.2
|
6 |
+
pandas==1.4.4
|
result/normal/mosei.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Time,Model,acc_7,acc_5,acc_2,F1_score,Corr,MAE,Loss
|
2 |
+
2024/12/08 13:54:11 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
3 |
+
2024/12/08 16:00:34 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
4 |
+
2024/12/08 21:00:18 - ,DLF,"(53.87, 0.0)","(55.66, 0.0)","(84.31, 0.0)","(84.36, 0.0)","(75.75, 0.0)","(53.86, 0.0)","(53.82, 0.0)"
|
5 |
+
2024/12/08 22:10:45 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
6 |
+
2024/12/09 00:43:05 - ,DLF,"(53.14, 0.0)","(55.03, 0.0)","(84.92, 0.0)","(84.9, 0.0)","(76.46, 0.0)","(54.05, 0.0)","(54.03, 0.0)"
|
7 |
+
2024/12/09 02:40:19 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
8 |
+
2024/12/09 04:45:03 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
9 |
+
2024/12/09 16:10:24 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
10 |
+
2024/12/13 02:59:46 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
11 |
+
2024/12/15 23:53:26 - ,DLF,"(53.9, 0.0)","(55.7, 0.0)","(85.42, 0.0)","(85.27, 0.0)","(76.36, 0.0)","(53.61, 0.0)","(53.66, 0.0)"
|
result/normal/mosi.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Time,Model,acc_7,acc_5,acc_2,F1_score,Corr,MAE,Loss
|
2 |
+
2024/12/08 14:55:08 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
|
3 |
+
2024/12/08 15:20:12 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
4 |
+
2024/12/08 15:45:53 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
5 |
+
2024/12/08 19:28:22 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
|
6 |
+
2024/12/08 19:46:53 - ,DLF,"(45.63, 0.0)","(52.77, 0.0)","(84.45, 0.0)","(84.42, 0.0)","(79.43, 0.0)","(72.2, 0.0)","(72.2, 0.0)"
|
7 |
+
2024/12/08 20:05:59 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
8 |
+
2024/12/08 20:26:47 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
9 |
+
2024/12/08 20:42:22 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
10 |
+
2024/12/09 11:34:14 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
11 |
+
2024/12/09 12:00:35 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
12 |
+
2024/12/09 15:54:01 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
13 |
+
2024/12/09 18:33:19 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
14 |
+
2024/12/13 02:44:39 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
15 |
+
2024/12/15 20:13:12 - ,DLF,"(46.36, 0.0)","(53.35, 0.0)","(83.38, 0.0)","(83.4, 0.0)","(78.83, 0.0)","(72.89, 0.0)","(72.94, 0.0)"
|
16 |
+
2024/12/15 20:38:24 - ,DLF,"(44.75, 0.0)","(51.9, 0.0)","(83.84, 0.0)","(83.85, 0.0)","(78.2, 0.0)","(72.78, 0.0)","(72.78, 0.0)"
|
17 |
+
2024/12/15 20:50:42 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
18 |
+
2024/12/15 23:25:21 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
19 |
+
2024/12/16 03:44:10 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
20 |
+
2024/12/16 04:06:00 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
21 |
+
2024/12/16 04:59:28 - ,DLF,"(47.08, 0.0)","(52.33, 0.0)","(85.06, 0.0)","(85.04, 0.0)","(78.14, 0.0)","(73.08, 0.0)","(73.07, 0.0)"
|
result/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The results will be saved here as a csv file
|
run.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from config import get_config_regression
|
10 |
+
from data_loader import MMDataLoader
|
11 |
+
from trains import ATIO
|
12 |
+
from utils import assign_gpu, setup_seed
|
13 |
+
from trains.singleTask.model import DLF
|
14 |
+
from trains.singleTask.distillnets import get_distillation_kernel, get_distillation_kernel_homo
|
15 |
+
from trains.singleTask.misc import softmax
|
16 |
+
import sys
|
17 |
+
|
18 |
+
from datetime import datetime
|
19 |
+
now = datetime.now()
|
20 |
+
format = "%Y/%m/%d %H:%M:%S"
|
21 |
+
formatted_now = now.strftime(format)
|
22 |
+
formatted_now = str(formatted_now)+" - "
|
23 |
+
|
24 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
25 |
+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:2"
|
26 |
+
logger = logging.getLogger('MMSA')
|
27 |
+
|
28 |
+
def _set_logger(log_dir, model_name, dataset_name, verbose_level):
|
29 |
+
|
30 |
+
# base logger
|
31 |
+
log_file_path = Path(log_dir) / f"{model_name}-{dataset_name}.log"
|
32 |
+
logger = logging.getLogger('MMSA')
|
33 |
+
logger.setLevel(logging.DEBUG)
|
34 |
+
|
35 |
+
# file handler
|
36 |
+
fh = logging.FileHandler(log_file_path)
|
37 |
+
fh_formatter = logging.Formatter('%(asctime)s - %(name)s [%(levelname)s] - %(message)s')
|
38 |
+
fh.setLevel(logging.DEBUG)
|
39 |
+
fh.setFormatter(fh_formatter)
|
40 |
+
logger.addHandler(fh)
|
41 |
+
|
42 |
+
# stream handler
|
43 |
+
stream_level = {0: logging.ERROR, 1: logging.INFO, 2: logging.DEBUG}
|
44 |
+
ch = logging.StreamHandler()
|
45 |
+
ch.setLevel(stream_level[verbose_level])
|
46 |
+
ch_formatter = logging.Formatter('%(name)s - %(message)s')
|
47 |
+
ch.setFormatter(ch_formatter)
|
48 |
+
logger.addHandler(ch)
|
49 |
+
|
50 |
+
return logger
|
51 |
+
|
52 |
+
|
53 |
+
def DLF_run(
|
54 |
+
model_name, dataset_name, config=None, config_file="", seeds=[], is_tune=False,
|
55 |
+
tune_times=500, feature_T="", feature_A="", feature_V="",
|
56 |
+
model_save_dir="", res_save_dir="", log_dir="",
|
57 |
+
gpu_ids=[0], num_workers=1, verbose_level=1, mode = '', is_training = False
|
58 |
+
):
|
59 |
+
# Initialization
|
60 |
+
model_name = model_name.upper()
|
61 |
+
dataset_name = dataset_name.lower()
|
62 |
+
|
63 |
+
if config_file != "":
|
64 |
+
config_file = Path(config_file)
|
65 |
+
else: # use default config files
|
66 |
+
config_file = Path(__file__).parent / "config" / "config.json"
|
67 |
+
if not config_file.is_file():
|
68 |
+
raise ValueError(f"Config file {str(config_file)} not found.")
|
69 |
+
if model_save_dir == "":
|
70 |
+
model_save_dir = Path.home() / "MMSA" / "saved_models"
|
71 |
+
Path(model_save_dir).mkdir(parents=True, exist_ok=True)
|
72 |
+
if res_save_dir == "":
|
73 |
+
res_save_dir = Path.home() / "MMSA" / "results"
|
74 |
+
Path(res_save_dir).mkdir(parents=True, exist_ok=True)
|
75 |
+
if log_dir == "":
|
76 |
+
log_dir = Path.home() / "MMSA" / "logs"
|
77 |
+
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
78 |
+
seeds = seeds if seeds != [] else [1111, 1112, 1113, 1114, 1115]
|
79 |
+
logger = _set_logger(log_dir, model_name, dataset_name, verbose_level)
|
80 |
+
|
81 |
+
|
82 |
+
args = get_config_regression(model_name, dataset_name, config_file)
|
83 |
+
args.is_training = is_training
|
84 |
+
args.mode = mode # train or test
|
85 |
+
args['model_save_path'] = Path(model_save_dir) / f"{args['model_name']}-{args['dataset_name']}.pth"
|
86 |
+
args['device'] = assign_gpu(gpu_ids)
|
87 |
+
args['train_mode'] = 'regression'
|
88 |
+
args['feature_T'] = feature_T
|
89 |
+
args['feature_A'] = feature_A
|
90 |
+
args['feature_V'] = feature_V
|
91 |
+
if config:
|
92 |
+
args.update(config)
|
93 |
+
|
94 |
+
|
95 |
+
res_save_dir = Path(res_save_dir) / "normal"
|
96 |
+
res_save_dir.mkdir(parents=True, exist_ok=True)
|
97 |
+
model_results = []
|
98 |
+
for i, seed in enumerate(seeds):
|
99 |
+
setup_seed(seed)
|
100 |
+
args['cur_seed'] = i + 1
|
101 |
+
result = _run(args, num_workers, is_tune)
|
102 |
+
model_results.append(result)
|
103 |
+
if args.is_training:
|
104 |
+
criterions = list(model_results[0].keys())
|
105 |
+
# save result to csv
|
106 |
+
csv_file = res_save_dir / f"{dataset_name}.csv"
|
107 |
+
if csv_file.is_file():
|
108 |
+
df = pd.read_csv(csv_file)
|
109 |
+
else:
|
110 |
+
df = pd.DataFrame(columns=["Time"]+["Model"] + criterions)
|
111 |
+
# save results
|
112 |
+
res = [model_name]
|
113 |
+
for c in criterions:
|
114 |
+
values = [r[c] for r in model_results]
|
115 |
+
mean = round(np.mean(values)*100, 2)
|
116 |
+
std = round(np.std(values)*100, 2)
|
117 |
+
res.append((mean, std))
|
118 |
+
|
119 |
+
res = [formatted_now]+res
|
120 |
+
df.loc[len(df)] = res
|
121 |
+
df.to_csv(csv_file, index=None)
|
122 |
+
logger.info(f"Results saved to {csv_file}.")
|
123 |
+
|
124 |
+
|
125 |
+
def _run(args, num_workers=4, is_tune=False, from_sena=False):
|
126 |
+
|
127 |
+
dataloader = MMDataLoader(args, num_workers)
|
128 |
+
|
129 |
+
if args.is_training:
|
130 |
+
print("training for DLF")
|
131 |
+
|
132 |
+
|
133 |
+
args.gd_size_low = 64
|
134 |
+
args.w_losses_low = [1, 10]
|
135 |
+
args.metric_low = 'l1'
|
136 |
+
|
137 |
+
|
138 |
+
args.gd_size_high = 32
|
139 |
+
args.w_losses_high = [1, 10]
|
140 |
+
args.metric_high = 'l1'
|
141 |
+
|
142 |
+
to_idx = [0, 1, 2]
|
143 |
+
from_idx = [0, 1, 2]
|
144 |
+
assert len(from_idx) >= 1
|
145 |
+
|
146 |
+
model = []
|
147 |
+
model_DLF = getattr(DLF, 'DLF')(args)
|
148 |
+
|
149 |
+
model_distill_homo = getattr(get_distillation_kernel_homo, 'DistillationKernel')(n_classes=1,
|
150 |
+
hidden_size=
|
151 |
+
args.dst_feature_dim_nheads[0],
|
152 |
+
gd_size=args.gd_size_low,
|
153 |
+
to_idx=to_idx, from_idx=from_idx,
|
154 |
+
gd_prior=softmax([0, 0, 1, 0, 1, 0], 0.25),
|
155 |
+
gd_reg=10,
|
156 |
+
w_losses=args.w_losses_low,
|
157 |
+
metric=args.metric_low,
|
158 |
+
alpha=1 / 8,
|
159 |
+
hyp_params=args)
|
160 |
+
|
161 |
+
model_distill_hetero = getattr(get_distillation_kernel, 'DistillationKernel')(n_classes=1,
|
162 |
+
hidden_size=
|
163 |
+
args.dst_feature_dim_nheads[0] * 2,
|
164 |
+
gd_size=args.gd_size_high,
|
165 |
+
to_idx=to_idx, from_idx=from_idx,
|
166 |
+
gd_prior=softmax([0, 0, 1, 0, 1, 1], 0.25),
|
167 |
+
gd_reg=10,
|
168 |
+
w_losses=args.w_losses_high,
|
169 |
+
metric=args.metric_high,
|
170 |
+
alpha=1 / 8,
|
171 |
+
hyp_params=args)
|
172 |
+
|
173 |
+
model_DLF = model_DLF.cuda()
|
174 |
+
|
175 |
+
model = [model_DLF]
|
176 |
+
else:
|
177 |
+
print("testing phase for DLF")
|
178 |
+
model = getattr(DLF, 'DLF')(args)
|
179 |
+
model = model.cuda()
|
180 |
+
|
181 |
+
trainer = ATIO().getTrain(args)
|
182 |
+
|
183 |
+
|
184 |
+
#test
|
185 |
+
if args.mode == 'test':
|
186 |
+
model.load_state_dict(torch.load('./pt/DLF'+str(args.dataset_name)+'.pth'),strict=False)
|
187 |
+
results = trainer.do_test(model, dataloader['test'], mode="TEST")
|
188 |
+
sys.stdout.flush()
|
189 |
+
input('[Press Any Key to start another run]')
|
190 |
+
#train
|
191 |
+
else:
|
192 |
+
epoch_results = trainer.do_train(model, dataloader, return_epoch_results=from_sena)
|
193 |
+
model[0].load_state_dict(torch.load('./pt/DLF'+str(args.dataset_name)+'.pth'))
|
194 |
+
|
195 |
+
results = trainer.do_test(model[0], dataloader['test'], mode="TEST")
|
196 |
+
|
197 |
+
del model
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
gc.collect()
|
200 |
+
time.sleep(1)
|
201 |
+
return results
|
test.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Testing script for DLF
|
3 |
+
"""
|
4 |
+
from run import DLF_run
|
5 |
+
|
6 |
+
DLF_run(model_name='DLF', dataset_name='mosei', is_tune=False, seeds=[1111], model_save_dir="./pt",
|
7 |
+
res_save_dir="./result", log_dir="./log", mode='test', is_training=False)
|
train.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Training script for DLF
|
3 |
+
"""
|
4 |
+
from run import DLF_run
|
5 |
+
|
6 |
+
DLF_run(model_name='DLF', dataset_name='mosi', is_tune=False, seeds=[1111], model_save_dir="./pt",
|
7 |
+
res_save_dir="./result", log_dir="./log", mode='train', is_training=True)
|
trains/ATIO.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
ATIO -- All Trains in One
|
3 |
+
"""
|
4 |
+
from .singleTask import *
|
5 |
+
|
6 |
+
__all__ = ['ATIO']
|
7 |
+
|
8 |
+
class ATIO():
|
9 |
+
def __init__(self):
|
10 |
+
self.TRAIN_MAP = {
|
11 |
+
'DLF': DLF,
|
12 |
+
}
|
13 |
+
|
14 |
+
def getTrain(self, args):
|
15 |
+
return self.TRAIN_MAP[args['model_name']](args)
|
trains/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ATIO import ATIO
|
trains/singleTask/DLF.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torch import optim
|
6 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
7 |
+
from tqdm import tqdm
|
8 |
+
from ..utils import MetricsTop, dict_to_str
|
9 |
+
from .HingeLoss import HingeLoss
|
10 |
+
|
11 |
+
|
12 |
+
logger = logging.getLogger('MMSA')
|
13 |
+
|
14 |
+
class MSE(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super(MSE, self).__init__()
|
17 |
+
|
18 |
+
def forward(self, pred, real):
|
19 |
+
diffs = torch.add(real, -pred)
|
20 |
+
n = torch.numel(diffs.data)
|
21 |
+
mse = torch.sum(diffs.pow(2)) / n
|
22 |
+
return mse
|
23 |
+
|
24 |
+
class DLF():
|
25 |
+
def __init__(self, args):
|
26 |
+
self.args = args
|
27 |
+
self.criterion = nn.L1Loss()
|
28 |
+
self.cosine = nn.CosineEmbeddingLoss()
|
29 |
+
self.metrics = MetricsTop(args.train_mode).getMetics(args.dataset_name)
|
30 |
+
self.MSE = MSE()
|
31 |
+
self.sim_loss = HingeLoss()
|
32 |
+
|
33 |
+
def do_train(self, model, dataloader, return_epoch_results=False):
|
34 |
+
|
35 |
+
# 0: DLF model
|
36 |
+
params = model[0].parameters()
|
37 |
+
|
38 |
+
optimizer = optim.Adam(params, lr=self.args.learning_rate)
|
39 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, verbose=True, patience=self.args.patience)
|
40 |
+
|
41 |
+
epochs, best_epoch = 0, 0
|
42 |
+
if return_epoch_results:
|
43 |
+
epoch_results = {
|
44 |
+
'train': [],
|
45 |
+
'valid': [],
|
46 |
+
'test': []
|
47 |
+
}
|
48 |
+
min_or_max = 'min' if self.args.KeyEval in ['Loss'] else 'max'
|
49 |
+
best_valid = 1e8 if min_or_max == 'min' else 0
|
50 |
+
|
51 |
+
net = []
|
52 |
+
net_DLF = model[0]
|
53 |
+
net.append(net_DLF)
|
54 |
+
model = net
|
55 |
+
|
56 |
+
while True:
|
57 |
+
epochs += 1
|
58 |
+
y_pred, y_true = [], []
|
59 |
+
for mod in model:
|
60 |
+
mod.train()
|
61 |
+
|
62 |
+
|
63 |
+
train_loss = 0.0
|
64 |
+
left_epochs = self.args.update_epochs
|
65 |
+
with tqdm(dataloader['train']) as td:
|
66 |
+
for batch_data in td:
|
67 |
+
|
68 |
+
if left_epochs == self.args.update_epochs:
|
69 |
+
optimizer.zero_grad()
|
70 |
+
left_epochs -= 1
|
71 |
+
vision = batch_data['vision'].to(self.args.device)
|
72 |
+
audio = batch_data['audio'].to(self.args.device)
|
73 |
+
text = batch_data['text'].to(self.args.device)
|
74 |
+
labels = batch_data['labels']['M'].to(self.args.device)
|
75 |
+
labels = labels.view(-1, 1)
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
output = model[0](text, audio, vision)
|
80 |
+
|
81 |
+
# task loss
|
82 |
+
loss_task_all = self.criterion(output['output_logit'], labels)
|
83 |
+
|
84 |
+
loss_task_l_hetero = self.criterion(output['logits_l_hetero'], labels)
|
85 |
+
loss_task_v_hetero = self.criterion(output['logits_v_hetero'], labels)
|
86 |
+
loss_task_a_hetero = self.criterion(output['logits_a_hetero'], labels)
|
87 |
+
loss_task_c = self.criterion(output['logits_c'], labels)
|
88 |
+
|
89 |
+
# total MSA loss L_msa
|
90 |
+
loss_task = 1* (1 * loss_task_all + 1*loss_task_c + 3 * loss_task_l_hetero + 1*loss_task_v_hetero + 1*loss_task_a_hetero)
|
91 |
+
|
92 |
+
# reconstruction loss L_r
|
93 |
+
loss_recon_l = self.MSE(output['recon_l'], output['origin_l'])
|
94 |
+
loss_recon_v = self.MSE(output['recon_v'], output['origin_v'])
|
95 |
+
loss_recon_a = self.MSE(output['recon_a'], output['origin_a'])
|
96 |
+
loss_recon = loss_recon_l + loss_recon_v + loss_recon_a
|
97 |
+
|
98 |
+
# specific loss L_s
|
99 |
+
loss_sl_slr = self.MSE(output['s_l'].permute(1, 2, 0), output['s_l_r'])
|
100 |
+
loss_sv_slv = self.MSE(output['s_v'].permute(1, 2, 0), output['s_v_r'])
|
101 |
+
loss_sa_sla = self.MSE(output['s_a'].permute(1, 2, 0), output['s_a_r'])
|
102 |
+
loss_s_sr = loss_sl_slr + loss_sv_slv + loss_sa_sla
|
103 |
+
|
104 |
+
# ort loss L_o
|
105 |
+
if self.args.dataset_name == 'mosi':
|
106 |
+
num = 50
|
107 |
+
elif self.args.dataset_name == 'mosei':
|
108 |
+
num = 10
|
109 |
+
|
110 |
+
cosine_similarity_s_c_l = self.cosine(output['s_l'].reshape(-1, num), output['c_l'].reshape(-1, num), torch.tensor([-1]).cuda())
|
111 |
+
cosine_similarity_s_c_v = self.cosine(output['s_v'].reshape(-1, num), output['c_v'].reshape(-1, num), torch.tensor([-1]).cuda())
|
112 |
+
cosine_similarity_s_c_a = self.cosine(output['s_a'].reshape(-1, num), output['c_a'].reshape(-1, num), torch.tensor([-1]).cuda())
|
113 |
+
|
114 |
+
loss_ort = cosine_similarity_s_c_l + cosine_similarity_s_c_v + cosine_similarity_s_c_a
|
115 |
+
|
116 |
+
# triplet margin loss L_m
|
117 |
+
c_l, c_v, c_a = output['c_l_sim'], output['c_v_sim'], output['c_a_sim']
|
118 |
+
ids, feats = [], []
|
119 |
+
for i in range(labels.size(0)):
|
120 |
+
feats.append(c_l[i].view(1, -1))
|
121 |
+
feats.append(c_v[i].view(1, -1))
|
122 |
+
feats.append(c_a[i].view(1, -1))
|
123 |
+
ids.append(labels[i].view(1, -1))
|
124 |
+
ids.append(labels[i].view(1, -1))
|
125 |
+
ids.append(labels[i].view(1, -1))
|
126 |
+
feats = torch.cat(feats, dim=0)
|
127 |
+
ids = torch.cat(ids, dim=0)
|
128 |
+
loss_sim = self.sim_loss(ids, feats)
|
129 |
+
|
130 |
+
#overall loss L_DLF
|
131 |
+
combined_loss = loss_task + (loss_s_sr + loss_recon + (loss_sim+loss_ort) * 0.1) * 0.1
|
132 |
+
|
133 |
+
combined_loss.backward()
|
134 |
+
|
135 |
+
|
136 |
+
if self.args.grad_clip != -1.0:
|
137 |
+
params = list(model[0].parameters())
|
138 |
+
|
139 |
+
nn.utils.clip_grad_value_(params, self.args.grad_clip)
|
140 |
+
|
141 |
+
train_loss += combined_loss.item()
|
142 |
+
|
143 |
+
|
144 |
+
y_pred.append(output['output_logit'].cpu())
|
145 |
+
y_true.append(labels.cpu())
|
146 |
+
if not left_epochs:
|
147 |
+
optimizer.step()
|
148 |
+
left_epochs = self.args.update_epochs
|
149 |
+
if not left_epochs:
|
150 |
+
# update
|
151 |
+
optimizer.step()
|
152 |
+
|
153 |
+
|
154 |
+
train_loss = train_loss / len(dataloader['train'])
|
155 |
+
pred, true = torch.cat(y_pred), torch.cat(y_true)
|
156 |
+
train_results = self.metrics(pred, true)
|
157 |
+
logger.info(
|
158 |
+
f">> Epoch: {epochs} "
|
159 |
+
f"TRAIN -({self.args.model_name}) [{epochs - best_epoch}/{epochs}/{self.args.cur_seed}] "
|
160 |
+
f">> total_loss: {round(train_loss, 4)} "
|
161 |
+
f"{dict_to_str(train_results)}"
|
162 |
+
)
|
163 |
+
# validation
|
164 |
+
val_results = self.do_test(model[0], dataloader['valid'], mode="VAL")
|
165 |
+
test_results = self.do_test(model[0], dataloader['test'], mode="TEST")
|
166 |
+
cur_valid = val_results[self.args.KeyEval]
|
167 |
+
scheduler.step(val_results['Loss'])
|
168 |
+
# save each epoch model
|
169 |
+
torch.save(model[0].state_dict(), './pt/' + str(self.args.dataset_name) + '_' + str(epochs) + '.pth')
|
170 |
+
# save best model
|
171 |
+
isBetter = cur_valid <= (best_valid - 1e-6) if min_or_max == 'min' else cur_valid >= (best_valid + 1e-6)
|
172 |
+
if isBetter:
|
173 |
+
best_valid, best_epoch = cur_valid, epochs
|
174 |
+
# save model
|
175 |
+
model_save_path = './pt/DLF' + str(self.args.dataset_name)+'.pth'
|
176 |
+
torch.save(model[0].state_dict(), model_save_path)
|
177 |
+
|
178 |
+
if return_epoch_results:
|
179 |
+
train_results["Loss"] = train_loss
|
180 |
+
epoch_results['train'].append(train_results)
|
181 |
+
epoch_results['valid'].append(val_results)
|
182 |
+
test_results = self.do_test(model, dataloader['test'], mode="TEST")
|
183 |
+
epoch_results['test'].append(test_results)
|
184 |
+
# early stop
|
185 |
+
if epochs - best_epoch >= self.args.early_stop:
|
186 |
+
return epoch_results if return_epoch_results else None
|
187 |
+
|
188 |
+
def do_test(self, model, dataloader, mode="VAL", return_sample_results=False):
|
189 |
+
|
190 |
+
model.eval()
|
191 |
+
y_pred, y_true = [], []
|
192 |
+
|
193 |
+
eval_loss = 0.0
|
194 |
+
if return_sample_results:
|
195 |
+
ids, sample_results = [], []
|
196 |
+
all_labels = []
|
197 |
+
features = {
|
198 |
+
"Feature_t": [],
|
199 |
+
"Feature_a": [],
|
200 |
+
"Feature_v": [],
|
201 |
+
"Feature_f": [],
|
202 |
+
}
|
203 |
+
|
204 |
+
with torch.no_grad():
|
205 |
+
with tqdm(dataloader) as td:
|
206 |
+
for batch_data in td:
|
207 |
+
vision = batch_data['vision'].to(self.args.device)
|
208 |
+
audio = batch_data['audio'].to(self.args.device)
|
209 |
+
text = batch_data['text'].to(self.args.device)
|
210 |
+
labels = batch_data['labels']['M'].to(self.args.device)
|
211 |
+
labels = labels.view(-1, 1)
|
212 |
+
output = model(text, audio, vision)
|
213 |
+
loss = self.criterion(output['output_logit'], labels)
|
214 |
+
eval_loss += loss.item()
|
215 |
+
y_pred.append(output['output_logit'].cpu())
|
216 |
+
y_true.append(labels.cpu())
|
217 |
+
|
218 |
+
eval_loss = eval_loss / len(dataloader)
|
219 |
+
pred, true = torch.cat(y_pred), torch.cat(y_true)
|
220 |
+
|
221 |
+
eval_results = self.metrics(pred, true)
|
222 |
+
eval_results["Loss"] = round(eval_loss, 4)
|
223 |
+
logger.info(f"{mode}-({self.args.model_name}) >> {dict_to_str(eval_results)}")
|
224 |
+
|
225 |
+
if return_sample_results:
|
226 |
+
eval_results["Ids"] = ids
|
227 |
+
eval_results["SResults"] = sample_results
|
228 |
+
for k in features.keys():
|
229 |
+
features[k] = np.concatenate(features[k], axis=0)
|
230 |
+
eval_results['Features'] = features
|
231 |
+
eval_results['Labels'] = all_labels
|
232 |
+
|
233 |
+
return eval_results
|
trains/singleTask/HingeLoss.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class HingeLoss(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super(HingeLoss, self).__init__()
|
8 |
+
|
9 |
+
def compute_cosine(self, x, y):
|
10 |
+
# x = self.compute_compact_s(x)
|
11 |
+
# y = self.compute_compact_s(y)
|
12 |
+
x_norm = torch.sqrt(torch.sum(torch.pow(x, 2), 1)+1e-8)
|
13 |
+
x_norm = torch.max(x_norm, 1e-8*torch.ones_like(x_norm))
|
14 |
+
y_norm = torch.sqrt(torch.sum(torch.pow(y, 2), 1)+1e-8)
|
15 |
+
y_norm = torch.max(y_norm, 1e-8*torch.ones_like(y_norm))
|
16 |
+
cosine = torch.sum(x * y, 1) / (x_norm * y_norm)
|
17 |
+
return cosine
|
18 |
+
|
19 |
+
def forward(self, ids, feats, margin=0.1):
|
20 |
+
B, F = feats.shape
|
21 |
+
|
22 |
+
s = feats.repeat(1, B).view(-1, F) # B**2 X F
|
23 |
+
s_ids = ids.view(B, 1).repeat(1, B) # B X B
|
24 |
+
|
25 |
+
t = feats.repeat(B, 1) # B**2 X F
|
26 |
+
t_ids = ids.view(1, B).repeat(B, 1) # B X B
|
27 |
+
|
28 |
+
cosine = self.compute_cosine(s, t) # B**2
|
29 |
+
equal_mask = torch.eye(B, dtype=torch.bool) # B X B
|
30 |
+
s_ids = s_ids[~equal_mask].view(B, B-1) # B X (B-1)
|
31 |
+
t_ids = t_ids[~equal_mask].view(B, B-1) # B X (B-1)
|
32 |
+
cosine = cosine.view(B, B)[~equal_mask].view(B, B-1) # B X (B-1)
|
33 |
+
|
34 |
+
sim_mask = (s_ids == t_ids) # B X (B-1)
|
35 |
+
margin = 0.15 * abs(s_ids - t_ids)#[~sim_mask].view(B, B - 3)
|
36 |
+
|
37 |
+
loss = 0
|
38 |
+
loss_num = 0
|
39 |
+
|
40 |
+
for i in range(B):
|
41 |
+
sim_num = sum(sim_mask[i])
|
42 |
+
dif_num = B - 1 - sim_num
|
43 |
+
if not sim_num or not dif_num:
|
44 |
+
continue
|
45 |
+
sim_cos = cosine[i, sim_mask[i]].reshape(-1, 1).repeat(1, dif_num)
|
46 |
+
dif_cos = cosine[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1)
|
47 |
+
t_margin = margin[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1)
|
48 |
+
|
49 |
+
loss_i = torch.max(torch.zeros_like(sim_cos), t_margin - sim_cos + dif_cos).mean()
|
50 |
+
loss += loss_i
|
51 |
+
loss_num += 1
|
52 |
+
|
53 |
+
if loss_num == 0:
|
54 |
+
loss_num = 1
|
55 |
+
|
56 |
+
loss = loss / loss_num
|
57 |
+
return loss
|
trains/singleTask/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .DLF import DLF
|
trains/singleTask/distillnets/get_distillation_kernel.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Graph distillation for hetero GD"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from ..utils import distance_metric, min_cosine
|
8 |
+
|
9 |
+
class DistillationKernel(nn.Module):
|
10 |
+
"""Graph Distillation kernel.
|
11 |
+
|
12 |
+
Calculate the edge weights e_{j->k} for each j. Modality k is specified by
|
13 |
+
to_idx, and the other modalities are specified by from_idx.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
|
17 |
+
gd_prior, gd_reg, w_losses, metric, alpha, hyp_params):
|
18 |
+
super(DistillationKernel, self).__init__()
|
19 |
+
self.W_logit = nn.Linear(n_classes, gd_size)
|
20 |
+
self.W_repr = nn.Linear(hidden_size, gd_size)
|
21 |
+
self.W_edge = nn.Linear(gd_size * 4, 1)
|
22 |
+
|
23 |
+
self.gd_size = gd_size
|
24 |
+
self.to_idx = to_idx
|
25 |
+
self.from_idx = from_idx
|
26 |
+
self.alpha = alpha
|
27 |
+
self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
|
28 |
+
self.gd_reg = gd_reg
|
29 |
+
self.w_losses = w_losses
|
30 |
+
self.metric = metric
|
31 |
+
self.hyp_params = hyp_params
|
32 |
+
|
33 |
+
|
34 |
+
def forward(self, logits, reprs):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
logits: (n_modalities, batch_size, n_classes)
|
38 |
+
reprs: (n_modalities, batch_siz`, hidden_size)
|
39 |
+
Return:
|
40 |
+
edges: weights e_{j->k} (n_modalities_from, batch_size)
|
41 |
+
"""
|
42 |
+
n_modalities, batch_size = logits.size()[:2]
|
43 |
+
z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
|
44 |
+
z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
|
45 |
+
z = torch.cat(
|
46 |
+
(z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
|
47 |
+
self.gd_size * 2)
|
48 |
+
edges = []
|
49 |
+
for j in self.to_idx:
|
50 |
+
for i in self.from_idx:
|
51 |
+
if i == j:
|
52 |
+
continue
|
53 |
+
else:
|
54 |
+
# To calculate e_{j->k}, concatenate z^j, z^k
|
55 |
+
e = self.W_edge(torch.cat((z[j], z[i]), dim=1))
|
56 |
+
edges.append(e)
|
57 |
+
edges = torch.cat(edges, dim=1)
|
58 |
+
edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1)
|
59 |
+
edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1)
|
60 |
+
return edges, edges_origin
|
61 |
+
|
62 |
+
|
63 |
+
def distillation_loss(self, logits, reprs, edges):
|
64 |
+
"""Calculate graph distillation losses, which include:
|
65 |
+
regularization loss, loss for logits, and loss for representation.
|
66 |
+
"""
|
67 |
+
loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
|
68 |
+
loss_logit, loss_repr = 0, 0
|
69 |
+
x = 0
|
70 |
+
for j in self.to_idx:
|
71 |
+
for i, idx in enumerate(self.from_idx):
|
72 |
+
if i == j:
|
73 |
+
continue
|
74 |
+
else:
|
75 |
+
w_distill = edges[x] + self.gd_prior[x]
|
76 |
+
# print(edges.sum(1), w_distill.sum(0))
|
77 |
+
loss_logit += self.w_losses[0] * distance_metric(
|
78 |
+
logits[j], logits[idx], self.metric, w_distill)
|
79 |
+
loss_repr += self.w_losses[1] * min_cosine(
|
80 |
+
reprs[j], reprs[idx], self.metric, w_distill)
|
81 |
+
x = x + 1
|
82 |
+
return loss_reg, loss_logit, loss_repr
|
83 |
+
|
84 |
+
|
85 |
+
def get_distillation_kernel(n_classes,
|
86 |
+
hidden_size,
|
87 |
+
gd_size,
|
88 |
+
to_idx,
|
89 |
+
from_idx,
|
90 |
+
gd_prior,
|
91 |
+
gd_reg,
|
92 |
+
w_losses,
|
93 |
+
metric,
|
94 |
+
alpha=1 / 8):
|
95 |
+
return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
|
96 |
+
gd_prior, gd_reg, w_losses, metric, alpha)
|
trains/singleTask/distillnets/get_distillation_kernel_homo.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Graph distillation for homo GD"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from ..utils import distance_metric, min_cosine
|
8 |
+
|
9 |
+
class DistillationKernel(nn.Module):
|
10 |
+
"""Graph Distillation kernel.
|
11 |
+
|
12 |
+
Calculate the edge weights e_{j->k} for each j. Modality k is specified by
|
13 |
+
to_idx, and the other modalities are specified by from_idx.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
|
17 |
+
gd_prior, gd_reg, w_losses, metric, alpha, hyp_params):
|
18 |
+
super(DistillationKernel, self).__init__()
|
19 |
+
self.W_logit = nn.Linear(n_classes, gd_size)
|
20 |
+
self.W_repr = nn.Linear(hidden_size, gd_size)
|
21 |
+
self.W_edge = nn.Linear(gd_size * 4, 1)
|
22 |
+
|
23 |
+
self.gd_size = gd_size
|
24 |
+
self.to_idx = to_idx
|
25 |
+
self.from_idx = from_idx
|
26 |
+
self.alpha = alpha
|
27 |
+
self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
|
28 |
+
self.gd_reg = gd_reg
|
29 |
+
self.w_losses = w_losses
|
30 |
+
self.metric = metric
|
31 |
+
self.hyp_params = hyp_params
|
32 |
+
|
33 |
+
|
34 |
+
def forward(self, logits, reprs):
|
35 |
+
"""
|
36 |
+
Args:
|
37 |
+
logits: (n_modalities, batch_size, n_classes)
|
38 |
+
reprs: (n_modalities, batch_siz`, hidden_size)
|
39 |
+
Return:
|
40 |
+
edges: weights e_{j->k} (n_modalities_from, batch_size)
|
41 |
+
"""
|
42 |
+
n_modalities, batch_size = logits.size()[:2]
|
43 |
+
z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
|
44 |
+
z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
|
45 |
+
z = torch.cat(
|
46 |
+
(z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
|
47 |
+
self.gd_size * 2)
|
48 |
+
|
49 |
+
|
50 |
+
edges = []
|
51 |
+
for j in self.to_idx:
|
52 |
+
for i in self.from_idx:
|
53 |
+
if i == j:
|
54 |
+
continue
|
55 |
+
else:
|
56 |
+
# To calculate e_{j->k}, concatenate z^j, z^k
|
57 |
+
e = self.W_edge(torch.cat((z[j], z[i]), dim=1))
|
58 |
+
edges.append(e)
|
59 |
+
edges = torch.cat(edges, dim=1)
|
60 |
+
edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1) # original value of edges
|
61 |
+
edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1) # normalized value of edges
|
62 |
+
return edges, edges_origin
|
63 |
+
|
64 |
+
|
65 |
+
def distillation_loss(self, logits, reprs, edges):
|
66 |
+
"""Calculate graph distillation losses, which include:
|
67 |
+
regularization loss, loss for logits, and loss for representation.
|
68 |
+
"""
|
69 |
+
loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
|
70 |
+
|
71 |
+
|
72 |
+
loss_logit, loss_repr = 0, 0
|
73 |
+
x = 0
|
74 |
+
for j in self.to_idx:
|
75 |
+
for i, idx in enumerate(self.from_idx):
|
76 |
+
if i == j:
|
77 |
+
continue
|
78 |
+
else:
|
79 |
+
w_distill = edges[x] + self.gd_prior[x]
|
80 |
+
# print(edges.sum(1), w_distill.sum(0))
|
81 |
+
loss_logit += self.w_losses[0] * distance_metric(
|
82 |
+
logits[j], logits[idx], self.metric, w_distill)
|
83 |
+
loss_repr += self.w_losses[1] * distance_metric(
|
84 |
+
reprs[j], reprs[idx], self.metric, w_distill)
|
85 |
+
x = x + 1
|
86 |
+
return loss_reg, loss_logit, loss_repr
|
87 |
+
|
88 |
+
|
89 |
+
def get_distillation_kernel(n_classes,
|
90 |
+
hidden_size,
|
91 |
+
gd_size,
|
92 |
+
to_idx,
|
93 |
+
from_idx,
|
94 |
+
gd_prior,
|
95 |
+
gd_reg,
|
96 |
+
w_losses,
|
97 |
+
metric,
|
98 |
+
alpha=1 / 8):
|
99 |
+
return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
|
100 |
+
gd_prior, gd_reg, w_losses, metric, alpha)
|
trains/singleTask/misc.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import average_precision_score
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.data
|
6 |
+
|
7 |
+
def to_numpy(array):
|
8 |
+
if isinstance(array, np.ndarray):
|
9 |
+
return array
|
10 |
+
if isinstance(array, torch.autograd.Variable):
|
11 |
+
array = array.data
|
12 |
+
if array.is_cuda:
|
13 |
+
array = array.cpu()
|
14 |
+
|
15 |
+
return array.numpy()
|
16 |
+
|
17 |
+
|
18 |
+
def squeeze(array):
|
19 |
+
if not isinstance(array, list) or len(array) > 1:
|
20 |
+
return array
|
21 |
+
else: # len(array) == 1:
|
22 |
+
return array[0]
|
23 |
+
|
24 |
+
|
25 |
+
def unsqueeze(array):
|
26 |
+
if isinstance(array, list):
|
27 |
+
return array
|
28 |
+
else:
|
29 |
+
return [array]
|
30 |
+
|
31 |
+
|
32 |
+
def is_due(*args):
|
33 |
+
"""Determines whether to perform an action or not, depending on the epoch.
|
34 |
+
Used for logging, saving, learning rate decay, etc.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
*args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
|
38 |
+
due_every (due every due_every epochs)
|
39 |
+
step, due_every (due every due_every steps)
|
40 |
+
Returns:
|
41 |
+
due: boolean: perform action or not
|
42 |
+
"""
|
43 |
+
if len(args) == 2 and isinstance(args[1], list):
|
44 |
+
epoch, due_at = args
|
45 |
+
due = epoch in due_at
|
46 |
+
elif len(args) == 3:
|
47 |
+
epoch, num_epochs, due_every = args
|
48 |
+
due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
|
49 |
+
else:
|
50 |
+
step, due_every = args
|
51 |
+
due = (due_every > 0) and (step % due_every == 0)
|
52 |
+
|
53 |
+
return due
|
54 |
+
|
55 |
+
|
56 |
+
def softmax(w, t=1.0, axis=None):
|
57 |
+
w = np.array(w) / t
|
58 |
+
e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
|
59 |
+
dist = e / np.sum(e, axis=axis, keepdims=True)
|
60 |
+
return dist
|
61 |
+
|
62 |
+
|
63 |
+
def min_cosine(student, teacher, option, weights=None):
|
64 |
+
cosine = torch.nn.CosineEmbeddingLoss()
|
65 |
+
dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda())
|
66 |
+
if weights is None:
|
67 |
+
dist = dists.mean()
|
68 |
+
else:
|
69 |
+
dist = (dists * weights).mean()
|
70 |
+
|
71 |
+
return dist
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def distance_metric(student, teacher, option, weights=None):
|
76 |
+
"""Distance metric to calculate the imitation loss.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
student: batch_size x n_classes
|
80 |
+
teacher: batch_size x n_classes
|
81 |
+
option: one of [cosine, l2, l2, kl]
|
82 |
+
weights: batch_size or float
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
The computed distance metric.
|
86 |
+
"""
|
87 |
+
if option == 'cosine':
|
88 |
+
dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
|
89 |
+
# dists = 1 - F.cosine_similarity(student, teacher, dim=1)
|
90 |
+
elif option == 'l2':
|
91 |
+
dists = (student-teacher.detach()).pow(2).sum(1)
|
92 |
+
elif option == 'l1':
|
93 |
+
dists = torch.abs(student-teacher.detach()).sum(1)
|
94 |
+
elif option == 'kl':
|
95 |
+
assert weights is None
|
96 |
+
T = 8
|
97 |
+
# averaged for each minibatch
|
98 |
+
dist = F.kl_div(
|
99 |
+
F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
|
100 |
+
T * T)
|
101 |
+
return dist
|
102 |
+
else:
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
if weights is None:
|
106 |
+
dist = dists.mean()
|
107 |
+
else:
|
108 |
+
dist = (dists * weights).mean()
|
109 |
+
|
110 |
+
return dist
|
111 |
+
|
112 |
+
|
113 |
+
def get_segments(input, timestep):
|
114 |
+
"""Split entire input into segments of length timestep.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
input: 1 x total_length x n_frames x ...
|
118 |
+
timestep: the timestamp.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
input: concatenated video segments
|
122 |
+
start_indices: indices of the segments
|
123 |
+
"""
|
124 |
+
assert input.size(0) == 1, 'Test time, batch_size must be 1'
|
125 |
+
|
126 |
+
input.squeeze_(dim=0)
|
127 |
+
# Find overlapping segments
|
128 |
+
length = input.size()[0]
|
129 |
+
step = timestep // 2
|
130 |
+
num_segments = (length - timestep) // step + 1
|
131 |
+
start_indices = (np.arange(num_segments) * step).tolist()
|
132 |
+
if length % step > 0:
|
133 |
+
start_indices.append(length - timestep)
|
134 |
+
|
135 |
+
# Get the segments
|
136 |
+
segments = []
|
137 |
+
for s in start_indices:
|
138 |
+
segment = input[s: (s + timestep)].unsqueeze(0)
|
139 |
+
segments.append(segment)
|
140 |
+
input = torch.cat(segments, dim=0)
|
141 |
+
return input, start_indices
|
142 |
+
|
143 |
+
def get_stats(logit, label):
|
144 |
+
'''
|
145 |
+
Calculate the accuracy.
|
146 |
+
'''
|
147 |
+
logit = to_numpy(logit)
|
148 |
+
label = to_numpy(label)
|
149 |
+
|
150 |
+
pred = np.argmax(logit, 1)
|
151 |
+
acc = np.sum(pred == label)/label.shape[0]
|
152 |
+
|
153 |
+
return acc, pred, label
|
154 |
+
|
155 |
+
|
156 |
+
def get_stats_detection(logit, label, n_classes=52):
|
157 |
+
'''
|
158 |
+
Calculate the accuracy and average precisions.
|
159 |
+
'''
|
160 |
+
logit = to_numpy(logit)
|
161 |
+
label = to_numpy(label)
|
162 |
+
scores = softmax(logit, axis=1)
|
163 |
+
|
164 |
+
pred = np.argmax(logit, 1)
|
165 |
+
length = label.shape[0]
|
166 |
+
acc = np.sum(pred == label)/length
|
167 |
+
|
168 |
+
keep_bg = label == 0
|
169 |
+
acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
|
170 |
+
ratio_bg = np.sum(keep_bg)/length
|
171 |
+
|
172 |
+
keep_action = label != 0
|
173 |
+
acc_action = np.sum(
|
174 |
+
pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
|
175 |
+
|
176 |
+
# Average precision
|
177 |
+
y_true = np.zeros((len(label), n_classes))
|
178 |
+
y_true[np.arange(len(label)), label] = 1
|
179 |
+
acc = np.sum(pred == label)/label.shape[0]
|
180 |
+
aps = average_precision_score(y_true, scores, average=None)
|
181 |
+
aps = list(filter(lambda x: not np.isnan(x), aps))
|
182 |
+
ap = np.mean(aps)
|
183 |
+
|
184 |
+
return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
|
185 |
+
|
186 |
+
|
187 |
+
def info(text):
|
188 |
+
print('\033[94m' + text + '\033[0m')
|
189 |
+
|
190 |
+
|
191 |
+
def warn(text):
|
192 |
+
print('\033[93m' + text + '\033[0m')
|
193 |
+
|
194 |
+
|
195 |
+
def err(text):
|
196 |
+
print('\033[91m' + text + '\033[0m')
|
trains/singleTask/model/DLF.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
here is the mian backbone for DLF
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from ...subNets import BertTextEncoder
|
8 |
+
from ...subNets.transformers_encoder.transformer import TransformerEncoder
|
9 |
+
|
10 |
+
class DLF(nn.Module):
|
11 |
+
def __init__(self, args):
|
12 |
+
super(DLF, self).__init__()
|
13 |
+
if args.use_bert:
|
14 |
+
self.text_model = BertTextEncoder(use_finetune=args.use_finetune, transformers=args.transformers,
|
15 |
+
pretrained=args.pretrained)
|
16 |
+
self.use_bert = args.use_bert
|
17 |
+
dst_feature_dims, nheads = args.dst_feature_dim_nheads
|
18 |
+
if args.dataset_name == 'mosi':
|
19 |
+
if args.need_data_aligned:
|
20 |
+
self.len_l, self.len_v, self.len_a = 50, 50, 50
|
21 |
+
else:
|
22 |
+
self.len_l, self.len_v, self.len_a = 50, 500, 375
|
23 |
+
if args.dataset_name == 'mosei':
|
24 |
+
if args.need_data_aligned:
|
25 |
+
self.len_l, self.len_v, self.len_a = 50, 50, 50
|
26 |
+
else:
|
27 |
+
self.len_l, self.len_v, self.len_a = 50, 500, 500
|
28 |
+
self.orig_d_l, self.orig_d_a, self.orig_d_v = args.feature_dims
|
29 |
+
self.d_l = self.d_a = self.d_v = dst_feature_dims
|
30 |
+
self.num_heads = nheads
|
31 |
+
self.layers = args.nlevels
|
32 |
+
self.attn_dropout = args.attn_dropout
|
33 |
+
self.attn_dropout_a = args.attn_dropout_a
|
34 |
+
self.attn_dropout_v = args.attn_dropout_v
|
35 |
+
self.relu_dropout = args.relu_dropout
|
36 |
+
self.embed_dropout = args.embed_dropout
|
37 |
+
self.res_dropout = args.res_dropout
|
38 |
+
self.output_dropout = args.output_dropout
|
39 |
+
self.text_dropout = args.text_dropout
|
40 |
+
self.attn_mask = args.attn_mask
|
41 |
+
combined_dim_low = self.d_a
|
42 |
+
combined_dim_high = self.d_a
|
43 |
+
combined_dim = (self.d_l + self.d_a + self.d_v ) + self.d_l * 3
|
44 |
+
|
45 |
+
output_dim = 1
|
46 |
+
|
47 |
+
# 1. Temporal convolutional layers for initial feature
|
48 |
+
self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=args.conv1d_kernel_size_l, padding=0, bias=False)
|
49 |
+
self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=args.conv1d_kernel_size_a, padding=0, bias=False)
|
50 |
+
self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=args.conv1d_kernel_size_v, padding=0, bias=False)
|
51 |
+
|
52 |
+
# 2. Modality-specific encoder
|
53 |
+
self.encoder_s_l = self.get_network(self_type='l', layers = self.layers)
|
54 |
+
self.encoder_s_v = self.get_network(self_type='v', layers = self.layers)
|
55 |
+
self.encoder_s_a = self.get_network(self_type='a', layers = self.layers)
|
56 |
+
|
57 |
+
# Modality-shared encoder
|
58 |
+
self.encoder_c = self.get_network(self_type='l', layers = self.layers)
|
59 |
+
|
60 |
+
|
61 |
+
# 3. Decoder for reconstruct three modalities
|
62 |
+
self.decoder_l = nn.Conv1d(self.d_l * 2, self.d_l, kernel_size=1, padding=0, bias=False)
|
63 |
+
self.decoder_v = nn.Conv1d(self.d_v * 2, self.d_v, kernel_size=1, padding=0, bias=False)
|
64 |
+
self.decoder_a = nn.Conv1d(self.d_a * 2, self.d_a, kernel_size=1, padding=0, bias=False)
|
65 |
+
|
66 |
+
# for calculate cosine sim between s_x
|
67 |
+
self.proj_cosine_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
|
68 |
+
self.proj_cosine_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
|
69 |
+
self.proj_cosine_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
|
70 |
+
|
71 |
+
# for align c_l, c_v, c_a
|
72 |
+
self.align_c_l = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
|
73 |
+
self.align_c_v = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
|
74 |
+
self.align_c_a = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
|
75 |
+
|
76 |
+
self.self_attentions_c_l = self.get_network(self_type='l')
|
77 |
+
self.self_attentions_c_v = self.get_network(self_type='v')
|
78 |
+
self.self_attentions_c_a = self.get_network(self_type='a')
|
79 |
+
|
80 |
+
self.proj1_c = nn.Linear(self.d_l * 3, self.d_l * 3)
|
81 |
+
self.proj2_c = nn.Linear(self.d_l * 3, self.d_l * 3)
|
82 |
+
self.out_layer_c = nn.Linear(self.d_l * 3, output_dim)
|
83 |
+
|
84 |
+
|
85 |
+
# 4 Multimodal Crossmodal Attentions
|
86 |
+
self.trans_l_with_a = self.get_network(self_type='la', layers = self.layers)
|
87 |
+
self.trans_l_with_v = self.get_network(self_type='lv', layers = self.layers)
|
88 |
+
self.trans_a_with_l = self.get_network(self_type='al')
|
89 |
+
self.trans_a_with_v = self.get_network(self_type='av')
|
90 |
+
self.trans_v_with_l = self.get_network(self_type='vl')
|
91 |
+
self.trans_v_with_a = self.get_network(self_type='va')
|
92 |
+
self.trans_l_mem = self.get_network(self_type='l_mem', layers=self.layers)
|
93 |
+
self.trans_a_mem = self.get_network(self_type='a_mem', layers=3)
|
94 |
+
self.trans_v_mem = self.get_network(self_type='v_mem', layers=3)
|
95 |
+
|
96 |
+
|
97 |
+
# 5. fc layers for shared features
|
98 |
+
self.proj1_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), combined_dim_low)
|
99 |
+
self.proj2_l_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1))
|
100 |
+
self.out_layer_l_low = nn.Linear(combined_dim_low * (self.len_l - args.conv1d_kernel_size_l + 1), output_dim)
|
101 |
+
self.proj1_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), combined_dim_low)
|
102 |
+
self.proj2_v_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1))
|
103 |
+
self.out_layer_v_low = nn.Linear(combined_dim_low * (self.len_v - args.conv1d_kernel_size_v + 1), output_dim)
|
104 |
+
self.proj1_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), combined_dim_low)
|
105 |
+
self.proj2_a_low = nn.Linear(combined_dim_low, combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1))
|
106 |
+
self.out_layer_a_low = nn.Linear(combined_dim_low * (self.len_a - args.conv1d_kernel_size_a + 1), output_dim)
|
107 |
+
|
108 |
+
|
109 |
+
# 6. fc layers for specific features
|
110 |
+
self.proj1_l_high = nn.Linear(combined_dim_high, combined_dim_high)
|
111 |
+
self.proj2_l_high = nn.Linear(combined_dim_high, combined_dim_high)
|
112 |
+
self.out_layer_l_high = nn.Linear(combined_dim_high, output_dim)
|
113 |
+
self.proj1_v_high = nn.Linear(combined_dim_high, combined_dim_high)
|
114 |
+
self.proj2_v_high = nn.Linear(combined_dim_high, combined_dim_high)
|
115 |
+
self.out_layer_v_high = nn.Linear(combined_dim_high, output_dim)
|
116 |
+
self.proj1_a_high = nn.Linear(combined_dim_high, combined_dim_high)
|
117 |
+
self.proj2_a_high = nn.Linear(combined_dim_high, combined_dim_high)
|
118 |
+
self.out_layer_a_high = nn.Linear(combined_dim_high, output_dim)
|
119 |
+
|
120 |
+
# 7. project for fusion
|
121 |
+
self.projector_l = nn.Linear(self.d_l, self.d_l)
|
122 |
+
self.projector_v = nn.Linear(self.d_v, self.d_v)
|
123 |
+
self.projector_a = nn.Linear(self.d_a, self.d_a)
|
124 |
+
self.projector_c = nn.Linear(3 * self.d_l, 3 * self.d_l)
|
125 |
+
|
126 |
+
# 8. final project
|
127 |
+
self.proj1 = nn.Linear(combined_dim, combined_dim)
|
128 |
+
self.proj2 = nn.Linear(combined_dim, combined_dim)
|
129 |
+
self.out_layer = nn.Linear(combined_dim, output_dim)
|
130 |
+
|
131 |
+
def get_network(self, self_type='l', layers=-1):
|
132 |
+
if self_type in ['l', 'al', 'vl']:
|
133 |
+
embed_dim, attn_dropout = self.d_l, self.attn_dropout
|
134 |
+
elif self_type in ['a', 'la', 'va']:
|
135 |
+
embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
|
136 |
+
elif self_type in ['v', 'lv', 'av']:
|
137 |
+
embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
|
138 |
+
elif self_type == 'l_mem':
|
139 |
+
embed_dim, attn_dropout = self.d_l, self.attn_dropout
|
140 |
+
elif self_type == 'a_mem':
|
141 |
+
embed_dim, attn_dropout = self.d_a, self.attn_dropout
|
142 |
+
elif self_type == 'v_mem':
|
143 |
+
embed_dim, attn_dropout = self.d_v, self.attn_dropout
|
144 |
+
else:
|
145 |
+
raise ValueError("Unknown network type")
|
146 |
+
|
147 |
+
return TransformerEncoder(embed_dim=embed_dim,
|
148 |
+
num_heads=self.num_heads,
|
149 |
+
layers=max(self.layers, layers),
|
150 |
+
attn_dropout=attn_dropout,
|
151 |
+
relu_dropout=self.relu_dropout,
|
152 |
+
res_dropout=self.res_dropout,
|
153 |
+
embed_dropout=self.embed_dropout,
|
154 |
+
attn_mask=self.attn_mask)
|
155 |
+
|
156 |
+
|
157 |
+
def forward(self, text, audio, video):
|
158 |
+
#extraction
|
159 |
+
if self.use_bert:
|
160 |
+
text = self.text_model(text)
|
161 |
+
x_l = F.dropout(text.transpose(1, 2), p=self.text_dropout, training=self.training)
|
162 |
+
x_a = audio.transpose(1, 2)
|
163 |
+
x_v = video.transpose(1, 2)
|
164 |
+
|
165 |
+
|
166 |
+
proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
|
167 |
+
proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
|
168 |
+
proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
|
169 |
+
|
170 |
+
proj_x_l = proj_x_l.permute(2, 0, 1)
|
171 |
+
proj_x_v = proj_x_v .permute(2, 0, 1)
|
172 |
+
proj_x_a = proj_x_a.permute(2, 0, 1)
|
173 |
+
|
174 |
+
#disentanglement
|
175 |
+
s_l = self.encoder_s_l(proj_x_l)
|
176 |
+
s_v = self.encoder_s_v(proj_x_v)
|
177 |
+
s_a = self.encoder_s_a(proj_x_a)
|
178 |
+
|
179 |
+
c_l = self.encoder_c(proj_x_l)
|
180 |
+
c_v = self.encoder_c(proj_x_v)
|
181 |
+
c_a = self.encoder_c(proj_x_a)
|
182 |
+
|
183 |
+
|
184 |
+
s_l = s_l.permute(1, 2, 0)
|
185 |
+
s_v = s_v.permute(1, 2, 0)
|
186 |
+
s_a = s_a.permute(1, 2, 0)
|
187 |
+
|
188 |
+
c_l = c_l.permute(1, 2, 0)
|
189 |
+
c_v = c_v.permute(1, 2, 0)
|
190 |
+
c_a = c_a.permute(1, 2, 0)
|
191 |
+
c_list = [c_l, c_v, c_a]
|
192 |
+
|
193 |
+
|
194 |
+
c_l_sim = self.align_c_l(c_l.contiguous().view(x_l.size(0), -1))
|
195 |
+
c_v_sim = self.align_c_v(c_v.contiguous().view(x_l.size(0), -1))
|
196 |
+
c_a_sim = self.align_c_a(c_a.contiguous().view(x_l.size(0), -1))
|
197 |
+
|
198 |
+
recon_l = self.decoder_l(torch.cat([s_l, c_list[0]], dim=1))
|
199 |
+
recon_v = self.decoder_v(torch.cat([s_v, c_list[1]], dim=1))
|
200 |
+
recon_a = self.decoder_a(torch.cat([s_a, c_list[2]], dim=1))
|
201 |
+
|
202 |
+
recon_l = recon_l.permute(2, 0, 1)
|
203 |
+
recon_v = recon_v.permute(2, 0, 1)
|
204 |
+
recon_a = recon_a.permute(2, 0, 1)
|
205 |
+
|
206 |
+
s_l_r = self.encoder_s_l(recon_l).permute(1, 2, 0)
|
207 |
+
s_v_r = self.encoder_s_v(recon_v).permute(1, 2, 0)
|
208 |
+
s_a_r = self.encoder_s_a(recon_a).permute(1, 2, 0)
|
209 |
+
|
210 |
+
s_l = s_l.permute(2, 0, 1)
|
211 |
+
s_v = s_v.permute(2, 0, 1)
|
212 |
+
s_a = s_a.permute(2, 0, 1)
|
213 |
+
|
214 |
+
c_l = c_l.permute(2, 0, 1)
|
215 |
+
c_v = c_v.permute(2, 0, 1)
|
216 |
+
c_a = c_a.permute(2, 0, 1)
|
217 |
+
|
218 |
+
#enhancement
|
219 |
+
hs_l_low = c_l.transpose(0, 1).contiguous().view(x_l.size(0), -1)
|
220 |
+
repr_l_low = self.proj1_l_low(hs_l_low)
|
221 |
+
hs_proj_l_low = self.proj2_l_low(
|
222 |
+
F.dropout(F.relu(repr_l_low, inplace=True), p=self.output_dropout, training=self.training))
|
223 |
+
hs_proj_l_low += hs_l_low
|
224 |
+
logits_l_low = self.out_layer_l_low(hs_proj_l_low)
|
225 |
+
|
226 |
+
hs_v_low = c_v.transpose(0, 1).contiguous().view(x_v.size(0), -1)
|
227 |
+
repr_v_low = self.proj1_v_low(hs_v_low)
|
228 |
+
hs_proj_v_low = self.proj2_v_low(
|
229 |
+
F.dropout(F.relu(repr_v_low, inplace=True), p=self.output_dropout, training=self.training))
|
230 |
+
hs_proj_v_low += hs_v_low
|
231 |
+
logits_v_low = self.out_layer_v_low(hs_proj_v_low)
|
232 |
+
|
233 |
+
hs_a_low = c_a.transpose(0, 1).contiguous().view(x_a.size(0), -1)
|
234 |
+
repr_a_low = self.proj1_a_low(hs_a_low)
|
235 |
+
hs_proj_a_low = self.proj2_a_low(
|
236 |
+
F.dropout(F.relu(repr_a_low, inplace=True), p=self.output_dropout, training=self.training))
|
237 |
+
hs_proj_a_low += hs_a_low
|
238 |
+
logits_a_low = self.out_layer_a_low(hs_proj_a_low)
|
239 |
+
|
240 |
+
|
241 |
+
c_l_att = self.self_attentions_c_l(c_l)
|
242 |
+
if type(c_l_att) == tuple:
|
243 |
+
c_l_att = c_l_att[0]
|
244 |
+
c_l_att = c_l_att[-1]
|
245 |
+
|
246 |
+
c_v_att = self.self_attentions_c_v(c_v)
|
247 |
+
if type(c_v_att) == tuple:
|
248 |
+
c_v_att = c_v_att[0]
|
249 |
+
c_v_att = c_v_att[-1]
|
250 |
+
|
251 |
+
c_a_att = self.self_attentions_c_a(c_a)
|
252 |
+
if type(c_a_att) == tuple:
|
253 |
+
c_a_att = c_a_att[0]
|
254 |
+
c_a_att = c_a_att[-1]
|
255 |
+
|
256 |
+
c_fusion = torch.cat([c_l_att, c_v_att, c_a_att], dim=1)
|
257 |
+
|
258 |
+
c_proj = self.proj2_c(
|
259 |
+
F.dropout(F.relu(self.proj1_c(c_fusion), inplace=True), p=self.output_dropout,
|
260 |
+
training=self.training))
|
261 |
+
c_proj += c_fusion
|
262 |
+
logits_c = self.out_layer_c(c_proj)
|
263 |
+
|
264 |
+
# LFA
|
265 |
+
# L --> L
|
266 |
+
h_ls = s_l
|
267 |
+
h_ls = self.trans_l_mem(h_ls)
|
268 |
+
if type(h_ls) == tuple:
|
269 |
+
h_ls = h_ls[0]
|
270 |
+
last_h_l = last_hs = h_ls[-1]
|
271 |
+
|
272 |
+
# A --> L
|
273 |
+
h_l_with_as = self.trans_l_with_a(s_l, s_a, s_a)
|
274 |
+
h_as = h_l_with_as
|
275 |
+
h_as = self.trans_a_mem(h_as)
|
276 |
+
if type(h_as) == tuple:
|
277 |
+
h_as = h_as[0]
|
278 |
+
last_h_a = last_hs = h_as[-1]
|
279 |
+
|
280 |
+
# V --> L
|
281 |
+
h_l_with_vs = self.trans_l_with_v(s_l, s_v, s_v)
|
282 |
+
h_vs = h_l_with_vs
|
283 |
+
h_vs = self.trans_v_mem(h_vs)
|
284 |
+
if type(h_vs) == tuple:
|
285 |
+
h_vs = h_vs[0]
|
286 |
+
last_h_v = last_hs = h_vs[-1]
|
287 |
+
|
288 |
+
|
289 |
+
hs_proj_l_high = self.proj2_l_high(
|
290 |
+
F.dropout(F.relu(self.proj1_l_high(last_h_l), inplace=True), p=self.output_dropout, training=self.training))
|
291 |
+
hs_proj_l_high += last_h_l
|
292 |
+
logits_l_high = self.out_layer_l_high(hs_proj_l_high)
|
293 |
+
|
294 |
+
hs_proj_v_high = self.proj2_v_high(
|
295 |
+
F.dropout(F.relu(self.proj1_v_high(last_h_v), inplace=True), p=self.output_dropout, training=self.training))
|
296 |
+
hs_proj_v_high += last_h_v
|
297 |
+
logits_v_high = self.out_layer_v_high(hs_proj_v_high)
|
298 |
+
|
299 |
+
hs_proj_a_high = self.proj2_a_high(
|
300 |
+
F.dropout(F.relu(self.proj1_a_high(last_h_a), inplace=True), p=self.output_dropout,
|
301 |
+
training=self.training))
|
302 |
+
hs_proj_a_high += last_h_a
|
303 |
+
logits_a_high = self.out_layer_a_high(hs_proj_a_high)
|
304 |
+
|
305 |
+
#fusion
|
306 |
+
last_h_l = torch.sigmoid(self.projector_l(hs_proj_l_high))
|
307 |
+
last_h_v = torch.sigmoid(self.projector_v(hs_proj_v_high))
|
308 |
+
last_h_a = torch.sigmoid(self.projector_a(hs_proj_a_high))
|
309 |
+
c_fusion = torch.sigmoid(self.projector_c(c_fusion))
|
310 |
+
|
311 |
+
last_hs = torch.cat([last_h_l, last_h_v, last_h_a, c_fusion], dim=1)
|
312 |
+
|
313 |
+
#prediction
|
314 |
+
last_hs_proj = self.proj2(
|
315 |
+
F.dropout(F.relu(self.proj1(last_hs), inplace=True), p=self.output_dropout, training=self.training))
|
316 |
+
last_hs_proj += last_hs
|
317 |
+
|
318 |
+
output = self.out_layer(last_hs_proj)
|
319 |
+
|
320 |
+
res = {
|
321 |
+
'origin_l': proj_x_l,
|
322 |
+
'origin_v': proj_x_v,
|
323 |
+
'origin_a': proj_x_a,
|
324 |
+
's_l': s_l,
|
325 |
+
's_v': s_v,
|
326 |
+
's_a': s_a,
|
327 |
+
'c_l': c_l,
|
328 |
+
'c_v': c_v,
|
329 |
+
'c_a': c_a,
|
330 |
+
's_l_r': s_l_r,
|
331 |
+
's_v_r': s_v_r,
|
332 |
+
's_a_r': s_a_r,
|
333 |
+
'recon_l': recon_l,
|
334 |
+
'recon_v': recon_v,
|
335 |
+
'recon_a': recon_a,
|
336 |
+
'c_l_sim': c_l_sim,
|
337 |
+
'c_v_sim': c_v_sim,
|
338 |
+
'c_a_sim': c_a_sim,
|
339 |
+
'logits_l_hetero': logits_l_high,
|
340 |
+
'logits_v_hetero': logits_v_high,
|
341 |
+
'logits_a_hetero': logits_a_high,
|
342 |
+
'logits_c': logits_c,
|
343 |
+
'output_logit': output
|
344 |
+
}
|
345 |
+
return res
|
trains/singleTask/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .misc import *
|
trains/singleTask/utils/misc.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import average_precision_score
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.utils.data
|
6 |
+
|
7 |
+
def to_numpy(array):
|
8 |
+
if isinstance(array, np.ndarray):
|
9 |
+
return array
|
10 |
+
if isinstance(array, torch.autograd.Variable):
|
11 |
+
array = array.data
|
12 |
+
if array.is_cuda:
|
13 |
+
array = array.cpu()
|
14 |
+
|
15 |
+
return array.numpy()
|
16 |
+
|
17 |
+
|
18 |
+
def squeeze(array):
|
19 |
+
if not isinstance(array, list) or len(array) > 1:
|
20 |
+
return array
|
21 |
+
else: # len(array) == 1:
|
22 |
+
return array[0]
|
23 |
+
|
24 |
+
|
25 |
+
def unsqueeze(array):
|
26 |
+
if isinstance(array, list):
|
27 |
+
return array
|
28 |
+
else:
|
29 |
+
return [array]
|
30 |
+
|
31 |
+
|
32 |
+
def is_due(*args):
|
33 |
+
"""Determines whether to perform an action or not, depending on the epoch.
|
34 |
+
Used for logging, saving, learning rate decay, etc.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
*args: epoch, due_at (due at epoch due_at) epoch, num_epochs,
|
38 |
+
due_every (due every due_every epochs)
|
39 |
+
step, due_every (due every due_every steps)
|
40 |
+
Returns:
|
41 |
+
due: boolean: perform action or not
|
42 |
+
"""
|
43 |
+
if len(args) == 2 and isinstance(args[1], list):
|
44 |
+
epoch, due_at = args
|
45 |
+
due = epoch in due_at
|
46 |
+
elif len(args) == 3:
|
47 |
+
epoch, num_epochs, due_every = args
|
48 |
+
due = (due_every >= 0) and (epoch % due_every == 0 or epoch == num_epochs)
|
49 |
+
else:
|
50 |
+
step, due_every = args
|
51 |
+
due = (due_every > 0) and (step % due_every == 0)
|
52 |
+
|
53 |
+
return due
|
54 |
+
|
55 |
+
|
56 |
+
def softmax(w, t=1.0, axis=None):
|
57 |
+
w = np.array(w) / t
|
58 |
+
e = np.exp(w - np.amax(w, axis=axis, keepdims=True))
|
59 |
+
dist = e / np.sum(e, axis=axis, keepdims=True)
|
60 |
+
return dist
|
61 |
+
|
62 |
+
def min_cosine(student, teacher, option, weights=None):
|
63 |
+
cosine = torch.nn.CosineEmbeddingLoss()
|
64 |
+
dists = cosine(student, teacher.detach(), torch.tensor([-1]).cuda())
|
65 |
+
if weights is None:
|
66 |
+
dist = dists.mean()
|
67 |
+
else:
|
68 |
+
dist = (dists * weights).mean()
|
69 |
+
|
70 |
+
return dist
|
71 |
+
|
72 |
+
|
73 |
+
def distance_metric(student, teacher, option, weights=None):
|
74 |
+
"""Distance metric to calculate the imitation loss.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
student: batch_size x n_classes
|
78 |
+
teacher: batch_size x n_classes
|
79 |
+
option: one of [cosine, l2, l2, kl]
|
80 |
+
weights: batch_size or float
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
The computed distance metric.
|
84 |
+
"""
|
85 |
+
if option == 'cosine':
|
86 |
+
dists = 1 - F.cosine_similarity(student, teacher.detach(), dim=1)
|
87 |
+
# dists = 1 - F.cosine_similarity(student, teacher, dim=1)
|
88 |
+
elif option == 'l2':
|
89 |
+
dists = (student-teacher.detach()).pow(2).sum(1)
|
90 |
+
elif option == 'l1':
|
91 |
+
dists = torch.abs(student-teacher.detach()).sum(1)
|
92 |
+
elif option == 'kl':
|
93 |
+
# assert weights is None
|
94 |
+
T = 8
|
95 |
+
# averaged for each minibatch
|
96 |
+
dist = F.kl_div(
|
97 |
+
F.log_softmax(student / T), F.softmax(teacher.detach() / T)) * (
|
98 |
+
T * T)
|
99 |
+
return dist
|
100 |
+
else:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
if weights is None:
|
104 |
+
dist = dists.mean()
|
105 |
+
else:
|
106 |
+
dist = (dists * weights).mean()
|
107 |
+
|
108 |
+
return dist
|
109 |
+
|
110 |
+
|
111 |
+
def get_segments(input, timestep):
|
112 |
+
"""Split entire input into segments of length timestep.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
input: 1 x total_length x n_frames x ...
|
116 |
+
timestep: the timestamp.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
input: concatenated video segments
|
120 |
+
start_indices: indices of the segments
|
121 |
+
"""
|
122 |
+
assert input.size(0) == 1, 'Test time, batch_size must be 1'
|
123 |
+
|
124 |
+
input.squeeze_(dim=0)
|
125 |
+
# Find overlapping segments
|
126 |
+
length = input.size()[0]
|
127 |
+
step = timestep // 2
|
128 |
+
num_segments = (length - timestep) // step + 1
|
129 |
+
start_indices = (np.arange(num_segments) * step).tolist()
|
130 |
+
if length % step > 0:
|
131 |
+
start_indices.append(length - timestep)
|
132 |
+
|
133 |
+
# Get the segments
|
134 |
+
segments = []
|
135 |
+
for s in start_indices:
|
136 |
+
segment = input[s: (s + timestep)].unsqueeze(0)
|
137 |
+
segments.append(segment)
|
138 |
+
input = torch.cat(segments, dim=0)
|
139 |
+
return input, start_indices
|
140 |
+
|
141 |
+
def get_stats(logit, label):
|
142 |
+
'''
|
143 |
+
Calculate the accuracy.
|
144 |
+
'''
|
145 |
+
logit = to_numpy(logit)
|
146 |
+
label = to_numpy(label)
|
147 |
+
|
148 |
+
pred = np.argmax(logit, 1)
|
149 |
+
acc = np.sum(pred == label)/label.shape[0]
|
150 |
+
|
151 |
+
return acc, pred, label
|
152 |
+
|
153 |
+
|
154 |
+
def get_stats_detection(logit, label, n_classes=52):
|
155 |
+
'''
|
156 |
+
Calculate the accuracy and average precisions.
|
157 |
+
'''
|
158 |
+
logit = to_numpy(logit)
|
159 |
+
label = to_numpy(label)
|
160 |
+
scores = softmax(logit, axis=1)
|
161 |
+
|
162 |
+
pred = np.argmax(logit, 1)
|
163 |
+
length = label.shape[0]
|
164 |
+
acc = np.sum(pred == label)/length
|
165 |
+
|
166 |
+
keep_bg = label == 0
|
167 |
+
acc_bg = np.sum(pred[keep_bg] == label[keep_bg])/label[keep_bg].shape[0]
|
168 |
+
ratio_bg = np.sum(keep_bg)/length
|
169 |
+
|
170 |
+
keep_action = label != 0
|
171 |
+
acc_action = np.sum(
|
172 |
+
pred[keep_action] == label[keep_action]) / label[keep_action].shape[0]
|
173 |
+
|
174 |
+
# Average precision
|
175 |
+
y_true = np.zeros((len(label), n_classes))
|
176 |
+
y_true[np.arange(len(label)), label] = 1
|
177 |
+
acc = np.sum(pred == label)/label.shape[0]
|
178 |
+
aps = average_precision_score(y_true, scores, average=None)
|
179 |
+
aps = list(filter(lambda x: not np.isnan(x), aps))
|
180 |
+
ap = np.mean(aps)
|
181 |
+
|
182 |
+
return ap, acc, acc_bg, acc_action, ratio_bg, pred, label
|
183 |
+
|
184 |
+
|
185 |
+
def info(text):
|
186 |
+
print('\033[94m' + text + '\033[0m')
|
187 |
+
|
188 |
+
|
189 |
+
def warn(text):
|
190 |
+
print('\033[93m' + text + '\033[0m')
|
191 |
+
|
192 |
+
|
193 |
+
def err(text):
|
194 |
+
print('\033[91m' + text + '\033[0m')
|
195 |
+
|
196 |
+
|
trains/subNets/AlignNets.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
__all__ = ['AlignSubNet']
|
5 |
+
|
6 |
+
class CTCModule(nn.Module):
|
7 |
+
def __init__(self, in_dim, out_seq_len):
|
8 |
+
'''
|
9 |
+
This module is performing alignment from A (e.g., audio) to B (e.g., text).
|
10 |
+
:param in_dim: Dimension for input modality A
|
11 |
+
:param out_seq_len: Sequence length for output modality B
|
12 |
+
From: https://github.com/yaohungt/Multimodal-Transformer
|
13 |
+
'''
|
14 |
+
super(CTCModule, self).__init__()
|
15 |
+
# Use LSTM for predicting the position from A to B
|
16 |
+
self.pred_output_position_inclu_blank = nn.LSTM(in_dim, out_seq_len+1, num_layers=2, batch_first=True) # 1 denoting blank
|
17 |
+
|
18 |
+
self.out_seq_len = out_seq_len
|
19 |
+
|
20 |
+
self.softmax = nn.Softmax(dim=2)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
'''
|
24 |
+
:input x: Input with shape [batch_size x in_seq_len x in_dim]
|
25 |
+
'''
|
26 |
+
# NOTE that the index 0 refers to blank.
|
27 |
+
pred_output_position_inclu_blank, _ = self.pred_output_position_inclu_blank(x)
|
28 |
+
|
29 |
+
prob_pred_output_position_inclu_blank = self.softmax(pred_output_position_inclu_blank) # batch_size x in_seq_len x out_seq_len+1
|
30 |
+
prob_pred_output_position = prob_pred_output_position_inclu_blank[:, :, 1:] # batch_size x in_seq_len x out_seq_len
|
31 |
+
prob_pred_output_position = prob_pred_output_position.transpose(1,2) # batch_size x out_seq_len x in_seq_len
|
32 |
+
pseudo_aligned_out = torch.bmm(prob_pred_output_position, x) # batch_size x out_seq_len x in_dim
|
33 |
+
|
34 |
+
# pseudo_aligned_out is regarded as the aligned A (w.r.t B)
|
35 |
+
# return pseudo_aligned_out, (pred_output_position_inclu_blank)
|
36 |
+
return pseudo_aligned_out
|
37 |
+
|
38 |
+
class AlignSubNet(nn.Module):
|
39 |
+
def __init__(self, args, mode):
|
40 |
+
"""
|
41 |
+
mode: the way of aligning
|
42 |
+
avg_pool, ctc, conv1d
|
43 |
+
"""
|
44 |
+
super(AlignSubNet, self).__init__()
|
45 |
+
assert mode in ['avg_pool', 'ctc', 'conv1d']
|
46 |
+
|
47 |
+
in_dim_t, in_dim_a, in_dim_v = args.feature_dims
|
48 |
+
seq_len_t, seq_len_a, seq_len_v = args.seq_lens
|
49 |
+
self.dst_len = seq_len_t
|
50 |
+
self.mode = mode
|
51 |
+
|
52 |
+
self.ALIGN_WAY = {
|
53 |
+
'avg_pool': self.__avg_pool,
|
54 |
+
'ctc': self.__ctc,
|
55 |
+
'conv1d': self.__conv1d
|
56 |
+
}
|
57 |
+
|
58 |
+
if mode == 'conv1d':
|
59 |
+
self.conv1d_T = nn.Conv1d(seq_len_t, self.dst_len, kernel_size=1, bias=False)
|
60 |
+
self.conv1d_A = nn.Conv1d(seq_len_a, self.dst_len, kernel_size=1, bias=False)
|
61 |
+
self.conv1d_V = nn.Conv1d(seq_len_v, self.dst_len, kernel_size=1, bias=False)
|
62 |
+
elif mode == 'ctc':
|
63 |
+
self.ctc_t = CTCModule(in_dim_t, self.dst_len)
|
64 |
+
self.ctc_a = CTCModule(in_dim_a, self.dst_len)
|
65 |
+
self.ctc_v = CTCModule(in_dim_v, self.dst_len)
|
66 |
+
|
67 |
+
def get_seq_len(self):
|
68 |
+
return self.dst_len
|
69 |
+
|
70 |
+
def __ctc(self, text_x, audio_x, video_x):
|
71 |
+
text_x = self.ctc_t(text_x) if text_x.size(1) != self.dst_len else text_x
|
72 |
+
audio_x = self.ctc_a(audio_x) if audio_x.size(1) != self.dst_len else audio_x
|
73 |
+
video_x = self.ctc_v(video_x) if video_x.size(1) != self.dst_len else video_x
|
74 |
+
return text_x, audio_x, video_x
|
75 |
+
|
76 |
+
def __avg_pool(self, text_x, audio_x, video_x):
|
77 |
+
def align(x):
|
78 |
+
raw_seq_len = x.size(1)
|
79 |
+
if raw_seq_len == self.dst_len:
|
80 |
+
return x
|
81 |
+
if raw_seq_len // self.dst_len == raw_seq_len / self.dst_len:
|
82 |
+
pad_len = 0
|
83 |
+
pool_size = raw_seq_len // self.dst_len
|
84 |
+
else:
|
85 |
+
pad_len = self.dst_len - raw_seq_len % self.dst_len
|
86 |
+
pool_size = raw_seq_len // self.dst_len + 1
|
87 |
+
pad_x = x[:, -1, :].unsqueeze(1).expand([x.size(0), pad_len, x.size(-1)])
|
88 |
+
x = torch.cat([x, pad_x], dim=1).view(x.size(0), pool_size, self.dst_len, -1)
|
89 |
+
x = x.mean(dim=1)
|
90 |
+
return x
|
91 |
+
text_x = align(text_x)
|
92 |
+
audio_x = align(audio_x)
|
93 |
+
video_x = align(video_x)
|
94 |
+
return text_x, audio_x, video_x
|
95 |
+
|
96 |
+
def __conv1d(self, text_x, audio_x, video_x):
|
97 |
+
text_x = self.conv1d_T(text_x) if text_x.size(1) != self.dst_len else text_x
|
98 |
+
audio_x = self.conv1d_A(text_x) if audio_x.size(1) != self.dst_len else audio_x
|
99 |
+
video_x = self.conv1d_V(text_x) if video_x.size(1) != self.dst_len else video_x
|
100 |
+
return text_x, audio_x, video_x
|
101 |
+
|
102 |
+
def forward(self, text_x, audio_x, video_x):
|
103 |
+
# already aligned
|
104 |
+
if text_x.size(1) == audio_x.size(1) == video_x.size(1):
|
105 |
+
return text_x, audio_x, video_x
|
106 |
+
return self.ALIGN_WAY[self.mode](text_x, audio_x, video_x)
|
trains/subNets/BertTextEncoder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
|
4 |
+
|
5 |
+
__all__ = ['BertTextEncoder']
|
6 |
+
|
7 |
+
TRANSFORMERS_MAP = {
|
8 |
+
'bert': (BertModel, BertTokenizer),
|
9 |
+
'roberta': (RobertaModel, RobertaTokenizer),
|
10 |
+
}
|
11 |
+
|
12 |
+
class BertTextEncoder(nn.Module):
|
13 |
+
def __init__(self, use_finetune=False, transformers='bert', pretrained='bert-base-uncased'):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
tokenizer_class = TRANSFORMERS_MAP[transformers][1]
|
17 |
+
model_class = TRANSFORMERS_MAP[transformers][0]
|
18 |
+
self.tokenizer = tokenizer_class.from_pretrained(pretrained)
|
19 |
+
self.model = model_class.from_pretrained(pretrained)
|
20 |
+
self.use_finetune = use_finetune
|
21 |
+
|
22 |
+
def get_tokenizer(self):
|
23 |
+
return self.tokenizer
|
24 |
+
|
25 |
+
# def from_text(self, text):
|
26 |
+
# """
|
27 |
+
# text: raw data
|
28 |
+
# """
|
29 |
+
# input_ids = self.get_id(text)
|
30 |
+
# with torch.no_grad():
|
31 |
+
# last_hidden_states = self.model(input_ids)[0] # Models outputs are now tuples
|
32 |
+
# return last_hidden_states.squeeze()
|
33 |
+
|
34 |
+
def forward(self, text):
|
35 |
+
"""
|
36 |
+
text: (batch_size, 3, seq_len)
|
37 |
+
3: input_ids, input_mask, segment_ids
|
38 |
+
input_ids: input_ids,
|
39 |
+
input_mask: attention_mask,
|
40 |
+
segment_ids: token_type_ids
|
41 |
+
"""
|
42 |
+
input_ids, input_mask, segment_ids = text[:,0,:].long(), text[:,1,:].float(), text[:,2,:].long()
|
43 |
+
if self.use_finetune:
|
44 |
+
last_hidden_states = self.model(input_ids=input_ids,
|
45 |
+
attention_mask=input_mask,
|
46 |
+
token_type_ids=segment_ids)[0] # Models outputs are now tuples
|
47 |
+
else:
|
48 |
+
with torch.no_grad():
|
49 |
+
last_hidden_states = self.model(input_ids=input_ids,
|
50 |
+
attention_mask=input_mask,
|
51 |
+
token_type_ids=segment_ids)[0] # Models outputs are now tuples
|
52 |
+
return last_hidden_states
|
trains/subNets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .BertTextEncoder import BertTextEncoder
|
2 |
+
from .AlignNets import AlignSubNet
|
trains/subNets/transformers_encoder/multihead_attention.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import Parameter
|
5 |
+
|
6 |
+
class MultiheadAttention(nn.Module):
|
7 |
+
"""Multi-headed attention.
|
8 |
+
See "Attention Is All You Need" for more details.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, embed_dim, num_heads, attn_dropout=0.,
|
12 |
+
bias=True, add_bias_kv=False, add_zero_attn=False):
|
13 |
+
super().__init__()
|
14 |
+
self.embed_dim = embed_dim
|
15 |
+
self.num_heads = num_heads
|
16 |
+
self.attn_dropout = attn_dropout
|
17 |
+
self.head_dim = embed_dim // num_heads
|
18 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
19 |
+
self.scaling = self.head_dim ** -0.5
|
20 |
+
|
21 |
+
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
|
22 |
+
self.register_parameter('in_proj_bias', None)
|
23 |
+
if bias:
|
24 |
+
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
|
25 |
+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
26 |
+
|
27 |
+
if add_bias_kv:
|
28 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
29 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
30 |
+
else:
|
31 |
+
self.bias_k = self.bias_v = None
|
32 |
+
|
33 |
+
self.add_zero_attn = add_zero_attn
|
34 |
+
|
35 |
+
self.reset_parameters()
|
36 |
+
|
37 |
+
def reset_parameters(self):
|
38 |
+
nn.init.xavier_uniform_(self.in_proj_weight)
|
39 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
40 |
+
if self.in_proj_bias is not None:
|
41 |
+
nn.init.constant_(self.in_proj_bias, 0.)
|
42 |
+
nn.init.constant_(self.out_proj.bias, 0.)
|
43 |
+
if self.bias_k is not None:
|
44 |
+
nn.init.xavier_normal_(self.bias_k)
|
45 |
+
if self.bias_v is not None:
|
46 |
+
nn.init.xavier_normal_(self.bias_v)
|
47 |
+
|
48 |
+
def forward(self, query, key, value, attn_mask=None):
|
49 |
+
"""Input shape: Time x Batch x Channel
|
50 |
+
Self-attention can be implemented by passing in the same arguments for
|
51 |
+
query, key and value. Timesteps can be masked by supplying a T x T mask in the
|
52 |
+
`attn_mask` argument. Padding elements can be excluded from
|
53 |
+
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
|
54 |
+
batch x src_len, where padding elements are indicated by 1s.
|
55 |
+
"""
|
56 |
+
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
|
57 |
+
kv_same = key.data_ptr() == value.data_ptr()
|
58 |
+
|
59 |
+
tgt_len, bsz, embed_dim = query.size()
|
60 |
+
assert embed_dim == self.embed_dim
|
61 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
62 |
+
assert key.size() == value.size()
|
63 |
+
|
64 |
+
aved_state = None
|
65 |
+
|
66 |
+
if qkv_same:
|
67 |
+
# self-attention
|
68 |
+
q, k, v = self.in_proj_qkv(query)
|
69 |
+
elif kv_same:
|
70 |
+
# encoder-decoder attention
|
71 |
+
q = self.in_proj_q(query)
|
72 |
+
|
73 |
+
if key is None:
|
74 |
+
assert value is None
|
75 |
+
k = v = None
|
76 |
+
else:
|
77 |
+
k, v = self.in_proj_kv(key)
|
78 |
+
else:
|
79 |
+
q = self.in_proj_q(query)
|
80 |
+
k = self.in_proj_k(key)
|
81 |
+
v = self.in_proj_v(value)
|
82 |
+
q = q * self.scaling
|
83 |
+
|
84 |
+
if self.bias_k is not None:
|
85 |
+
assert self.bias_v is not None
|
86 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
87 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
88 |
+
if attn_mask is not None:
|
89 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
90 |
+
|
91 |
+
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
92 |
+
if k is not None:
|
93 |
+
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
94 |
+
if v is not None:
|
95 |
+
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
|
96 |
+
|
97 |
+
src_len = k.size(1)
|
98 |
+
|
99 |
+
if self.add_zero_attn:
|
100 |
+
src_len += 1
|
101 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
102 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
103 |
+
if attn_mask is not None:
|
104 |
+
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
|
105 |
+
|
106 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
107 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
108 |
+
|
109 |
+
if attn_mask is not None:
|
110 |
+
try:
|
111 |
+
attn_weights += attn_mask.unsqueeze(0)
|
112 |
+
except:
|
113 |
+
print(attn_weights.shape)
|
114 |
+
print(attn_mask.unsqueeze(0).shape)
|
115 |
+
assert False
|
116 |
+
|
117 |
+
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
|
118 |
+
# attn_weights = F.relu(attn_weights)
|
119 |
+
# attn_weights = attn_weights / torch.max(attn_weights)
|
120 |
+
attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)
|
121 |
+
|
122 |
+
attn = torch.bmm(attn_weights, v)
|
123 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
124 |
+
|
125 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
126 |
+
attn = self.out_proj(attn)
|
127 |
+
|
128 |
+
# average attention weights over heads
|
129 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
130 |
+
attn_weights = attn_weights.sum(dim=1) / self.num_heads
|
131 |
+
return attn, attn_weights
|
132 |
+
|
133 |
+
def in_proj_qkv(self, query):
|
134 |
+
return self._in_proj(query).chunk(3, dim=-1)
|
135 |
+
|
136 |
+
def in_proj_kv(self, key):
|
137 |
+
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
|
138 |
+
|
139 |
+
def in_proj_q(self, query, **kwargs):
|
140 |
+
return self._in_proj(query, end=self.embed_dim, **kwargs)
|
141 |
+
|
142 |
+
def in_proj_k(self, key):
|
143 |
+
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
|
144 |
+
|
145 |
+
def in_proj_v(self, value):
|
146 |
+
return self._in_proj(value, start=2 * self.embed_dim)
|
147 |
+
|
148 |
+
def _in_proj(self, input, start=0, end=None, **kwargs):
|
149 |
+
weight = kwargs.get('weight', self.in_proj_weight)
|
150 |
+
bias = kwargs.get('bias', self.in_proj_bias)
|
151 |
+
weight = weight[start:end, :]
|
152 |
+
if bias is not None:
|
153 |
+
bias = bias[start:end]
|
154 |
+
return F.linear(input, weight, bias)
|
trains/subNets/transformers_encoder/position_embedding.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
def make_positions(tensor, padding_idx, left_pad):
|
6 |
+
"""Replace non-padding symbols with their position numbers.
|
7 |
+
Position numbers begin at padding_idx+1.
|
8 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
9 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
10 |
+
"""
|
11 |
+
max_pos = padding_idx + 1 + tensor.size(1)
|
12 |
+
device = tensor.get_device()
|
13 |
+
buf_name = f'range_buf_{device}'
|
14 |
+
if not hasattr(make_positions, buf_name):
|
15 |
+
setattr(make_positions, buf_name, tensor.new())
|
16 |
+
setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
|
17 |
+
if getattr(make_positions, buf_name).numel() < max_pos:
|
18 |
+
torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
|
19 |
+
mask = tensor.ne(padding_idx)
|
20 |
+
positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
|
21 |
+
if left_pad:
|
22 |
+
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
|
23 |
+
new_tensor = tensor.clone()
|
24 |
+
return new_tensor.masked_scatter_(mask, positions[mask]).long()
|
25 |
+
|
26 |
+
|
27 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
28 |
+
"""This module produces sinusoidal positional embeddings of any length.
|
29 |
+
Padding symbols are ignored, but it is necessary to specify whether padding
|
30 |
+
is added on the left side (left_pad=True) or right side (left_pad=False).
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
|
34 |
+
super().__init__()
|
35 |
+
self.embedding_dim = embedding_dim
|
36 |
+
self.padding_idx = padding_idx
|
37 |
+
self.left_pad = left_pad
|
38 |
+
self.weights = dict() # device --> actual weight; due to nn.DataParallel :-(
|
39 |
+
self.register_buffer('_float_tensor', torch.FloatTensor(1))
|
40 |
+
|
41 |
+
@staticmethod
|
42 |
+
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
|
43 |
+
"""Build sinusoidal embeddings.
|
44 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
45 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
46 |
+
"""
|
47 |
+
half_dim = embedding_dim // 2
|
48 |
+
emb = math.log(10000) / (half_dim - 1)
|
49 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
50 |
+
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
|
51 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
|
52 |
+
if embedding_dim % 2 == 1:
|
53 |
+
# zero pad
|
54 |
+
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
|
55 |
+
if padding_idx is not None:
|
56 |
+
emb[padding_idx, :] = 0
|
57 |
+
return emb
|
58 |
+
|
59 |
+
def forward(self, input):
|
60 |
+
"""Input is expected to be of size [bsz x seqlen]."""
|
61 |
+
bsz, seq_len = input.size()
|
62 |
+
max_pos = self.padding_idx + 1 + seq_len
|
63 |
+
device = input.get_device()
|
64 |
+
if device not in self.weights or max_pos > self.weights[device].size(0):
|
65 |
+
# recompute/expand embeddings if needed
|
66 |
+
self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
|
67 |
+
max_pos,
|
68 |
+
self.embedding_dim,
|
69 |
+
self.padding_idx,
|
70 |
+
)
|
71 |
+
self.weights[device] = self.weights[device].type_as(self._float_tensor).to(input.device)
|
72 |
+
positions = make_positions(input, self.padding_idx, self.left_pad)
|
73 |
+
return self.weights[device].index_select(0, positions.contiguous().view(-1)).view(bsz, seq_len, -1).detach()
|
74 |
+
|
75 |
+
def max_positions(self):
|
76 |
+
"""Maximum number of supported positions."""
|
77 |
+
return int(1e5) # an arbitrary large number
|
trains/subNets/transformers_encoder/transformer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch import nn
|
5 |
+
from .multihead_attention import MultiheadAttention
|
6 |
+
from .position_embedding import SinusoidalPositionalEmbedding
|
7 |
+
|
8 |
+
class TransformerEncoder(nn.Module):
|
9 |
+
"""
|
10 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
11 |
+
is a :class:`TransformerEncoderLayer`.
|
12 |
+
Args:
|
13 |
+
embed_tokens (torch.nn.Embedding): input embedding
|
14 |
+
num_heads (int): number of heads
|
15 |
+
layers (int): number of layers
|
16 |
+
attn_dropout (float): dropout applied on the attention weights
|
17 |
+
relu_dropout (float): dropout applied on the first layer of the residual block
|
18 |
+
res_dropout (float): dropout applied on the residual block
|
19 |
+
attn_mask (bool): whether to apply mask on the attention weights
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0,
|
23 |
+
embed_dropout=0.0, attn_mask=False):
|
24 |
+
super().__init__()
|
25 |
+
self.dropout = embed_dropout # Embedding dropout
|
26 |
+
self.attn_dropout = attn_dropout
|
27 |
+
self.embed_dim = embed_dim
|
28 |
+
self.embed_scale = math.sqrt(embed_dim)
|
29 |
+
self.embed_positions = SinusoidalPositionalEmbedding(embed_dim)
|
30 |
+
|
31 |
+
self.attn_mask = attn_mask
|
32 |
+
|
33 |
+
self.layers = nn.ModuleList([]) #define multiple transformer layers
|
34 |
+
for layer in range(layers):
|
35 |
+
new_layer = TransformerEncoderLayer(embed_dim,
|
36 |
+
num_heads=num_heads,
|
37 |
+
attn_dropout=attn_dropout,
|
38 |
+
relu_dropout=relu_dropout,
|
39 |
+
res_dropout=res_dropout,
|
40 |
+
attn_mask=attn_mask)
|
41 |
+
self.layers.append(new_layer)
|
42 |
+
|
43 |
+
self.register_buffer('version', torch.Tensor([2]))
|
44 |
+
self.normalize = True
|
45 |
+
if self.normalize:
|
46 |
+
self.layer_norm = LayerNorm(embed_dim)
|
47 |
+
|
48 |
+
def forward(self, x_in, x_in_k = None, x_in_v = None):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
|
52 |
+
x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
|
53 |
+
x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
|
54 |
+
Returns:
|
55 |
+
dict:
|
56 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
57 |
+
shape `(src_len, batch, embed_dim)`
|
58 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
59 |
+
padding elements of shape `(batch, src_len)`
|
60 |
+
"""
|
61 |
+
# embed tokens and positions
|
62 |
+
x = self.embed_scale * x_in
|
63 |
+
#breakpoint()
|
64 |
+
if self.embed_positions is not None:
|
65 |
+
x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
|
66 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
67 |
+
|
68 |
+
if x_in_k is not None and x_in_v is not None:
|
69 |
+
# embed tokens and positions
|
70 |
+
x_k = self.embed_scale * x_in_k
|
71 |
+
x_v = self.embed_scale * x_in_v
|
72 |
+
if self.embed_positions is not None:
|
73 |
+
x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
|
74 |
+
x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
|
75 |
+
x_k = F.dropout(x_k, p=self.dropout, training=self.training)
|
76 |
+
x_v = F.dropout(x_v, p=self.dropout, training=self.training)
|
77 |
+
|
78 |
+
# encoder layers
|
79 |
+
intermediates = [x]
|
80 |
+
for layer in self.layers:
|
81 |
+
if x_in_k is not None and x_in_v is not None:
|
82 |
+
x = layer(x, x_k, x_v)
|
83 |
+
else:
|
84 |
+
x = layer(x)
|
85 |
+
intermediates.append(x)
|
86 |
+
|
87 |
+
if self.normalize:
|
88 |
+
x = self.layer_norm(x)
|
89 |
+
|
90 |
+
return x
|
91 |
+
|
92 |
+
def max_positions(self):
|
93 |
+
"""Maximum input length supported by the encoder."""
|
94 |
+
if self.embed_positions is None:
|
95 |
+
return self.max_source_positions
|
96 |
+
return min(self.max_source_positions, self.embed_positions.max_positions())
|
97 |
+
|
98 |
+
|
99 |
+
class TransformerEncoderLayer(nn.Module):
|
100 |
+
"""Encoder layer block.
|
101 |
+
In the original paper each operation (multi-head attention or FFN) is
|
102 |
+
postprocessed with: `dropout -> add residual -> layernorm`. In the
|
103 |
+
tensor2tensor code they suggest that learning is more robust when
|
104 |
+
preprocessing each layer with layernorm and postprocessing with:
|
105 |
+
`dropout -> add residual`. We default to the approach in the paper, but the
|
106 |
+
tensor2tensor approach can be enabled by setting
|
107 |
+
*args.encoder_normalize_before* to ``True``.
|
108 |
+
Args:
|
109 |
+
embed_dim: Embedding dimension
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
|
113 |
+
attn_mask=False):
|
114 |
+
super().__init__()
|
115 |
+
self.embed_dim = embed_dim
|
116 |
+
self.num_heads = num_heads
|
117 |
+
|
118 |
+
self.self_attn = MultiheadAttention(
|
119 |
+
embed_dim=self.embed_dim,
|
120 |
+
num_heads=self.num_heads,
|
121 |
+
attn_dropout=attn_dropout
|
122 |
+
)
|
123 |
+
self.attn_mask = attn_mask
|
124 |
+
|
125 |
+
self.relu_dropout = relu_dropout
|
126 |
+
self.res_dropout = res_dropout
|
127 |
+
self.normalize_before = True #true means using tensor2tensor approach
|
128 |
+
|
129 |
+
self.fc1 = Linear(self.embed_dim, 4*self.embed_dim) # The "Add & Norm" part in the paper
|
130 |
+
self.fc2 = Linear(4*self.embed_dim, self.embed_dim)
|
131 |
+
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for _ in range(2)]) #Define two layer_norms layers
|
132 |
+
|
133 |
+
def forward(self, x, x_k=None, x_v=None): #Two Transformer layers
|
134 |
+
"""
|
135 |
+
Args:
|
136 |
+
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
137 |
+
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
|
138 |
+
`(batch, src_len)` where padding elements are indicated by ``1``.
|
139 |
+
x_k (Tensor): same as x
|
140 |
+
x_v (Tensor): same as x
|
141 |
+
Returns:
|
142 |
+
encoded output of shape `(batch, src_len, embed_dim)`
|
143 |
+
"""
|
144 |
+
residual = x
|
145 |
+
x = self.maybe_layer_norm(0, x, before=True)
|
146 |
+
mask = buffered_future_mask(x, x_k) if self.attn_mask else None
|
147 |
+
if x_k is None and x_v is None:
|
148 |
+
x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
|
149 |
+
else:
|
150 |
+
x_k = self.maybe_layer_norm(0, x_k, before=True)
|
151 |
+
x_v = self.maybe_layer_norm(0, x_v, before=True)
|
152 |
+
x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
|
153 |
+
x = F.dropout(x, p=self.res_dropout, training=self.training)
|
154 |
+
x = residual + x
|
155 |
+
x = self.maybe_layer_norm(0, x, after=True) #First Transformer layer
|
156 |
+
|
157 |
+
residual = x
|
158 |
+
x = self.maybe_layer_norm(1, x, before=True)
|
159 |
+
x = F.relu(self.fc1(x))
|
160 |
+
x = F.dropout(x, p=self.relu_dropout, training=self.training)
|
161 |
+
x = self.fc2(x)
|
162 |
+
x = F.dropout(x, p=self.res_dropout, training=self.training)
|
163 |
+
x = residual + x
|
164 |
+
x = self.maybe_layer_norm(1, x, after=True)
|
165 |
+
return x #The second one
|
166 |
+
|
167 |
+
def maybe_layer_norm(self, i, x, before=False, after=False):
|
168 |
+
assert before ^ after #before XOR after, allow only one is true
|
169 |
+
if after ^ self.normalize_before:
|
170 |
+
return self.layer_norms[i](x)
|
171 |
+
else:
|
172 |
+
return x
|
173 |
+
|
174 |
+
def fill_with_neg_inf(t):
|
175 |
+
"""FP16-compatible function that fills a tensor with -inf."""
|
176 |
+
return t.float().fill_(float('-inf')).type_as(t)
|
177 |
+
|
178 |
+
|
179 |
+
def buffered_future_mask(tensor, tensor2=None):
|
180 |
+
dim1 = dim2 = tensor.size(0)
|
181 |
+
if tensor2 is not None:
|
182 |
+
dim2 = tensor2.size(0)
|
183 |
+
future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1))
|
184 |
+
if tensor.is_cuda:
|
185 |
+
future_mask = future_mask.to(tensor.device)
|
186 |
+
return future_mask[:dim1, :dim2]
|
187 |
+
|
188 |
+
|
189 |
+
def Linear(in_features, out_features, bias=True):
|
190 |
+
m = nn.Linear(in_features, out_features, bias)
|
191 |
+
nn.init.xavier_uniform_(m.weight)
|
192 |
+
if bias:
|
193 |
+
nn.init.constant_(m.bias, 0.)
|
194 |
+
return m
|
195 |
+
|
196 |
+
|
197 |
+
def LayerNorm(embedding_dim):
|
198 |
+
m = nn.LayerNorm(embedding_dim)
|
199 |
+
return m
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == '__main__':
|
203 |
+
encoder = TransformerEncoder(300, 4, 2) #embed_dim, num_heads, layers
|
204 |
+
x = torch.tensor(torch.rand(20, 2, 300))
|
205 |
+
print(encoder(x).shape)
|
trains/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .functions import dict_to_str, setup_seed, assign_gpu, count_parameters
|
2 |
+
from .metricsTop import MetricsTop
|
trains/utils/functions.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import pynvml
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
logger = logging.getLogger('MMSA')
|
9 |
+
|
10 |
+
|
11 |
+
def dict_to_str(src_dict):
|
12 |
+
dst_str = ""
|
13 |
+
for key in src_dict.keys():
|
14 |
+
dst_str += " %s: %.4f " %(key, src_dict[key])
|
15 |
+
return dst_str
|
16 |
+
|
17 |
+
def setup_seed(seed):
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
np.random.seed(seed)
|
20 |
+
random.seed(seed)
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
|
24 |
+
def assign_gpu(gpu_ids, memory_limit=1e16):
|
25 |
+
if len(gpu_ids) == 0 and torch.cuda.is_available():
|
26 |
+
# find most free gpu
|
27 |
+
pynvml.nvmlInit()
|
28 |
+
n_gpus = pynvml.nvmlDeviceGetCount()
|
29 |
+
dst_gpu_id, min_mem_used = 0, memory_limit
|
30 |
+
for g_id in range(n_gpus):
|
31 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(g_id)
|
32 |
+
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
33 |
+
mem_used = meminfo.used
|
34 |
+
if mem_used < min_mem_used:
|
35 |
+
min_mem_used = mem_used
|
36 |
+
dst_gpu_id = g_id
|
37 |
+
logger.info(f'Found gpu {dst_gpu_id}, used memory {min_mem_used}.')
|
38 |
+
gpu_ids.append(dst_gpu_id)
|
39 |
+
# device
|
40 |
+
using_cuda = len(gpu_ids) > 0 and torch.cuda.is_available()
|
41 |
+
# logger.info("Let's use %d GPUs!" % len(gpu_ids))
|
42 |
+
device = torch.device('cuda:%d' % int(gpu_ids[0]) if using_cuda else 'cpu')
|
43 |
+
return device
|
44 |
+
|
45 |
+
def count_parameters(model):
|
46 |
+
res = 0
|
47 |
+
for p in model.parameters():
|
48 |
+
if p.requires_grad:
|
49 |
+
res += p.numel()
|
50 |
+
# print(p)
|
51 |
+
return res
|
trains/utils/metricsTop.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from sklearn.metrics import accuracy_score, f1_score
|
4 |
+
from sklearn.metrics import mutual_info_score
|
5 |
+
|
6 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
7 |
+
|
8 |
+
__all__ = ['MetricsTop']
|
9 |
+
|
10 |
+
class MetricsTop():
|
11 |
+
def __init__(self, train_mode):
|
12 |
+
if train_mode == "regression":
|
13 |
+
self.metrics_dict = {
|
14 |
+
'MOSI': self.__eval_mosi_regression,
|
15 |
+
'MOSEI': self.__eval_mosei_regression,
|
16 |
+
}
|
17 |
+
else:
|
18 |
+
self.metrics_dict = {
|
19 |
+
'MOSI': self.__eval_mosi_classification,
|
20 |
+
'MOSEI': self.__eval_mosei_classification,
|
21 |
+
}
|
22 |
+
|
23 |
+
def __eval_mosi_classification(self, y_pred, y_true):
|
24 |
+
"""
|
25 |
+
{
|
26 |
+
"Negative": 0,
|
27 |
+
"Neutral": 1,
|
28 |
+
"Positive": 2
|
29 |
+
}
|
30 |
+
"""
|
31 |
+
y_pred = y_pred.cpu().detach().numpy()
|
32 |
+
y_true = y_true.cpu().detach().numpy()
|
33 |
+
# three classes
|
34 |
+
y_pred_3 = np.argmax(y_pred, axis=1)
|
35 |
+
Mult_acc_3 = accuracy_score(y_pred_3, y_true)
|
36 |
+
F1_score_3 = f1_score(y_true, y_pred_3, average='weighted')
|
37 |
+
# two classes
|
38 |
+
y_pred = np.array([[v[0], v[2]] for v in y_pred])
|
39 |
+
# with 0 (<= 0 or > 0)
|
40 |
+
y_pred_2 = np.argmax(y_pred, axis=1)
|
41 |
+
y_true_2 = []
|
42 |
+
for v in y_true:
|
43 |
+
y_true_2.append(0 if v <= 1 else 1)
|
44 |
+
y_true_2 = np.array(y_true_2)
|
45 |
+
Has0_acc_2 = accuracy_score(y_pred_2, y_true_2)
|
46 |
+
Has0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
|
47 |
+
# without 0 (< 0 or > 0)
|
48 |
+
non_zeros = np.array([i for i, e in enumerate(y_true) if e != 1])
|
49 |
+
y_pred_2 = y_pred[non_zeros]
|
50 |
+
y_pred_2 = np.argmax(y_pred_2, axis=1)
|
51 |
+
y_true_2 = y_true[non_zeros]
|
52 |
+
Non0_acc_2 = accuracy_score(y_pred_2, y_true_2)
|
53 |
+
Non0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
|
54 |
+
|
55 |
+
eval_results = {
|
56 |
+
"Has0_acc_2": round(Has0_acc_2, 4),
|
57 |
+
"Has0_F1_score": round(Has0_F1_score, 4),
|
58 |
+
"Non0_acc_2": round(Non0_acc_2, 4),
|
59 |
+
"Non0_F1_score": round(Non0_F1_score, 4),
|
60 |
+
"Acc_3": round(Mult_acc_3, 4),
|
61 |
+
"F1_score_3": round(F1_score_3, 4)
|
62 |
+
}
|
63 |
+
return eval_results
|
64 |
+
|
65 |
+
def __eval_mosei_classification(self, y_pred, y_true):
|
66 |
+
return self.__eval_mosi_classification(y_pred, y_true)
|
67 |
+
|
68 |
+
def __multiclass_acc(self, y_pred, y_true):
|
69 |
+
"""
|
70 |
+
Compute the multiclass accuracy w.r.t. groundtruth
|
71 |
+
|
72 |
+
:param preds: Float array representing the predictions, dimension (N,)
|
73 |
+
:param truths: Float/int array representing the groundtruth classes, dimension (N,)
|
74 |
+
:return: Classification accuracy
|
75 |
+
"""
|
76 |
+
return np.sum(np.round(y_pred) == np.round(y_true)) / float(len(y_true))
|
77 |
+
|
78 |
+
def __eval_mosei_regression(self, y_pred, y_true, exclude_zero=False):
|
79 |
+
test_preds = y_pred.view(-1).cpu().detach().numpy()
|
80 |
+
test_truth = y_true.view(-1).cpu().detach().numpy()
|
81 |
+
|
82 |
+
test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
|
83 |
+
test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
|
84 |
+
test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
|
85 |
+
test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)
|
86 |
+
test_preds_a3 = np.clip(test_preds, a_min=-1., a_max=1.)
|
87 |
+
test_truth_a3 = np.clip(test_truth, a_min=-1., a_max=1.)
|
88 |
+
|
89 |
+
|
90 |
+
mae = np.mean(np.absolute(test_preds - test_truth)).astype(np.float64)
|
91 |
+
corr = np.corrcoef(test_preds, test_truth)[0][1]
|
92 |
+
mult_a7 = self.__multiclass_acc(test_preds_a7, test_truth_a7)
|
93 |
+
mult_a5 = self.__multiclass_acc(test_preds_a5, test_truth_a5)
|
94 |
+
mult_a3 = self.__multiclass_acc(test_preds_a3, test_truth_a3)
|
95 |
+
|
96 |
+
non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0])
|
97 |
+
non_zeros_binary_truth = (test_truth[non_zeros] > 0)
|
98 |
+
non_zeros_binary_preds = (test_preds[non_zeros] > 0)
|
99 |
+
|
100 |
+
non_zeros_acc2 = accuracy_score(non_zeros_binary_preds, non_zeros_binary_truth)
|
101 |
+
non_zeros_f1_score = f1_score(non_zeros_binary_truth, non_zeros_binary_preds, average='weighted')
|
102 |
+
|
103 |
+
binary_truth = (test_truth >= 0)
|
104 |
+
binary_preds = (test_preds >= 0)
|
105 |
+
acc2 = accuracy_score(binary_preds, binary_truth)
|
106 |
+
f_score = f1_score(binary_truth, binary_preds, average='weighted')
|
107 |
+
|
108 |
+
eval_results = {
|
109 |
+
"acc_7": round(mult_a7, 4),
|
110 |
+
"acc_5": round(mult_a5, 4),
|
111 |
+
"acc_2": round(non_zeros_acc2, 4),
|
112 |
+
"F1_score": round(non_zeros_f1_score, 4),
|
113 |
+
"Corr": round(corr, 4),
|
114 |
+
"MAE": round(mae, 4)
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
return eval_results
|
120 |
+
|
121 |
+
def __eval_mosi_regression(self, y_pred, y_true):
|
122 |
+
return self.__eval_mosei_regression(y_pred, y_true)
|
123 |
+
|
124 |
+
def getMetics(self, datasetName):
|
125 |
+
return self.metrics_dict[datasetName.upper()]
|
utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .functions import dict_to_str, setup_seed, assign_gpu, count_parameters
|
2 |
+
from .metricsTop import MetricsTop
|
utils/functions.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import pynvml
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
logger = logging.getLogger('MMSA')
|
9 |
+
|
10 |
+
|
11 |
+
def dict_to_str(src_dict):
|
12 |
+
dst_str = ""
|
13 |
+
for key in src_dict.keys():
|
14 |
+
dst_str += " %s: %.4f " %(key, src_dict[key])
|
15 |
+
return dst_str
|
16 |
+
|
17 |
+
def setup_seed(seed):
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
np.random.seed(seed)
|
20 |
+
random.seed(seed)
|
21 |
+
torch.backends.cudnn.benchmark = False
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
|
24 |
+
def assign_gpu(gpu_ids, memory_limit=1e16):
|
25 |
+
if len(gpu_ids) == 0 and torch.cuda.is_available():
|
26 |
+
# find most free gpu
|
27 |
+
pynvml.nvmlInit()
|
28 |
+
n_gpus = pynvml.nvmlDeviceGetCount()
|
29 |
+
dst_gpu_id, min_mem_used = 0, memory_limit
|
30 |
+
for g_id in range(n_gpus):
|
31 |
+
handle = pynvml.nvmlDeviceGetHandleByIndex(g_id)
|
32 |
+
meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
33 |
+
mem_used = meminfo.used
|
34 |
+
if mem_used < min_mem_used:
|
35 |
+
min_mem_used = mem_used
|
36 |
+
dst_gpu_id = g_id
|
37 |
+
logger.info(f'Found gpu {dst_gpu_id}, used memory {min_mem_used}.')
|
38 |
+
gpu_ids.append(dst_gpu_id)
|
39 |
+
# device
|
40 |
+
using_cuda = len(gpu_ids) > 0 and torch.cuda.is_available()
|
41 |
+
# logger.info("Let's use %d GPUs!" % len(gpu_ids))
|
42 |
+
device = torch.device('cuda:%d' % int(gpu_ids[0]) if using_cuda else 'cpu')
|
43 |
+
return device
|
44 |
+
|
45 |
+
def count_parameters(model):
|
46 |
+
res = 0
|
47 |
+
for p in model.parameters():
|
48 |
+
if p.requires_grad:
|
49 |
+
res += p.numel()
|
50 |
+
# print(p)
|
51 |
+
return res
|
utils/metricsTop.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import accuracy_score, f1_score
|
3 |
+
|
4 |
+
__all__ = ['MetricsTop']
|
5 |
+
|
6 |
+
class MetricsTop():
|
7 |
+
def __init__(self, train_mode):
|
8 |
+
if train_mode == "regression":
|
9 |
+
self.metrics_dict = {
|
10 |
+
'MOSI': self.__eval_mosi_regression,
|
11 |
+
'MOSEI': self.__eval_mosei_regression,
|
12 |
+
}
|
13 |
+
else:
|
14 |
+
self.metrics_dict = {
|
15 |
+
'MOSI': self.__eval_mosi_classification,
|
16 |
+
'MOSEI': self.__eval_mosei_classification,
|
17 |
+
}
|
18 |
+
|
19 |
+
def __eval_mosi_classification(self, y_pred, y_true):
|
20 |
+
y_pred = y_pred.cpu().detach().numpy()
|
21 |
+
y_true = y_true.cpu().detach().numpy()
|
22 |
+
# three classes
|
23 |
+
y_pred_3 = np.argmax(y_pred, axis=1)
|
24 |
+
Mult_acc_3 = accuracy_score(y_pred_3, y_true)
|
25 |
+
F1_score_3 = f1_score(y_true, y_pred_3, average='weighted')
|
26 |
+
# two classes
|
27 |
+
y_pred = np.array([[v[0], v[2]] for v in y_pred])
|
28 |
+
# with 0 (<= 0 or > 0)
|
29 |
+
y_pred_2 = np.argmax(y_pred, axis=1)
|
30 |
+
y_true_2 = []
|
31 |
+
for v in y_true:
|
32 |
+
y_true_2.append(0 if v <= 1 else 1)
|
33 |
+
y_true_2 = np.array(y_true_2)
|
34 |
+
Has0_acc_2 = accuracy_score(y_pred_2, y_true_2)
|
35 |
+
Has0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
|
36 |
+
# without 0 (< 0 or > 0)
|
37 |
+
non_zeros = np.array([i for i, e in enumerate(y_true) if e != 1])
|
38 |
+
y_pred_2 = y_pred[non_zeros]
|
39 |
+
y_pred_2 = np.argmax(y_pred_2, axis=1)
|
40 |
+
y_true_2 = y_true[non_zeros]
|
41 |
+
Non0_acc_2 = accuracy_score(y_pred_2, y_true_2)
|
42 |
+
Non0_F1_score = f1_score(y_true_2, y_pred_2, average='weighted')
|
43 |
+
|
44 |
+
eval_results = {
|
45 |
+
"Has0_acc_2": round(Has0_acc_2, 4),
|
46 |
+
"Has0_F1_score": round(Has0_F1_score, 4),
|
47 |
+
"Non0_acc_2": round(Non0_acc_2, 4),
|
48 |
+
"Non0_F1_score": round(Non0_F1_score, 4),
|
49 |
+
"Acc_3": round(Mult_acc_3, 4),
|
50 |
+
"F1_score_3": round(F1_score_3, 4)
|
51 |
+
}
|
52 |
+
return eval_results
|
53 |
+
|
54 |
+
def __eval_mosei_classification(self, y_pred, y_true):
|
55 |
+
return self.__eval_mosi_classification(y_pred, y_true)
|
56 |
+
|
57 |
+
|
58 |
+
def __multiclass_acc(self, y_pred, y_true):
|
59 |
+
"""
|
60 |
+
Compute the multiclass accuracy w.r.t. groundtruth
|
61 |
+
|
62 |
+
:param preds: Float array representing the predictions, dimension (N,)
|
63 |
+
:param truths: Float/int array representing the groundtruth classes, dimension (N,)
|
64 |
+
:return: Classification accuracy
|
65 |
+
"""
|
66 |
+
return np.sum(np.round(y_pred) == np.round(y_true)) / float(len(y_true))
|
67 |
+
|
68 |
+
def __eval_mosei_regression(self, y_pred, y_true, exclude_zero=False):
|
69 |
+
test_preds = y_pred.view(-1).cpu().detach().numpy()
|
70 |
+
test_truth = y_true.view(-1).cpu().detach().numpy()
|
71 |
+
|
72 |
+
test_preds_a7 = np.clip(test_preds, a_min=-3., a_max=3.)
|
73 |
+
test_truth_a7 = np.clip(test_truth, a_min=-3., a_max=3.)
|
74 |
+
test_preds_a5 = np.clip(test_preds, a_min=-2., a_max=2.)
|
75 |
+
test_truth_a5 = np.clip(test_truth, a_min=-2., a_max=2.)
|
76 |
+
test_preds_a3 = np.clip(test_preds, a_min=-1., a_max=1.)
|
77 |
+
test_truth_a3 = np.clip(test_truth, a_min=-1., a_max=1.)
|
78 |
+
|
79 |
+
|
80 |
+
mae = np.mean(np.absolute(test_preds - test_truth)).astype(np.float64) # Average L1 distance between preds and truths
|
81 |
+
corr = np.corrcoef(test_preds, test_truth)[0][1]
|
82 |
+
mult_a7 = self.__multiclass_acc(test_preds_a7, test_truth_a7)
|
83 |
+
mult_a5 = self.__multiclass_acc(test_preds_a5, test_truth_a5)
|
84 |
+
mult_a3 = self.__multiclass_acc(test_preds_a3, test_truth_a3)
|
85 |
+
|
86 |
+
non_zeros = np.array([i for i, e in enumerate(test_truth) if e != 0])
|
87 |
+
non_zeros_binary_truth = (test_truth[non_zeros] > 0)
|
88 |
+
non_zeros_binary_preds = (test_preds[non_zeros] > 0)
|
89 |
+
|
90 |
+
non_zeros_acc2 = accuracy_score(non_zeros_binary_preds, non_zeros_binary_truth)
|
91 |
+
non_zeros_f1_score = f1_score(non_zeros_binary_truth, non_zeros_binary_preds, average='weighted')
|
92 |
+
|
93 |
+
binary_truth = (test_truth >= 0)
|
94 |
+
binary_preds = (test_preds >= 0)
|
95 |
+
acc2 = accuracy_score(binary_preds, binary_truth)
|
96 |
+
f_score = f1_score(binary_truth, binary_preds, average='weighted')
|
97 |
+
|
98 |
+
eval_results = {
|
99 |
+
"Acc_2": round(non_zeros_acc2, 4),
|
100 |
+
"F1_score": round(non_zeros_f1_score, 4),
|
101 |
+
"Acc_7": round(mult_a7, 4),
|
102 |
+
"MAE": round(mae, 4),
|
103 |
+
}
|
104 |
+
return eval_results
|
105 |
+
|
106 |
+
|
107 |
+
def __eval_mosi_regression(self, y_pred, y_true):
|
108 |
+
return self.__eval_mosei_regression(y_pred, y_true)
|
109 |
+
|
110 |
+
def getMetics(self, datasetName):
|
111 |
+
return self.metrics_dict[datasetName.upper()]
|