Generating cats with a WGAN
6/30/2019
Introduction
The following images contain one fake cat (generated by a neural network) in each row. The rest of the cats below come from the training data. Can you identify the fake cats? The answer is at the end of this article!
GANs
The six fake cats above were produced by a generative adversarial network, or GAN for short (code available on my GitHub). GANs, introduced by Goodfellow et al. consist of two networks which together learn to mimick the distribution of whatever data is given to them (such as cat pictures, for example). The first network, called the generator, takes in a random vector as input and attempts to produce an output which looks like it belongs to a given dataset. The second network, called the discriminator, takes in a mix of “real” data from the dataset, and “fake” data which is generated by the generator, and attempts to distinguish the real from the fake. The networks compete to outdo each other, and in turn, they each force each other to continue to improve.
WGANs
Training GANs is notoriously difficult, and a number of improvements have been suggested since they were first introduced. One of the more impactful is the Wasserstein GAN (or WGAN for short), introduced by Arjovsky et al. In a WGAN, the discriminator is replaced by a network called the critic. Rather than outputing the probability of the data being real, the critic outputs a real number which represents the quality of the image. The critic's goal is to output a high number when fed real data, and a low (negative) number when fed fake data. I won't go into the details of GANs or WGANs in this article, but these can be found in the linked papers.
Deep cats
I trained a WGAN to generate pictures of cats using the cats dataset, which consists of over 9000 images of cats (!) of various pixel resolutions, with facial features labeled. A preprocessing step (thanks to Jolicoeur-Martineau) is applied to crop the cats' faces and throw out all resulting images of size less than 128x128. The remaining images are then resized to 128x128. Unfortunately the processed dataset contains fewer than 9000 images of cats (!).
The network architecture and training method used follows Karras et al. In short, the generator and discriminator are convolutional networks with two convolutional layers per image resolution, and are mirror images of one another, with the generator using convolution transpose layers where the discriminator uses convolutional layers. Training is performed on (resized) low resolution images first until convergence, and then upscaled by a factor of two, repeating until the desired resolution is achieved. In particular, the networks start with 4x4 images, then 8x8, 16x16, 32x32, 64x64, and finally 128x128. During upsampling, higher resolution images are faded in slowly by using a residual block structure. For more precise details, see the reference paper, which provides a very clear explanation.
Results
The WGAN I trained was able to produce some very convincing images of cats. Here are some of the better images it produced.
Unfortunately, not all the images produced by the network are high quality. Some are decent quality images with flaws, while others are downright nonsense. Here are some examples of not so great output (some of these are pretty funny). While the majority of the output images are of good quality, it would be desirable to improve the average quality.
To make sure that the network is really producing novel images rather than just “memorizing” the data, a common test is to take some of the generated images and find their nearest neighbors from the real data. The following images show cats generated by the data (left column) each with their five nearest neighbors in the training data (seen below along the rows) according to pixel-wise L2-norm. You might recognize the cats below from the introduction, as these are the same images of cats reordered. Thus this answers the question of which cats from the introduction are fake.
As you can see, all the generated cats look original. I should note here that using pixel-wise distance is not the best method of comparison since, for example, reflecting an image horizontally across its mid-line might result in an image which is far in pixel-wise distance, but to a human is not a fundamentally different image. A better method is to use a pretrained neural network which has achieved high performance on a computer vision task such as image classification (ideally involving cats' faces) and computing nearest neighbors using the activations of various layers of the pretrained network. However, I stuck with pixel-wise distance in this analysis for simplicity.
Cat morphs
As a cool visualization, the generator can be used to produce “cat morphs”, which are created by choosing a collection of random input vectors, linearly interpolating between them, and then passing through the generator. Below is an example of the kind of morphs the network can create.