Generative Adversarial Networks(GANs) with MNIST Fashion Dataset

Gbemisola Akinola-Alli
4 min readApr 1, 2023

--

Source: https://machinelearningmastery.com/what-are-generative-adversarial-networks-gans/

GANs are unsupervised deep learning algorithms that use previous data to generate new data. They are mainly used with images and automatically learn the information given by the input images and use that to create new images.

They can be used to create art, improve the resolution of images, generate images from text and various other uses.

GANs are also special because they need two models:

  • Generative model: This model's job is the name suggests; to generate new samples/data based on the input data. It uses a randomly selected vector from a distribution as a seed variable.
  • Discriminative model: In supervised learning discriminative models are known as classification models, which is similar to what this step does. The discriminative model compares the generated data with real-life samples and classifies it as either real or generated. In the long run, the aim is for the real and generated samples to become indistinguishable.

During the training process, The two models are updated with their two separate optimisers.

Using GANs with MNIST Fashion Dataset

In this article, we use the MNIST fashion dataset as input data to generate more images. The MNIST fashion dataset contains 60,000 images of various clothing items and is one of the most popular datasets for computer vision

1. Importing libraries

import numpy as np 
import torch.nn as nn
import torch
from torch.utils.data.dataloader import DataLoader
import torchvision.datasets as datasets
import torchvision
import matplotlib.pyplot as plt
import PIL.Image as Image
import torch.functional as F
import pandas as pd
import torchvision.transforms as T
from torchvision.utils import make_grid

2. Preparing the dataset

Luckily the MNIST dataset can easily be found with torchvision.datasets function and other popular datasets can also be found with it.


transforms= T.Compose([
T.ToTensor(),
T.Normalize((0.5,), (0.5))
])

data_mnist= datasets.FashionMNIST(root='.', train=True, transform=transforms, download=True)

batch_size =32
train_loader = DataLoader(dataset= data_mnist, batch_size =batch_size, shuffle= True)

This is what the data looks like:

Images from the MNIST Fashion dataset

3. Creating the discriminator model

We create a simple model with linear layers, GANs could be way more complicated. and can be customised. Though a GAN has 4 layers.

class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.flat =nn.Flatten()
self.dis_model=nn.Sequential(
#1st layer
nn.Linear(784, 1024),
nn.ReLU(),
nn.Dropout(0.3),

#2nd Layer
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),

#3rd Layer
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),

#4th layer
nn.Linear(256, 1),
nn.Sigmoid()


)

def forward(self, x):

x=self.flat(x)
out=self.dis_model(x)

return out
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
discriminate =Discriminator().to(device)

4. Creating the Generator model

class Generator(nn.Module):
def __init__(self):
super().__init__()

self.gen_model=nn.Sequential(
#1st layer
nn.Linear(100, 256),
nn.ReLU(),

#2nd Layer
nn.Linear(256, 512),
nn.ReLU(),

#3rd Layer
nn.Linear(512, 1024),
nn.ReLU(),

#4th layer
nn.Linear(1024, 784),
nn.Tanh()
)


def forward(self, x):

out=self.gen_model(x)
out=out.view(x.size(0), 1, 28, 28)

return out
generator =Generator().to(device)

5. Setting model parameters

lr=0.0001
epochs=50
loss_function =nn.BCELoss()

optim_gen= torch.optim.Adam(generator.parameters(), lr=lr)
optim_dis= torch.optim.Adam(discriminate.parameters(), lr=lr)

6. Time to train!

for epoch in range(epochs):

for n, (input_data, labels) in enumerate(train_loader):

input_data =input_data.to(device)

# create ones for labels of the discriminator i.e. binary 1 for real 0 for fake
input_labels=torch.ones((batch_size,1)).to(device)

#create noise as the input data for the first instance

noise = torch.randn((batch_size, 100)).to(device)
fake_labels=torch.zeros((batch_size,1)).to(device)

# Put noise into the generator
generated_data = generator(noise)

#combine real and fake samples and labels for training

all_data = torch.cat((input_data, generated_data))
all_labels= torch.cat((input_labels, fake_labels))

#Training the discriminator
discriminate.zero_grad()

discriminate_output=discriminate(all_data)
loss_discrminate = loss_function(discriminate_output, all_labels)

loss_discrminate.backward()
optim_dis.step()

#data for the generator
noise = torch.randn((batch_size, 100)).to(device)

#Training the generrator
generator.zero_grad()

generated_output =generator(noise)
dis_gen_output=discriminate(generated_output)
loss_generate=loss_function(dis_gen_output, input_labels)

loss_generate.backward()
optim_gen.step()

#print loss

if n== batch_size-1:

print(f'Epoch: {epoch+1} Loss Dis: {loss_discrminate}')
print(f'Epoch: {epoch+1} Loss Gen: {loss_generate}')

Testing our model

After the model has been trained, only the generator model is really needed to create new images. The discriminator model can be repurposed as a classifier of some sort for further processing.

We test the trained model by putting some noise into the trained generator model:

noise= torch.randn((batch_size , 100)).to(device)
generated_data = generator(noise)

Et Viola! we get new images of clothes!

generated_data = generated_data.cpu().detach()

for i in range(16):
ax = plt.subplot(4, 4, i + 1)
plt.imshow(generated_data[i].squeeze(), cmap="gray_r")
plt.xticks([])
plt.yticks([])
Generated Images with GAN model

My favourite thing about GANs is that in instances where you don’t have enough training data for another project. You can simply build a GAN to generate more data!

References:

https://machinelearningmastery.com/generative_adversarial_networks/

--

--