Variational Autoencoders on MNIST¶

Build a Conditional Variational Autoencoder (CVAE) in PyTorch that learns to generate hand-written digits from the MNIST dataset.

1. Import Libraries

Import PyTorch, torchvision for MNIST, and matplotlib for visualisation.

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

2. Hyperparameters

Set device, latent space dimensionality (20), number of classes (10 digits), batch size, and image dimensions.

In [ ]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 20
num_classes = 10
batch_size = 128
img_size = 28
img_channels = 1

3. Load Dataset

Download MNIST (60,000 training images of handwritten digits 0–9), apply a ToTensor transform, and wrap in a DataLoader with batch size 128 and shuffling enabled.

In [ ]:
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

4. One-Hot Label Encoding

Convert integer digit labels (0–9) into 10-dimensional one-hot vectors. These are concatenated with image data to condition the model on which digit to encode or generate.

In [ ]:
def one_hot(labels, num_classes=10):
    return F.one_hot(labels, num_classes).float()

5. Why Variational Autoencoders?

Simple autoencoders have a deterministic latent space — they compress input to a single point, which makes the space discontinuous and poorly suited for generation. Variational Autoencoders (VAEs) encode each input as a distribution (mean μ and variance σ²) rather than a single point. Sampling from this distribution during training forces the latent space to be smooth and continuous — meaning you can sample any point and decode it into a realistic image.

6. Conditional VAE Model

The CVAE has three components:

  • Encoder — takes a flattened image + one-hot label (784 + 10 = 794 dims), outputs μ and logvar of the latent distribution
  • Reparameterize — samples z = μ + ε × σ where ε ~ N(0,1), making sampling differentiable for backprop
  • Decoder — takes z + one-hot label, reconstructs the image via a linear layer + Sigmoid (pixel values ∈ [0,1])
In [ ]:
class CVAE(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes

        self.encoder = nn.Sequential(
            nn.Linear(28*28 + num_classes, 400),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)

        self.decoder_input = nn.Linear(latent_dim + num_classes, 400)
        self.decoder = nn.Sequential(
            nn.ReLU(),
            nn.Linear(400, 28*28),
            nn.Sigmoid()
        )

    def encode(self, x, labels):
        x = x.view(x.size(0), -1)
        x = torch.cat([x, labels], dim=1)
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, labels):
        z = torch.cat([z, labels], dim=1)
        h = self.decoder_input(z)
        return self.decoder(h).view(-1, 1, 28, 28)

    def forward(self, x, labels):
        mu, logvar = self.encode(x, labels)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z, labels)
        return recon, mu, logvar

7. Loss Function

The VAE loss has two terms:

  • BCE (Reconstruction Loss) — how well the decoded output matches the original image
  • KL Divergence — how close the learned latent distribution is to a standard Gaussian N(0,1). Acts as a regulariser that keeps the latent space smooth.
In [ ]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

8. Training

Instantiate the CVAE, use Adam optimizer (lr=1e-3), and train for 100 epochs. Each batch: convert labels to one-hot, forward pass, compute VAE loss, backprop, update weights. Print average loss every 10 epochs.

In [ ]:
model = CVAE(latent_dim, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 100
model.train()
for epoch in range(epochs):
    total_loss = 0
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels_onehot = one_hot(labels).to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(imgs, labels_onehot)
        loss = loss_function(recon, imgs, mu, logvar)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader.dataset)
    if epoch % 10 == 0 or epoch == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")

9. Generate Digits

Sample a random latent vector z ~ N(0,1), condition it on a target digit label, decode it, and display the generated image.

In [ ]:
def show_generated_digit(model, digit=4):
    model.eval()
    with torch.no_grad():
        z = torch.randn(1, latent_dim).to(device)
        labels = one_hot(torch.tensor([digit]), num_classes).to(device)
        generated = model.decode(z, labels).cpu()

        plt.imshow(generated[0].squeeze(), cmap='gray')
        plt.axis('off')
        plt.show()

show_generated_digit(model, digit=4)