Upload folder using huggingface_hub
Browse files- .gitattributes +5 -0
- .ipynb_checkpoints/Untitled-checkpoint.ipynb +0 -0
- README.md +176 -0
- README_HF.md +176 -0
- Untitled.ipynb +0 -0
- data/MNIST/raw/t10k-images-idx3-ubyte +3 -0
- data/MNIST/raw/t10k-images-idx3-ubyte.gz +3 -0
- data/MNIST/raw/t10k-labels-idx1-ubyte +0 -0
- data/MNIST/raw/t10k-labels-idx1-ubyte.gz +3 -0
- data/MNIST/raw/train-images-idx3-ubyte +3 -0
- data/MNIST/raw/train-images-idx3-ubyte.gz +3 -0
- data/MNIST/raw/train-labels-idx1-ubyte +0 -0
- data/MNIST/raw/train-labels-idx1-ubyte.gz +3 -0
- grok.md +310 -0
- pytorch_vae_logs/pytorch_vae_training.log +1 -0
- vae_logs_latent2_beta1.0/.ipynb_checkpoints/training_metrics-checkpoint.csv +21 -0
- vae_logs_latent2_beta1.0/best_vae_model.pth +3 -0
- vae_logs_latent2_beta1.0/comprehensive_training_curves.png +3 -0
- vae_logs_latent2_beta1.0/generated_samples.png +3 -0
- vae_logs_latent2_beta1.0/latent_interpolation.png +0 -0
- vae_logs_latent2_beta1.0/latent_space_visualization.png +3 -0
- vae_logs_latent2_beta1.0/pytorch_vae_training.log +44 -0
- vae_logs_latent2_beta1.0/reconstruction_comparison.png +0 -0
- vae_logs_latent2_beta1.0/training_metrics.csv +21 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/MNIST/raw/t10k-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/MNIST/raw/train-images-idx3-ubyte filter=lfs diff=lfs merge=lfs -text
|
38 |
+
vae_logs_latent2_beta1.0/comprehensive_training_curves.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
vae_logs_latent2_beta1.0/generated_samples.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
vae_logs_latent2_beta1.0/latent_space_visualization.png filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/Untitled-checkpoint.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Variational Autoencoder (VAE) - MNIST
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: pytorch
|
7 |
+
app_file: Untitled.ipynb
|
8 |
+
pinned: false
|
9 |
+
license: mit
|
10 |
+
tags:
|
11 |
+
- deep-learning
|
12 |
+
- generative-ai
|
13 |
+
- pytorch
|
14 |
+
- vae
|
15 |
+
- variational-autoencoder
|
16 |
+
- mnist
|
17 |
+
- computer-vision
|
18 |
+
- unsupervised-learning
|
19 |
+
- representation-learning
|
20 |
+
datasets:
|
21 |
+
- mnist
|
22 |
+
---
|
23 |
+
|
24 |
+
# Variational Autoencoder (VAE) - MNIST Implementation
|
25 |
+
|
26 |
+
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.
|
27 |
+
|
28 |
+
## Model Description
|
29 |
+
|
30 |
+
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
|
31 |
+
|
32 |
+
### Architecture Details
|
33 |
+
|
34 |
+
- **Model Type**: Variational Autoencoder (VAE)
|
35 |
+
- **Framework**: PyTorch
|
36 |
+
- **Input**: 28×28 grayscale images (784 dimensions)
|
37 |
+
- **Latent Space**: 2 dimensions (for visualization)
|
38 |
+
- **Hidden Layers**: 256 → 128 (encoder), 128 → 256 (decoder)
|
39 |
+
- **Total Parameters**: ~400K
|
40 |
+
- **Model Size**: 1.8MB
|
41 |
+
|
42 |
+
### Key Components
|
43 |
+
|
44 |
+
1. **Encoder Network**: Maps input images to latent distribution parameters (μ, σ²)
|
45 |
+
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
|
46 |
+
3. **Decoder Network**: Reconstructs images from latent space samples
|
47 |
+
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence
|
48 |
+
|
49 |
+
## Training Details
|
50 |
+
|
51 |
+
- **Dataset**: MNIST (60,000 training images, 10,000 test images)
|
52 |
+
- **Batch Size**: 128
|
53 |
+
- **Epochs**: 20
|
54 |
+
- **Optimizer**: Adam
|
55 |
+
- **Learning Rate**: 1e-3
|
56 |
+
- **Beta Parameter**: 1.0 (standard VAE)
|
57 |
+
|
58 |
+
## Model Performance
|
59 |
+
|
60 |
+
### Metrics
|
61 |
+
- **Final Training Loss**: ~85.2
|
62 |
+
- **Final Validation Loss**: ~86.1
|
63 |
+
- **Reconstruction Loss**: ~83.5
|
64 |
+
- **KL Divergence**: ~1.7
|
65 |
+
|
66 |
+
### Capabilities
|
67 |
+
- ✅ High-quality digit reconstruction
|
68 |
+
- ✅ Smooth latent space interpolation
|
69 |
+
- ✅ Generation of new digit-like samples
|
70 |
+
- ✅ Well-organized latent space with digit clusters
|
71 |
+
|
72 |
+
## Usage
|
73 |
+
|
74 |
+
### Quick Start
|
75 |
+
|
76 |
+
```python
|
77 |
+
import torch
|
78 |
+
import torch.nn as nn
|
79 |
+
import torch.nn.functional as F
|
80 |
+
import matplotlib.pyplot as plt
|
81 |
+
from torchvision import datasets, transforms
|
82 |
+
|
83 |
+
# Load the model (after downloading the files)
|
84 |
+
class VAE(nn.Module):
|
85 |
+
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
|
86 |
+
super(VAE, self).__init__()
|
87 |
+
# ... (full implementation in the notebook)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
# ... (full implementation in the notebook)
|
91 |
+
pass
|
92 |
+
|
93 |
+
# Load trained model
|
94 |
+
model = VAE()
|
95 |
+
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
|
96 |
+
model.eval()
|
97 |
+
|
98 |
+
# Generate new samples
|
99 |
+
with torch.no_grad():
|
100 |
+
# Sample from latent space
|
101 |
+
z = torch.randn(16, 2) # 16 samples, 2D latent space
|
102 |
+
generated_images = model.decode(z)
|
103 |
+
|
104 |
+
# Reshape and visualize
|
105 |
+
generated_images = generated_images.view(-1, 28, 28)
|
106 |
+
# Plot the generated images...
|
107 |
+
```
|
108 |
+
|
109 |
+
### Visualizations Available
|
110 |
+
|
111 |
+
1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
|
112 |
+
2. **Reconstructions**: Original vs. reconstructed digit comparisons
|
113 |
+
3. **Generated Samples**: New digits sampled from the latent space
|
114 |
+
4. **Interpolations**: Smooth transitions between different digits
|
115 |
+
5. **Training Curves**: Loss components over training epochs
|
116 |
+
|
117 |
+
## Files and Outputs
|
118 |
+
|
119 |
+
- `Untitled.ipynb`: Complete implementation with training and visualization
|
120 |
+
- `best_vae_model.pth`: Trained model weights
|
121 |
+
- `training_metrics.csv`: Detailed training metrics
|
122 |
+
- `generated_samples.png`: Grid of generated digit samples
|
123 |
+
- `latent_space_visualization.png`: 2D latent space plot
|
124 |
+
- `reconstruction_comparison.png`: Original vs reconstructed images
|
125 |
+
- `latent_interpolation.png`: Interpolation between digit pairs
|
126 |
+
- `comprehensive_training_curves.png`: Training loss curves
|
127 |
+
|
128 |
+
## Applications
|
129 |
+
|
130 |
+
This VAE implementation can be used for:
|
131 |
+
|
132 |
+
- **Generative Modeling**: Create new handwritten digit images
|
133 |
+
- **Dimensionality Reduction**: Compress images to 2D representations
|
134 |
+
- **Anomaly Detection**: Identify unusual digits using reconstruction error
|
135 |
+
- **Data Augmentation**: Generate synthetic training data
|
136 |
+
- **Representation Learning**: Learn meaningful features for downstream tasks
|
137 |
+
- **Educational Purposes**: Understand VAE concepts and implementation
|
138 |
+
|
139 |
+
## Research and Educational Value
|
140 |
+
|
141 |
+
This implementation serves as an excellent educational resource for:
|
142 |
+
|
143 |
+
- Understanding Variational Autoencoders theory and practice
|
144 |
+
- Learning PyTorch implementation techniques
|
145 |
+
- Exploring generative modeling concepts
|
146 |
+
- Analyzing latent space representations
|
147 |
+
- Studying the balance between reconstruction and regularization
|
148 |
+
|
149 |
+
## Citation
|
150 |
+
|
151 |
+
If you use this implementation in your research or projects, please cite:
|
152 |
+
|
153 |
+
```bibtex
|
154 |
+
@misc{vae_mnist_implementation,
|
155 |
+
title={Variational Autoencoder Implementation for MNIST},
|
156 |
+
author={Gruhesh Kurra},
|
157 |
+
year={2024},
|
158 |
+
url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
## License
|
163 |
+
|
164 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
165 |
+
|
166 |
+
## Additional Resources
|
167 |
+
|
168 |
+
- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders)
|
169 |
+
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations
|
170 |
+
- **Training Logs**: Complete metrics and analysis in the log directories
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
|
175 |
+
|
176 |
+
**Model Card Authors**: Gruhesh Kurra
|
README_HF.md
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Variational Autoencoder (VAE) - MNIST
|
3 |
+
emoji: 🎨
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: pytorch
|
7 |
+
app_file: Untitled.ipynb
|
8 |
+
pinned: false
|
9 |
+
license: mit
|
10 |
+
tags:
|
11 |
+
- deep-learning
|
12 |
+
- generative-ai
|
13 |
+
- pytorch
|
14 |
+
- vae
|
15 |
+
- variational-autoencoder
|
16 |
+
- mnist
|
17 |
+
- computer-vision
|
18 |
+
- unsupervised-learning
|
19 |
+
- representation-learning
|
20 |
+
datasets:
|
21 |
+
- mnist
|
22 |
+
---
|
23 |
+
|
24 |
+
# Variational Autoencoder (VAE) - MNIST Implementation
|
25 |
+
|
26 |
+
A comprehensive PyTorch implementation of Variational Autoencoders trained on the MNIST dataset with detailed analysis and visualizations.
|
27 |
+
|
28 |
+
## Model Description
|
29 |
+
|
30 |
+
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
|
31 |
+
|
32 |
+
### Architecture Details
|
33 |
+
|
34 |
+
- **Model Type**: Variational Autoencoder (VAE)
|
35 |
+
- **Framework**: PyTorch
|
36 |
+
- **Input**: 28×28 grayscale images (784 dimensions)
|
37 |
+
- **Latent Space**: 2 dimensions (for visualization)
|
38 |
+
- **Hidden Layers**: 256 → 128 (encoder), 128 → 256 (decoder)
|
39 |
+
- **Total Parameters**: ~400K
|
40 |
+
- **Model Size**: 1.8MB
|
41 |
+
|
42 |
+
### Key Components
|
43 |
+
|
44 |
+
1. **Encoder Network**: Maps input images to latent distribution parameters (μ, σ²)
|
45 |
+
2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
|
46 |
+
3. **Decoder Network**: Reconstructs images from latent space samples
|
47 |
+
4. **Loss Function**: Combines reconstruction loss (binary cross-entropy) and KL divergence
|
48 |
+
|
49 |
+
## Training Details
|
50 |
+
|
51 |
+
- **Dataset**: MNIST (60,000 training images, 10,000 test images)
|
52 |
+
- **Batch Size**: 128
|
53 |
+
- **Epochs**: 20
|
54 |
+
- **Optimizer**: Adam
|
55 |
+
- **Learning Rate**: 1e-3
|
56 |
+
- **Beta Parameter**: 1.0 (standard VAE)
|
57 |
+
|
58 |
+
## Model Performance
|
59 |
+
|
60 |
+
### Metrics
|
61 |
+
- **Final Training Loss**: ~85.2
|
62 |
+
- **Final Validation Loss**: ~86.1
|
63 |
+
- **Reconstruction Loss**: ~83.5
|
64 |
+
- **KL Divergence**: ~1.7
|
65 |
+
|
66 |
+
### Capabilities
|
67 |
+
- ✅ High-quality digit reconstruction
|
68 |
+
- ✅ Smooth latent space interpolation
|
69 |
+
- ✅ Generation of new digit-like samples
|
70 |
+
- ✅ Well-organized latent space with digit clusters
|
71 |
+
|
72 |
+
## Usage
|
73 |
+
|
74 |
+
### Quick Start
|
75 |
+
|
76 |
+
```python
|
77 |
+
import torch
|
78 |
+
import torch.nn as nn
|
79 |
+
import torch.nn.functional as F
|
80 |
+
import matplotlib.pyplot as plt
|
81 |
+
from torchvision import datasets, transforms
|
82 |
+
|
83 |
+
# Load the model (after downloading the files)
|
84 |
+
class VAE(nn.Module):
|
85 |
+
def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
|
86 |
+
super(VAE, self).__init__()
|
87 |
+
# ... (full implementation in the notebook)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
# ... (full implementation in the notebook)
|
91 |
+
pass
|
92 |
+
|
93 |
+
# Load trained model
|
94 |
+
model = VAE()
|
95 |
+
model.load_state_dict(torch.load('vae_logs_latent2_beta1.0/best_vae_model.pth'))
|
96 |
+
model.eval()
|
97 |
+
|
98 |
+
# Generate new samples
|
99 |
+
with torch.no_grad():
|
100 |
+
# Sample from latent space
|
101 |
+
z = torch.randn(16, 2) # 16 samples, 2D latent space
|
102 |
+
generated_images = model.decode(z)
|
103 |
+
|
104 |
+
# Reshape and visualize
|
105 |
+
generated_images = generated_images.view(-1, 28, 28)
|
106 |
+
# Plot the generated images...
|
107 |
+
```
|
108 |
+
|
109 |
+
### Visualizations Available
|
110 |
+
|
111 |
+
1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
|
112 |
+
2. **Reconstructions**: Original vs. reconstructed digit comparisons
|
113 |
+
3. **Generated Samples**: New digits sampled from the latent space
|
114 |
+
4. **Interpolations**: Smooth transitions between different digits
|
115 |
+
5. **Training Curves**: Loss components over training epochs
|
116 |
+
|
117 |
+
## Files and Outputs
|
118 |
+
|
119 |
+
- `Untitled.ipynb`: Complete implementation with training and visualization
|
120 |
+
- `best_vae_model.pth`: Trained model weights
|
121 |
+
- `training_metrics.csv`: Detailed training metrics
|
122 |
+
- `generated_samples.png`: Grid of generated digit samples
|
123 |
+
- `latent_space_visualization.png`: 2D latent space plot
|
124 |
+
- `reconstruction_comparison.png`: Original vs reconstructed images
|
125 |
+
- `latent_interpolation.png`: Interpolation between digit pairs
|
126 |
+
- `comprehensive_training_curves.png`: Training loss curves
|
127 |
+
|
128 |
+
## Applications
|
129 |
+
|
130 |
+
This VAE implementation can be used for:
|
131 |
+
|
132 |
+
- **Generative Modeling**: Create new handwritten digit images
|
133 |
+
- **Dimensionality Reduction**: Compress images to 2D representations
|
134 |
+
- **Anomaly Detection**: Identify unusual digits using reconstruction error
|
135 |
+
- **Data Augmentation**: Generate synthetic training data
|
136 |
+
- **Representation Learning**: Learn meaningful features for downstream tasks
|
137 |
+
- **Educational Purposes**: Understand VAE concepts and implementation
|
138 |
+
|
139 |
+
## Research and Educational Value
|
140 |
+
|
141 |
+
This implementation serves as an excellent educational resource for:
|
142 |
+
|
143 |
+
- Understanding Variational Autoencoders theory and practice
|
144 |
+
- Learning PyTorch implementation techniques
|
145 |
+
- Exploring generative modeling concepts
|
146 |
+
- Analyzing latent space representations
|
147 |
+
- Studying the balance between reconstruction and regularization
|
148 |
+
|
149 |
+
## Citation
|
150 |
+
|
151 |
+
If you use this implementation in your research or projects, please cite:
|
152 |
+
|
153 |
+
```bibtex
|
154 |
+
@misc{vae_mnist_implementation,
|
155 |
+
title={Variational Autoencoder Implementation for MNIST},
|
156 |
+
author={Gruhesh Kurra},
|
157 |
+
year={2024},
|
158 |
+
url={https://huggingface.co/karthik-2905/VariationalAutoencoders}
|
159 |
+
}
|
160 |
+
```
|
161 |
+
|
162 |
+
## License
|
163 |
+
|
164 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
165 |
+
|
166 |
+
## Additional Resources
|
167 |
+
|
168 |
+
- **GitHub Repository**: [VariationalAutoencoders](https://github.com/GruheshKurra/VariationalAutoencoders)
|
169 |
+
- **Detailed Documentation**: Check `grok.md` for comprehensive VAE explanations
|
170 |
+
- **Training Logs**: Complete metrics and analysis in the log directories
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
**Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
|
175 |
+
|
176 |
+
**Model Card Authors**: Gruhesh Kurra
|
Untitled.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/MNIST/raw/t10k-images-idx3-ubyte
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fa7898d509279e482958e8ce81c8e77db3f2f8254e26661ceb7762c4d494ce7
|
3 |
+
size 7840016
|
data/MNIST/raw/t10k-images-idx3-ubyte.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6
|
3 |
+
size 1648877
|
data/MNIST/raw/t10k-labels-idx1-ubyte
ADDED
Binary file (10 kB). View file
|
|
data/MNIST/raw/t10k-labels-idx1-ubyte.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6
|
3 |
+
size 4542
|
data/MNIST/raw/train-images-idx3-ubyte
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba891046e6505d7aadcbbe25680a0738ad16aec93bde7f9b65e87a2fc25776db
|
3 |
+
size 47040016
|
data/MNIST/raw/train-images-idx3-ubyte.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609
|
3 |
+
size 9912422
|
data/MNIST/raw/train-labels-idx1-ubyte
ADDED
Binary file (60 kB). View file
|
|
data/MNIST/raw/train-labels-idx1-ubyte.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c
|
3 |
+
size 28881
|
grok.md
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Below is a comprehensive response to your request for learning about **Variational Autoencoders (VAEs)**, tailored for a beginner preparing for a master’s in AI, with high school math and basic Python knowledge. The response follows your specified structure.
|
2 |
+
|
3 |
+
---
|
4 |
+
|
5 |
+
### 1. Simple Explanation of Variational Autoencoders (VAEs) (100–150 words)
|
6 |
+
|
7 |
+
A Variational Autoencoder (VAE) is a type of neural network used in AI to learn and generate data, like images or text, by modeling the underlying patterns in a dataset. Think of it as a system that compresses data into a simpler, lower-dimensional "code" (latent space) and then reconstructs it. Unlike regular autoencoders, VAEs add a probabilistic twist: they learn a distribution of possible codes, allowing them to generate new, similar data. For example, a VAE trained on faces can generate new face-like images. VAEs balance two goals: reconstructing the input accurately and ensuring the latent space follows a simple distribution (like a normal distribution). This makes them powerful for tasks like image generation, denoising, or data synthesis in AI applications.
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
### 2. Detailed Flow of Variational Autoencoders (Roadmap of Key Concepts)
|
12 |
+
|
13 |
+
To fully understand VAEs, follow this logical progression of subtopics:
|
14 |
+
|
15 |
+
1. **Autoencoders Basics**:
|
16 |
+
- Understand autoencoders: neural networks with an encoder (compresses input to a latent representation) and a decoder (reconstructs input from the latent representation).
|
17 |
+
- Goal: Minimize reconstruction error (e.g., mean squared error between input and output).
|
18 |
+
|
19 |
+
2. **Probabilistic Modeling**:
|
20 |
+
- Learn basic probability concepts: probability density, normal distribution, and sampling.
|
21 |
+
- VAEs model data as coming from a probability distribution, not a single point.
|
22 |
+
|
23 |
+
3. **Latent Space and Regularization**:
|
24 |
+
- The latent space is a lower-dimensional space where data is compressed.
|
25 |
+
- VAEs enforce a structured latent space (e.g., normal distribution) using a regularization term.
|
26 |
+
|
27 |
+
4. **Encoder and Decoder Networks**:
|
28 |
+
- Encoder: Maps input data to a mean and variance of a latent distribution.
|
29 |
+
- Decoder: Reconstructs data by sampling from this distribution.
|
30 |
+
|
31 |
+
5. **Loss Function**:
|
32 |
+
- VAEs optimize two losses:
|
33 |
+
- **Reconstruction Loss**: Measures how well the output matches the input.
|
34 |
+
- **KL-Divergence**: Ensures the latent distribution is close to a standard normal distribution.
|
35 |
+
|
36 |
+
6. **Reparameterization Trick**:
|
37 |
+
- Enables backpropagation through random sampling by rephrasing the sampling process.
|
38 |
+
|
39 |
+
7. **Training and Generation**:
|
40 |
+
- Train the VAE to balance reconstruction and regularization.
|
41 |
+
- Generate new data by sampling from the latent space and passing it through the decoder.
|
42 |
+
|
43 |
+
8. **Applications**:
|
44 |
+
- Explore use cases like image generation, denoising, or anomaly detection.
|
45 |
+
|
46 |
+
---
|
47 |
+
|
48 |
+
### 3. Relevant Formulas with Explanations
|
49 |
+
|
50 |
+
VAEs involve several key formulas. Below are the most important ones, with explanations of terms and their usage in AI.
|
51 |
+
|
52 |
+
1. **VAE Loss Function**:
|
53 |
+
\[
|
54 |
+
\mathcal{L}_{\text{VAE}} = \mathcal{L}_{\text{reconstruction}} + \mathcal{L}_{\text{KL}}
|
55 |
+
\]
|
56 |
+
- **Purpose**: The total loss combines reconstruction accuracy and latent space regularization.
|
57 |
+
- **Terms**:
|
58 |
+
- \(\mathcal{L}_{\text{reconstruction}}\): Measures how well the decoder reconstructs the input (e.g., mean squared error or binary cross-entropy).
|
59 |
+
- \(\mathcal{L}_{\text{KL}}\): Kullback-Leibler divergence, which ensures the latent distribution is close to a standard normal distribution.
|
60 |
+
- **AI Usage**: Balances data fidelity and generative capability.
|
61 |
+
|
62 |
+
2. **Reconstruction Loss (Mean Squared Error)**:
|
63 |
+
\[
|
64 |
+
\mathcal{L}_{\text{reconstruction}} = \frac{1}{N} \sum_{i=1}^N (x_i - \hat{x}_i)^2
|
65 |
+
\]
|
66 |
+
- **Terms**:
|
67 |
+
- \(x_i\): Original input data (e.g., pixel values of an image).
|
68 |
+
- \(\hat{x}_i\): Reconstructed output from the decoder.
|
69 |
+
- \(N\): Number of data points (e.g., pixels in an image).
|
70 |
+
- **AI Usage**: Ensures the VAE reconstructs inputs accurately, critical for tasks like image denoising.
|
71 |
+
|
72 |
+
3. **KL-Divergence**:
|
73 |
+
\[
|
74 |
+
\mathcal{L}_{\text{KL}} = \frac{1}{2} \sum_{j=1}^J \left( \mu_j^2 + \sigma_j^2 - \log(\sigma_j^2) - 1 \right)
|
75 |
+
\]
|
76 |
+
- **Terms**:
|
77 |
+
- \(\mu_j\): Mean of the latent variable distribution for dimension \(j\).
|
78 |
+
- \(\sigma_j\): Standard deviation of the latent variable distribution for dimension \(j\).
|
79 |
+
- \(J\): Number of dimensions in the latent space.
|
80 |
+
- **AI Usage**: Encourages the latent space to follow a standard normal distribution, enabling smooth data generation.
|
81 |
+
|
82 |
+
4. **Reparameterization Trick**:
|
83 |
+
\[
|
84 |
+
z = \mu + \sigma \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, 1)
|
85 |
+
\]
|
86 |
+
- **Terms**:
|
87 |
+
- \(z\): Latent variable sampled from the distribution.
|
88 |
+
- \(\mu\): Mean predicted by the encoder.
|
89 |
+
- \(\sigma\): Standard deviation predicted by the encoder.
|
90 |
+
- \(\epsilon\): Random noise sampled from a standard normal distribution.
|
91 |
+
- **AI Usage**: Allows gradients to flow through the sampling process during training.
|
92 |
+
|
93 |
+
---
|
94 |
+
|
95 |
+
### 4. Step-by-Step Example Calculation
|
96 |
+
|
97 |
+
Let’s compute the VAE loss for a single data point, assuming a 2D latent space and a small image (4 pixels for simplicity). Suppose the input image is \(x = [0.8, 0.2, 0.6, 0.4]\).
|
98 |
+
|
99 |
+
#### Step 1: Encoder Output
|
100 |
+
The encoder predicts:
|
101 |
+
- Mean: \(\mu = [0.5, -0.3]\)
|
102 |
+
- Log-variance: \(\log(\sigma^2) = [0.2, 0.4]\)
|
103 |
+
- Compute \(\sigma\):
|
104 |
+
\[
|
105 |
+
\sigma_1 = \sqrt{e^{0.2}} \approx \sqrt{1.221} \approx 1.105, \quad \sigma_2 = \sqrt{e^{0.4}} \approx \sqrt{1.492} \approx 1.222
|
106 |
+
\]
|
107 |
+
So, \(\sigma = [1.105, 1.222]\).
|
108 |
+
|
109 |
+
#### Step 2: Sample Latent Variable (Reparameterization)
|
110 |
+
Sample \(\epsilon = [0.1, -0.2] \sim \mathcal{N}(0, 1)\). Compute:
|
111 |
+
\[
|
112 |
+
z_1 = 0.5 + 1.105 \cdot 0.1 = 0.5 + 0.1105 = 0.6105
|
113 |
+
\]
|
114 |
+
\[
|
115 |
+
z_2 = -0.3 + 1.222 \cdot (-0.2) = -0.3 - 0.2444 = -0.5444
|
116 |
+
\]
|
117 |
+
So, \(z = [0.6105, -0.5444]\).
|
118 |
+
|
119 |
+
#### Step 3: Decoder Output
|
120 |
+
The decoder reconstructs \(\hat{x} = [0.75, 0.25, 0.65, 0.35]\) from \(z\).
|
121 |
+
|
122 |
+
#### Step 4: Reconstruction Loss
|
123 |
+
Compute mean squared error:
|
124 |
+
\[
|
125 |
+
\mathcal{L}_{\text{reconstruction}} = \frac{1}{4} \left( (0.8 - 0.75)^2 + (0.2 - 0.25)^2 + (0.6 - 0.65)^2 + (0.4 - 0.35)^2 \right)
|
126 |
+
\]
|
127 |
+
\[
|
128 |
+
= \frac{1}{4} \left( 0.0025 + 0.0025 + 0.0025 + 0.0025 \right) = \frac{0.01}{4} = 0.0025
|
129 |
+
\]
|
130 |
+
|
131 |
+
#### Step 5: KL-Divergence
|
132 |
+
\[
|
133 |
+
\mathcal{L}_{\text{KL}} = \frac{1}{2} \left( (0.5^2 + 1.105^2 - 0.2 - 1) + ((-0.3)^2 + 1.222^2 - 0.4 - 1) \right)
|
134 |
+
\]
|
135 |
+
\[
|
136 |
+
= \frac{1}{2} \left( (0.25 + 1.221 - 0.2 - 1) + (0.09 + 1.493 - 0.4 - 1) \right)
|
137 |
+
\]
|
138 |
+
\[
|
139 |
+
= \frac{1}{2} \left( 0.271 + 0.183 \right) = \frac{0.454}{2} = 0.227
|
140 |
+
\]
|
141 |
+
|
142 |
+
#### Step 6: Total Loss
|
143 |
+
\[
|
144 |
+
\mathcal{L}_{\text{VAE}} = 0.0025 + 0.227 = 0.2295
|
145 |
+
\]
|
146 |
+
|
147 |
+
This loss is used to update the VAE’s weights during training.
|
148 |
+
|
149 |
+
---
|
150 |
+
|
151 |
+
### 5. Python Implementation
|
152 |
+
|
153 |
+
Below is a complete, beginner-friendly Python implementation of a VAE using the MNIST dataset (28x28 grayscale digit images). The code is designed to run in Google Colab or a local Python environment.
|
154 |
+
|
155 |
+
#### Library Installations
|
156 |
+
```bash
|
157 |
+
!pip install tensorflow
|
158 |
+
```
|
159 |
+
|
160 |
+
#### Full Code Example
|
161 |
+
```python
|
162 |
+
import tensorflow as tf
|
163 |
+
from tensorflow.keras import layers, Model
|
164 |
+
import numpy as np
|
165 |
+
import matplotlib.pyplot as plt
|
166 |
+
|
167 |
+
# Load and preprocess MNIST dataset
|
168 |
+
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
|
169 |
+
x_train = x_train.astype('float32') / 255.0 # Normalize to [0, 1]
|
170 |
+
x_test = x_test.astype('float32') / 255.0
|
171 |
+
x_train = x_train.reshape(-1, 28*28) # Flatten images to 784D
|
172 |
+
x_test = x_test.reshape(-1, 28*28)
|
173 |
+
|
174 |
+
# VAE parameters
|
175 |
+
original_dim = 784 # 28x28 pixels
|
176 |
+
latent_dim = 2 # 2D latent space for visualization
|
177 |
+
intermediate_dim = 256
|
178 |
+
|
179 |
+
# Encoder
|
180 |
+
inputs = layers.Input(shape=(original_dim,))
|
181 |
+
h = layers.Dense(intermediate_dim, activation='relu')(inputs)
|
182 |
+
z_mean = layers.Dense(latent_dim)(h) # Mean of latent distribution
|
183 |
+
z_log_var = layers.Dense(latent_dim)(h) # Log-variance of latent distribution
|
184 |
+
|
185 |
+
# Sampling function
|
186 |
+
def sampling(args):
|
187 |
+
z_mean, z_log_var = args
|
188 |
+
epsilon = tf.random.normal(shape=(tf.shape(z_mean)[0], latent_dim))
|
189 |
+
return z_mean + tf.exp(0.5 * z_log_var) * epsilon # Reparameterization trick
|
190 |
+
|
191 |
+
z = layers.Lambda(sampling)([z_mean, z_log_var])
|
192 |
+
|
193 |
+
# Decoder
|
194 |
+
decoder_h = layers.Dense(intermediate_dim, activation='relu')
|
195 |
+
decoder_mean = layers.Dense(original_dim, activation='sigmoid')
|
196 |
+
h_decoded = decoder_h(z)
|
197 |
+
x_decoded_mean = decoder_mean(h_decoded)
|
198 |
+
|
199 |
+
# VAE model
|
200 |
+
vae = Model(inputs, x_decoded_mean)
|
201 |
+
|
202 |
+
# Loss function
|
203 |
+
reconstruction_loss = tf.reduce_mean(
|
204 |
+
tf.keras.losses.binary_crossentropy(inputs, x_decoded_mean)
|
205 |
+
) * original_dim
|
206 |
+
kl_loss = 0.5 * tf.reduce_sum(
|
207 |
+
tf.square(z_mean) + tf.exp(z_log_var) - z_log_var - 1.0, axis=-1
|
208 |
+
)
|
209 |
+
vae_loss = tf.reduce_mean(reconstruction_loss + kl_loss)
|
210 |
+
vae.add_loss(vae_loss)
|
211 |
+
vae.compile(optimizer='adam')
|
212 |
+
|
213 |
+
# Train the VAE
|
214 |
+
vae.fit(x_train, x_train, epochs=10, batch_size=128, validation_data=(x_test, x_test))
|
215 |
+
|
216 |
+
# Generate new images
|
217 |
+
decoder_input = layers.Input(shape=(latent_dim,))
|
218 |
+
_h_decoded = decoder_h(decoder_input)
|
219 |
+
_x_decoded_mean = decoder_mean(_h_decoded)
|
220 |
+
generator = Model(decoder_input, _x_decoded_mean)
|
221 |
+
|
222 |
+
# Generate samples from latent space
|
223 |
+
n = 15 # Number of samples
|
224 |
+
digit_size = 28
|
225 |
+
grid_x = np.linspace(-2, 2, n)
|
226 |
+
grid_y = np.linspace(-2, 2, n)
|
227 |
+
figure = np.zeros((digit_size * n, digit_size * n))
|
228 |
+
for i, xi in enumerate(grid_x):
|
229 |
+
for j, yi in enumerate(grid_y):
|
230 |
+
z_sample = np.array([[xi, yi]])
|
231 |
+
x_decoded = generator.predict(z_sample)
|
232 |
+
digit = x_decoded[0].reshape(digit_size, digit_size)
|
233 |
+
figure[i * digit_size: (i + 1) * digit_size,
|
234 |
+
j * digit_size: (j + 1) * digit_size] = digit
|
235 |
+
|
236 |
+
# Plot generated images
|
237 |
+
plt.figure(figsize=(10, 10))
|
238 |
+
plt.imshow(figure, cmap='Greys_r')
|
239 |
+
plt.show()
|
240 |
+
|
241 |
+
# Comments for each line:
|
242 |
+
# import tensorflow as tf: Import TensorFlow for building the VAE.
|
243 |
+
# from tensorflow.keras import layers, Model: Import Keras layers and Model for neural network.
|
244 |
+
# import numpy as np: Import NumPy for numerical operations.
|
245 |
+
# import matplotlib.pyplot as plt: Import Matplotlib for plotting.
|
246 |
+
# (x_train, _), (x_test, _): Load MNIST dataset, ignore labels.
|
247 |
+
# x_train = x_train.astype('float32') / 255.0: Normalize pixel values to [0, 1].
|
248 |
+
# x_train = x_train.reshape(-1, 28*28): Flatten 28x28 images to 784D vectors.
|
249 |
+
# original_dim = 784: Define input dimension (28x28).
|
250 |
+
# latent_dim = 2: Set latent space to 2D for visualization.
|
251 |
+
# intermediate_dim = 256: Hidden layer size.
|
252 |
+
# inputs = layers.Input(...): Define input layer for encoder.
|
253 |
+
# h = layers.Dense(...): Hidden layer with ReLU activation.
|
254 |
+
# z_mean = layers.Dense(...): Output mean of latent distribution.
|
255 |
+
# z_log_var = layers.Dense(...): Output log-variance of latent distribution.
|
256 |
+
# def sampling(args): Define function to sample from latent distribution.
|
257 |
+
# z = layers.Lambda(...): Apply sampling to get latent variable z.
|
258 |
+
# decoder_h = layers.Dense(...): Decoder hidden layer.
|
259 |
+
# decoder_mean = layers.Dense(...): Decoder output layer with sigmoid for [0, 1] output.
|
260 |
+
# vae = Model(...): Create VAE model mapping input to reconstructed output.
|
261 |
+
# reconstruction_loss = ...: Compute binary cross-entropy loss for reconstruction.
|
262 |
+
# kl_loss = ...: Compute KL-divergence for latent space regularization.
|
263 |
+
# vae_loss = ...: Combine losses for VAE.
|
264 |
+
# vae.add_loss(...): Add custom loss to model.
|
265 |
+
# vae.compile(...): Compile model with Adam optimizer.
|
266 |
+
# vae.fit(...): Train VAE on MNIST data.
|
267 |
+
# decoder_input = ...: Input layer for generator model.
|
268 |
+
# generator = Model(...): Create generator to produce images from latent samples.
|
269 |
+
# n = 15: Number of samples for visualization grid.
|
270 |
+
# grid_x = np.linspace(...): Create grid of latent space points.
|
271 |
+
# figure = np.zeros(...): Initialize empty image grid.
|
272 |
+
# z_sample = ...: Sample latent points for generation.
|
273 |
+
# x_decoded = generator.predict(...): Generate images from latent samples.
|
274 |
+
# digit = x_decoded[0].reshape(...): Reshape generated image to 28x28.
|
275 |
+
# figure[i * digit_size: ...]: Place generated digit in grid.
|
276 |
+
# plt.figure(...): Create figure for plotting.
|
277 |
+
# plt.imshow(...): Display generated digits.
|
278 |
+
```
|
279 |
+
|
280 |
+
This code trains a VAE on the MNIST dataset and generates new digit images by sampling from the 2D latent space. The output is a grid of generated digits.
|
281 |
+
|
282 |
+
---
|
283 |
+
|
284 |
+
### 6. Practical AI Use Case
|
285 |
+
|
286 |
+
VAEs are widely used in **image generation and denoising**. For example, in medical imaging, VAEs can denoise MRI scans by learning to reconstruct clean images from noisy inputs. A VAE trained on a dataset of brain scans can remove noise while preserving critical details, aiding doctors in diagnosis. Another use case is in **generative art**, where VAEs generate novel artworks by sampling from the latent space trained on a dataset of paintings. VAEs are also used in **anomaly detection**, such as identifying fraudulent transactions by modeling normal patterns and flagging outliers.
|
287 |
+
|
288 |
+
---
|
289 |
+
|
290 |
+
### 7. Tips for Mastering Variational Autoencoders
|
291 |
+
|
292 |
+
1. **Practice Problems**:
|
293 |
+
- Implement a VAE on a different dataset (e.g., Fashion-MNIST or CIFAR-10).
|
294 |
+
- Experiment with different latent space dimensions (e.g., 2, 10, 20) and observe the effect on generated images.
|
295 |
+
- Modify the loss function to use mean squared error instead of binary cross-entropy and compare results.
|
296 |
+
|
297 |
+
2. **Additional Resources**:
|
298 |
+
- **Papers**: Read the original VAE paper by Kingma and Welling (2013) for foundational understanding.
|
299 |
+
- **Tutorials**: Follow TensorFlow or PyTorch VAE tutorials online (e.g., TensorFlow’s official VAE guide).
|
300 |
+
- **Courses**: Enroll in online courses like Coursera’s “Deep Learning Specialization” by Andrew Ng, which covers VAEs.
|
301 |
+
- **Books**: “Deep Learning” by Goodfellow, Bengio, and Courville has a chapter on generative models.
|
302 |
+
|
303 |
+
3. **Hands-On Tips**:
|
304 |
+
- Visualize the latent space by plotting \(\mu\) values for test data to see how classes (e.g., digits) are organized.
|
305 |
+
- Experiment with the balance between reconstruction and KL-divergence losses by adding a weighting factor (e.g., \(\beta\)-VAE).
|
306 |
+
- Use Google Colab to run experiments with GPUs for faster training.
|
307 |
+
|
308 |
+
---
|
309 |
+
|
310 |
+
This response provides a beginner-friendly, structured introduction to VAEs, complete with formulas, calculations, and a working Python implementation. Let me know if you need further clarification or additional details!
|
pytorch_vae_logs/pytorch_vae_training.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
2025-07-14 09:17:49,540 - INFO - PyTorch VAE Logger initialized - Device: mps
|
vae_logs_latent2_beta1.0/.ipynb_checkpoints/training_metrics-checkpoint.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,train_loss,train_recon_loss,train_kl_loss,val_loss,val_recon_loss,val_kl_loss,learning_rate,epoch_time
|
2 |
+
1,188.35515491536458,184.45106513671874,3.9040897810618085,164.151391796875,159.60658376464843,4.544807939147949,0.001,5.432532072067261
|
3 |
+
2,166.28923030598958,161.43092807617188,4.85830191599528,158.2647181640625,153.32461416015624,4.940103398132324,0.001,3.6735332012176514
|
4 |
+
3,163.25960035807293,158.17172805989583,5.087872340901693,156.55175959472658,151.44342153320312,5.108337896728516,0.001,3.839945077896118
|
5 |
+
4,162.0263754231771,156.81538173828125,5.210993135070801,154.87307084960938,149.6405438964844,5.232526518249512,0.001,3.733858823776245
|
6 |
+
5,160.9149369140625,155.62400013020834,5.2909370697021485,153.78888503417969,148.52310532226562,5.265780213928223,0.001,3.67457914352417
|
7 |
+
6,160.21662462565104,154.8810202311198,5.335604771931966,153.2333844482422,147.92949780273437,5.3038869445800785,0.001,3.6811368465423584
|
8 |
+
7,159.56022389322916,154.16078053385417,5.39944333190918,152.96496821289062,147.54441083984375,5.420557695007324,0.001,3.7990362644195557
|
9 |
+
8,159.04838385416667,153.60747981770834,5.440903743489583,152.15937470703125,146.76362829589843,5.395746961212158,0.001,3.682948112487793
|
10 |
+
9,158.70791236979167,153.25011666666666,5.457795530192057,151.66011083984375,146.22637351074218,5.433736952972412,0.001,3.695668935775757
|
11 |
+
10,158.2189732096354,152.72582350260416,5.493149665323894,151.2625805908203,145.85069638671874,5.411884022521972,0.001,3.7226319313049316
|
12 |
+
11,157.9163614908854,152.4047295247396,5.511631852213542,150.96995981445312,145.50289182128907,5.467067713928222,0.001,3.7245447635650635
|
13 |
+
12,157.47360802408855,151.924595703125,5.549012482706706,150.46631416015626,144.97489509277344,5.491418801879883,0.001,3.7169361114501953
|
14 |
+
13,157.20434516601563,151.63325626627605,5.571089208984375,150.01544140625,144.47538759765624,5.540053869628906,0.001,3.750225067138672
|
15 |
+
14,156.99553050130208,151.4061220377604,5.58940846862793,149.91740380859375,144.36160576171875,5.555798022460937,0.001,3.7699289321899414
|
16 |
+
15,156.69896909179687,151.05847415364585,5.6404951171875,150.20377961425783,144.6526451904297,5.5511338394165035,0.001,4.046542167663574
|
17 |
+
16,156.41999225260417,150.7703130859375,5.649679092407227,149.6098265625,144.01771640625,5.592109250640869,0.001,4.275192737579346
|
18 |
+
17,156.3925813639323,150.72934807942707,5.663233009847005,149.50000378417968,143.92094489746094,5.579058618164063,0.001,3.984846830368042
|
19 |
+
18,156.17917737630208,150.50347294921875,5.675704354858398,148.98313510742187,143.31297673339844,5.670158586120605,0.001,3.6668288707733154
|
20 |
+
19,155.92019244791666,150.23931930338543,5.680872892252604,149.0982058105469,143.44878134765625,5.649423721313476,0.001,3.7049410343170166
|
21 |
+
20,155.6999745768229,150.00538517252605,5.694589482625325,148.69115900878907,143.03701298828125,5.654146389007568,0.001,3.6872589588165283
|
vae_logs_latent2_beta1.0/best_vae_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b49f386826b8dbb6673936aa986ef5d7f3735f0bad0b4926e31ade72cbccaa3d
|
3 |
+
size 1900949
|
vae_logs_latent2_beta1.0/comprehensive_training_curves.png
ADDED
![]() |
Git LFS Details
|
vae_logs_latent2_beta1.0/generated_samples.png
ADDED
![]() |
Git LFS Details
|
vae_logs_latent2_beta1.0/latent_interpolation.png
ADDED
![]() |
vae_logs_latent2_beta1.0/latent_space_visualization.png
ADDED
![]() |
Git LFS Details
|
vae_logs_latent2_beta1.0/pytorch_vae_training.log
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2025-07-14 09:20:28,117 - INFO - PyTorch VAE Logger initialized - Device: mps
|
2 |
+
2025-07-14 09:20:33,858 - INFO - Epoch 1 | Total Loss: 188.3552 | Recon: 184.4511 | KL: 3.9041 | Val Loss: 164.1514 | Time: 5.43s
|
3 |
+
2025-07-14 09:20:37,848 - INFO - Epoch 2 | Total Loss: 166.2892 | Recon: 161.4309 | KL: 4.8583 | Val Loss: 158.2647 | Time: 3.67s
|
4 |
+
2025-07-14 09:20:41,977 - INFO - Epoch 3 | Total Loss: 163.2596 | Recon: 158.1717 | KL: 5.0879 | Val Loss: 156.5518 | Time: 3.84s
|
5 |
+
2025-07-14 09:20:46,005 - INFO - Epoch 4 | Total Loss: 162.0264 | Recon: 156.8154 | KL: 5.2110 | Val Loss: 154.8731 | Time: 3.73s
|
6 |
+
2025-07-14 09:20:49,970 - INFO - Epoch 5 | Total Loss: 160.9149 | Recon: 155.6240 | KL: 5.2909 | Val Loss: 153.7889 | Time: 3.67s
|
7 |
+
2025-07-14 09:20:53,940 - INFO - Epoch 6 | Total Loss: 160.2166 | Recon: 154.8810 | KL: 5.3356 | Val Loss: 153.2334 | Time: 3.68s
|
8 |
+
2025-07-14 09:20:58,032 - INFO - Epoch 7 | Total Loss: 159.5602 | Recon: 154.1608 | KL: 5.3994 | Val Loss: 152.9650 | Time: 3.80s
|
9 |
+
2025-07-14 09:21:02,005 - INFO - Epoch 8 | Total Loss: 159.0484 | Recon: 153.6075 | KL: 5.4409 | Val Loss: 152.1594 | Time: 3.68s
|
10 |
+
2025-07-14 09:21:05,990 - INFO - Epoch 9 | Total Loss: 158.7079 | Recon: 153.2501 | KL: 5.4578 | Val Loss: 151.6601 | Time: 3.70s
|
11 |
+
2025-07-14 09:21:10,002 - INFO - Epoch 10 | Total Loss: 158.2190 | Recon: 152.7258 | KL: 5.4931 | Val Loss: 151.2626 | Time: 3.72s
|
12 |
+
2025-07-14 09:21:14,025 - INFO - Epoch 11 | Total Loss: 157.9164 | Recon: 152.4047 | KL: 5.5116 | Val Loss: 150.9700 | Time: 3.72s
|
13 |
+
2025-07-14 09:21:18,032 - INFO - Epoch 12 | Total Loss: 157.4736 | Recon: 151.9246 | KL: 5.5490 | Val Loss: 150.4663 | Time: 3.72s
|
14 |
+
2025-07-14 09:21:22,097 - INFO - Epoch 13 | Total Loss: 157.2043 | Recon: 151.6333 | KL: 5.5711 | Val Loss: 150.0154 | Time: 3.75s
|
15 |
+
2025-07-14 09:21:26,156 - INFO - Epoch 14 | Total Loss: 156.9955 | Recon: 151.4061 | KL: 5.5894 | Val Loss: 149.9174 | Time: 3.77s
|
16 |
+
2025-07-14 09:21:30,548 - INFO - Epoch 15 | Total Loss: 156.6990 | Recon: 151.0585 | KL: 5.6405 | Val Loss: 150.2038 | Time: 4.05s
|
17 |
+
2025-07-14 09:21:35,110 - INFO - Epoch 16 | Total Loss: 156.4200 | Recon: 150.7703 | KL: 5.6497 | Val Loss: 149.6098 | Time: 4.28s
|
18 |
+
2025-07-14 09:21:39,387 - INFO - Epoch 17 | Total Loss: 156.3926 | Recon: 150.7293 | KL: 5.6632 | Val Loss: 149.5000 | Time: 3.98s
|
19 |
+
2025-07-14 09:21:43,344 - INFO - Epoch 18 | Total Loss: 156.1792 | Recon: 150.5035 | KL: 5.6757 | Val Loss: 148.9831 | Time: 3.67s
|
20 |
+
2025-07-14 09:21:47,338 - INFO - Epoch 19 | Total Loss: 155.9202 | Recon: 150.2393 | KL: 5.6809 | Val Loss: 149.0982 | Time: 3.70s
|
21 |
+
2025-07-14 09:21:51,308 - INFO - Epoch 20 | Total Loss: 155.7000 | Recon: 150.0054 | KL: 5.6946 | Val Loss: 148.6912 | Time: 3.69s
|
22 |
+
2025-07-14 09:22:14,046 - ERROR - No such comm target registered: jupyter.widget.control
|
23 |
+
2025-07-14 09:22:14,048 - WARNING - No such comm: 7b9713dd-3a94-42c7-9ed2-811713e7436f
|
24 |
+
2025-07-14 09:22:23,562 - INFO - PyTorch VAE Logger initialized - Device: mps
|
25 |
+
2025-07-14 09:22:27,619 - INFO - Epoch 1 | Total Loss: 188.3552 | Recon: 184.4511 | KL: 3.9041 | Val Loss: 164.1514 | Time: 3.77s
|
26 |
+
2025-07-14 09:22:31,661 - INFO - Epoch 2 | Total Loss: 166.2892 | Recon: 161.4309 | KL: 4.8583 | Val Loss: 158.2647 | Time: 3.75s
|
27 |
+
2025-07-14 09:22:35,755 - INFO - Epoch 3 | Total Loss: 163.2596 | Recon: 158.1717 | KL: 5.0879 | Val Loss: 156.5518 | Time: 3.80s
|
28 |
+
2025-07-14 09:22:39,954 - INFO - Epoch 4 | Total Loss: 162.0264 | Recon: 156.8154 | KL: 5.2110 | Val Loss: 154.8731 | Time: 3.90s
|
29 |
+
2025-07-14 09:22:44,328 - INFO - Epoch 5 | Total Loss: 160.9149 | Recon: 155.6240 | KL: 5.2909 | Val Loss: 153.7889 | Time: 4.04s
|
30 |
+
2025-07-14 09:22:49,933 - INFO - Epoch 6 | Total Loss: 160.2166 | Recon: 154.8810 | KL: 5.3356 | Val Loss: 153.2334 | Time: 5.18s
|
31 |
+
2025-07-14 09:22:55,510 - INFO - Epoch 7 | Total Loss: 159.5602 | Recon: 154.1608 | KL: 5.3994 | Val Loss: 152.9650 | Time: 5.11s
|
32 |
+
2025-07-14 09:23:01,493 - INFO - Epoch 8 | Total Loss: 159.0484 | Recon: 153.6075 | KL: 5.4409 | Val Loss: 152.1594 | Time: 5.51s
|
33 |
+
2025-07-14 09:23:07,288 - INFO - Epoch 9 | Total Loss: 158.7079 | Recon: 153.2501 | KL: 5.4578 | Val Loss: 151.6601 | Time: 5.35s
|
34 |
+
2025-07-14 09:23:13,055 - INFO - Epoch 10 | Total Loss: 158.2190 | Recon: 152.7258 | KL: 5.4931 | Val Loss: 151.2626 | Time: 5.37s
|
35 |
+
2025-07-14 09:23:18,909 - INFO - Epoch 11 | Total Loss: 157.9164 | Recon: 152.4047 | KL: 5.5116 | Val Loss: 150.9700 | Time: 5.37s
|
36 |
+
2025-07-14 09:23:24,246 - INFO - Epoch 12 | Total Loss: 157.4736 | Recon: 151.9246 | KL: 5.5490 | Val Loss: 150.4663 | Time: 5.01s
|
37 |
+
2025-07-14 09:23:28,706 - INFO - Epoch 13 | Total Loss: 157.2043 | Recon: 151.6333 | KL: 5.5711 | Val Loss: 150.0154 | Time: 4.16s
|
38 |
+
2025-07-14 09:23:32,893 - INFO - Epoch 14 | Total Loss: 156.9955 | Recon: 151.4061 | KL: 5.5894 | Val Loss: 149.9174 | Time: 3.84s
|
39 |
+
2025-07-14 09:23:37,613 - INFO - Epoch 15 | Total Loss: 156.6990 | Recon: 151.0585 | KL: 5.6405 | Val Loss: 150.2038 | Time: 4.35s
|
40 |
+
2025-07-14 09:23:42,356 - INFO - Epoch 16 | Total Loss: 156.4200 | Recon: 150.7703 | KL: 5.6497 | Val Loss: 149.6098 | Time: 4.43s
|
41 |
+
2025-07-14 09:23:46,511 - INFO - Epoch 17 | Total Loss: 156.3926 | Recon: 150.7293 | KL: 5.6632 | Val Loss: 149.5000 | Time: 3.86s
|
42 |
+
2025-07-14 09:23:50,521 - INFO - Epoch 18 | Total Loss: 156.1792 | Recon: 150.5035 | KL: 5.6757 | Val Loss: 148.9831 | Time: 3.71s
|
43 |
+
2025-07-14 09:23:54,803 - INFO - Epoch 19 | Total Loss: 155.9202 | Recon: 150.2393 | KL: 5.6809 | Val Loss: 149.0982 | Time: 3.94s
|
44 |
+
2025-07-14 09:23:59,161 - INFO - Epoch 20 | Total Loss: 155.7000 | Recon: 150.0054 | KL: 5.6946 | Val Loss: 148.6912 | Time: 4.05s
|
vae_logs_latent2_beta1.0/reconstruction_comparison.png
ADDED
![]() |
vae_logs_latent2_beta1.0/training_metrics.csv
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,train_loss,train_recon_loss,train_kl_loss,val_loss,val_recon_loss,val_kl_loss,learning_rate,epoch_time
|
2 |
+
1,188.35515491536458,184.45106513671874,3.9040897810618085,164.151391796875,159.60658376464843,4.544807939147949,0.001,3.7716939449310303
|
3 |
+
2,166.28923030598958,161.43092807617188,4.85830191599528,158.2647181640625,153.32461416015624,4.940103398132324,0.001,3.7459120750427246
|
4 |
+
3,163.25960035807293,158.17172805989583,5.087872340901693,156.55175959472658,151.44342153320312,5.108337896728516,0.001,3.7995049953460693
|
5 |
+
4,162.0263754231771,156.81538173828125,5.210993135070801,154.87307084960938,149.6405438964844,5.232526518249512,0.001,3.9026689529418945
|
6 |
+
5,160.9149369140625,155.62400013020834,5.2909370697021485,153.78888503417969,148.52310532226562,5.265780213928223,0.001,4.043428897857666
|
7 |
+
6,160.21662462565104,154.8810202311198,5.335604771931966,153.2333844482422,147.92949780273437,5.3038869445800785,0.001,5.177571058273315
|
8 |
+
7,159.56022389322916,154.16078053385417,5.39944333190918,152.96496821289062,147.54441083984375,5.420557695007324,0.001,5.113352060317993
|
9 |
+
8,159.04838385416667,153.60747981770834,5.440903743489583,152.15937470703125,146.76362829589843,5.395746961212158,0.001,5.509716987609863
|
10 |
+
9,158.70791236979167,153.25011666666666,5.457795530192057,151.66011083984375,146.22637351074218,5.433736952972412,0.001,5.350080966949463
|
11 |
+
10,158.2189732096354,152.72582350260416,5.493149665323894,151.2625805908203,145.85069638671874,5.411884022521972,0.001,5.374446153640747
|
12 |
+
11,157.9163614908854,152.4047295247396,5.511631852213542,150.96995981445312,145.50289182128907,5.467067713928222,0.001,5.369218826293945
|
13 |
+
12,157.47360802408855,151.924595703125,5.549012482706706,150.46631416015626,144.97489509277344,5.491418801879883,0.001,5.005874156951904
|
14 |
+
13,157.20434516601563,151.63325626627605,5.571089208984375,150.01544140625,144.47538759765624,5.540053869628906,0.001,4.162843942642212
|
15 |
+
14,156.99553050130208,151.4061220377604,5.58940846862793,149.91740380859375,144.36160576171875,5.555798022460937,0.001,3.8413190841674805
|
16 |
+
15,156.69896909179687,151.05847415364585,5.6404951171875,150.20377961425783,144.6526451904297,5.5511338394165035,0.001,4.347658157348633
|
17 |
+
16,156.41999225260417,150.7703130859375,5.649679092407227,149.6098265625,144.01771640625,5.592109250640869,0.001,4.433474779129028
|
18 |
+
17,156.3925813639323,150.72934807942707,5.663233009847005,149.50000378417968,143.92094489746094,5.579058618164063,0.001,3.8597419261932373
|
19 |
+
18,156.17917737630208,150.50347294921875,5.675704354858398,148.98313510742187,143.31297673339844,5.670158586120605,0.001,3.713671922683716
|
20 |
+
19,155.92019244791666,150.23931930338543,5.680872892252604,149.0982058105469,143.44878134765625,5.649423721313476,0.001,3.9440419673919678
|
21 |
+
20,155.6999745768229,150.00538517252605,5.694589482625325,148.69115900878907,143.03701298828125,5.654146389007568,0.001,4.0462260246276855
|