from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class VGGLikeEncode(nn.Module): def __init__( self, in_channels: int = 1, out_channels: int = 128, feature_dim: int = 32, apply_pooling: bool = False ): """ VGG-like encoder for grayscale images. :param in_channels: number of input channels :param out_channels: number of output channels :param feature_dim: number of channels in the intermediate layers :param apply_pooling: whether to apply global average pooling at the end """ super().__init__() self.apply_pooling = apply_pooling self.block1 = nn.Sequential( nn.Conv2d(in_channels, feature_dim, kernel_size=3, padding=1), nn.BatchNorm2d(feature_dim), nn.ReLU(inplace=True), nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2) ) self.block2 = nn.Sequential( nn.Conv2d(feature_dim, feature_dim * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(feature_dim * 2), nn.Conv2d(feature_dim * 2, feature_dim * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2) ) self.block3 = nn.Sequential( nn.Conv2d(feature_dim * 2, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.BatchNorm2d(out_channels), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=1) ) self.global_avg_pool = nn.AdaptiveAvgPool2d(1) self.blocks = [self.block1, self.block2, self.block3] def forward(self, x: Tensor) -> Tensor: x = self.block1(x) x = self.block2(x) x = self.block3(x) if self.apply_pooling: x = self.global_avg_pool(x).view(x.shape[0], -1) return x def get_conv_layer(self, block_num: int): if block_num >= len(self.blocks): return None return self.blocks[block_num][0] class CrossAttentionClassifier(nn.Module): def __init__( self, feature_dim: int = 32, num_heads: int = 4, linear_dim: int = 128, out_channels: int = 128, encoder: Optional[VGGLikeEncode] = None ): """ Cross-attention classifier for comparing two grayscale images. :param feature_dim: number of channels in the intermediate layers :param num_heads: number of attention heads :param linear_dim: number of units in the linear layer :param out_channels: number of output channels :param encoder: encoder to use """ super(CrossAttentionClassifier, self).__init__() if encoder: self.encoder = encoder else: self.encoder = VGGLikeEncode(in_channels=1, feature_dim=feature_dim, out_channels=out_channels) self.out_channels = out_channels self.seq_len = 8 * 8 self.pos_embedding = nn.Parameter(torch.randn(self.seq_len, 1, out_channels) * 0.01) self.cross_attention = nn.MultiheadAttention( embed_dim=out_channels, num_heads=num_heads, batch_first=False ) self.norm = nn.LayerNorm(out_channels) self.classifier = nn.Sequential( nn.Linear(out_channels, linear_dim), nn.ReLU(), nn.Linear(linear_dim, 1) ) def forward(self, img1: Tensor, img2: Tensor) -> Tuple[Tensor, Tensor]: feat1 = self.encoder(img1) feat2 = self.encoder(img2) B, C, H, W = feat1.shape seq_len = H * W feat1_flat = feat1.view(B, C, seq_len).permute(2, 0, 1) feat2_flat = feat2.view(B, C, seq_len).permute(2, 0, 1) feat1_flat = feat1_flat + self.pos_embedding feat2_flat = feat2_flat + self.pos_embedding feat1_flat = self.norm(feat1_flat) feat2_flat = self.norm(feat2_flat) attn_output, attn_weights = self.cross_attention( query=feat1_flat, key=feat2_flat, value=feat2_flat, need_weights=True, average_attn_weights=True ) pooled_features = attn_output.mean(dim=0) logits = self.classifier(pooled_features).squeeze(-1) return logits, attn_weights class NormalizedMSELoss(nn.Module): def __init__(self): """ Normalized MSE loss for BYOL training. """ super(NormalizedMSELoss, self).__init__() def forward(self, view1: Tensor, view2: Tensor) -> Tensor: v1 = F.normalize(view1, dim=-1) v2 = F.normalize(view2, dim=-1) return 2 - 2 * (v1 * v2).sum(dim=-1) class MLP(nn.Module): def __init__(self, input_dim: int, projection_dim: int = 128, hidden_dim: int = 512): """ MLP for BYOL training. :param input_dim: input dimension :param projection_dim: projection dimension :param hidden_dim: hidden dimension """ super(MLP, self).__init__() self.net = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, projection_dim) ) def forward(self, x: Tensor) -> Tensor: return self.net(x) class EncoderProjecter(nn.Module): def __init__(self, encoder: nn.Module, hidden_dim: int = 512, projection_out_dim: int = 128): """ Encoder followed by a projection MLP. :param encoder: encoder to use :param hidden_dim: hidden dimension :param projection_out_dim: projection output dimension """ super(EncoderProjecter, self).__init__() self.encoder = encoder self.projection = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) def forward(self, x: Tensor) -> Tensor: h = self.encoder(x) return self.projection(h) # https://arxiv.org/pdf/2006.07733 class BYOL(nn.Module): def __init__( self, hidden_dim: int = 512, projection_out_dim: int = 128, target_decay: float = 0.9975 ): """ BYOL model for self-supervised learning. :param hidden_dim: hidden dimension :param projection_out_dim: projection output dimension :param target_decay: target network decay rate """ super(BYOL, self).__init__() encoder = VGGLikeEncode(in_channels=1, out_channels=128, feature_dim=32, apply_pooling=True) self.online_network = EncoderProjecter(encoder) self.online_predictor = MLP(input_dim=128, projection_dim=projection_out_dim, hidden_dim=hidden_dim) self.target_network = EncoderProjecter(encoder) self.target_network.load_state_dict(self.online_network.state_dict()) self.target_network.eval() for param in self.target_network.parameters(): param.requires_grad = False self.target_decay = target_decay self.loss_function = NormalizedMSELoss() @torch.no_grad() def soft_update_target_network(self): for online_p, target_p in zip(self.online_network.parameters(), self.target_network.parameters()): target_p.data = target_p.data * self.target_decay + online_p.data * (1. - self.target_decay) def forward(self, view: Tensor) -> Tuple[Tensor, Tensor]: online_proj = self.online_network(view) target_proj = self.target_network(view) return online_proj, target_proj def loss(self, view1: Tensor, view2: Tensor) -> Tensor: online_proj1, target_proj1 = self(view1) online_proj2, target_proj2 = self(view2) online_prediction_1 = self.online_predictor(online_proj1) online_prediction_2 = self.online_predictor(online_proj2) loss1 = self.loss_function(online_prediction_1, target_proj2.detach()) loss2 = self.loss_function(online_prediction_2, target_proj1.detach()) return torch.mean(loss1 + loss2)