# Introduction to Generative Adversarial Networks (GANs)

In this notebook, we will introduce Generative Adversarial Networks (GANs) and implement a simple GAN using PyTorch. We will first discuss the basic concepts of GANs and training models on 2D data. Then, we will implement a GAN to generate handwritten digits from the MNIST dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

First let us define the device to work on:

In [None]:
print(f"GPU available using Cuda: {torch.cuda.is_available()}")
print(
    f"GPU available using MPS: {torch.backends.mps.is_available() and torch.backends.mps.is_built()}"
)
if torch.cuda.is_available():
    device = torch.device("cuda")
    gpu_available = True
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    gpu_available = True
else:
    device = torch.device("cpu")
    gpu_available = False


print(f"Device: {device}")
print(f"Number of GPUs: {torch.cuda.device_count()}")

# 1 The GAN Model

Generative Adversarials Networks (GANs) are a class of generative models that are used to generate new data samples. GANs consist of two neural networks, a generator and a discriminator, that are trained simultaneously. The generator generates new data samples, while the discriminator tries to distinguish between real and generated samples. The generator and discriminator are trained in an adversarial manner, where the generator tries to generate samples that are indistinguishable from real samples, and the discriminator tries to distinguish between real and generated samples.

## 1.1 The Generator

One easy way to thing about the generator is as a function that takes random noise as input and generates new data samples. The generator is a neural network that takes random noise as input and generates new data samples. The generator is trained to generate samples that are indistinguishable from real samples. 

Typically the generator is a feedforward neural network that takes random noise as input in $R^d$ and generates samples in $R^D$. The distribution of the output is the generated distribution. Typically, the input noise is sampled from a normal distribution and for instance, if the generator function is linear:


In [None]:
def input_noise(n=1000, noise_dim=1):
    return torch.randn(n, noise_dim)


def generator_function(x):
    return 2 * x - 3.5


input_data = input_noise()
output_data = generator_function(input_data)
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.hist(input_data.numpy(), bins=100, range=(-10, 10))
plt.title("Input data")
plt.subplot(1, 2, 2)
plt.hist(output_data.numpy(), bins=100, range=(-10, 10))
plt.title("Output data")
plt.show()

The generator can be implemented with a neural network to represent more complex functions.

## 1.2 The Discriminator

Let us assume that we have a target distribution $p_{\text{data}}(x)$ and a generator distribution $p_{\text{model}}(x)$. The discriminator is trained to distinguish between samples from the target distribution and the generator distribution. The discriminator is trained to differentiate between samples from the target distribution and the generator distribution. In other words, the discriminator outputs the probability that a sample is from the target distribution:

In [None]:
def target_density(x):
    return 0.5 * torch.exp(
        torch.distributions.Normal(-4, 1).log_prob(x)
    ) + 0.5 * torch.exp(torch.distributions.Normal(4, 1).log_prob(x))


def model_density(x):
    return torch.exp(torch.distributions.Normal(-3.5, 2).log_prob(x))

In [None]:
x = torch.linspace(-10, 10, 1000)
plt.plot(x, target_density(x), label="Target")
plt.plot(x, model_density(x), label="Model")
plt.legend()

The optimal discriminator is given by:
$$D^*(x) = \frac{p_{\text{data}}(x)}{p_{\text{data}}(x) + p_{\text{model}}(x)}$$

In [None]:
Dopt = target_density(x) / (target_density(x) + model_density(x) + 1e-6)
plt.figure()
plt.plot(x, Dopt)
plt.title("Optimal Density Ratio")
plt.show()

The Generator is trained to generate samples that are indistinguishable from samples from the target distribution. In other words, the generator tries to generate samples with the highest probability of being from the target distribution.
We will now implement a simple GAN on 2D data.

# 2. GAN on two dimensional data Gaussian Mixture

In this section, we will implement a simple GAN on 2D data. We will generate samples from a Gaussian Mixture distribution and train a GAN to generate samples from the same distribution using PyTorch.
We will implement a linear generator and discriminator for this task, then move on to more complex models. 


## 1.1 Training loop on simple data with linear models

First let us visualize the data:

In [None]:
def simple_2D_data(n=1000):
    means = torch.Tensor([[-4, 0], [0, 4], [-2, 2]])
    sigmas = torch.Tensor([0.6] * 3)
    n_per_cluster = n // 3
    data = []
    for i in range(3):
        data.append(
            torch.distributions.Normal(means[i], sigmas[i]).sample((n_per_cluster,))
        )
    data = torch.cat(data) / 6
    return data


