Pytorch MNIST - GAN

Pytorch MNIST - GAN
Photo by Alina Grubnyak / Unsplash

Here I will discuss about my attempt to implement a Generative Adversarial Network trained on MNIST Handwritten Digit Dataset.

Generative Adversarial Networks, or GANs, are a fascinating part of machine learning that's all about creating something new from existing data. The idea involves two neural networks in a bit of a friendly competition. Think of it as a game where one network tries to create fake data, and the other judges whether it's real or not, learning from each interaction.

The whole setup works because the networks learn from each other. The generator network makes the fake data, and the discriminator network tries to catch the fakes. The generator's aim isn't to create a perfect match but to be good enough to trick the discriminator. It's a bit like a game of cat and mouse, where both sides are constantly evolving to get better than the other. At the begining both networks are untrained.

Here is the original paper: https://arxiv.org/abs/1406.2661

I will first go through the paper trying to get a fundamental understanding about the concept.

Abstract

There are two models: "a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G". Model G will create new data that looks like they came from training Data. Model D tries to tell which data are from the actual training set and which are from the Model G.

G is trained to make D think that data from G are real data from the training set. D is trained to correctly discriminate between real data and fake data from G.

Both networks can be implemented as Neural Networks and trained using Backpropagation.

Introduction

The paper has an interesting analogy about the whole process.

đź’ˇ
The generative model can be thought of as analogous to a team of counterfeiters, trying to produce fake currency and use it without detection, while the discriminative model is analogous to the police, trying to detect the counterfeit currency. Competition in this game drives both teams to improve their methods until the counterfeits are indistiguishable from the genuine articles.

Math

The paper summarises the entire model as a min max game as shown in the following equation.

The expectation \( \mathbb{E}_{x\sim p_{\text{data}(x)}} \) over the real data distribution \( p_{\text{data}} \) assesses how well the discriminator recognizes real data \( x \), while the expectation \( \mathbb{E}_{z\sim p_z(z)} \) over the generator's noise distribution \( p_z \) measures the discriminator's ability to detect fake data produced by the generator \( G(z) \).

The \( \min_{G} \) part means the generator is trying to minimise this function—it wants to get better at fooling the discriminator. The \( \max_{D} \) part means the discriminator is trying to maximise the function—it wants to get better at catching the fakes.

The first term is where the discriminator evaluates real data and tries to assign a high probability (or score) to it being real.

The second term is where the discriminator looks at fake data from the generator and tries to give it a low score. The paper states that, instead of minimising \( \log(1 - D(G(z))) \) we can maximise \( \log(D(G(z))) \).

Pseudocode

Training the Discriminator

According to the algorithm the following needs to be ascended.

Now let's take a look at the Binary Cross Entropy loss given by the Pytorch Library. (https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html)

Here loss has a negative sign on the front. Ascending the original equation is the same as descending the BCE loss. Therefore we can use the BCE loss to train the discriminator.

Training the Generator

the following equation needs to be descended.

For this we will also be using the BCE loss. But this is not as straight foward as the discriminator. Since the discriminator is determining between two classes, using BCE loss is intuitive. I found this explanation for this issue. https://stats.stackexchange.com/questions/242907/why-use-binary-cross-entropy-for-generator-in-adversarial-networks

Code Walkthrough

First we need to device we need to run our model on. If we have cuda we can run on cuda, If your computer is an Apple silicon mac you can use metal framework or else just cpu.

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

Next step is loading data.

class FeatureDataset(Dataset):
    def __init__(self, file_name):
        data_csv = pd.read_csv(file_name, header=0)
        dataset = np.array(data_csv, dtype=float)

        x = dataset[:, 1:785]
        x = maxabs_scale(x, axis=1)

        x = torch.tensor(x, dtype=torch.float, device=device)

        y = torch.ones((x.shape[0], 1), dtype=torch.float, device=device)

        self.x_train = x
        self.y_train = y

    def __len__(self):
        return len(self.y_train)

    def __getitem__(self, idx):
        return self.x_train[idx], self.y_train[idx]


batch_size = 64
epochs = 50
learning_rate = 3e-4

feature_set = FeatureDataset('data/mnist_train.csv')
data_loader = torch.utils.data.DataLoader(feature_set, batch_size=batch_size, shuffle=True, drop_last=True)

We've got this class called FeatureDataset which is a custom dataset inheriting from PyTorch's Dataset class. It's tailored to load and prep data for our model.

