A generative adversarial network (GAN) is a deep neural system that is designed to generate fake/synthetic data items. A GAN has a clever architecture made of two neural networks: a generator that creates fake data items, and a discriminator that classifies a data item as fake (0) or real (1). GANs are most often used to generate synthetic images, but GANs can generate any kind of data.
Training a GAN is quite difficult. There are twice as many hyperparameters to deal with (number of hidden layers, number of nodes in each layer, activation function, batch size, optimization algorithm, learing rate, batch size, loss function and so on) as there are for a regular neural network.
Knowing when to stop training a regular neural network is difficult and usually involves looking at the value of the loss function during training. I’ve been experimenting with GANs and wondered how to know when to stop training. Looking at the loss values of the generator and the discriminator won’t really work because the generator is constantly trying to create fake data to fool the discriminator, while at the same time the discriminator is learning how to tell fake data from real data.
I did a simple thought experiment and figured that if the discriminator was able to distinguish fake data items from real data items with about 50% accuracy, then the generator is doing a good job of creating fake data items.
So I coded up an experiment with a function that computes the prediction accuracy of the discriminator, based on n fake data items produced by the generator. I used PyTorch, my neural code library of choice but the same ideas can be used in Keras or TensorFlow. In pseudo-code:
loop n times use generator to create a fake data item feed fake image to discriminator, get result p if p "less-than" 0.5 then num correct += 1 # determined it was fake else num wrong += 1 # thought it was real end-loop return n_correct / (n_correct + n_wrong)
The code implementation is:
def Accuracy(gen, dis, n, verbose=False): # accuracy of discriminator on n fake images from generator gen.eval() dis.eval() n_correct = 0; n_wrong = 0 for i in range(n): zz = T.normal(0.0, 1.0, size=(1, gen.inpt_dim)).to(device) # 20 values fake_image = gen(zz) # one fake image pp = dis(fake_image) # pseudo-prob if pp "less-than" 0.5: n_correct += 1 # discriminator knew it was fake else: n_wrong += 1 # dis thought it was a real image if verbose == True: print("") print(fake_image) print(pp) input() return (n_correct * 1.0) / (n_correct + n_wrong)
I ran the experiment code on a GAN that creates synthetic ‘3’ digits based on the UCI Digits dataset. Each ‘3’ digit is a crude 8×8 grayscale image of a handwritten digit.
The results were quite satisfactory. The accuracy of the discriminator started out near 100% as expected because the generator hadn’t learned to make good fake images yet. As training continued, the accuracy of the discriminator slowly went down as the generator got better.
Anyway, very interesting and good fun.
Animated films are synthetic versions of reality. I like several stop motion animation films, including these three. Left: “Coraline” (2009) is sort of a dark, modern day Alice in Wonderland. Great story, great animation. Center: “Isle of Dogs” (2018) is a fantastically creative story that’s difficult to describe. An amazing film. Right: “James and the Giant Peach” (1996) is an adaptation of a story from the ultra-inventive mind of author Roald Dahl (1916-1990).