File size: 3,064 Bytes
6789f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
import torch
from typing import Optional
from dataclasses import dataclass, field
from omegaconf import MISSING
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from fairseq.logging import metrics
try:
from ..data import MaeFinetuningImageDataset
except:
sys.path.append("..")
from data import MaeFinetuningImageDataset
logger = logging.getLogger(__name__)
@dataclass
class MaeImageClassificationConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
input_size: int = 224
local_cache_path: Optional[str] = None
rebuild_batches: bool = True
@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig)
class MaeImageClassificationTask(FairseqTask):
""" """
cfg: MaeImageClassificationConfig
@classmethod
def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
self.datasets[split] = MaeFinetuningImageDataset(
root=data_path,
split=split,
is_train=split == "train",
input_size=cfg.input_size,
local_cache_path=cfg.local_cache_path,
shuffle=split == "train",
)
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False):
model = super().build_model(model_cfg, from_checkpoint)
actualized_cfg = getattr(model, "cfg", None)
if actualized_cfg is not None:
if hasattr(actualized_cfg, "pretrained_model_args"):
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args
return model
def reduce_metrics(self, logging_outputs, criterion):
super().reduce_metrics(logging_outputs, criterion)
if "correct" in logging_outputs[0]:
zero = torch.scalar_tensor(0.0)
correct = sum(log.get("correct", zero) for log in logging_outputs)
metrics.log_scalar_sum("_correct", correct)
metrics.log_derived(
"accuracy",
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
|