Generative AI
Collection
6 items
โข
Updated
This is a Conditional GAN (cGAN) trained on the MNIST dataset to generate realistic 28x28 grayscale images of handwritten digits. The model leverages label information to guide image generation and was developed as part of the Generative AI course.
The training process was tracked using Weights and Biases. You can view the full training logs and metrics here.
You can download the model checkpoint directly from the hub using the huggingface_hub library:
from huggingface_hub import hf_hub_download
# Download the model checkpoint from the hub
checkpoint_path = hf_hub_download(repo_id="hussamalafandi/cGAN-MNIST", filename="generator.pth")
import torch
from c_gan import Generator
# Load the configuration
config = {
"latent_dim": 100,
"ngf": 64,
"nc": 1,
"num_classes": 10,
"embed_dim": 50
}
# Initialize the generator
generator = Generator(config)
# Load the downloaded checkpoint (or a local path)
generator.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
# Set the model to evaluation mode
generator.eval()
# Example: Generate an image
latent_vector = torch.randn(1, config["latent_dim"], 1, 1) # Batch size of 1
if torch.cuda.is_available():
latent_vector = latent_vector.cuda()
generator = generator.cuda()
generated_image = generator(latent_vector, torch.tensor([7])) # Example label: 7