Variational Autoencoder (VAE) - MNIST Implementation
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.
Model Description
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
Architecture Details
- Model Type: Variational Autoencoder (VAE)
- Framework: PyTorch
- Input: 28Γ28 grayscale images (784 dimensions)
- Latent Space: 2 dimensions (for visualization)
- Hidden Layers: 256 β 128 (encoder), 128 β 256 (decoder)
- Total Parameters: ~400K
- Model Size: 1.8MB
Key Components
- Encoder Network: Maps input images to latent distribution parameters (ΞΌ, ΟΒ²)
- Reparameterization Trick: Enables differentiable sampling from the latent distribution
- Decoder Network: Reconstructs images from latent space samples
- Loss Function: Combines reconstruction loss (binary cross-entropy) and KL divergence
Training Details
- Dataset: MNIST (60,000 training images, 10,000 test images)
- Batch Size: 128
- Epochs: 20
- Optimizer: Adam
- Learning Rate: 1e-3
- Beta Parameter: 1.0 (standard VAE)
Model Performance
Metrics
- Final Training Loss: ~85.2
- Final Validation Loss: ~86.1
- Reconstruction Loss: ~83.5
- KL Divergence: ~1.7
Capabilities
- β High-quality digit reconstruction
- β Smooth latent space interpolation
- β Generation of new digit-like samples
- β Well-organized latent space with digit clusters
Usage
Quick Start
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Load the model (after downloading the files)
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
super(VAE, self).__init__()
# ... (full implementation in the notebook)
def forward(self, x):
# ... (full implementation in the notebook)
pass
# Load trained model
model = VAE()
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
model.eval()
# Generate new samples
with torch.no_grad():
# Sample from latent space
z = torch.randn(16, 2) # 16 samples, 2D latent space
generated_images = model.decode(z)
# Reshape and visualize
generated_images = generated_images.view(-1, 28, 28)
# Plot the generated images...
Visualizations Available
- Latent Space Visualization: 2D scatter plot showing digit clusters
- Reconstructions: Original vs. reconstructed digit comparisons
- Generated Samples: New digits sampled from the latent space
- Interpolations: Smooth transitions between different digits
- Training Curves: Loss components over training epochs
Files and Outputs
Untitled.ipynb
: Complete implementation with training and visualizationbest_vae_model.pth
: Trained model weightstraining_metrics.csv
: Detailed training metricsgenerated_samples.png
: Grid of generated digit sampleslatent_space_visualization.png
: 2D latent space plotreconstruction_comparison.png
: Original vs reconstructed imageslatent_interpolation.png
: Interpolation between digit pairscomprehensive_training_curves.png
: Training loss curves
Applications
This VAE implementation can be used for:
- Generative Modeling: Create new handwritten digit images
- Dimensionality Reduction: Compress images to 2D representations
- Anomaly Detection: Identify unusual digits using reconstruction error
- Data Augmentation: Generate synthetic training data
- Representation Learning: Learn meaningful features for downstream tasks
- Educational Purposes: Understand VAE concepts and implementation
Research and Educational Value
This implementation serves as an excellent educational resource for:
- Understanding Variational Autoencoders theory and practice
- Learning PyTorch implementation techniques
- Exploring generative modeling concepts
- Analyzing latent space representations
- Studying the balance between reconstruction and regularization
Citation
If you use this implementation in your research or projects, please cite:
@misc{vae_mnist_implementation,
title={Variational Autoencoder Implementation for MNIST},
author={Gruhesh Kurra},
year={2024},
url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Additional Resources
- GitHub Repository: VariationalAutoencoders
- Detailed Documentation: Check
grok.md
for comprehensive VAE explanations - Training Logs: Complete metrics and analysis in the log directories
Tags: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
Model Card Authors: Gruhesh Kurra