Pytorch MNIST - GAN
Here I will discuss about my attempt to implement a Generative Adversarial Network trained on MNIST Handwritten Digit Dataset.
Generative Adversarial Networks, or GANs, are a fascinating part of machine learning that's all about creating something new from existing data. The idea involves two neural networks in a bit of a friendly competition. Think of it as a game where one network tries to create fake data, and the other judges whether it's real or not, learning from each interaction.
The whole setup works because the networks learn from each other. The generator network makes the fake data, and the discriminator network tries to catch the fakes. The generator's aim isn't to create a perfect match but to be good enough to trick the discriminator. It's a bit like a game of cat and mouse, where both sides are constantly evolving to get better than the other. At the begining both networks are untrained.
Here is the original paper: https://arxiv.org/abs/1406.2661
I will first go through the paper trying to get a fundamental understanding about the concept.
Abstract
There are two models: "a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G". Model G will create new data that looks like they came from training Data. Model D tries to tell which data are from the actual training set and which are from the Model G.
G is trained to make D think that data from G are real data from the training set. D is trained to correctly discriminate between real data and fake data from G.
Both networks can be implemented as Neural Networks and trained using Backpropagation.
Introduction
The paper has an interesting analogy about the whole process.
Math
The paper summarises the entire model as a min max game as shown in the following equation.
The expectation \( \mathbb{E}_{x\sim p_{\text{data}(x)}} \) over the real data distribution \( p_{\text{data}} \) assesses how well the discriminator recognizes real data \( x \), while the expectation \( \mathbb{E}_{z\sim p_z(z)} \) over the generator's noise distribution \( p_z \) measures the discriminator's ability to detect fake data produced by the generator \( G(z) \).
The \( \min_{G} \) part means the generator is trying to minimise this function—it wants to get better at fooling the discriminator. The \( \max_{D} \) part means the discriminator is trying to maximise the function—it wants to get better at catching the fakes.
The first term is where the discriminator evaluates real data and tries to assign a high probability (or score) to it being real.
The second term is where the discriminator looks at fake data from the generator and tries to give it a low score. The paper states that, instead of minimising \( \log(1 - D(G(z))) \) we can maximise \( \log(D(G(z))) \).
Pseudocode
Training the Discriminator
According to the algorithm the following needs to be ascended.
Now let's take a look at the Binary Cross Entropy loss given by the Pytorch Library. (https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html)
Here loss has a negative sign on the front. Ascending the original equation is the same as descending the BCE loss. Therefore we can use the BCE loss to train the discriminator.
Training the Generator
the following equation needs to be descended.
For this we will also be using the BCE loss. But this is not as straight foward as the discriminator. Since the discriminator is determining between two classes, using BCE loss is intuitive. I found this explanation for this issue. https://stats.stackexchange.com/questions/242907/why-use-binary-cross-entropy-for-generator-in-adversarial-networks
Code Walkthrough
First we need to device we need to run our model on. If we have cuda we can run on cuda, If your computer is an Apple silicon mac you can use metal framework or else just cpu.
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
Next step is loading data.
class FeatureDataset(Dataset):
def __init__(self, file_name):
data_csv = pd.read_csv(file_name, header=0)
dataset = np.array(data_csv, dtype=float)
x = dataset[:, 1:785]
x = maxabs_scale(x, axis=1)
x = torch.tensor(x, dtype=torch.float, device=device)
y = torch.ones((x.shape[0], 1), dtype=torch.float, device=device)
self.x_train = x
self.y_train = y
def __len__(self):
return len(self.y_train)
def __getitem__(self, idx):
return self.x_train[idx], self.y_train[idx]
batch_size = 64
epochs = 50
learning_rate = 3e-4
feature_set = FeatureDataset('data/mnist_train.csv')
data_loader = torch.utils.data.DataLoader(feature_set, batch_size=batch_size, shuffle=True, drop_last=True)
We've got this class called FeatureDataset
which is a custom dataset inheriting from PyTorch's Dataset
class. It's tailored to load and prep data for our model.
Inside the __init__
function, It reads a CSV file containing our data using pandas
. Then, it converts that data into a NumPy array.
Format of the dataset
- column
0
- Target (Integer from1
to9
) - column
1: 785
- Inputs (Integers from0
to255
)
There are 784
inputs because the images are 28*28
pixels. 2D pixel array is flattend to a single dimension in the input dataset.
So, we slice the array to grab all rows but only the columns from the second to the 785th. This is our feature set. 1st column is the output set.
The maxabs_scale
function comes from sklearn.preprocessing
to scale our features. This is a common prep step to help our model learn better. After scalling our feature values will be from -1
to. +1
.
After scaling, the features (x
) and labels (y
) are converted into PyTorch tensors. We parse our device as a parameter when creating all tensors to keep them on the same device.
Now, self.x_train
and self.y_train
are set up as the features and labels for training. The __len__
method just tells PyTorch how many examples we've got.
The __getitem__
method is PyTorch's way of grabbing an individual data point. Here, it's set up to return a single feature-label pair when given an index.
__len__
and __getitem__
methods are mandotory to implement when implementing Dataset
class.
Then, outside of the class, some straightforward variables are set for the batch size (how many examples we look at one time), epochs (how many times we'll run through the whole dataset), and the learning rate (a small step size to adjust the model's weights during training).
Finally, we create an instance of our FeatureDataset
class by feeding it the path to our CSV file and wrap it with a DataLoader
. DataLoader
handles all data loading tasks making the process much easier for us. We parse the batch size, shuffle as true
and drop last as true
. Then samples from our dataset will be shuffled and if there are elements left after getting some batches they will be ignored. (If the dataset size is 100
and batch size is 30
, last 10
samples will be dropped)
Next step is to define our Models.
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.linear_stack = nn.Sequential(
nn.Linear(64, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
def forward(self, x):
logits = self.linear_stack(x)
return logits
To create a model we need to extend the nn.Module
class and implement the forward pass. We have a linear stack of layers. First layer is linear with 64
inputs. That is because we will be using a random noise of 64
integers as the input for the generator. And then the ReLU
activation function is used. In the final layer we need an image as the output. So data in the shape of our training dataset is created, 784
integers. After scalling our training data, values were between -1
and +1
. Therefore we use the Tanh
actiavtion function in the last layer to also get data between -1
and +1
.
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.linear_stack = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
logits = self.linear_stack(x.float())
return logits
In the discriminator we take an image as the input. Again we will use the Relu
activation. In the final layer we just need a single node. This will say the probability of the input being fake or not. To get the output between 0
and 1
the sigmoid
activation is used.
Here is the most important part, training the model.
for epoch in range(epochs):
print(f"epoch: {epoch}/{epochs}")
for batch_idx, (real, _) in enumerate(data_loader):
noise = torch.randn((batch_size, 64), device=device)
fake = gen(noise)
### Train Discriminator
disc_real = disc(real)
lossD_real = criterion(disc_real, torch.ones_like(disc_real, device=device))
disc_fake = disc(fake.detach())
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake, device=device))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward()
opt_disc.step()
### Train Generator
noise = torch.randn((batch_size, 64), device=device)
fake = gen(noise)
output = disc(fake)
lossG = criterion(output, torch.ones_like(output, device=device))
gen.zero_grad()
lossG.backward()
opt_gen.step()
torch.save(gen.state_dict(), f"models/generator_epoch_{epoch}.pt")
torch.save(disc.state_dict(), f"models/discriminator_epoch_{epoch}.pt")
We loop the training for the number of epochs. In each epoch, loop through each batch of the dataset. For each batch,
- Generate a fake image
- generate a noise to use as the input for the generative model
- generate a fake image from the generator
- Train the discriminator
- get the output for the real image from the discriminator
- we use
1
as the desired label. (All images should be classified as real) - calculate real loss
- get the output for the fake image from the discriminator
- we use
0
as the desired label. (All images should be classified as fake) - calculate fake loss
- calculate the final loss as the average of fake loss and real loss
- do backprop on the discriminator
- Train the Generator
- generate a noise
- generate a fake image from that
- we use
1
as the desired label from the discriminator. (All images should be classified as real). Even though these images are fake, from the generators perspective, it is trying to fool the discriminator. - get the output from the loss
- calculate the loss
- do backprop on the generator
At the end of each epoch, both models are saved.
Output
Training data:
Output data:
Comments ()