在 CIFAR10 小型图像数据集上训练一个简单的 CNN-Capsule Network。

无数据增益的情况下:在 10 轮迭代后验证集准确率达到 75%,在 15 轮后达到 79%,在 20 轮后过拟合。

有数据增益情况下:在 10 轮迭代后验证集准确率达到 75%,在 15 轮后达到 79%,在 30 轮后达到 83%。在我的测试中,50 轮后最高的验证集准确率为 83.79%。

这是一个快速版的实现,在 GTX 1070 GPU 上迭代只需 20s/epoch。

  1. from __future__ import print_function
  2. from keras import backend as K
  3. from keras.layers import Layer
  4. from keras import activations
  5. from keras import utils
  6. from keras.datasets import cifar10
  7. from keras.models import Model
  8. from keras.layers import *
  9. from keras.preprocessing.image import ImageDataGenerator
  10. # 挤压函数
  11. # 我们在此使用 0.5,而不是 Hinton 论文中给出的 1
  12. # 如果为 1,则向量的范数将被缩小。
  13. # 如果为 0.5,则当原始范数小于 0.5 时,范数将被放大,
  14. # 当原始范数大于 0.5 时,范数将被缩小。
  15. def squash(x, axis=-1):
  16. s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
  17. scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm)
  18. return scale * x
  19. # 使用自定义的 softmax 函数,而非 K.softmax,
  20. # 因为 K.softmax 不能指定轴。
  21. def softmax(x, axis=-1):
  22. ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
  23. return ex / K.sum(ex, axis=axis, keepdims=True)
  24. # 定义 margin loss,类似于 hinge loss
  25. def margin_loss(y_true, y_pred):
  26. lamb, margin = 0.5, 0.1
  27. return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
  28. 1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)
  29. class Capsule(Layer):
  30. """一个由纯 Keras 实现的 Capsule 网络。
  31. 总共有两个版本的 Capsule。
  32. 一个类似于全连接层 (用于固定尺寸的输入),
  33. 另一个类似于时序分布全连接层 (用于变成输入)。
  34. Capsure 的输入尺寸必须为 (batch_size,
  35. input_num_capsule,
  36. input_dim_capsule
  37. )
  38. 以及输出尺寸必须为 (batch_size,
  39. num_capsule,
  40. dim_capsule
  41. )
  42. Capsule 实现来自于 https://github.com/bojone/Capsule/
  43. Capsule 论文: https://arxiv.org/abs/1710.09829
  44. """
  45. def __init__(self,
  46. num_capsule,
  47. dim_capsule,
  48. routings=3,
  49. share_weights=True,
  50. activation='squash',
  51. **kwargs):
  52. super(Capsule, self).__init__(**kwargs)
  53. self.num_capsule = num_capsule
  54. self.dim_capsule = dim_capsule
  55. self.routings = routings
  56. self.share_weights = share_weights
  57. if activation == 'squash':
  58. self.activation = squash
  59. else:
  60. self.activation = activations.get(activation)
  61. def build(self, input_shape):
  62. input_dim_capsule = input_shape[-1]
  63. if self.share_weights:
  64. self.kernel = self.add_weight(
  65. name='capsule_kernel',
  66. shape=(1, input_dim_capsule,
  67. self.num_capsule * self.dim_capsule),
  68. initializer='glorot_uniform',
  69. trainable=True)
  70. else:
  71. input_num_capsule = input_shape[-2]
  72. self.kernel = self.add_weight(
  73. name='capsule_kernel',
  74. shape=(input_num_capsule, input_dim_capsule,
  75. self.num_capsule * self.dim_capsule),
  76. initializer='glorot_uniform',
  77. trainable=True)
  78. def call(self, inputs):
  79. """遵循 Hinton 论文中的路由算法,
  80. 但是将 b = b + <u,v> 替换为 b = <u,v>。
  81. 这一改变将提升 Capsule 的特征表示能力。
  82. 然而,你仍可以将
  83. b = K.batch_dot(outputs, hat_inputs, [2, 3])
  84. 替换为
  85. b += K.batch_dot(outputs, hat_inputs, [2, 3])
  86. 来实现一个标准的路由。
  87. """
  88. if self.share_weights:
  89. hat_inputs = K.conv1d(inputs, self.kernel)
  90. else:
  91. hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1])
  92. batch_size = K.shape(inputs)[0]
  93. input_num_capsule = K.shape(inputs)[1]
  94. hat_inputs = K.reshape(hat_inputs,
  95. (batch_size, input_num_capsule,
  96. self.num_capsule, self.dim_capsule))
  97. hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3))
  98. b = K.zeros_like(hat_inputs[:, :, :, 0])
  99. for i in range(self.routings):
  100. c = softmax(b, 1)
  101. o = self.activation(K.batch_dot(c, hat_inputs, [2, 2]))
  102. if i < self.routings - 1:
  103. b = K.batch_dot(o, hat_inputs, [2, 3])
  104. if K.backend() == 'theano':
  105. o = K.sum(o, axis=1)
  106. return o
  107. def compute_output_shape(self, input_shape):
  108. return (None, self.num_capsule, self.dim_capsule)
  109. batch_size = 128
  110. num_classes = 10
  111. epochs = 100
  112. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  113. x_train = x_train.astype('float32')
  114. x_test = x_test.astype('float32')
  115. x_train /= 255
  116. x_test /= 255
  117. y_train = utils.to_categorical(y_train, num_classes)
  118. y_test = utils.to_categorical(y_test, num_classes)
  119. # 一个常规的 Conv2D 模型
  120. input_image = Input(shape=(None, None, 3))
  121. x = Conv2D(64, (3, 3), activation='relu')(input_image)
  122. x = Conv2D(64, (3, 3), activation='relu')(x)
  123. x = AveragePooling2D((2, 2))(x)
  124. x = Conv2D(128, (3, 3), activation='relu')(x)
  125. x = Conv2D(128, (3, 3), activation='relu')(x)
  126. """现在我们将其尺寸重新调整为 (batch_size, input_num_capsule, input_dim_capsule),再连接一个 Capsule 网络。
  127. 最终模型的输出为长度为 10 的 Capsure,其 dim=16。
  128. Capsule 的长度表示为 proba,
  129. 因此问题变成了一个『10个二分类』的问题。
  130. """
  131. x = Reshape((-1, 128))(x)
  132. capsule = Capsule(10, 16, 3, True)(x)
  133. output = Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)))(capsule)
  134. model = Model(inputs=input_image, outputs=output)
  135. # 使用 margin loss
  136. model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy'])
  137. model.summary()
  138. # 可以比较有无数据增益对应的性能
  139. data_augmentation = True
  140. if not data_augmentation:
  141. print('Not using data augmentation.')
  142. model.fit(
  143. x_train,
  144. y_train,
  145. batch_size=batch_size,
  146. epochs=epochs,
  147. validation_data=(x_test, y_test),
  148. shuffle=True)
  149. else:
  150. print('Using real-time data augmentation.')
  151. # 这一步将进行数据处理和实时数据增益:
  152. datagen = ImageDataGenerator(
  153. featurewise_center=False, # 将整个数据集的均值设为 0
  154. samplewise_center=False, # 将每个样本的均值设为 0
  155. featurewise_std_normalization=False, # 将输入除以整个数据集的标准差
  156. samplewise_std_normalization=False, # 将输入除以其标准差
  157. zca_whitening=False, # 运用 ZCA 白化
  158. zca_epsilon=1e-06, # ZCA 白化的 epsilon值
  159. rotation_range=0, # 随机旋转图像范围 (角度, 0 to 180)
  160. width_shift_range=0.1, # 随机水平移动图像 (总宽度的百分比)
  161. height_shift_range=0.1, # 随机垂直移动图像 (总高度的百分比)
  162. shear_range=0., # 设置随机裁剪范围
  163. zoom_range=0., # 设置随机放大范围
  164. channel_shift_range=0., # 设置随机通道切换的范围
  165. # 设置填充输入边界之外的点的模式
  166. fill_mode='nearest',
  167. cval=0., # 在 fill_mode = "constant" 时使用的值
  168. horizontal_flip=True, # 随机水平翻转图像
  169. vertical_flip=False, # 随机垂直翻转图像
  170. # 设置缩放因子 (在其他转换之前使用)
  171. rescale=None,
  172. # 设置将应用于每一个输入的函数
  173. preprocessing_function=None,
  174. # 图像数据格式,"channels_first" 或 "channels_last" 之一
  175. data_format=None,
  176. # 保留用于验证的图像比例(严格在0和1之间)
  177. validation_split=0.0)
  178. # 计算特征标准化所需的计算量
  179. # (如果应用 ZCA 白化,则为 std,mean和主成分)。
  180. datagen.fit(x_train)
  181. # 利用由 datagen.flow() 生成的批来训练模型。
  182. model.fit_generator(
  183. datagen.flow(x_train, y_train, batch_size=batch_size),
  184. epochs=epochs,
  185. validation_data=(x_test, y_test),
  186. workers=4)