import argparse import torch import wandb from torch import nn, optim from torch.nn.functional import cosine_similarity from torch.optim import lr_scheduler from torch.utils.data import DataLoader from tqdm import tqdm from typing_extensions import Optional from src.dataset import RandomAugmentedDataset, get_byol_transforms from src.models import BYOL def get_data_loaders( batch_size: int, num_train_samples: int, num_val_samples: int, shape_params: Optional[dict] = None, num_workers: int = 0 ): augmentations = get_byol_transforms() train_dataset = RandomAugmentedDataset( augmentations, shape_params, num_samples=num_train_samples, train=True ) val_dataset = RandomAugmentedDataset( augmentations, shape_params, num_samples=num_val_samples, train=False ) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) return train_loader, val_loader def build_model(lr: float): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BYOL().to(device) optimizer = optim.Adam( list(model.online_network.parameters()) + list(model.online_predictor.parameters()), lr=lr ) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2) return model, optimizer, scheduler, device def train_epoch( model: nn.Module, optimizer: optim.Optimizer, train_loader: DataLoader, device: torch.device ) -> dict: model.train() running_train_loss = 0.0 total_cos_sim, total_l2_dist, total_feat_norm, total_grad_norm = 0.0, 0.0, 0.0, 0.0 num_train_batches = 0 for (view_1, view_2) in tqdm(train_loader, desc="Training"): view_1 = view_1.to(device) view_2 = view_2.to(device) loss = model.loss(view_1, view_2) optimizer.zero_grad() loss.backward() with torch.no_grad(): online_proj1, target_proj1 = model(view_1) online_proj2, target_proj2 = model(view_2) cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() feat_norm = torch.norm(online_proj1, dim=-1).mean().item() grad_norm = torch.norm( torch.cat([ p.grad.flatten() for p in model.online_network.parameters() if p.grad is not None ]) ).item() total_cos_sim += cos_sim total_l2_dist += l2_dist total_feat_norm += feat_norm total_grad_norm += grad_norm optimizer.step() model.soft_update_target_network() running_train_loss += loss.item() num_train_batches += 1 train_loss = running_train_loss / num_train_batches train_cos_sim = total_cos_sim / num_train_batches train_l2_dist = total_l2_dist / num_train_batches train_feat_norm = total_feat_norm / num_train_batches train_grad_norm = total_grad_norm / num_train_batches return { "loss": train_loss, "cos_sim": train_cos_sim, "l2_dist": train_l2_dist, "feat_norm": train_feat_norm, "grad_norm": train_grad_norm, } @torch.no_grad() def validate( model: nn.Module, val_loader: DataLoader, device: torch.device ) -> dict: model.eval() running_val_loss = 0.0 total_cos_sim, total_l2_dist, total_feat_norm = 0.0, 0.0, 0.0 num_val_batches = 0 for (view_1, view_2) in tqdm(val_loader, desc="Validation"): view_1 = view_1.to(device) view_2 = view_2.to(device) loss = model.loss(view_1, view_2) running_val_loss += loss.item() online_proj1, target_proj1 = model(view_1) online_proj2, target_proj2 = model(view_2) cos_sim = cosine_similarity(online_proj1, target_proj2).mean().item() l2_dist = torch.norm(online_proj1 - target_proj2, dim=-1).mean().item() feat_norm = torch.norm(online_proj1, dim=-1).mean().item() total_cos_sim += cos_sim total_l2_dist += l2_dist total_feat_norm += feat_norm num_val_batches += 1 val_loss = running_val_loss / num_val_batches val_cos_sim = total_cos_sim / num_val_batches val_l2_dist = total_l2_dist / num_val_batches val_feat_norm = total_feat_norm / num_val_batches return { "loss": val_loss, "cos_sim": val_cos_sim, "l2_dist": val_l2_dist, "feat_norm": val_feat_norm } def train( model: nn.Module, optimizer: optim.Optimizer, scheduler, device: torch.device, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int, early_stopping_patience: int = 3, save_path: str = "best_byol.pth" ): best_loss = float("inf") epochs_no_improve = 0 print("Start training...") for epoch in range(num_epochs): print(f"Epoch {epoch + 1}/{num_epochs}") train_metrics = train_epoch(model, optimizer, train_loader, device) val_metrics = validate(model, val_loader, device) wandb.log({ "epoch": epoch + 1, "train_loss": train_metrics["loss"], "train_cos_sim": train_metrics["cos_sim"], "train_l2_dist": train_metrics["l2_dist"], "train_feat_norm": train_metrics["feat_norm"], "train_grad_norm": train_metrics["grad_norm"], "val_loss": val_metrics["loss"], "val_cos_sim": val_metrics["cos_sim"], "val_l2_dist": val_metrics["l2_dist"], "val_feat_norm": val_metrics["feat_norm"], }) print( f"Train Loss: {train_metrics['loss']:.4f} | " f"CosSim: {train_metrics['cos_sim']:.4f} | " f"L2Dist: {train_metrics['l2_dist']:.4f}" ) print( f"Val Loss: {val_metrics['loss']:.4f} | " f"CosSim: {val_metrics['cos_sim']:.4f} | " f"L2Dist: {val_metrics['l2_dist']:.4f}" ) current_val_loss = val_metrics["loss"] if current_val_loss < best_loss or val_metrics['cos_sim'] >= 0.86: best_loss = current_val_loss encoder_state_dict = model.online_network.encoder.state_dict() torch.save(encoder_state_dict, save_path) epochs_no_improve = 0 else: epochs_no_improve += 1 scheduler.step(val_metrics["cos_sim"]) if epochs_no_improve >= early_stopping_patience: print(f"Early stopping on epoch {epoch + 1}") break def main(config: dict): wandb.init(project="contrastive_learning_byol", config=config) train_loader, val_loader = get_data_loaders( batch_size=config["batch_size"], num_train_samples=config["num_train_samples"], num_val_samples=config["num_val_samples"], shape_params=config["shape_params"] ) model, optimizer, scheduler, device = build_model( lr=config["lr"] ) train( model=model, optimizer=optimizer, scheduler=scheduler, device=device, train_loader=train_loader, val_loader=val_loader, num_epochs=config["num_epochs"], early_stopping_patience=config["early_stopping_patience"], save_path=config["save_path"] ) wandb.finish() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train BYOL model") parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--lr", type=float, default=5e-4) parser.add_argument("--num_epochs", type=int, default=15) parser.add_argument("--num_train_samples", type=int, default=100000) parser.add_argument("--num_val_samples", type=int, default=10000) parser.add_argument("--random_intensity", type=int, default=1) parser.add_argument("--early_stopping_patience", type=int, default=3) parser.add_argument("--save_path", type=str, default="best_byol.pth") args = parser.parse_args() config = { "batch_size": args.batch_size, "lr": args.lr, "num_epochs": args.num_epochs, "num_train_samples": args.num_train_samples, "num_val_samples": args.num_val_samples, "shape_params": { "random_intensity": bool(args.random_intensity) }, "early_stopping_patience": args.early_stopping_patience, "save_path": args.save_path } main(config)