|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import sys |
|
|
|
from typing import Optional, List |
|
from dataclasses import dataclass, field |
|
from omegaconf import MISSING, II |
|
|
|
from fairseq.data import SubsampleDataset |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.tasks import FairseqTask, register_task |
|
|
|
try: |
|
from ..data import MaeImageDataset |
|
except: |
|
sys.path.append("..") |
|
from data import MaeImageDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class ImageMaskingConfig: |
|
patch_size: int = II("model.modalities.image.patch_size") |
|
mask_prob: float = II("model.modalities.image.mask_prob") |
|
mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust") |
|
mask_length: int = II("model.modalities.image.mask_length") |
|
inverse_mask: bool = II("model.modalities.image.inverse_mask") |
|
mask_dropout: float = II("model.modalities.image.mask_dropout") |
|
clone_batch: int = II("model.clone_batch") |
|
expand_adjacent: bool = False |
|
non_overlapping: bool = False |
|
|
|
|
|
@dataclass |
|
class MaeImagePretrainingConfig(FairseqDataclass): |
|
data: str = field(default=MISSING, metadata={"help": "path to data directory"}) |
|
multi_data: Optional[List[str]] = None |
|
input_size: int = 224 |
|
local_cache_path: Optional[str] = None |
|
key: str = "imgs" |
|
|
|
beit_transforms: bool = False |
|
target_transform: bool = False |
|
no_transform: bool = False |
|
|
|
rebuild_batches: bool = True |
|
|
|
precompute_mask_config: Optional[ImageMaskingConfig] = None |
|
|
|
subsample: float = 1 |
|
seed: int = II("common.seed") |
|
dataset_type: str = "imagefolder" |
|
|
|
|
|
@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig) |
|
class MaeImagePretrainingTask(FairseqTask): |
|
""" """ |
|
|
|
cfg: MaeImagePretrainingConfig |
|
|
|
@classmethod |
|
def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs): |
|
"""Setup the task (e.g., load dictionaries). |
|
|
|
Args: |
|
cfg (AudioPretrainingConfig): configuration of this task |
|
""" |
|
|
|
return cls(cfg) |
|
|
|
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): |
|
data_path = self.cfg.data |
|
cfg = task_cfg or self.cfg |
|
|
|
compute_mask = cfg.precompute_mask_config is not None |
|
mask_args = {} |
|
if compute_mask: |
|
mask_args = cfg.precompute_mask_config |
|
|
|
self.datasets[split] = MaeImageDataset( |
|
root=data_path if cfg.multi_data is None else cfg.multi_data, |
|
split=split, |
|
input_size=cfg.input_size, |
|
local_cache_path=cfg.local_cache_path, |
|
key=cfg.key, |
|
beit_transforms=cfg.beit_transforms, |
|
target_transform=cfg.target_transform, |
|
no_transform=cfg.no_transform, |
|
compute_mask=compute_mask, |
|
dataset_type=cfg.dataset_type, |
|
**mask_args, |
|
) |
|
|
|
if cfg.subsample < 1: |
|
self.datasets[split] = SubsampleDataset( |
|
self.datasets[split], |
|
cfg.subsample, |
|
shuffle=True, |
|
seed=cfg.seed, |
|
) |
|
|
|
@property |
|
def source_dictionary(self): |
|
return None |
|
|
|
@property |
|
def target_dictionary(self): |
|
return None |
|
|
|
def max_positions(self): |
|
"""Maximum input length supported by the encoder.""" |
|
return sys.maxsize, sys.maxsize |
|
|