|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import sys |
|
import torch |
|
|
|
from typing import Optional |
|
from dataclasses import dataclass, field |
|
from omegaconf import MISSING |
|
|
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.tasks import FairseqTask, register_task |
|
from fairseq.logging import metrics |
|
|
|
try: |
|
from ..data import MaeFinetuningImageDataset |
|
except: |
|
sys.path.append("..") |
|
from data import MaeFinetuningImageDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class MaeImageClassificationConfig(FairseqDataclass): |
|
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) |
|
input_size: int = 224 |
|
local_cache_path: Optional[str] = None |
|
|
|
rebuild_batches: bool = True |
|
|
|
|
|
@register_task("mae_image_classification", dataclass=MaeImageClassificationConfig) |
|
class MaeImageClassificationTask(FairseqTask): |
|
""" """ |
|
|
|
cfg: MaeImageClassificationConfig |
|
|
|
@classmethod |
|
def setup_task(cls, cfg: MaeImageClassificationConfig, **kwargs): |
|
"""Setup the task (e.g., load dictionaries). |
|
|
|
Args: |
|
cfg (AudioPretrainingConfig): configuration of this task |
|
""" |
|
|
|
return cls(cfg) |
|
|
|
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): |
|
data_path = self.cfg.data |
|
cfg = task_cfg or self.cfg |
|
|
|
self.datasets[split] = MaeFinetuningImageDataset( |
|
root=data_path, |
|
split=split, |
|
is_train=split == "train", |
|
input_size=cfg.input_size, |
|
local_cache_path=cfg.local_cache_path, |
|
shuffle=split == "train", |
|
) |
|
|
|
def build_model(self, model_cfg: FairseqDataclass, from_checkpoint=False): |
|
model = super().build_model(model_cfg, from_checkpoint) |
|
|
|
actualized_cfg = getattr(model, "cfg", None) |
|
if actualized_cfg is not None: |
|
if hasattr(actualized_cfg, "pretrained_model_args"): |
|
model_cfg.pretrained_model_args = actualized_cfg.pretrained_model_args |
|
|
|
return model |
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
super().reduce_metrics(logging_outputs, criterion) |
|
|
|
if "correct" in logging_outputs[0]: |
|
zero = torch.scalar_tensor(0.0) |
|
correct = sum(log.get("correct", zero) for log in logging_outputs) |
|
metrics.log_scalar_sum("_correct", correct) |
|
|
|
metrics.log_derived( |
|
"accuracy", |
|
lambda meters: 100 * meters["_correct"].sum / meters["sample_size"].sum, |
|
) |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
@property |
|
def target_dictionary(self): |
|
return None |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
return sys.maxsize, sys.maxsize |
|
|