|
"""Graph distillation for homo GD""" |
|
|
|
import torch |
|
from torch.autograd import Variable |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from ..utils import distance_metric, min_cosine |
|
|
|
class DistillationKernel(nn.Module): |
|
"""Graph Distillation kernel. |
|
|
|
Calculate the edge weights e_{j->k} for each j. Modality k is specified by |
|
to_idx, and the other modalities are specified by from_idx. |
|
""" |
|
|
|
def __init__(self, n_classes, hidden_size, gd_size, to_idx, from_idx, |
|
gd_prior, gd_reg, w_losses, metric, alpha, hyp_params): |
|
super(DistillationKernel, self).__init__() |
|
self.W_logit = nn.Linear(n_classes, gd_size) |
|
self.W_repr = nn.Linear(hidden_size, gd_size) |
|
self.W_edge = nn.Linear(gd_size * 4, 1) |
|
|
|
self.gd_size = gd_size |
|
self.to_idx = to_idx |
|
self.from_idx = from_idx |
|
self.alpha = alpha |
|
self.gd_prior = Variable(torch.FloatTensor(gd_prior).cuda()) |
|
self.gd_reg = gd_reg |
|
self.w_losses = w_losses |
|
self.metric = metric |
|
self.hyp_params = hyp_params |
|
|
|
|
|
def forward(self, logits, reprs): |
|
""" |
|
Args: |
|
logits: (n_modalities, batch_size, n_classes) |
|
reprs: (n_modalities, batch_siz`, hidden_size) |
|
Return: |
|
edges: weights e_{j->k} (n_modalities_from, batch_size) |
|
""" |
|
n_modalities, batch_size = logits.size()[:2] |
|
z_logits = self.W_logit(logits.view(n_modalities * batch_size, -1)) |
|
z_reprs = self.W_repr(reprs.view(n_modalities * batch_size, -1)) |
|
z = torch.cat( |
|
(z_logits, z_reprs), dim=1).view(n_modalities, batch_size, |
|
self.gd_size * 2) |
|
|
|
|
|
edges = [] |
|
for j in self.to_idx: |
|
for i in self.from_idx: |
|
if i == j: |
|
continue |
|
else: |
|
|
|
e = self.W_edge(torch.cat((z[j], z[i]), dim=1)) |
|
edges.append(e) |
|
edges = torch.cat(edges, dim=1) |
|
edges_origin = edges.sum(0).unsqueeze(0).transpose(0, 1) |
|
edges = F.softmax(edges * self.alpha, dim=1).transpose(0, 1) |
|
return edges, edges_origin |
|
|
|
|
|
def distillation_loss(self, logits, reprs, edges): |
|
"""Calculate graph distillation losses, which include: |
|
regularization loss, loss for logits, and loss for representation. |
|
""" |
|
loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg |
|
|
|
|
|
loss_logit, loss_repr = 0, 0 |
|
x = 0 |
|
for j in self.to_idx: |
|
for i, idx in enumerate(self.from_idx): |
|
if i == j: |
|
continue |
|
else: |
|
w_distill = edges[x] + self.gd_prior[x] |
|
|
|
loss_logit += self.w_losses[0] * distance_metric( |
|
logits[j], logits[idx], self.metric, w_distill) |
|
loss_repr += self.w_losses[1] * distance_metric( |
|
reprs[j], reprs[idx], self.metric, w_distill) |
|
x = x + 1 |
|
return loss_reg, loss_logit, loss_repr |
|
|
|
|
|
def get_distillation_kernel(n_classes, |
|
hidden_size, |
|
gd_size, |
|
to_idx, |
|
from_idx, |
|
gd_prior, |
|
gd_reg, |
|
w_losses, |
|
metric, |
|
alpha=1 / 8): |
|
return DistillationKernel(n_classes, hidden_size, gd_size, to_idx, from_idx, |
|
gd_prior, gd_reg, w_losses, metric, alpha) |
|
|