Spaces:
svjack
/
Runtime error

yuandong513
feat: init
17cd746
raw
history blame
1.1 kB
import torch
import torch.nn as nn
class SmoothL1Loss(nn.Module):
def __init__(self, scale=0.01):
super(SmoothL1Loss, self).__init__()
self.scale = scale
self.EPSILON = 1e-10
def __repr__(self):
return "SmoothL1Loss()"
def forward(self, output: torch.Tensor, groundtruth: torch.Tensor, reduction='mean'):
"""
input: b x n x 2
output: b x n x 1 => 1
"""
if output.dim() == 4:
shape = output.shape
groundtruth = groundtruth.reshape(shape[0], shape[1], 1, shape[3])
delta_2 = (output - groundtruth).pow(2).sum(dim=-1, keepdim=False)
delta = delta_2.clamp(min=1e-6).sqrt()
# delta = torch.sqrt(delta_2 + self.EPSILON)
loss = torch.where( \
delta_2 < self.scale * self.scale, \
0.5 / self.scale * delta_2, \
delta - 0.5 * self.scale)
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
return loss