Train a simple deep CNN on the CIFAR10 small images dataset.

It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs.(it's still underfitting at that point, though).

  1. from __future__ import print_function
  2. import keras
  3. from keras.datasets import cifar10
  4. from keras.preprocessing.image import ImageDataGenerator
  5. from keras.models import Sequential
  6. from keras.layers import Dense, Dropout, Activation, Flatten
  7. from keras.layers import Conv2D, MaxPooling2D
  8. import os
  9. batch_size = 32
  10. num_classes = 10
  11. epochs = 100
  12. data_augmentation = True
  13. num_predictions = 20
  14. save_dir = os.path.join(os.getcwd(), 'saved_models')
  15. model_name = 'keras_cifar10_trained_model.h5'
  16. # The data, split between train and test sets:
  17. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  18. print('x_train shape:', x_train.shape)
  19. print(x_train.shape[0], 'train samples')
  20. print(x_test.shape[0], 'test samples')
  21. # Convert class vectors to binary class matrices.
  22. y_train = keras.utils.to_categorical(y_train, num_classes)
  23. y_test = keras.utils.to_categorical(y_test, num_classes)
  24. model = Sequential()
  25. model.add(Conv2D(32, (3, 3), padding='same',
  26. input_shape=x_train.shape[1:]))
  27. model.add(Activation('relu'))
  28. model.add(Conv2D(32, (3, 3)))
  29. model.add(Activation('relu'))
  30. model.add(MaxPooling2D(pool_size=(2, 2)))
  31. model.add(Dropout(0.25))
  32. model.add(Conv2D(64, (3, 3), padding='same'))
  33. model.add(Activation('relu'))
  34. model.add(Conv2D(64, (3, 3)))
  35. model.add(Activation('relu'))
  36. model.add(MaxPooling2D(pool_size=(2, 2)))
  37. model.add(Dropout(0.25))
  38. model.add(Flatten())
  39. model.add(Dense(512))
  40. model.add(Activation('relu'))
  41. model.add(Dropout(0.5))
  42. model.add(Dense(num_classes))
  43. model.add(Activation('softmax'))
  44. # initiate RMSprop optimizer
  45. opt = keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6)
  46. # Let's train the model using RMSprop
  47. model.compile(loss='categorical_crossentropy',
  48. optimizer=opt,
  49. metrics=['accuracy'])
  50. x_train = x_train.astype('float32')
  51. x_test = x_test.astype('float32')
  52. x_train /= 255
  53. x_test /= 255
  54. if not data_augmentation:
  55. print('Not using data augmentation.')
  56. model.fit(x_train, y_train,
  57. batch_size=batch_size,
  58. epochs=epochs,
  59. validation_data=(x_test, y_test),
  60. shuffle=True)
  61. else:
  62. print('Using real-time data augmentation.')
  63. # This will do preprocessing and realtime data augmentation:
  64. datagen = ImageDataGenerator(
  65. featurewise_center=False, # set input mean to 0 over the dataset
  66. samplewise_center=False, # set each sample mean to 0
  67. featurewise_std_normalization=False, # divide inputs by std of the dataset
  68. samplewise_std_normalization=False, # divide each input by its std
  69. zca_whitening=False, # apply ZCA whitening
  70. zca_epsilon=1e-06, # epsilon for ZCA whitening
  71. rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
  72. # randomly shift images horizontally (fraction of total width)
  73. width_shift_range=0.1,
  74. # randomly shift images vertically (fraction of total height)
  75. height_shift_range=0.1,
  76. shear_range=0., # set range for random shear
  77. zoom_range=0., # set range for random zoom
  78. channel_shift_range=0., # set range for random channel shifts
  79. # set mode for filling points outside the input boundaries
  80. fill_mode='nearest',
  81. cval=0., # value used for fill_mode = "constant"
  82. horizontal_flip=True, # randomly flip images
  83. vertical_flip=False, # randomly flip images
  84. # set rescaling factor (applied before any other transformation)
  85. rescale=None,
  86. # set function that will be applied on each input
  87. preprocessing_function=None,
  88. # image data format, either "channels_first" or "channels_last"
  89. data_format=None,
  90. # fraction of images reserved for validation (strictly between 0 and 1)
  91. validation_split=0.0)
  92. # Compute quantities required for feature-wise normalization
  93. # (std, mean, and principal components if ZCA whitening is applied).
  94. datagen.fit(x_train)
  95. # Fit the model on the batches generated by datagen.flow().
  96. model.fit_generator(datagen.flow(x_train, y_train,
  97. batch_size=batch_size),
  98. epochs=epochs,
  99. validation_data=(x_test, y_test),
  100. workers=4)
  101. # Save model and weights
  102. if not os.path.isdir(save_dir):
  103. os.makedirs(save_dir)
  104. model_path = os.path.join(save_dir, model_name)
  105. model.save(model_path)
  106. print('Saved trained model at %s ' % model_path)
  107. # Score trained model.
  108. scores = model.evaluate(x_test, y_test, verbose=1)
  109. print('Test loss:', scores[0])
  110. print('Test accuracy:', scores[1])