Variational Autoencoder (VAE) - MNIST Implementation

A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.

Model Description

This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.

Architecture Details

  • Model Type: Variational Autoencoder (VAE)
  • Framework: PyTorch
  • Input: 28Γ—28 grayscale images (784 dimensions)
  • Latent Space: 2 dimensions (for visualization)
  • Hidden Layers: 256 β†’ 128 (encoder), 128 β†’ 256 (decoder)
  • Total Parameters: ~400K
  • Model Size: 1.8MB

Key Components

  1. Encoder Network: Maps input images to latent distribution parameters (ΞΌ, σ²)
  2. Reparameterization Trick: Enables differentiable sampling from the latent distribution
  3. Decoder Network: Reconstructs images from latent space samples
  4. Loss Function: Combines reconstruction loss (binary cross-entropy) and KL divergence

Training Details

  • Dataset: MNIST (60,000 training images, 10,000 test images)
  • Batch Size: 128
  • Epochs: 20
  • Optimizer: Adam
  • Learning Rate: 1e-3
  • Beta Parameter: 1.0 (standard VAE)

Model Performance

Metrics

  • Final Training Loss: ~85.2
  • Final Validation Loss: ~86.1
  • Reconstruction Loss: ~83.5
  • KL Divergence: ~1.7

Capabilities

  • βœ… High-quality digit reconstruction
  • βœ… Smooth latent space interpolation
  • βœ… Generation of new digit-like samples
  • βœ… Well-organized latent space with digit clusters

Usage

Quick Start

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

# Load the model (after downloading the files)
class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
        super(VAE, self).__init__()
        # ... (full implementation in the notebook)
    
    def forward(self, x):
        # ... (full implementation in the notebook)
        pass

# Load trained model
model = VAE()
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
model.eval()

# Generate new samples
with torch.no_grad():
    # Sample from latent space
    z = torch.randn(16, 2)  # 16 samples, 2D latent space
    generated_images = model.decode(z)
    
    # Reshape and visualize
    generated_images = generated_images.view(-1, 28, 28)
    # Plot the generated images...

Visualizations Available

  1. Latent Space Visualization: 2D scatter plot showing digit clusters
  2. Reconstructions: Original vs. reconstructed digit comparisons
  3. Generated Samples: New digits sampled from the latent space
  4. Interpolations: Smooth transitions between different digits
  5. Training Curves: Loss components over training epochs

Files and Outputs

  • Untitled.ipynb: Complete implementation with training and visualization
  • best_vae_model.pth: Trained model weights
  • training_metrics.csv: Detailed training metrics
  • generated_samples.png: Grid of generated digit samples
  • latent_space_visualization.png: 2D latent space plot
  • reconstruction_comparison.png: Original vs reconstructed images
  • latent_interpolation.png: Interpolation between digit pairs
  • comprehensive_training_curves.png: Training loss curves

Applications

This VAE implementation can be used for:

  • Generative Modeling: Create new handwritten digit images
  • Dimensionality Reduction: Compress images to 2D representations
  • Anomaly Detection: Identify unusual digits using reconstruction error
  • Data Augmentation: Generate synthetic training data
  • Representation Learning: Learn meaningful features for downstream tasks
  • Educational Purposes: Understand VAE concepts and implementation

Research and Educational Value

This implementation serves as an excellent educational resource for:

  • Understanding Variational Autoencoders theory and practice
  • Learning PyTorch implementation techniques
  • Exploring generative modeling concepts
  • Analyzing latent space representations
  • Studying the balance between reconstruction and regularization

Citation

If you use this implementation in your research or projects, please cite:

@misc{vae_mnist_implementation,
  title={Variational Autoencoder Implementation for MNIST},
  author={Gruhesh Kurra},
  year={2024},
  url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.

Additional Resources

  • GitHub Repository: VariationalAutoencoders
  • Detailed Documentation: Check grok.md for comprehensive VAE explanations
  • Training Logs: Complete metrics and analysis in the log directories

Tags: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning

Model Card Authors: Gruhesh Kurra

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train karthik-2905/VariationalAutoencoders