Ligeng-Zhu's picture
Upload files with `vila-upload`.
c857c8b verified
raw
history blame
1.57 kB
from typing import List, Union
import torch
from torch.nn.functional import cross_entropy
from .constants import IGNORE_INDEX
__all__ = ["soft_cross_entropy"]
def soft_cross_entropy(
outputs: torch.Tensor,
targets: torch.Tensor,
soft_tokens: Union[torch.Tensor, List[int]],
std: float = 1,
ignore_index: int = IGNORE_INDEX,
) -> torch.Tensor:
# Remove last token from outputs and first token from targets
outputs = outputs[..., :-1, :].contiguous()
targets = targets[..., 1:].contiguous()
# Flatten outputs and targets
targets = targets.view(-1)
outputs = outputs.view(targets.size(0), -1)
# Remove outputs and targets with ignore_index
indices = targets != ignore_index
outputs = outputs[indices]
targets = targets[indices]
# Convert soft token IDs to tensor
if isinstance(soft_tokens, list):
soft_tokens = torch.tensor(soft_tokens).to(targets)
# Calculate loss for non-soft tokens
indices = torch.isin(targets, soft_tokens, invert=True)
loss = cross_entropy(outputs[indices], targets[indices], reduction="sum")
# Calculate loss for soft tokens
indices = torch.isin(targets, soft_tokens)
targets_indices = torch.zeros_like(outputs[indices])
for k, target in enumerate(targets[indices]):
dist = torch.exp(-((target - soft_tokens) ** 2) / (2 * std**2))
targets_indices[k][soft_tokens] = dist / dist.sum()
loss += cross_entropy(outputs[indices], targets_indices, reduction="sum")
# Return average loss
return loss / targets.size(0)