ViTGaze / data /masking.py
yhsong's picture
initial commit
f9561b9 verified
raw
history blame
6.06 kB
import random
import math
import numpy as np
import torch
from torch.nn import functional as F
class SceneMaskGenerator:
def __init__(
self,
input_size,
min_num_patches=16,
max_num_patches_ratio=0.5,
min_aspect=0.3,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.input_size = input_size
self.num_patches = input_size[0] * input_size[1]
self.min_num_patches = min_num_patches
self.max_num_patches = max_num_patches_ratio * self.num_patches
self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
def _mask(self, mask, max_mask_patches):
delta = 0
for _ in range(4):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
height, width = self.input_size
if w < width and h < height:
top = random.randint(0, height - h)
left = random.randint(0, width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
mask[top : top + h, left : left + w] = 1
delta = h * w - num_masked
break
return delta
def __call__(self, head_mask):
mask = np.zeros(shape=self.input_size, dtype=bool)
mask_count = 0
num_masking_patches = random.uniform(self.min_num_patches, self.max_num_patches)
while mask_count < num_masking_patches:
max_mask_patches = num_masking_patches - mask_count
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
mask = torch.from_numpy(mask).unsqueeze(0)
head_mask = (
F.interpolate(head_mask.unsqueeze(0), mask.shape[-2:]).squeeze(0) < 0.5
)
return torch.logical_and(mask, head_mask).squeeze(0)
class HeadMaskGenerator:
def __init__(
self,
input_size,
min_num_patches=4,
max_num_patches_ratio=0.5,
min_aspect=0.3,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.input_size = input_size
self.num_patches = input_size[0] * input_size[1]
self.min_num_patches = min_num_patches
self.max_num_patches_ratio = max_num_patches_ratio
self.log_aspect_ratio = (math.log(min_aspect), -math.log(min_aspect))
def __call__(
self,
x_min,
y_min,
x_max,
y_max, # coords in [0,1]
):
height = math.floor((y_max - y_min) * self.input_size[0])
width = math.floor((x_max - x_min) * self.input_size[1])
origin_area = width * height
if origin_area < self.min_num_patches:
return torch.zeros(size=self.input_size, dtype=bool)
target_area = random.uniform(
self.min_num_patches, self.max_num_patches_ratio * origin_area
)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = min(int(round(math.sqrt(target_area * aspect_ratio))), height)
w = min(int(round(math.sqrt(target_area / aspect_ratio))), width)
top = random.randint(0, height - h) + int(y_min * self.input_size[0])
left = random.randint(0, width - w) + int(x_min * self.input_size[1])
mask = torch.zeros(size=self.input_size, dtype=bool)
mask[top : top + h, left : left + w] = True
return mask
class MaskGenerator:
def __init__(
self,
input_size,
mask_scene: bool = False,
mask_head: bool = False,
min_scene_patches=16,
max_scene_patches_ratio=0.5,
min_head_patches=4,
max_head_patches_ratio=0.5,
min_aspect=0.3,
mask_prob=0.2,
head_prob=0.2,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.input_size = input_size
if mask_scene:
self.scene_mask_generator = SceneMaskGenerator(
input_size, min_scene_patches, max_scene_patches_ratio, min_aspect
)
else:
self.scene_mask_generator = None
if mask_head:
self.head_mask_generator = HeadMaskGenerator(
input_size, min_head_patches, max_head_patches_ratio, min_aspect
)
else:
self.head_mask_generator = None
self.no_mask = not (mask_scene or mask_head)
self.mask_head = mask_head and not mask_scene
self.mask_scene = mask_scene and not mask_head
self.scene_prob = mask_prob
self.head_prob = head_prob
def __call__(
self,
x_min,
y_min,
x_max,
y_max,
head_mask,
):
mask_scene = random.random() < self.scene_prob
mask_head = random.random() < self.head_prob
no_mask = (
self.no_mask
or (self.mask_head and not mask_head)
or (self.mask_scene and not mask_scene)
or not (mask_scene or mask_head)
)
if no_mask:
return torch.zeros(size=self.input_size, dtype=bool)
if self.mask_scene:
return self.scene_mask_generator(head_mask)
if self.mask_head:
return self.head_mask_generator(x_min, y_min, x_max, y_max)
if mask_head and mask_scene:
return torch.logical_or(
self.scene_mask_generator(head_mask),
self.head_mask_generator(x_min, y_min, x_max, y_max),
)
elif mask_head:
return self.head_mask_generator(x_min, y_min, x_max, y_max)
return self.scene_mask_generator(head_mask)