from torch import nn | |
from transformers import Trainer | |
class CustomTrainer(Trainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
labels = inputs.pop("labels") | |
# forward pass | |
outputs = model(**inputs) | |
logits = outputs.get("logits") | |
# compute custom loss for 3 labels with different weights | |
loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) | |
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
return (loss, outputs) if return_outputs else loss | |
Callbacks | |
Another option for customizing the [Trainer] is to use callbacks. |