dataset = simple_2D_data(n=5000)
plt.scatter(dataset[:, 0], dataset[:, 1], c="black", s=1)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.title("Target Data")
plt.show()

<font color='blue'>**TODO:**</font> Complete the code below to implement a linear generator and discriminator:
- The generator should be a linear layer that takes an input noise in dimension `noise_dim` and outputs a sample in dimension 2. You can add output activation if needed.
- The discriminator should be a linear layer that takes an input sample in dimension 2 and outputs a probability that the sample is from the target distribution. The output should be a scalar between 0 and 1. You can add output activation if needed.

```python

In [None]:
class LinearGenerator(nn.Module):
    def __init__(self, noise_dim=1):
        super(LinearGenerator, self).__init__()
        # Complete the code here
        self.activation = nn.Tanh()

    def forward(self, x):
        # Complete the code here
        return self.actication(x)

class LinearGenerator(nn.Module):
    def __init__(self, noise_dim=1):
        super(LinearGenerator, self).__init__()
 
    def forward(self, x):
        # Complete the code here
        return F.tanh(x)
    

class LinearGenerator(nn.Module):
    def __init__(self, noise_dim=1):
        super(LinearGenerator, self).__init__()
        self.blocks = nn.Sequential(
            nn.Linear(noise_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
        )


    def forward(self, x):
        # Complete the code here
        return self.blocks(x)
    

class LinearDiscriminator(nn.Module):
    def __init__(self):
        super(LinearDiscriminator, self).__init__()
        # Complete the code here

    def forward(self, x):
        # Complete the code here 

Let us the both models:

In [None]:
def visualize_generator(generator, noise_dim, n=1000):
    noise = input_noise(n, noise_dim=noise_dim).to(device)
    generated_data = generator(noise).detach().cpu()
    plt.figure()
    plt.scatter(generated_data[:, 0], generated_data[:, 1], c="blue", s=1)
    plt.xlim(-1.1, 1.1)
    plt.ylim(-1.1, 1.1)
    plt.title("Generated Data")
    plt.show()


def visualize_discriminator(discriminator, n=1000):
    x = torch.linspace(-1, 1, n)
    y = torch.linspace(-1, 1, n)
    xx, yy = torch.meshgrid(x, y)
    grid = (
        torch.from_numpy(np.stack([xx, yy], axis=-1).reshape(-1, 2)).float().to(device)
    )
    with torch.no_grad():
        predictions = discriminator(grid).cpu().view(n, n)
    plt.imshow(predictions.T, extent=(-1, 1, -1, 1), origin="lower")
    plt.colorbar()
    plt.title("Discriminator Output")
    plt.show()


def visualize_loss(loss_D, loss_G):
    plt.figure()
    plt.plot(loss_D, label="Discriminator Loss")
    plt.plot(loss_G, label="Generator Loss")
    plt.legend()
    plt.title("Losses")
    plt.show()

In [None]:
noise_dim = 10
generator = LinearGenerator(noise_dim).to(device)
discriminator = LinearDiscriminator().to(device)

visualize_generator(generator, noise_dim)
visualize_discriminator(discriminator)

Since the discriminator is trained to differentiate between samples from the target distribution and the generator distribution, we can train it as a binary classifier. Let us say that the point from the target distribution is labeled as 1 and the point from the generator distribution is labeled as 0.

<font color='blue'>**TODO:**</font> Complete the code below to implement the training loss for the discriminator using the binary cross entropy loss ```torch.nn.BCELoss``` [(Doc)](https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html). The discriminator should output the probability that the sample is from the target distribution. The target should be 1 for samples from the target distribution and 0 for samples from the generator distribution.


In [None]:
criterion = nn.BCELoss()


def discriminator_loss(output_real, output_fake):
    real_labels = torch.ones_like(output_real)
    fake_labels = torch.zeros_like(output_fake)

    real_loss = # Complete here
    fake_loss = # Complete here

    return real_loss + fake_loss

The generator loss is the opposite of the discriminator loss. We can write it as the Binary Cross Entropy loss with the labels flipped.

<font color='blue'>**TODO:**</font> Complete the code below to implement the training loss for the generator. 



In [None]:
def generator_loss(output_fake):
    fake_labels = # Complete here
    fake_loss = # Complete here
    return fake_loss

<font color='blue'>**TODO:**</font>  Run the code below to make sure that the training losses are implemented correctly.


In [None]:
noise_dim = 10
noise = input_noise(1000, noise_dim=noise_dim).to(device)
data = simple_2D_data(n=5000).to(device)

generator = LinearGenerator(noise_dim).to(device)
discriminator = LinearDiscriminator().to(device)

print(
    "The discriminator loss is: ",
    discriminator_loss(discriminator(data), discriminator(generator(noise))),
)
print("The generator loss is: ", generator_loss(discriminator(generator(noise))))

The training loop for the GAN consists in alternating a gradient step on the discriminator and the generator:
1. Sample a batch a data from the target distribution.
2. Sample a batch of noise from a normal distribution with the same batch size.
3. Generate samples from the generator.
4. Compute the discriminator loss based on real and generated samples.
5. Apply a gradient step on the discriminator.
6. Compute the generator loss based on generated samples.
7. Apply a gradient step on the generator.

<font color='blue'>**TODO:**</font> Complete the code below to implement the training step for the GAN. The training step should apply a gradient step on the discriminator and the generator. The training step should return the discriminator and generator losses.  The training step should be able to apply a gradient step on the discriminator `num_steps` times before applying a gradient step on the generator.

In [None]:
def training_step(
    generator, discriminator, optimizer_D, optimizer_G, noise, data, num_steps=1
):
    for _ in range(num_steps):
        discriminator.train()
        generator.eval()
        # Complete here to train the discriminator

    discriminator.eval()
    generator.train()
    # Complete here to train the generator

    return loss_D.item(), loss_G.item()

<font color='blue'>**TODO:**</font> Run the code below to make sure that the training step is implemented correctly.

In [None]:
noise_dim = 10
batch_size = 100

generator = LinearGenerator(noise_dim).to(device)
discriminator = LinearDiscriminator().to(device)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.002)

noise = input_noise(n=batch_size, noise_dim=noise_dim).to(device)
data = simple_2D_data(n=batch_size).to(device)

print("Before training")
print(
    "The discriminator loss is: ",
    discriminator_loss(discriminator(data), discriminator(generator(noise))),
)
print("The generator loss is: ", generator_loss(discriminator(generator(noise))))

for _ in range(100):
    loss_D, loss_G = training_step(
        generator, discriminator, optimizer_D, optimizer_G, noise, data
    )

print("\nAfter a few training steps")
print(
    "The discriminator loss is: ",
    discriminator_loss(discriminator(data), discriminator(generator(noise))),
)
print("The generator loss is: ", generator_loss(discriminator(generator(noise))))

<font color='blue'>**TODO:**</font>  Complete the code below to implement the training loop for the GAN. At each epoch, we should sample 5000 data points from the target distribution and split them into batches of size `batch_size`. 

In [None]:
def train_GAN(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    sample_function,
    n_epochs=20,
    batch_size=100,
    num_steps=1,
):
    loss_D_values = []
    loss_G_values = []
    noise = # Complete here
    data = # Complete here
    n_batches = len(data) // batch_size
    for epoch in range(n_epochs):
        data = data[torch.randperm(len(data))]
        noise = noise[torch.randperm(len(noise))]
        avg_loss_D = 0
        avg_loss_G = 0

        for i in range(n_batches):
            noise_i = # Complete here
            data_i = # Complete here
            loss_D, loss_G = # Complete here


            loss_D_values.append(loss_D)
            loss_G_values.append(loss_G)
            avg_loss_D += loss_D
            avg_loss_G += loss_G

        print(
            f"Epoch {epoch+1}/{n_epochs} Loss D: {avg_loss_D/n_batches:.4f} Loss G: {avg_loss_G/n_batches:.4f}"
        )

    return loss_D_values, loss_G_values

<font color='blue'>**TODO:**</font> Run the code below to train the GAN on 2D data.

In [None]:
generator = LinearGenerator(noise_dim).to(device)
discriminator = LinearDiscriminator().to(device)

optimizer_D = optim.Adam(discriminator.parameters(), betas=(0.5, 0.9), lr=5e-4)
optimizer_G = optim.Adam(generator.parameters(), betas=(0.5, 0.9), lr=5e-4)

loss_D_values, loss_G_values = train_GAN(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    simple_2D_data,
    n_epochs=20,
    batch_size=100,
    num_steps=5,
)

In [None]:
visualize_loss(loss_D_values, loss_G_values)
visualize_discriminator(discriminator)
visualize_generator(generator, noise_dim)

Let us try with a more complex model:

<font color='blue'>**TODO:**</font> Complete the code below to implement a more complex generator and discriminator:
- Generator: Implement a neural network with 4 hidden layers of size 256, 512, 1024 with Leaky ReLU activation functions. The output layer should output a sample in dimension 2.
- Discriminator: Implement a neural network with 7 hidden layers of size 1024, 512, 256 with Leaky ReLU activation functions. The output layer should output a scalar between 0 and 1.

In [None]:
class DenseGenerator(nn.Module):
    def __init__(self, noise_dim=1, out_dim=2):
        super(DenseGenerator, self).__init__()
        # Complete the code here

    def forward(self, x):
        # Complete the code here


class DenseDiscriminator(nn.Module):
    def __init__(self, in_dim=2):
        super(DenseDiscriminator, self).__init__()
        # Complete the code here

    def forward(self, x):
        # Complete the code here

<font color='blue'>**TODO:**</font> Run the code below to train the GAN on 2D data with a more complex model. Observe the generated samples and the discriminator. 

In [None]:
noise_dim = 10
batch_size = 100

generator = DenseGenerator(noise_dim).to(device)
discriminator = DenseDiscriminator().to(device)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)

loss_D_values, loss_G_values = train_GAN(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    simple_2D_data,
    n_epochs=20,
    batch_size=100,
    num_steps=5,
)

In [None]:
visualize_loss(loss_D_values, loss_G_values)
visualize_discriminator(discriminator)
visualize_generator(generator, noise_dim)

We will now try with 2D data slightly more complex:

In [None]:
def complex_2D_data(n=1000):
    means = torch.Tensor([[-2, 0], [2, 0], [0, 2], [0, -2]])
    sigmas = torch.Tensor([0.5] * 4)
    n_per_cluster = n // 4
    data = []
    for i in range(4):
        data.append(
            torch.distributions.Normal(means[i], sigmas[i]).sample((n_per_cluster,))
        )
    data = torch.cat(data) / 6
    return data


dataset = complex_2D_data(n=5000)
plt.scatter(dataset[:, 0], dataset[:, 1], c="black", s=1)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.title("Target Data")
plt.show()

<font color='blue'>**TODO:**</font> Run the code below to train the GAN on 2D data with a more complex model. Observe the generated samples and the discriminator.

In [None]:
noise_dim = 10
batch_size = 100

generator = DenseGenerator(noise_dim).to(device)
discriminator = DenseDiscriminator().to(device)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)

loss_D_values, loss_G_values = train_GAN(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    complex_2D_data,
    n_epochs=20,
    batch_size=100,
    num_steps=5,
)

In [None]:
visualize_loss(loss_D_values, loss_G_values)
visualize_discriminator(discriminator)
visualize_generator(generator, noise_dim)

The GAN model is know for being hard to train for two main reasons:
- Mode collapse: The generator collapses to a single point and generates the same sample.
- Model instability: The generator and discriminator are trained simultaneously and the training can be unstable.

We can observe the model collapse easilly with the 2D data on this dataset:


In [None]:
def complex_large_2D_data(n=1000):
    means = torch.Tensor([[-4, 0], [4, 0], [0, 4], [0, -4]])
    sigmas = torch.Tensor([0.6] * 4)
    n_per_cluster = n // 4
    data = []
    for i in range(4):
        data.append(
            torch.distributions.Normal(means[i], sigmas[i]).sample((n_per_cluster,))
        )
    data = torch.cat(data) / 6
    return data


dataset = complex_large_2D_data(n=5000)
plt.scatter(dataset[:, 0], dataset[:, 1], c="black", s=1)
plt.xlim(-1.1, 1.1)
plt.ylim(-1.1, 1.1)
plt.title("Target Data")
plt.show()

<font color='blue'>**TODO:**</font> Run the code below to train the GAN on 2D data with a more complex model. Conclude on mode collapse.

In [None]:
noise_dim = 10
batch_size = 100

generator = DenseGenerator(noise_dim).to(device)
discriminator = DenseDiscriminator().to(device)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)

loss_D_values, loss_G_values = train_GAN(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    complex_large_2D_data,
    n_epochs=20,
    batch_size=100,
    num_steps=5,
)

In [None]:
visualize_loss(loss_D_values, loss_G_values)
visualize_discriminator(discriminator)
visualize_generator(generator, noise_dim)

# 3. GAN on MNIST

We will now implement a GAN to generate handwritten digits from the MNIST dataset. The MNIST dataset consists of 28x28 grayscale images of handwritten digits from 0 to 9. We will use the same architectyre as before. For compuational reasons, we will use a rescaled version of the MNIST dataset with images of size 'size'x'size'.

In [None]:
size = 14

The code below loads the MNIST dataset and rescales the images to the desired size.

In [None]:
# Transform to resize, convert to tensor, and flatten to a vector of size (size)
transform = transforms.Compose(
    [
        transforms.Resize((size, size)),  # Resize to sizexsize pixels
        transforms.ToTensor(),  # Convert to a tensor with shape (1, size, size)
        transforms.Lambda(
            lambda x: x.view(size * size)
        ),  # Flatten to (size*size,) and normalize
        transforms.Lambda(
            lambda x: 2 * x - 1
        ),  # Normalize to [-1, 1] to avoid overflow
    ]
)


# Custom dataset to filter only labels 0 and 1
class FilteredMNIST(datasets.MNIST):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Filter to keep only images with labels 0 and 1
        indices = [
            i for i, label in enumerate(self.targets) if label in [0, 1, 2, 3, 4]
        ]
        self.data = self.data[indices]
        self.targets = self.targets[indices]

    def __getitem__(self, index):
        # Return only the image, not the label
        image, _ = super().__getitem__(index)
        return image


# Create datasets and loaders
dataset_train = FilteredMNIST(
    root="./data", train=True, download=True, transform=transform
)
dataset_test = FilteredMNIST(
    root="./data", train=False, download=True, transform=transform
)

loader_train = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True)
loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=128, shuffle=True)

first_batch = next(iter(loader_test))

plt.figure(1, figsize=(10, 10))
plt.suptitle("Test Set", y=0.57)
for i in range(10):
    plt.subplot(1, 10, i + 1)
    plt.imshow(first_batch[i].view(size, size).numpy(), cmap="gray")
    plt.axis("off")
plt.show()

The function below will plot the generated samples from the generator.

In [None]:
def plot_images(generator, noise_dim):
    generator.eval()
    noise = input_noise(10, noise_dim=noise_dim).to(device)
    generated_images = generator(noise).detach().cpu()
    generated_images = (generated_images + 1) / 2
    plt.figure(figsize=(10, 1))
    for i in range(10):
        plt.subplot(1, 10, i + 1)
        plt.imshow(generated_images[i].view(size, size).numpy(), cmap="gray")
        plt.axis("off")
    plt.show()

As usual for image data, we will split the data into batches of size `batch_size`. The training loop must be adapted to train the GAN on image data. 

<font color='blue'>**TODO:**</font> Complete the code below to implement the training loop for the GAN on MNIST. At each epoch, we iterate over the batches of the dataset and train the GAN.

In [None]:
def train_GAN_MNIST(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    n_epochs=20,
    num_steps=1,
):
    loss_D_values = []
    loss_G_values = []

    for epoch in range(n_epochs):
        avg_loss_D = 0
        avg_loss_G = 0

        # Complete here to train the GAN

        print(
            f"Epoch {epoch+1}/{n_epochs} Loss D: {avg_loss_D/len(loader_train):.4f} Loss G: {avg_loss_G/len(loader_train):.4f}"
        )
        if epoch % 5 == 0:
            plot_images(generator, noise_dim)

    return loss_D_values, loss_G_values

<font color='blue'>**TODO:**</font> Run the code below to train the GAN on MNIST.

In [None]:
noise_dim = 100
generator = DenseGenerator(noise_dim, out_dim=size * size).to(device)
discriminator = DenseDiscriminator(in_dim=size * size).to(device)

optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)

loss_D_values, loss_G_values = train_GAN_MNIST(
    generator,
    discriminator,
    optimizer_D,
    optimizer_G,
    noise_dim,
    n_epochs=100,
    num_steps=1,
)

In [None]:
visualize_loss(loss_D_values, loss_G_values)
plot_images(generator, noise_dim)