|
--- |
|
tags: |
|
- cgan |
|
- conditional-gan |
|
- generative-adversarial-network |
|
- image-generation |
|
- deep-learning |
|
datasets: |
|
- MNIST |
|
license: mit |
|
--- |
|
|
|
# Conditional GAN Model Card |
|
|
|
## Model Description |
|
This is a Conditional GAN (cGAN) trained on the MNIST dataset to generate realistic 28x28 grayscale images of handwritten digits. The model leverages label information to guide image generation and was developed as part of the [Generative AI](https://github.com/hussamalafandi/Generative_AI) course. |
|
|
|
## Training Details |
|
- **Dataset**: MNIST |
|
- **Subset Size**: 60,000 images |
|
- **Image Size**: 28x28 |
|
- **Number of Channels**: 1 (grayscale) |
|
- **Latent Dimension**: 100 |
|
- **Generator Feature Map Size**: 64 |
|
- **Discriminator Feature Map Size**: 64 |
|
- **Batch Size**: 128 |
|
- **Epochs**: 50 |
|
- **Learning Rate**: 0.0002 |
|
- **Beta1**: 0.5 |
|
- **Weight Decay**: 0 |
|
- **Optimizer**: Adam |
|
- **Hardware**: CUDA-enabled GPU |
|
- **Logging**: Weights and Biases (wandb) |
|
|
|
### Weights and Biases Run |
|
|
|
The training process was tracked using [Weights and Biases](https://wandb.ai). You can view the full training logs and metrics [here](https://wandb.ai/hussam-alafandi/cGAN-MNIST/runs/w11n93e5?nw=nwuserhussamalafandi). |
|
|
|
## Usage |
|
|
|
### Downloading the Model from the Hub |
|
|
|
You can download the model checkpoint directly from the hub using the huggingface_hub library: |
|
```python |
|
from huggingface_hub import hf_hub_download |
|
|
|
# Download the model checkpoint from the hub |
|
checkpoint_path = hf_hub_download(repo_id="hussamalafandi/cGAN-MNIST", filename="generator.pth") |
|
``` |
|
|
|
### Loading the Model Locally |
|
|
|
```python |
|
import torch |
|
from c_gan import Generator |
|
|
|
# Load the configuration |
|
config = { |
|
"latent_dim": 100, |
|
"ngf": 64, |
|
"nc": 1, |
|
"num_classes": 10, |
|
"embed_dim": 50 |
|
} |
|
|
|
# Initialize the generator |
|
generator = Generator(config) |
|
|
|
# Load the downloaded checkpoint (or a local path) |
|
generator.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))) |
|
|
|
# Set the model to evaluation mode |
|
generator.eval() |
|
|
|
# Example: Generate an image |
|
latent_vector = torch.randn(1, config["latent_dim"], 1, 1) # Batch size of 1 |
|
|
|
if torch.cuda.is_available(): |
|
latent_vector = latent_vector.cuda() |
|
generator = generator.cuda() |
|
|
|
generated_image = generator(latent_vector, torch.tensor([7])) # Example label: 7 |
|
``` |
|
|
|
## Example Results |
|
|
|
 |
|
|
|
## Resources |
|
- **Course Repository**: [Generative AI Course](https://github.com/hussamalafandi/Generative_AI) |
|
- **WandB Run**: [cGAN-MNIST Run](https://wandb.ai/hussam-alafandi/cGAN-MNIST/runs/w11n93e5?nw=nwuserhussamalafandi) |
|
|