File size: 2,095 Bytes
29c9ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
from torch import nn
class Loss(object):
def __call__(self, *args, **kwargs):
raise NotImplementedError
# Dummy Loss for testing.
class DummyLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class DummyK400Loss(Loss):
"""dummy k400 loss for MViT."""
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(
logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))
class CrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
class ArgmaxCrossEntropy(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets.argmax(dim=1))
class BCE(Loss):
def __init__(self):
self.loss = nn.BCEWithLogitsLoss()
def __call__(self, logits, targets, **kwargs):
targets = targets.squeeze(0)
return self.loss(logits, targets)
class NLGLoss(Loss):
def __init__(self):
self.loss = nn.CrossEntropyLoss()
def __call__(self, logits, text_label, **kwargs):
targets = text_label[text_label != -100]
return self.loss(logits, targets)
class MSE(Loss):
def __init__(self):
self.loss = nn.MSELoss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class L1(Loss):
def __init__(self):
self.loss = nn.L1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
class SmoothL1(Loss):
def __init__(self):
self.loss = nn.SmoothL1Loss()
def __call__(self, logits, targets, **kwargs):
return self.loss(logits, targets)
|