Example Code for a Generative Adversarial Network (GAN) Using PyTorch

One weekend, I decided to implement a generative adversarial network (GAN) using the PyTorch library. The purpose of a GAN is to generate fake image data that is realistic looking. I used the well-known MNIST image dataset to train a GAN and then used the GAN to generate fake images.

Each MNIST image is a grayscale 28 x 28 (784 total) pixels picture of a handwritten digit from ‘0’ to ‘9. I used the training set which has 60,000 images. Each pixel value is between 0 and 255 but it’s standard practice to scale all pixel values to between 0.0 and 1.0.

Implementing the GAN was very difficult. I used about a dozen resources I found on the Internet. Even so, the demo code took me well over 40 hours to get to a point where the code was running more or less correctly.

The GAN has a Generator and a Discriminator. The demo Generator accepts 100 random Gaussian distributed values (with mean 0 and standard deviation 1, so most values are between -3.0 and +3.0) and emits 784 values between 0 and 1 — an image. The Discriminator accepts 784 values between 0 and 1 and emits a value between 0 and 1 where values less than 0.5 indicate a fake image and values greater than 0.5 indicate a real image from the MNIST training dataset.


Left: 16 fake images created after 40 training epochs. Right: 16 fake images after 100 training epochs. It’s not easy to know when to stop training a GAN.

The Generator tries to generate images that fool the Discriminator — which means the Discriminator outputs a value very close to 0.5 when presented with a fake image created by the Generator. Simultaneously, the Generator is updated using loss/error information from the Discriminator in a way so that generated images are harder for the Discriminator to classify correctly. The idea is conceptually difficult, and implementation is technically difficult.

After I got my demo running I realized that because a GAN is essentially an unsupervised technique, looking at the loss values doesn’t give you any information about when to stop training. To gain insight about when to stop, you’d have to look at the distribution of the output values of the Discriminator. When almost all output values are very close to 0.5 that means the Discriminator is being fooled by the Generator. I didn’t implement that functionality in my demo — that would mean another 4-8 hours of coding and experimentation.

I have looked at generative adversarial networks before and didn’t have much use for them. The main reason I was re-looking at GANs was that I read a very interesting research paper that proposed using GANs for times series anomaly detection. Briefly, an Autoencoder-based architecture for anomaly detection is too accurate creating fakes, so too many anomalies are missed. But a GAN is less accurate at creating fakes so you get too many anomaly false alarms. The research paper I read proposed combining an Autoencoder with a GAN.



Three famous pairs of movie adversaries where the villain is a PhD doctor.

Left: James Bond and Dr. Julius No. In this scene from “Dr. No” (1962), Bond has been captured and is in Dr. No’s underwater lair. This movie launched the Bond franchise.

Center: Sir Denis Nayland Smith and Dr. Fu Manchu. This is a scene from “The Mask of Fu Manchu” (1932) where Scotland Yard inspector Smith has been captured by the evil doctor and his daughter.

Right: Sherlock Holmes and Dr. James Moriarty. This is a scene from “Sherlock Holmes: A Game of Shadows” (2011) where Holmes is in Moriarty’s office. Unknown to Moriarty (and the movie audience), Holmes notices one of Moriarty’s books which is eventually the key to defeating the evil doctor.


# gan_mnist.py

# PyTorch 1.7.0-CPU Anaconda3-2020.02  Python 3.7.6
# Windows 10 
# kludged together using many examples found online

import numpy as np
import torch as T
import torchvision as tv
import time
import matplotlib.pyplot as plt

device = T.device("cpu")  # apply to Tensor or Module

# -----------------------------------------------------------

class Discriminator(T.nn.Module):  # 784-128-64-32-1
  def __init__(self):
    super(Discriminator, self).__init__()
    self.fc1 = T.nn.Linear(784, 128)
    self.fc2 = T.nn.Linear(128, 64)
    self.fc3 = T.nn.Linear(64, 32)
    self.fc4 = T.nn.Linear(32, 1)

    T.nn.init.xavier_uniform_(self.fc1.weight)
    T.nn.init.zeros_(self.fc1.bias)
    T.nn.init.xavier_uniform_(self.fc2.weight)
    T.nn.init.zeros_(self.fc2.bias)
    T.nn.init.xavier_uniform_(self.fc3.weight)
    T.nn.init.zeros_(self.fc3.bias)
    T.nn.init.xavier_uniform_(self.fc4.weight)
    T.nn.init.zeros_(self.fc4.bias)
        
  def forward(self, x):
    x = x.view(-1, 28*28)   # flatten image(s)
    z = T.tanh(self.fc1(x))
    z = T.tanh(self.fc2(z))
    z = T.tanh(self.fc3(z))
    oupt = T.sigmoid(self.fc4(z))  # BCELoss
    return oupt

# -----------------------------------------------------------

