Knowing When To Stop Training a Generative Adversarial Network (GAN)

A generative adversarial network (GAN) is a deep neural system that is designed to generate fake/synthetic data items. A GAN has a clever architecture made of two neural networks: a generator that creates fake data items, and a discriminator that classifies a data item as fake (0) or real (1). GANs are most often used to generate synthetic images, but GANs can generate any kind of data.

Training a GAN is quite difficult. There are twice as many hyperparameters to deal with (number of hidden layers, number of nodes in each layer, activation function, batch size, optimization algorithm, learing rate, batch size, loss function and so on) as there are for a regular neural network.

Knowing when to stop training a regular neural network is difficult and usually involves looking at the value of the loss function during training. I’ve been experimenting with GANs and wondered how to know when to stop training. Looking at the loss values of the generator and the discriminator won’t really work because the generator is constantly trying to create fake data to fool the discriminator, while at the same time the discriminator is learning how to tell fake data from real data.

I did a simple thought experiment and figured that if the discriminator was able to distinguish fake data items from real data items with about 50% accuracy, then the generator is doing a good job of creating fake data items.

So I coded up an experiment with a function that computes the prediction accuracy of the discriminator, based on n fake data items produced by the generator. I used PyTorch, my neural code library of choice but the same ideas can be used in Keras or TensorFlow. In pseudo-code:

loop n times
  use generator to create a fake data item
  feed fake image to discriminator, get result p
  if p "less-than" 0.5 then
    num correct += 1  # determined it was fake
    num wrong += 1    # thought it was real
return n_correct / (n_correct + n_wrong)

The code implementation is:

def Accuracy(gen, dis, n, verbose=False):
  # accuracy of discriminator on n fake images from generator
  n_correct = 0; n_wrong = 0

  for i in range(n):
    zz = T.normal(0.0, 1.0,
      size=(1, gen.inpt_dim)).to(device)  # 20 values
    fake_image = gen(zz)  # one fake image
    pp = dis(fake_image)  # pseudo-prob
    if pp "less-than" 0.5:
      n_correct += 1      # discriminator knew it was fake
      n_wrong += 1        # dis thought it was a real image

    if verbose == True:

  return (n_correct * 1.0) / (n_correct + n_wrong)

I ran the experiment code on a GAN that creates synthetic ‘3’ digits based on the UCI Digits dataset. Each ‘3’ digit is a crude 8×8 grayscale image of a handwritten digit.

The results were quite satisfactory. The accuracy of the discriminator started out near 100% as expected because the generator hadn’t learned to make good fake images yet. As training continued, the accuracy of the discriminator slowly went down as the generator got better.

Anyway, very interesting and good fun.

Animated films are synthetic versions of reality. I like several stop motion animation films, including these three. Left: “Coraline” (2009) is sort of a dark, modern day Alice in Wonderland. Great story, great animation. Center: “Isle of Dogs” (2018) is a fantastically creative story that’s difficult to describe. An amazing film. Right: “James and the Giant Peach” (1996) is an adaptation of a story from the ultra-inventive mind of author Roald Dahl (1916-1990).

This entry was posted in PyTorch. Bookmark the permalink.

Leave a Reply

Please log in using one of these methods to post your comment: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s