Variational Autoencoders on MNIST¶
Build a Conditional Variational Autoencoder (CVAE) in PyTorch that learns to generate hand-written digits from the MNIST dataset.
1. Import Libraries
Import PyTorch, torchvision for MNIST, and matplotlib for visualisation.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
2. Hyperparameters
Set device, latent space dimensionality (20), number of classes (10 digits), batch size, and image dimensions.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 20
num_classes = 10
batch_size = 128
img_size = 28
img_channels = 1
3. Load Dataset
Download MNIST (60,000 training images of handwritten digits 0–9), apply a ToTensor transform, and wrap in a DataLoader with batch size 128 and shuffling enabled.
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
4. One-Hot Label Encoding
Convert integer digit labels (0–9) into 10-dimensional one-hot vectors. These are concatenated with image data to condition the model on which digit to encode or generate.
def one_hot(labels, num_classes=10):
return F.one_hot(labels, num_classes).float()
5. Why Variational Autoencoders?
Simple autoencoders have a deterministic latent space — they compress input to a single point, which makes the space discontinuous and poorly suited for generation. Variational Autoencoders (VAEs) encode each input as a distribution (mean
μand varianceσ²) rather than a single point. Sampling from this distribution during training forces the latent space to be smooth and continuous — meaning you can sample any point and decode it into a realistic image.
6. Conditional VAE Model
The CVAE has three components:
- Encoder — takes a flattened image + one-hot label (784 + 10 = 794 dims), outputs
μandlogvarof the latent distribution - Reparameterize — samples
z = μ + ε × σwhereε ~ N(0,1), making sampling differentiable for backprop - Decoder — takes
z+ one-hot label, reconstructs the image via a linear layer + Sigmoid (pixel values ∈ [0,1])
class CVAE(nn.Module):
def __init__(self, latent_dim, num_classes):
super(CVAE, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.encoder = nn.Sequential(
nn.Linear(28*28 + num_classes, 400),
nn.ReLU()
)
self.fc_mu = nn.Linear(400, latent_dim)
self.fc_logvar = nn.Linear(400, latent_dim)
self.decoder_input = nn.Linear(latent_dim + num_classes, 400)
self.decoder = nn.Sequential(
nn.ReLU(),
nn.Linear(400, 28*28),
nn.Sigmoid()
)
def encode(self, x, labels):
x = x.view(x.size(0), -1)
x = torch.cat([x, labels], dim=1)
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, labels):
z = torch.cat([z, labels], dim=1)
h = self.decoder_input(z)
return self.decoder(h).view(-1, 1, 28, 28)
def forward(self, x, labels):
mu, logvar = self.encode(x, labels)
z = self.reparameterize(mu, logvar)
recon = self.decode(z, labels)
return recon, mu, logvar
7. Loss Function
The VAE loss has two terms:
- BCE (Reconstruction Loss) — how well the decoded output matches the original image
- KL Divergence — how close the learned latent distribution is to a standard Gaussian
N(0,1). Acts as a regulariser that keeps the latent space smooth.
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
8. Training
Instantiate the CVAE, use Adam optimizer (lr=1e-3), and train for 100 epochs. Each batch: convert labels to one-hot, forward pass, compute VAE loss, backprop, update weights. Print average loss every 10 epochs.
model = CVAE(latent_dim, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 100
model.train()
for epoch in range(epochs):
total_loss = 0
for imgs, labels in train_loader:
imgs = imgs.to(device)
labels_onehot = one_hot(labels).to(device)
optimizer.zero_grad()
recon, mu, logvar = model(imgs, labels_onehot)
loss = loss_function(recon, imgs, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader.dataset)
if epoch % 10 == 0 or epoch == 0:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
9. Generate Digits
Sample a random latent vector z ~ N(0,1), condition it on a target digit label, decode it, and display the generated image.
def show_generated_digit(model, digit=4):
model.eval()
with torch.no_grad():
z = torch.randn(1, latent_dim).to(device)
labels = one_hot(torch.tensor([digit]), num_classes).to(device)
generated = model.decode(z, labels).cpu()
plt.imshow(generated[0].squeeze(), cmap='gray')
plt.axis('off')
plt.show()
show_generated_digit(model, digit=4)