import torch import torchvision from torchvision.utils import save_image, make_grid import os import argparse from datetime import datetime from config import Config from model import SmoothDiffusionUNet from noise_scheduler import FrequencyAwareNoise from sample import frequency_aware_sample, progressive_frequency_sample, aggressive_frequency_sample def load_model(checkpoint_path, device): """Load model from checkpoint""" print(f"Loading model from: {checkpoint_path}") # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device) # Initialize model and noise scheduler if 'config' in checkpoint: config = checkpoint['config'] else: config = Config() # Fallback to default config model = SmoothDiffusionUNet(config).to(device) noise_scheduler = FrequencyAwareNoise(config) # Load model state if 'model_state_dict' in checkpoint: model.load_state_dict(checkpoint['model_state_dict']) epoch = checkpoint.get('epoch', 'unknown') loss = checkpoint.get('loss', 'unknown') print(f"Loaded model from epoch {epoch}, loss: {loss}") else: # Handle simple state dict (final model) model.load_state_dict(checkpoint) print("Loaded model state dict") return model, noise_scheduler, config def generate_samples(model, noise_scheduler, config, device, n_samples=16, save_path=None): """Generate samples using the frequency-aware approach""" print(f"Generating {n_samples} samples using frequency-aware sampling...") # Use the proper frequency-aware sampling function samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=n_samples) print(f"Final samples range: [{samples.min().item():.3f}, {samples.max().item():.3f}]") # Save samples if save_path: save_image(grid, save_path, normalize=False) print(f"Samples saved to: {save_path}") return samples, grid def compare_checkpoints(log_dir, device, n_samples=8): """Compare samples from different checkpoints""" print(f"Comparing checkpoints in: {log_dir}") # Find all checkpoint files checkpoint_files = [] for file in os.listdir(log_dir): if file.startswith('model_epoch_') and file.endswith('.pth'): epoch = int(file.split('_')[2].split('.')[0]) checkpoint_files.append((epoch, file)) # Sort by epoch checkpoint_files.sort() if not checkpoint_files: print("No checkpoint files found!") return print(f"Found {len(checkpoint_files)} checkpoints") # Generate samples for each checkpoint all_grids = [] epochs = [] for epoch, filename in checkpoint_files: print(f"\n--- Testing Epoch {epoch} ---") checkpoint_path = os.path.join(log_dir, filename) try: model, noise_scheduler, config = load_model(checkpoint_path, device) samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=n_samples) all_grids.append(grid) epochs.append(epoch) # Save individual epoch samples save_path = os.path.join(log_dir, f"test_samples_epoch_{epoch}.png") save_image(grid, save_path, normalize=False) except Exception as e: print(f"Error testing epoch {epoch}: {e}") continue # Create comparison grid if all_grids: print(f"Generated samples for {len(epochs)} epochs: {epochs}") print("Individual epoch samples saved in log directory") print("Note: Matplotlib comparison disabled due to NumPy compatibility issues") def test_single_checkpoint(checkpoint_path, device, n_samples=16, method='optimized'): """Test a single checkpoint with different sampling methods""" model, noise_scheduler, config = load_model(checkpoint_path, device) # Generate samples with chosen method timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if method == 'progressive': print("Using progressive frequency sampling...") samples, grid = progressive_frequency_sample(model, noise_scheduler, device, n_samples=n_samples) save_path = f"test_samples_progressive_{timestamp}.png" elif method == 'aggressive': print("Using aggressive frequency sampling...") samples, grid = aggressive_frequency_sample(model, noise_scheduler, device, n_samples=n_samples) save_path = f"test_samples_aggressive_{timestamp}.png" else: print("Using optimized frequency-aware sampling...") samples, grid = frequency_aware_sample(model, noise_scheduler, device, n_samples=n_samples) save_path = f"test_samples_optimized_{timestamp}.png" # Save the results save_image(grid, save_path, normalize=False) print(f"Samples saved to: {save_path}") return samples, grid def main(): parser = argparse.ArgumentParser(description='Test trained diffusion model') parser.add_argument('--checkpoint', type=str, help='Path to specific checkpoint file') parser.add_argument('--log_dir', type=str, help='Path to log directory (for comparing all checkpoints)') parser.add_argument('--n_samples', type=int, default=16, help='Number of samples to generate') parser.add_argument('--device', type=str, default='auto', help='Device to use (cuda/cpu/auto)') parser.add_argument('--method', type=str, default='optimized', choices=['optimized', 'progressive', 'aggressive'], help='Sampling method: optimized (adaptive), progressive (fewer steps), or aggressive (strong denoising)') args = parser.parse_args() # Setup device if args.device == 'auto': device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}") if args.checkpoint: # Test single checkpoint print("=== Testing Single Checkpoint ===") test_single_checkpoint(args.checkpoint, device, args.n_samples, args.method) elif args.log_dir: # Compare all checkpoints in log directory print("=== Comparing All Checkpoints ===") compare_checkpoints(args.log_dir, device, args.n_samples) else: # Interactive mode - find latest log directory log_dirs = [] if os.path.exists('./logs'): for item in os.listdir('./logs'): if os.path.isdir(os.path.join('./logs', item)): log_dirs.append(item) if log_dirs: latest_log = sorted(log_dirs)[-1] log_path = os.path.join('./logs', latest_log) print(f"Found latest log directory: {log_path}") print("=== Comparing All Checkpoints in Latest Run ===") compare_checkpoints(log_path, device, args.n_samples) else: print("No log directories found. Please specify --checkpoint or --log_dir") if __name__ == "__main__": main()