# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from functools import partial import logging import math import random import time import numpy as np import os import torch from torchvision import datasets, transforms from .path_dataset import PathDataset from fairseq.data import FairseqDataset from fairseq.data.data_utils import compute_block_mask_1d, compute_block_mask_2d from shutil import copyfile logger = logging.getLogger(__name__) def load(path, loader, cache): if hasattr(caching_loader, "cache_root"): cache = caching_loader.cache_root cached_path = cache + path num_tries = 3 for curr_try in range(num_tries): try: if curr_try == 2: return loader(path) if not os.path.exists(cached_path) or curr_try > 0: os.makedirs(os.path.dirname(cached_path), exist_ok=True) copyfile(path, cached_path) os.chmod(cached_path, 0o777) return loader(cached_path) except Exception as e: logger.warning(str(e)) if "Errno 13" in str(e): caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}" logger.warning(f"setting cache root to {caching_loader.cache_root}") cached_path = caching_loader.cache_root + path if curr_try == (num_tries - 1): raise time.sleep(2) def caching_loader(cache_root: str, loader): if cache_root is None: return loader if cache_root == "slurm_tmpdir": cache_root = os.environ["SLURM_TMPDIR"] assert len(cache_root) > 0 if not cache_root.endswith("/"): cache_root += "/" return partial(load, loader=loader, cache=cache_root) class RandomResizedCropAndInterpolationWithTwoPic: """Crop the given PIL Image to random size and aspect ratio with random interpolation. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge scale: range of size of the origin size cropped ratio: range of aspect ratio of the origin aspect ratio cropped interpolation: Default: PIL.Image.BILINEAR """ def __init__( self, size, second_size=None, scale=(0.08, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0), interpolation="bilinear", second_interpolation="lanczos", ): if isinstance(size, tuple): self.size = size else: self.size = (size, size) if second_size is not None: if isinstance(second_size, tuple): self.second_size = second_size else: self.second_size = (second_size, second_size) else: self.second_size = None if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): logger.warning("range should be of kind (min, max)") if interpolation == "random": from PIL import Image self.interpolation = (Image.BILINEAR, Image.BICUBIC) else: self.interpolation = self._pil_interp(interpolation) self.second_interpolation = ( self._pil_interp(second_interpolation) if second_interpolation is not None else None ) self.scale = scale self.ratio = ratio def _pil_interp(self, method): from PIL import Image if method == "bicubic": return Image.BICUBIC elif method == "lanczos": return Image.LANCZOS elif method == "hamming": return Image.HAMMING else: # default bilinear, do we want to allow nearest? return Image.BILINEAR @staticmethod def get_params(img, scale, ratio): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. scale (tuple): range of size of the origin size cropped ratio (tuple): range of aspect ratio of the origin aspect ratio cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ area = img.size[0] * img.size[1] for attempt in range(10): target_area = random.uniform(*scale) * area log_ratio = (math.log(ratio[0]), math.log(ratio[1])) aspect_ratio = math.exp(random.uniform(*log_ratio)) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if w <= img.size[0] and h <= img.size[1]: i = random.randint(0, img.size[1] - h) j = random.randint(0, img.size[0] - w) return i, j, h, w # Fallback to central crop in_ratio = img.size[0] / img.size[1] if in_ratio < min(ratio): w = img.size[0] h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = img.size[1] w = int(round(h * max(ratio))) else: # whole image w = img.size[0] h = img.size[1] i = (img.size[1] - h) // 2 j = (img.size[0] - w) // 2 return i, j, h, w def __call__(self, img): import torchvision.transforms.functional as F """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Randomly cropped and resized image. """ i, j, h, w = self.get_params(img, self.scale, self.ratio) if isinstance(self.interpolation, (tuple, list)): interpolation = random.choice(self.interpolation) else: interpolation = self.interpolation if self.second_size is None: return F.resized_crop(img, i, j, h, w, self.size, interpolation) else: return F.resized_crop( img, i, j, h, w, self.size, interpolation ), F.resized_crop( img, i, j, h, w, self.second_size, self.second_interpolation ) class MaeImageDataset(FairseqDataset): def __init__( self, root: str, split: str, input_size, local_cache_path=None, shuffle=True, key="imgs", beit_transforms=False, target_transform=False, no_transform=False, compute_mask=False, patch_size: int = 16, mask_prob: float = 0.75, mask_prob_adjust: float = 0, mask_length: int = 1, inverse_mask: bool = False, expand_adjacent: bool = False, mask_dropout: float = 0, non_overlapping: bool = False, require_same_masks: bool = True, clone_batch: int = 1, dataset_type: str = "imagefolder", ): FairseqDataset.__init__(self) self.shuffle = shuffle self.key = key loader = caching_loader(local_cache_path, datasets.folder.default_loader) self.transform_source = None self.transform_target = None if target_transform: self.transform_source = transforms.ColorJitter(0.4, 0.4, 0.4) self.transform_target = transforms.ColorJitter(0.4, 0.4, 0.4) if no_transform: if input_size <= 224: crop_pct = 224 / 256 else: crop_pct = 1.0 size = int(input_size / crop_pct) self.transform_train = transforms.Compose( [ transforms.Resize(size, interpolation=3), transforms.CenterCrop(input_size), ] ) self.transform_train = transforms.Resize((input_size, input_size)) elif beit_transforms: beit_transform_list = [] if not target_transform: beit_transform_list.append(transforms.ColorJitter(0.4, 0.4, 0.4)) beit_transform_list.extend( [ transforms.RandomHorizontalFlip(p=0.5), RandomResizedCropAndInterpolationWithTwoPic( size=input_size, second_size=None, interpolation="bicubic", second_interpolation=None, ), ] ) self.transform_train = transforms.Compose(beit_transform_list) else: self.transform_train = transforms.Compose( [ transforms.RandomResizedCrop( input_size, scale=(0.2, 1.0), interpolation=3 ), # 3 is bicubic transforms.RandomHorizontalFlip(), ] ) self.final_transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) if dataset_type == "imagefolder": self.dataset = datasets.ImageFolder( os.path.join(root, split), loader=loader ) elif dataset_type == "path": self.dataset = PathDataset( root, loader, None, None, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ) else: raise Exception(f"invalid dataset type {dataset_type}") logger.info( f"initial transform: {self.transform_train}, " f"source transform: {self.transform_source}, " f"target transform: {self.transform_target}, " f"final transform: {self.final_transform}" ) logger.info(f"loaded {len(self.dataset)} examples") self.is_compute_mask = compute_mask self.patches = (input_size // patch_size) ** 2 self.mask_prob = mask_prob self.mask_prob_adjust = mask_prob_adjust self.mask_length = mask_length self.inverse_mask = inverse_mask self.expand_adjacent = expand_adjacent self.mask_dropout = mask_dropout self.non_overlapping = non_overlapping self.require_same_masks = require_same_masks self.clone_batch = clone_batch def __getitem__(self, index): img, _ = self.dataset[index] img = self.transform_train(img) source = None target = None if self.transform_source is not None: source = self.final_transform(self.transform_source(img)) if self.transform_target is not None: target = self.final_transform(self.transform_target(img)) if source is None: img = self.final_transform(img) v = {"id": index, self.key: source if source is not None else img} if target is not None: v["target"] = target if self.is_compute_mask: if self.mask_length == 1: mask = compute_block_mask_1d( shape=(self.clone_batch, self.patches), mask_prob=self.mask_prob, mask_length=self.mask_length, mask_prob_adjust=self.mask_prob_adjust, inverse_mask=self.inverse_mask, require_same_masks=True, ) else: mask = compute_block_mask_2d( shape=(self.clone_batch, self.patches), mask_prob=self.mask_prob, mask_length=self.mask_length, mask_prob_adjust=self.mask_prob_adjust, inverse_mask=self.inverse_mask, require_same_masks=True, expand_adjcent=self.expand_adjacent, mask_dropout=self.mask_dropout, non_overlapping=self.non_overlapping, ) v["precomputed_mask"] = mask return v def __len__(self): return len(self.dataset) def collater(self, samples): if len(samples) == 0: return {} collated_img = torch.stack([s[self.key] for s in samples], dim=0) res = { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": { self.key: collated_img, }, } if "target" in samples[0]: collated_target = torch.stack([s["target"] for s in samples], dim=0) res["net_input"]["target"] = collated_target if "precomputed_mask" in samples[0]: collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0) res["net_input"]["precomputed_mask"] = collated_mask return res def num_tokens(self, index): return 1 def size(self, index): return 1 @property def sizes(self): return np.full((len(self),), 1) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" if self.shuffle: order = [np.random.permutation(len(self))] else: order = [np.arange(len(self))] return order[0]