from typing import List, Union import torch from torch.nn.functional import cross_entropy from .constants import IGNORE_INDEX __all__ = ["soft_cross_entropy"] def soft_cross_entropy( outputs: torch.Tensor, targets: torch.Tensor, soft_tokens: Union[torch.Tensor, List[int]], std: float = 1, ignore_index: int = IGNORE_INDEX, ) -> torch.Tensor: # Remove last token from outputs and first token from targets outputs = outputs[..., :-1, :].contiguous() targets = targets[..., 1:].contiguous() # Flatten outputs and targets targets = targets.view(-1) outputs = outputs.view(targets.size(0), -1) # Remove outputs and targets with ignore_index indices = targets != ignore_index outputs = outputs[indices] targets = targets[indices] # Convert soft token IDs to tensor if isinstance(soft_tokens, list): soft_tokens = torch.tensor(soft_tokens).to(targets) # Calculate loss for non-soft tokens indices = torch.isin(targets, soft_tokens, invert=True) loss = cross_entropy(outputs[indices], targets[indices], reduction="sum") # Calculate loss for soft tokens indices = torch.isin(targets, soft_tokens) targets_indices = torch.zeros_like(outputs[indices]) for k, target in enumerate(targets[indices]): dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2)) targets_indices[k][soft_tokens] = dist / dist.sum() loss += cross_entropy(outputs[indices], targets_indices, reduction="sum") # Return average loss return loss / targets.size(0)