Introduction to Generative Adversarial Networks (GANs): Types, and Applications, and Implementation

In this article, we’ll introduce the reader to Generative Adversarial Networks (GANs). We assume the reader has some prior experience with neural networks, such as artificial neural networks.

Here’s the plan of attack:

  1. Introduction to Generative Adversarial Networks
  2. The Discriminative Model
  3. The Generative Model
  4. How GANs Work
  5. Different types of GANs
  6. Applications of GANs
  7. Build a simple GAN in Keras

Introduction to Generative Adversarial Networks

GANs were introduced by Ian Goodfellow et al. in 2014. Yann LeCun described adversarial training as the coolest thing since sliced bread. GANs are neural networks that generate synthetic data given certain input data. For example, GANs can be taught how to generate images from text. Generative Adversarial Networks consists of two models; generative and discriminative.

The Discriminative Model

The discriminative model operates like a normal binary classifier that’s able to classify images into different categories. It determines whether an image is real and from a given dataset or is artificially generated.

The Generative Model

The discriminative model tries to predict certain classes given certain features. The generative model tries to predict features given classes. This involves determining the probability of a feature given a class.

How GANs Work

A GAN has two players: a generator and a discriminator. A generator generates new instances of an object while the discriminator determines whether the new instance belongs to the actual dataset.

Let’s say you have a dataset containing images of shoes and would like to generate ‘fake’ shoes. The role of the generator would be to generate the new shoes while the discriminator’s goal is to determine images coming from the generator as fake.

During the training process, weights and biases are adjusted through backpropagation until the discriminator learns to distinguish real images of shoes from fake images. The generator gets feedback from the discriminator and uses it to produce images that are more ‘real’. The discriminator network is a convolutional neural network that classifies the images as either fake or real. The generator produces new images through a de-convolutional neural network.

Different types of GANs

Deep Convolutional GANs (DCGANs)

DCGANs are an improvement of GANs. They are more stable and generate higher quality images. In DCGAN, batch normalization is done in both networks, i.e the generator network and the discriminator network. They can be used for style transfer. For example, you can use a dataset of handbags to generate shoes in the same style as the handbags.

Conditional GANs (cGANs)

These GANs use extra label information and result in better quality images and are able to control how generated images will look. cGANs learn to produce better images by exploiting the information fed to the model.


The authors of this paper propose a solution to the problem of synthesizing high-quality images from text descriptions in computer vision. They propose Stacked Generative Adversarial Networks (StackGAN) to generate 256×256 photo-realistic images conditioned on text descriptions. They decompose the hard problem into more manageable sub-problems through a sketch-refinement process.

The Stage-I GAN sketches the primitive shape and colors of the object based on the given text description, yielding Stage-I low-resolution images. The Stage-II GAN takes Stage-I results and text descriptions as inputs, and generates high-resolution images with photo-realistic details.


InfoGAN is an information-theoretic extension to the GAN that is able to learn disentangled representations in an unsupervised manner. InfoGANs are used when your dataset is very complex, when you’d like to train a cGAN and the dataset is not labelled, and when you’d like to see the most important features of your images.

Wasserstein GANs(WGAN)

WGANs change the loss function to include a Wasserstein distance. They have loss functions that correlate to image quality.

Discover Cross-Domain Relations with Generative Adversarial Networks(Disco GANS)

The authors of this paper propose a method based on generative adversarial networks that learns to discover relations between different domains. Using the discovered relations, the network transfers style from one domain to another. In the process, it preserves key attributes such as orientation and face identity.

There are many more types of GANs, but we won’t be able to cover all of them in this article. Let’s move on to some practical applications of GANs.

Applications of GANs

  • Predicting the next frame in a video
  • Increasing Resolution of an image

GANs can be fed low resolution images and produce high resolution images.

  • Text-to-Image Generation

Using a StackGAN, one can generate images from a text description. For example, a StackGAN can generate an image of a flying bird from a sentence describing this image and action.

  • Image to Image Translation

You can feed sketches of images to to GANs and they are able to generate the real image.

Building a simple GAN in Keras

Now that we have a proper understanding of GANs, let’s get our hands dirty by coding a simple GAN in Keras that’s going to generate digits.

We kick it off by importing Keras for building the model and Matplotlib for plotting the digits.

import keras
import matplotlib.pyplot as plt

