# Copyright (c) Microsoft Corporation. All rights reserved. # See LICENSE in the repo root for license information. # # Portions: # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import logging from PIL import Image from torchvision import transforms from .transforms import ( GaussianBlur, MaybeToTensor, make_normalize_transform, ) logger = logging.getLogger("dinov2") class DataAugmentationDINO(object): def __init__( self, global_crops_scale, local_crops_scale, local_crops_number, global_crops_size=224, local_crops_size=96, ): self.global_crops_scale = global_crops_scale self.local_crops_scale = local_crops_scale self.local_crops_number = local_crops_number self.global_crops_size = global_crops_size self.local_crops_size = local_crops_size logger.info("###################################") logger.info("Using data augmentation parameters:") logger.info(f"global_crops_scale: {global_crops_scale}") logger.info(f"local_crops_scale: {local_crops_scale}") logger.info(f"local_crops_number: {local_crops_number}") logger.info(f"global_crops_size: {global_crops_size}") logger.info(f"local_crops_size: {local_crops_size}") logger.info("###################################") # random resized crop and flip self.geometric_augmentation_global = transforms.Compose( [ transforms.RandomResizedCrop( global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(p=0.5), ] ) self.geometric_augmentation_local = transforms.Compose( [ transforms.RandomResizedCrop( local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC ), transforms.RandomHorizontalFlip(p=0.5), ] ) # color distorsions / blurring color_jittering = transforms.Compose( [ transforms.RandomApply( [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8, ), transforms.RandomGrayscale(p=0.2), ] ) global_transfo1_extra = GaussianBlur(p=0.5) global_transfo2_extra = transforms.Compose( [ GaussianBlur(p=0.1), ] ) local_transfo_extra = GaussianBlur(p=0.5) # normalization self.normalize = transforms.Compose( [ MaybeToTensor(), make_normalize_transform(), ] ) self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) def __call__(self, image): output = {} # global crops: im1_base = self.geometric_augmentation_global(image) global_crop_1 = self.global_transfo1(im1_base) im2_base = self.geometric_augmentation_global(image) global_crop_2 = self.global_transfo2(im2_base) output["global_crops"] = [global_crop_1, global_crop_2] # global crops for teacher: output["global_crops_teacher"] = [global_crop_1, global_crop_2] # local crops: local_crops = [ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) ] output["local_crops"] = local_crops output["offsets"] = () return output def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose: augmentation_config = cfg.evaluation.online.augmentation interpolation = getattr(Image.Resampling, augmentation_config.interpolation) resize_size = crop_size = cfg.crops.global_crops_size resize = transforms.Resize(resize_size, interpolation=interpolation) crop = transforms.CenterCrop(crop_size) affine = transforms.RandomAffine( degrees=augmentation_config.degrees, scale=augmentation_config.scale, shear=augmentation_config.shear, interpolation=interpolation, ) transforms_list = [ resize, crop, affine, MaybeToTensor(), make_normalize_transform(), ] if augmentation_config.horizontal_flip: transforms_list.append(transforms.RandomHorizontalFlip()) return transforms.Compose(transforms_list)