duynhm's picture
Initial commit
be2715b
raw
history blame
1.82 kB
import re
import torch.nn as nn
class BaseObject(nn.Module):
def __init__(self, name=None):
super().__init__()
self._name = name
@property
def __name__(self):
if self._name is None:
name = self.__class__.__name__
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
else:
return self._name
class Metric(BaseObject):
pass
class Loss(BaseObject):
def __add__(self, other):
if isinstance(other, Loss):
return SumOfLosses(self, other)
else:
raise ValueError('Loss should be inherited from `Loss` class')
def __radd__(self, other):
return self.__add__(other)
def __mul__(self, value):
if isinstance(value, (int, float)):
return MultipliedLoss(self, value)
else:
raise ValueError('Loss should be inherited from `BaseLoss` class')
def __rmul__(self, other):
return self.__mul__(other)
class SumOfLosses(Loss):
def __init__(self, l1, l2):
name = '{} + {}'.format(l1.__name__, l2.__name__)
super().__init__(name=name)
self.l1 = l1
self.l2 = l2
def __call__(self, *inputs):
return self.l1.forward(*inputs) + self.l2.forward(*inputs)
class MultipliedLoss(Loss):
def __init__(self, loss, multiplier):
# resolve name
if len(loss.__name__.split('+')) > 1:
name = '{} * ({})'.format(multiplier, loss.__name__)
else:
name = '{} * {}'.format(multiplier, loss.__name__)
super().__init__(name=name)
self.loss = loss
self.multiplier = multiplier
def __call__(self, *inputs):
return self.multiplier * self.loss.forward(*inputs)