File size: 3,617 Bytes
010952f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import logging
import sys
from typing import Optional, List
from dataclasses import dataclass, field
from omegaconf import MISSING, II
from fairseq.data import SubsampleDataset
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
try:
from ..data import MaeImageDataset
except:
sys.path.append("..")
from data import MaeImageDataset
logger = logging.getLogger(__name__)
@dataclass
class ImageMaskingConfig:
patch_size: int = II("model.modalities.image.patch_size")
mask_prob: float = II("model.modalities.image.mask_prob")
mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust")
mask_length: int = II("model.modalities.image.mask_length")
inverse_mask: bool = II("model.modalities.image.inverse_mask")
mask_dropout: float = II("model.modalities.image.mask_dropout")
clone_batch: int = II("model.clone_batch")
expand_adjacent: bool = False
non_overlapping: bool = False
@dataclass
class MaeImagePretrainingConfig(FairseqDataclass):
data: str = field(default=MISSING, metadata={"help": "path to data directory"})
multi_data: Optional[List[str]] = None
input_size: int = 224
local_cache_path: Optional[str] = None
key: str = "imgs"
beit_transforms: bool = False
target_transform: bool = False
no_transform: bool = False
rebuild_batches: bool = True
precompute_mask_config: Optional[ImageMaskingConfig] = None
subsample: float = 1
seed: int = II("common.seed")
dataset_type: str = "imagefolder"
@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig)
class MaeImagePretrainingTask(FairseqTask):
""" """
cfg: MaeImagePretrainingConfig
@classmethod
def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
cfg (AudioPretrainingConfig): configuration of this task
"""
return cls(cfg)
def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs):
data_path = self.cfg.data
cfg = task_cfg or self.cfg
compute_mask = cfg.precompute_mask_config is not None
mask_args = {}
if compute_mask:
mask_args = cfg.precompute_mask_config
self.datasets[split] = MaeImageDataset(
root=data_path if cfg.multi_data is None else cfg.multi_data,
split=split,
input_size=cfg.input_size,
local_cache_path=cfg.local_cache_path,
key=cfg.key,
beit_transforms=cfg.beit_transforms,
target_transform=cfg.target_transform,
no_transform=cfg.no_transform,
compute_mask=compute_mask,
dataset_type=cfg.dataset_type,
**mask_args,
)
if cfg.subsample < 1:
self.datasets[split] = SubsampleDataset(
self.datasets[split],
cfg.subsample,
shuffle=True,
seed=cfg.seed,
)
@property
def source_dictionary(self):
return None
@property
def target_dictionary(self):
return None
def max_positions(self):
"""Maximum input length supported by the encoder."""
return sys.maxsize, sys.maxsize
|