|
from typing import List, Union |
|
|
|
import torch |
|
from torch.nn.functional import cross_entropy |
|
|
|
from .constants import IGNORE_INDEX |
|
|
|
__all__ = ["soft_cross_entropy"] |
|
|
|
|
|
def soft_cross_entropy( |
|
outputs: torch.Tensor, |
|
targets: torch.Tensor, |
|
soft_tokens: Union[torch.Tensor, List[int]], |
|
std: float = 1, |
|
ignore_index: int = IGNORE_INDEX, |
|
) -> torch.Tensor: |
|
|
|
outputs = outputs[..., :-1, :].contiguous() |
|
targets = targets[..., 1:].contiguous() |
|
|
|
|
|
targets = targets.view(-1) |
|
outputs = outputs.view(targets.size(0), -1) |
|
|
|
|
|
indices = targets != ignore_index |
|
outputs = outputs[indices] |
|
targets = targets[indices] |
|
|
|
|
|
if isinstance(soft_tokens, list): |
|
soft_tokens = torch.tensor(soft_tokens).to(targets) |
|
|
|
|
|
indices = torch.isin(targets, soft_tokens, invert=True) |
|
loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") |
|
|
|
|
|
indices = torch.isin(targets, soft_tokens) |
|
targets_indices = torch.zeros_like(outputs[indices]) |
|
for k, target in enumerate(targets[indices]): |
|
dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) |
|
targets_indices[k][soft_tokens] = dist / dist.sum() |
|
loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") |
|
|
|
|
|
return loss / targets.size(0) |
|
|