class Generator(T.nn.Module):  # 100-32-64-128-784
  def __init__(self): 
    super(Generator, self).__init__()
    self.fc1 = T.nn.Linear(100, 32)
    self.fc2 = T.nn.Linear(32, 64)
    self.fc3 = T.nn.Linear(64, 128)
    self.fc4 = T.nn.Linear(128, 784)

    T.nn.init.xavier_uniform_(self.fc1.weight)
    T.nn.init.zeros_(self.fc1.bias)
    T.nn.init.xavier_uniform_(self.fc2.weight)
    T.nn.init.zeros_(self.fc2.bias)
    T.nn.init.xavier_uniform_(self.fc3.weight)
    T.nn.init.zeros_(self.fc3.bias)
    T.nn.init.xavier_uniform_(self.fc4.weight)
    T.nn.init.zeros_(self.fc4.bias)

  def forward(self, x):
    z = T.tanh(self.fc1(x))
    z = T.tanh(self.fc2(z))
    z = T.tanh(self.fc3(z))
    oupt = T.sigmoid(self.fc4(z))  # consider no activation
    return oupt

# -----------------------------------------------------------

def view_images(images_list, idx):
  fig, axes = plt.subplots(figsize=(7,7), nrows=4,
    ncols=4, sharey=True, sharex=True)
  for ax, img in zip(axes.flatten(), images_list[idx]):
    img = img.detach()
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)
    im = ax.imshow(img.reshape((28,28)), cmap='gray_r')
  plt.show()

# -----------------------------------------------------------

def main():
  # 0. get started
  print("\nBegin GAN MNIST demo ")
  np.random.seed(1)
  T.manual_seed(1)
  
  # 1. create MNIST DataLoader object
  print("\nCreating MNIST Dataset and DataLoader ")

  bat_size = 64 
  # requested size, not necessarily actual 
  # 60,000 train images / 64 = 937 batches + 1 of size 32

  trfrm = tv.transforms.ToTensor()   # also divides by 255
  train_ds = tv.datasets.MNIST(root="data", train=True,
    download=True, transform=trfrm)
  train_ldr = T.utils.data.DataLoader(train_ds,
    batch_size=bat_size, shuffle=True, drop_last=True)   

  # 2. create networks
  dis = Discriminator().to(device)  # 784-128-64-32-1
  gen = Generator().to(device)      # 100-32-64-128-784

  # 3. train GAN model
  max_epochs = 100
  ep_log_interval = 10
  lrn_rate = 0.002       # small for Adam

  dis.train()            # set mode
  gen.train()
  dis_optimizer = T.optim.Adam(dis.parameters(), lrn_rate)
  gen_optimizer = T.optim.Adam(gen.parameters(), lrn_rate)
  loss_func = T.nn.BCELoss()
  all_ones = T.ones(bat_size, dtype=T.float32).to(device)
  all_zeros = T.zeros(bat_size, dtype=T.float32).to(device)

# -----------------------------------------------------------

  print("\nStarting training ")
  for epoch in range(0, max_epochs):
    for batch_idx, (real_images, _) in enumerate(train_ldr):
      dis_accum_loss = 0.0  # to display progress
      gen_accum_loss = 0.0

      # 1a. train discriminator (0/1) using real images
      dis_optimizer.zero_grad()
      dis_real_oupt = dis(real_images)  # [0, 1]
      dis_real_loss = loss_func(dis_real_oupt.squeeze(),
        all_ones)

      # 1b. train discriminator using fake images
      zz = T.normal(0.0, 1.0,
        size=(bat_size, 100)).to(device)
      fake_images = gen(zz)
      dis_fake_oupt = dis(fake_images)
      dis_fake_loss = loss_func(dis_fake_oupt.squeeze(),
        all_zeros)     
      dis_loss_tot = dis_real_loss + dis_fake_loss
      dis_accum_loss += dis_loss_tot

      dis_loss_tot.backward()
      dis_optimizer.step()

      # 2. train gen with fake images and flipped labels
      gen_optimizer.zero_grad()
      zz = T.normal(0.0, 1.0,
        size=(bat_size, 100)).to(device) 
      fake_images = gen(zz)
      dis_fake_oupt = dis(fake_images)
      gen_loss = loss_func(dis_fake_oupt.squeeze(), all_ones)
      gen_accum_loss += gen_loss

      gen_loss.backward()
      gen_optimizer.step()

    if epoch % ep_log_interval == 0 or epoch == max_epochs-1:
      print(" epoch: %4d | dis loss: %0.4f | gen loss: %0.4f "\
        % (epoch, dis_accum_loss, gen_accum_loss))
      dt = time.strftime("%Y_%m_%d-%H_%M_%S")
      fn = ".\\Models\\" + str(dt) + str("-") + "epoch_" + \
        str(epoch) + "_gan_mnist_model.pt"
      T.save(gen.state_dict(), fn)

  print("Training commplete ")

# -----------------------------------------------------------

  print("\nGenerating 16 images using trained generator ")
  num_images = 16
  zz = T.normal(0.0, 1.0, size=(num_images, 100)).to(device)
  gen.eval() 
  rand_images = gen(zz)
  images_list = [rand_images]
  view_images(images_list, 0)

  print("\nEnd GAN MNIST demo ")

if __name__ == "__main__":
  main()
This entry was posted in Machine Learning, PyTorch. Bookmark the permalink.

2 Responses to Example Code for a Generative Adversarial Network (GAN) Using PyTorch

  1. Peter Boos says:

    Interesting, if i’m correctly then the latest text generators (used to be only LSTMS) are GAN + LSTM based as well now.

  2. Yes — and this is why I was re-looking at GANs.

Leave a Reply

Please log in using one of these methods to post your comment:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s