Inside the __init__ function, It reads a CSV file containing our data using pandas. Then, it converts that data into a NumPy array.

Format of the dataset

  • column 0 - Target (Integer from 1 to 9)
  • column 1: 785 - Inputs (Integers from 0 to 255)

There are 784 inputs because the images are 28*28 pixels. 2D pixel array is flattend to a single dimension in the input dataset.

So, we slice the array to grab all rows but only the columns from the second to the 785th. This is our feature set. 1st column is the output set.

The maxabs_scale function comes from sklearn.preprocessing to scale our features. This is a common prep step to help our model learn better. After scalling our feature values will be from -1 to. +1.

After scaling, the features (x) and labels (y) are converted into PyTorch tensors. We parse our device as a parameter when creating all tensors to keep them on the same device.

Now, self.x_train and self.y_train are set up as the features and labels for training. The __len__ method just tells PyTorch how many examples we've got.

The __getitem__ method is PyTorch's way of grabbing an individual data point. Here, it's set up to return a single feature-label pair when given an index.

__len__ and __getitem__ methods are mandotory to implement when implementing Dataset class.

Then, outside of the class, some straightforward variables are set for the batch size (how many examples we look at one time), epochs (how many times we'll run through the whole dataset), and the learning rate (a small step size to adjust the model's weights during training).

Finally, we create an instance of our FeatureDataset class by feeding it the path to our CSV file and wrap it with a DataLoader. DataLoader handles all data loading tasks making the process much easier for us. We parse the batch size, shuffle as true and drop last as true. Then samples from our dataset will be shuffled and if there are elements left after getting some batches they will be ignored. (If the dataset size is 100 and batch size is 30, last 10 samples will be dropped)

Next step is to define our Models.

class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(64, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        logits = self.linear_stack(x)
        return logits

To create a model we need to extend the nn.Module class and implement the forward pass. We have a linear stack of layers. First layer is linear with 64 inputs. That is because we will be using a random noise of 64 integers as the input for the generator. And then the ReLU activation function is used. In the final layer we need an image as the output. So data in the shape of our training dataset is created, 784 integers. After scalling our training data, values were between -1 and +1. Therefore we use the Tanh actiavtion function in the last layer to also get data between -1 and +1.

class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        logits = self.linear_stack(x.float())
        return logits

In the discriminator we take an image as the input. Again we will use the Relu activation. In the final layer we just need a single node. This will say the probability of the input being fake or not. To get the output between 0 and 1 the sigmoid activation is used.

Here is the most important part, training the model.

for epoch in range(epochs):
        print(f"epoch: {epoch}/{epochs}")
        for batch_idx, (real, _) in enumerate(data_loader):
            noise = torch.randn((batch_size, 64), device=device)
            fake = gen(noise)
            
            ### Train Discriminator
            disc_real = disc(real)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real, device=device))
    
            disc_fake = disc(fake.detach())
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
    
            lossD = (lossD_real + lossD_fake) / 2
            
            
            disc.zero_grad()
            lossD.backward()
            opt_disc.step()
            
            ### Train Generator
            noise = torch.randn((batch_size, 64), device=device)
            fake = gen(noise)
            output = disc(fake)
            lossG = criterion(output, torch.ones_like(output, device=device))

            gen.zero_grad()
            lossG.backward()
            opt_gen.step()
    
        torch.save(gen.state_dict(), f"models/generator_epoch_{epoch}.pt")
        torch.save(disc.state_dict(), f"models/discriminator_epoch_{epoch}.pt")

We loop the training for the number of epochs. In each epoch, loop through each batch of the dataset. For each batch,

  • Generate a fake image
    • generate a noise to use as the input for the generative model
    • generate a fake image from the generator
  • Train the discriminator
    • get the output for the real image from the discriminator
    • we use 1 as the desired label. (All images should be classified as real)
    • calculate real loss
    • get the output for the fake image from the discriminator
    • we use 0 as the desired label. (All images should be classified as fake)
    • calculate fake loss
    • calculate the final loss as the average of fake loss and real loss
    • do backprop on the discriminator
  • Train the Generator
    • generate a noise
    • generate a fake image from that
    • we use 1 as the desired label from the discriminator. (All images should be classified as real). Even though these images are fake, from the generators perspective, it is trying to fool the discriminator.
    • get the output from the loss
    • calculate the loss
    • do backprop on the generator

At the end of each epoch, both models are saved.

Output

Training data:

Output data: