duynhm's picture
Initial commit
be2715b
raw
history blame
2.14 kB
from typing import List, Optional
import torch
from ._functional import soft_tversky_score
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .dice import DiceLoss
__all__ = ["TverskyLoss", "TverskyLossFocal"]
class TverskyLoss(DiceLoss):
"""Implementation of Tversky loss for image segmentation task.
Where TP and FP is weighted by alpha and beta params.
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
It supports binary, multiclass and multilabel cases
Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute in loss computation;
By default, all channels are included.
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
from_logits: If True assumes input is raw logits
smooth:
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: Small epsilon for numerical stability
alpha: Weight constant that penalize model for FPs (False Positives)
beta: Weight constant that penalize model for FNs (False Positives)
gamma: Constant that squares the error function. Defaults to ``1.0``
Return:
loss: torch.Tensor
"""
def __init__(
self,
mode: str,
classes: List[int] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(mode, classes, log_loss, from_logits, smooth, ignore_index, eps)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
def aggregate_loss(self, loss):
return loss.mean() ** self.gamma
def compute_score(self, output, target, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor:
return soft_tversky_score(output, target, self.alpha, self.beta, smooth, eps, dims)