PyTorch
ssl-aasist
custom_code
File size: 2,095 Bytes
29c9ba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) Facebook, Inc. All Rights Reserved

import torch

from torch import nn


class Loss(object):
    def __call__(self, *args, **kwargs):
        raise NotImplementedError


# Dummy Loss for testing.
class DummyLoss(Loss):
    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits, targets)


class DummyK400Loss(Loss):
    """dummy k400 loss for MViT."""
    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(
            logits, torch.randint(0, 400, (logits.size(0),), device=logits.device))


class CrossEntropy(Loss):
    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))


class ArgmaxCrossEntropy(Loss):
    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits, targets.argmax(dim=1))


class BCE(Loss):
    def __init__(self):
        self.loss = nn.BCEWithLogitsLoss()

    def __call__(self, logits, targets, **kwargs):
        targets = targets.squeeze(0)
        return self.loss(logits, targets)


class NLGLoss(Loss):
    def __init__(self):
        self.loss = nn.CrossEntropyLoss()

    def __call__(self, logits, text_label, **kwargs):
        targets = text_label[text_label != -100]
        return self.loss(logits, targets)


class MSE(Loss):
    def __init__(self):
        self.loss = nn.MSELoss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits, targets)


class L1(Loss):
    def __init__(self):
        self.loss = nn.L1Loss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits, targets)


class SmoothL1(Loss):
    def __init__(self):
        self.loss = nn.SmoothL1Loss()

    def __call__(self, logits, targets, **kwargs):
        return self.loss(logits, targets)