File size: 3,482 Bytes
9157432 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
"""Graph distillation for homo GD"""
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from ..utils import distance_metric, min_cosine
class DistillationKernel(nn.Module):
"""Graph Distillation kernel.
Calculate the edge weights e_{j->k} for each j. Modality k is specified by
to_idx, and the other modalities are specified by from_idx.
"""
def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx,
gd_prior, gd_reg, w_losses, metric, alpha, hyp_params):
super(DistillationKernel, self).__init__()
self.W_logit = nn.Linear(n_classes, gd_size)
self.W_repr = nn.Linear(hidden_size, gd_size)
self.W_edge = nn.Linear(gd_size * 4, 1)
self.gd_size = gd_size
self.to_idx = to_idx
self.from_idx = from_idx
self.alpha = alpha
self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda())
self.gd_reg = gd_reg
self.w_losses = w_losses
self.metric = metric
self.hyp_params = hyp_params
def forward(self, logits, reprs):
"""
Args:
logits: (n_modalities, batch_size, n_classes)
reprs: (n_modalities, batch_siz`, hidden_size)
Return:
edges: weights e_{j->k} (n_modalities_from, batch_size)
"""
n_modalities, batch_size = logits.size()[:2]
z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1))
z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1))
z = torch.cat(
(z_logits, z_reprs), dim=1).view(n_modalities, batch_size,
self.gd_size * 2)
edges = []
for j in self.to_idx:
for i in self.from_idx:
if i == j:
continue
else:
# To calculate e_{j->k}, concatenate z^j, z^k
e = self.W_edge(torch.cat((z[j], z[i]), dim=1))
edges.append(e)
edges = torch.cat(edges, dim=1)
edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1) # original value of edges
edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1) # normalized value of edges
return edges, edges_origin
def distillation_loss(self, logits, reprs, edges):
"""Calculate graph distillation losses, which include:
regularization loss, loss for logits, and loss for representation.
"""
loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg
loss_logit, loss_repr = 0, 0
x = 0
for j in self.to_idx:
for i, idx in enumerate(self.from_idx):
if i == j:
continue
else:
w_distill = edges[x] + self.gd_prior[x]
# print(edges.sum(1), w_distill.sum(0))
loss_logit += self.w_losses[0] * distance_metric(
logits[j], logits[idx], self.metric, w_distill)
loss_repr += self.w_losses[1] * distance_metric(
reprs[j], reprs[idx], self.metric, w_distill)
x = x + 1
return loss_reg, loss_logit, loss_repr
def get_distillation_kernel(n_classes,
hidden_size,
gd_size,
to_idx,
from_idx,
gd_prior,
gd_reg,
w_losses,
metric,
alpha=1 / 8):
return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx,
gd_prior, gd_reg, w_losses, metric, alpha)
|