Next we need to import a few packages that we’ll use to build our GAN:

  • Input used to instantiate a Keras tensor.
  • Sequential to initialize the neural network
  • Dense for adding more layers
  • LeakyReLU Leaky version of a Rectified Linear Unit.
  • The MNIST dataset
  • The Adam optimizer
  • Initializers to define the way to set the initial random weights of Keras layers.
  • tqdm, a package that enable us to visualize the training process using a progress bar
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
from tqdm import tqdm

In order to ensure that we get the same results, we set the same seed.

import numpy as np
Next we set the dimension of a random noise vector.
random_dim = 100

We shall use the popular MNIST dataset which contains images of digits from 0–9. It has a training set of 60,000 examples and a test set of 10,000 examples.

This dataset is available with Keras and can be imported as mnist.load_data().

We then need to normalize the inputs so that they can be between -1 and 1. After this we convert the training set from a three dimensional array to a two dimensional array.

def load_minst_data():
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_train = (x_train.astype(np.float32) - 127.5)/127.5
  x_train = x_train.reshape(60000, 784)
  return (x_train, y_train, x_test, y_test)

Now we create the generator and discriminator networks. We use the Adam optimizer in both networks. lr is the learning rate and beta_1 lies between 0 and 1 but is usually set closer to 1.

optimizer = Adam(lr=0.0002, beta_1=0.5)

We’ll create three hidden layers for each and use LeakyReLU as the rectifier function. If you get the error below, make sure you’re using the latest version of TensorFlow.

AttributeError: module ‘tensorflow.python.ops.nn’ has no attribute ‘leaky_relu’

We set input_dim to 784 because our training set is 6000 by 784. We set the kernel initializer as RandomNormal, which generates tensors with a normal distribution. We pass the stddev parameter which is a Python scalar and represents the standard deviation of the random values to generate. We also add dropout layers to fight overfitting.

generator = Sequential()
generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)

We set up the discriminator layer in the same way.

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)

Now that we have the discriminator and the generator, we need to combine them. Since we want to train one network at a time, we set trainable to False. We then declare ganInput variable which will take the random dimension declared earlier.

discriminator.trainable = False
ganInput = Input(shape=(random_dim,))

Next we declare the output of the generator, which will be an image. We then get the output of the discriminator, which is the probability of the image being real or not. Since it’s a binary classification we use the binary_crossentropy loss function.

x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

The next step is to create a wall of the generated images. We can do this using Matplotlib and by declaring a function. Since we’d converted the input data to 3D, we also need to reshape the generated images to 3D.

def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
  noise = np.random.normal(0, 1, size=[examples, random_dim])
  generated_images = generator.predict(noise)
  generated_images = generated_images.reshape(examples, 28, 28)
  for i in range(generated_images.shape[0]):
    plt.subplot(dim[0], dim[1], i+1)
    plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
  plt.savefig('image_generated_%d.png' % epoch)

Next we define a function to do the training and generate the images. This function will do several things:

  • Get the training and testing data by loading it from Keras
  • Split the data into bath sizes of 128
  • Use tqdm to show the progress of the training
  • Get a random set of input images
  • Generate fake MNIST images
  • Train the discriminator
  • Train the generator
  • Plot the generated images
def train(epochs=1, batch_size=128):
  x_train, y_train, x_test, y_test = load_minst_data()
  batch_count = x_train.shape[0] / batch_size
  for e in range(1, epochs+1):
    print('-'*10, 'Epoch %d' % e, '-'*10)
    for _ in tqdm(range(int(batch_count))):
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)
      generated_images = generator.predict(noise)
      X = np.concatenate([image_batch, generated_images])
      y_dis = np.zeros(2*batch_size)
      y_dis[:batch_size] = 0.9
      discriminator.trainable = True
      discriminator.train_on_batch(X, y_dis)
      noise = np.random.normal(0, 1, size=[batch_size, random_dim])
      y_gen = np.ones(batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y_gen)
    plot_generated_images(e, generator)

Next we call the function and pass the number of epochs we want and the batch sizes. The higher the number of epochs, the longer it’ll take to run.

train(1, 128)

Here’s a sample output obtained after running one epoch.

And here’s a sample output after doing 25 epochs. The one with more epochs is much clearer. The more epochs you choose the longer your model will take to run. More epochs also require more processing power.


Generative Adversarial Networks are new in the data science field. You can learn more about them by reading the research papers mentioned above. You can also read up on other neural networks such as convolutional neural networks and Artificial Neural Networks.

Avatar photo


Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square