Train an Auxiliary Classifier GAN (ACGAN) on the MNIST dataset.

More details on Auxiliary Classifier GANs.

You should start to see reasonable images after ~5 epochs, and good imagesby ~15 epochs. You should use a GPU, as the convolution-heavy operations arevery slow on the CPU. Prefer the TensorFlow backend if you plan on iterating,as the compilation time can be a blocker using Theano.

Timings:

HardwareBackendTime / Epoch
CPUTF3 hrs
Titan X (maxwell)TF4 min
Titan X (maxwell)TH7 min

Consult Auxiliary Classifier Generative Adversarial Networks in Keras for more information and example output.

  1. from __future__ import print_function
  2. from collections import defaultdict
  3. try:
  4. import cPickle as pickle
  5. except ImportError:
  6. import pickle
  7. from PIL import Image
  8. from six.moves import range
  9. from keras.datasets import mnist
  10. from keras import layers
  11. from keras.layers import Input, Dense, Reshape, Flatten, Embedding, Dropout
  12. from keras.layers import BatchNormalization
  13. from keras.layers.advanced_activations import LeakyReLU
  14. from keras.layers.convolutional import Conv2DTranspose, Conv2D
  15. from keras.models import Sequential, Model
  16. from keras.optimizers import Adam
  17. from keras.utils.generic_utils import Progbar
  18. import numpy as np
  19. np.random.seed(1337)
  20. num_classes = 10
  21. def build_generator(latent_size):
  22. # we will map a pair of (z, L), where z is a latent vector and L is a
  23. # label drawn from P_c, to image space (..., 28, 28, 1)
  24. cnn = Sequential()
  25. cnn.add(Dense(3 * 3 * 384, input_dim=latent_size, activation='relu'))
  26. cnn.add(Reshape((3, 3, 384)))
  27. # upsample to (7, 7, ...)
  28. cnn.add(Conv2DTranspose(192, 5, strides=1, padding='valid',
  29. activation='relu',
  30. kernel_initializer='glorot_normal'))
  31. cnn.add(BatchNormalization())
  32. # upsample to (14, 14, ...)
  33. cnn.add(Conv2DTranspose(96, 5, strides=2, padding='same',
  34. activation='relu',
  35. kernel_initializer='glorot_normal'))
  36. cnn.add(BatchNormalization())
  37. # upsample to (28, 28, ...)
  38. cnn.add(Conv2DTranspose(1, 5, strides=2, padding='same',
  39. activation='tanh',
  40. kernel_initializer='glorot_normal'))
  41. # this is the z space commonly referred to in GAN papers
  42. latent = Input(shape=(latent_size, ))
  43. # this will be our label
  44. image_class = Input(shape=(1,), dtype='int32')
  45. cls = Embedding(num_classes, latent_size,
  46. embeddings_initializer='glorot_normal')(image_class)
  47. # hadamard product between z-space and a class conditional embedding
  48. h = layers.multiply([latent, cls])
  49. fake_image = cnn(h)
  50. return Model([latent, image_class], fake_image)
  51. def build_discriminator():
  52. # build a relatively standard conv net, with LeakyReLUs as suggested in
  53. # the reference paper
  54. cnn = Sequential()
  55. cnn.add(Conv2D(32, 3, padding='same', strides=2,
  56. input_shape=(28, 28, 1)))
  57. cnn.add(LeakyReLU(0.2))
  58. cnn.add(Dropout(0.3))
  59. cnn.add(Conv2D(64, 3, padding='same', strides=1))
  60. cnn.add(LeakyReLU(0.2))
  61. cnn.add(Dropout(0.3))
  62. cnn.add(Conv2D(128, 3, padding='same', strides=2))
  63. cnn.add(LeakyReLU(0.2))
  64. cnn.add(Dropout(0.3))
  65. cnn.add(Conv2D(256, 3, padding='same', strides=1))
  66. cnn.add(LeakyReLU(0.2))
  67. cnn.add(Dropout(0.3))
  68. cnn.add(Flatten())
  69. image = Input(shape=(28, 28, 1))
  70. features = cnn(image)
  71. # first output (name=generation) is whether or not the discriminator
  72. # thinks the image that is being shown is fake, and the second output
  73. # (name=auxiliary) is the class that the discriminator thinks the image
  74. # belongs to.
  75. fake = Dense(1, activation='sigmoid', name='generation')(features)
  76. aux = Dense(num_classes, activation='softmax', name='auxiliary')(features)
  77. return Model(image, [fake, aux])
  78. if __name__ == '__main__':
  79. # batch and latent size taken from the paper
  80. epochs = 100
  81. batch_size = 100
  82. latent_size = 100
  83. # Adam parameters suggested in https://arxiv.org/abs/1511.06434
  84. adam_lr = 0.0002
  85. adam_beta_1 = 0.5
  86. # build the discriminator
  87. print('Discriminator model:')
  88. discriminator = build_discriminator()
  89. discriminator.compile(
  90. optimizer=Adam(learning_rate=adam_lr, beta_1=adam_beta_1),
  91. loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
  92. )
  93. discriminator.summary()
  94. # build the generator
  95. generator = build_generator(latent_size)
  96. latent = Input(shape=(latent_size, ))
  97. image_class = Input(shape=(1,), dtype='int32')
  98. # get a fake image
  99. fake = generator([latent, image_class])
  100. # we only want to be able to train generation for the combined model
  101. discriminator.trainable = False
  102. fake, aux = discriminator(fake)
  103. combined = Model([latent, image_class], [fake, aux])
  104. print('Combined model:')
  105. combined.compile(
  106. optimizer=Adam(learning_rate=adam_lr, beta_1=adam_beta_1),
  107. loss=['binary_crossentropy', 'sparse_categorical_crossentropy']
  108. )
  109. combined.summary()
  110. # get our mnist data, and force it to be of shape (..., 28, 28, 1) with
  111. # range [-1, 1]
  112. (x_train, y_train), (x_test, y_test) = mnist.load_data()
  113. x_train = (x_train.astype(np.float32) - 127.5) / 127.5
  114. x_train = np.expand_dims(x_train, axis=-1)
  115. x_test = (x_test.astype(np.float32) - 127.5) / 127.5
  116. x_test = np.expand_dims(x_test, axis=-1)
  117. num_train, num_test = x_train.shape[0], x_test.shape[0]
  118. train_history = defaultdict(list)
  119. test_history = defaultdict(list)
  120. for epoch in range(1, epochs + 1):
  121. print('Epoch {}/{}'.format(epoch, epochs))
  122. num_batches = int(np.ceil(x_train.shape[0] / float(batch_size)))
  123. progress_bar = Progbar(target=num_batches)
  124. epoch_gen_loss = []
  125. epoch_disc_loss = []
  126. for index in range(num_batches):
  127. # get a batch of real images
  128. image_batch = x_train[index * batch_size:(index + 1) * batch_size]
  129. label_batch = y_train[index * batch_size:(index + 1) * batch_size]
  130. # generate a new batch of noise
  131. noise = np.random.uniform(-1, 1, (len(image_batch), latent_size))
  132. # sample some labels from p_c
  133. sampled_labels = np.random.randint(0, num_classes, len(image_batch))
  134. # generate a batch of fake images, using the generated labels as a
  135. # conditioner. We reshape the sampled labels to be
  136. # (len(image_batch), 1) so that we can feed them into the embedding
  137. # layer as a length one sequence
  138. generated_images = generator.predict(
  139. [noise, sampled_labels.reshape((-1, 1))], verbose=0)
  140. x = np.concatenate((image_batch, generated_images))
  141. # use one-sided soft real/fake labels
  142. # Salimans et al., 2016
  143. # https://arxiv.org/pdf/1606.03498.pdf (Section 3.4)
  144. soft_zero, soft_one = 0, 0.95
  145. y = np.array(
  146. [soft_one] * len(image_batch) + [soft_zero] * len(image_batch))
  147. aux_y = np.concatenate((label_batch, sampled_labels), axis=0)
  148. # we don't want the discriminator to also maximize the classification
  149. # accuracy of the auxiliary classifier on generated images, so we
  150. # don't train discriminator to produce class labels for generated
  151. # images (see https://openreview.net/forum?id=rJXTf9Bxg).
  152. # To preserve sum of sample weights for the auxiliary classifier,
  153. # we assign sample weight of 2 to the real images.
  154. disc_sample_weight = [np.ones(2 * len(image_batch)),
  155. np.concatenate((np.ones(len(image_batch)) * 2,
  156. np.zeros(len(image_batch))))]
  157. # see if the discriminator can figure itself out...
  158. epoch_disc_loss.append(discriminator.train_on_batch(
  159. x, [y, aux_y], sample_weight=disc_sample_weight))
  160. # make new noise. we generate 2 * batch size here such that we have
  161. # the generator optimize over an identical number of images as the
  162. # discriminator
  163. noise = np.random.uniform(-1, 1, (2 * len(image_batch), latent_size))
  164. sampled_labels = np.random.randint(0, num_classes, 2 * len(image_batch))
  165. # we want to train the generator to trick the discriminator
  166. # For the generator, we want all the {fake, not-fake} labels to say
  167. # not-fake
  168. trick = np.ones(2 * len(image_batch)) * soft_one
  169. epoch_gen_loss.append(combined.train_on_batch(
  170. [noise, sampled_labels.reshape((-1, 1))],
  171. [trick, sampled_labels]))
  172. progress_bar.update(index + 1)
  173. print('Testing for epoch {}:'.format(epoch))
  174. # evaluate the testing loss here
  175. # generate a new batch of noise
  176. noise = np.random.uniform(-1, 1, (num_test, latent_size))
  177. # sample some labels from p_c and generate images from them
  178. sampled_labels = np.random.randint(0, num_classes, num_test)
  179. generated_images = generator.predict(
  180. [noise, sampled_labels.reshape((-1, 1))], verbose=False)
  181. x = np.concatenate((x_test, generated_images))
  182. y = np.array([1] * num_test + [0] * num_test)
  183. aux_y = np.concatenate((y_test, sampled_labels), axis=0)
  184. # see if the discriminator can figure itself out...
  185. discriminator_test_loss = discriminator.evaluate(
  186. x, [y, aux_y], verbose=False)
  187. discriminator_train_loss = np.mean(np.array(epoch_disc_loss), axis=0)
  188. # make new noise
  189. noise = np.random.uniform(-1, 1, (2 * num_test, latent_size))
  190. sampled_labels = np.random.randint(0, num_classes, 2 * num_test)
  191. trick = np.ones(2 * num_test)
  192. generator_test_loss = combined.evaluate(
  193. [noise, sampled_labels.reshape((-1, 1))],
  194. [trick, sampled_labels], verbose=False)
  195. generator_train_loss = np.mean(np.array(epoch_gen_loss), axis=0)
  196. # generate an epoch report on performance
  197. train_history['generator'].append(generator_train_loss)
  198. train_history['discriminator'].append(discriminator_train_loss)
  199. test_history['generator'].append(generator_test_loss)
  200. test_history['discriminator'].append(discriminator_test_loss)
  201. print('{0:<22s} | {1:4s} | {2:15s} | {3:5s}'.format(
  202. 'component', *discriminator.metrics_names))
  203. print('-' * 65)
  204. ROW_FMT = '{0:<22s} | {1:<4.2f} | {2:<15.4f} | {3:<5.4f}'
  205. print(ROW_FMT.format('generator (train)',
  206. *train_history['generator'][-1]))
  207. print(ROW_FMT.format('generator (test)',
  208. *test_history['generator'][-1]))
  209. print(ROW_FMT.format('discriminator (train)',
  210. *train_history['discriminator'][-1]))
  211. print(ROW_FMT.format('discriminator (test)',
  212. *test_history['discriminator'][-1]))
  213. # save weights every epoch
  214. generator.save_weights(
  215. 'params_generator_epoch_{0:03d}.hdf5'.format(epoch), True)
  216. discriminator.save_weights(
  217. 'params_discriminator_epoch_{0:03d}.hdf5'.format(epoch), True)
  218. # generate some digits to display
  219. num_rows = 40
  220. noise = np.tile(np.random.uniform(-1, 1, (num_rows, latent_size)),
  221. (num_classes, 1))
  222. sampled_labels = np.array([
  223. [i] * num_rows for i in range(num_classes)
  224. ]).reshape(-1, 1)
  225. # get a batch to display
  226. generated_images = generator.predict(
  227. [noise, sampled_labels], verbose=0)
  228. # prepare real images sorted by class label
  229. real_labels = y_train[(epoch - 1) * num_rows * num_classes:
  230. epoch * num_rows * num_classes]
  231. indices = np.argsort(real_labels, axis=0)
  232. real_images = x_train[(epoch - 1) * num_rows * num_classes:
  233. epoch * num_rows * num_classes][indices]
  234. # display generated images, white separator, real images
  235. img = np.concatenate(
  236. (generated_images,
  237. np.repeat(np.ones_like(x_train[:1]), num_rows, axis=0),
  238. real_images))
  239. # arrange them into a grid
  240. img = (np.concatenate([r.reshape(-1, 28)
  241. for r in np.split(img, 2 * num_classes + 1)
  242. ], axis=-1) * 127.5 + 127.5).astype(np.uint8)
  243. Image.fromarray(img).save(
  244. 'plot_epoch_{0:03d}_generated.png'.format(epoch))
  245. with open('acgan-history.pkl', 'wb') as f:
  246. pickle.dump({'train': train_history, 'test': test_history}, f)