import torch import torch.nn as nn class HingeLoss(nn.Module): def __init__(self): super(HingeLoss, self).__init__() def compute_cosine(self, x, y): # x = self.compute_compact_s(x) # y = self.compute_compact_s(y) x_norm = torch.sqrt(torch.sum(torch.pow(x, 2), 1)+1e-8) x_norm = torch.max(x_norm, 1e-8*torch.ones_like(x_norm)) y_norm = torch.sqrt(torch.sum(torch.pow(y, 2), 1)+1e-8) y_norm = torch.max(y_norm, 1e-8*torch.ones_like(y_norm)) cosine = torch.sum(x * y, 1) / (x_norm * y_norm) return cosine def forward(self, ids, feats, margin=0.1): B, F = feats.shape s = feats.repeat(1, B).view(-1, F) # B**2 X F s_ids = ids.view(B, 1).repeat(1, B) # B X B t = feats.repeat(B, 1) # B**2 X F t_ids = ids.view(1, B).repeat(B, 1) # B X B cosine = self.compute_cosine(s, t) # B**2 equal_mask = torch.eye(B, dtype=torch.bool) # B X B s_ids = s_ids[~equal_mask].view(B, B-1) # B X (B-1) t_ids = t_ids[~equal_mask].view(B, B-1) # B X (B-1) cosine = cosine.view(B, B)[~equal_mask].view(B, B-1) # B X (B-1) sim_mask = (s_ids == t_ids) # B X (B-1) margin = 0.15 * abs(s_ids - t_ids)#[~sim_mask].view(B, B - 3) loss = 0 loss_num = 0 for i in range(B): sim_num = sum(sim_mask[i]) dif_num = B - 1 - sim_num if not sim_num or not dif_num: continue sim_cos = cosine[i, sim_mask[i]].reshape(-1, 1).repeat(1, dif_num) dif_cos = cosine[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1) t_margin = margin[i, ~sim_mask[i]].reshape(-1, 1).repeat(1, sim_num).transpose(0, 1) loss_i = torch.max(torch.zeros_like(sim_cos), t_margin - sim_cos + dif_cos).mean() loss += loss_i loss_num += 1 if loss_num == 0: loss_num = 1 loss = loss / loss_num return loss