Visualization of the filters of VGG16, via gradient ascent in input space.

This script can run on CPU in a few minutes.

Results example: Visualization

  1. from __future__ import print_function
  2. import time
  3. import numpy as np
  4. from PIL import Image as pil_image
  5. from keras.preprocessing.image import save_img
  6. from keras import layers
  7. from keras.applications import vgg16
  8. from keras import backend as K
  9. def normalize(x):
  10. """utility function to normalize a tensor.
  11. # Arguments
  12. x: An input tensor.
  13. # Returns
  14. The normalized input tensor.
  15. """
  16. return x / (K.sqrt(K.mean(K.square(x))) + K.epsilon())
  17. def deprocess_image(x):
  18. """utility function to convert a float array into a valid uint8 image.
  19. # Arguments
  20. x: A numpy-array representing the generated image.
  21. # Returns
  22. A processed numpy-array, which could be used in e.g. imshow.
  23. """
  24. # normalize tensor: center on 0., ensure std is 0.25
  25. x -= x.mean()
  26. x /= (x.std() + K.epsilon())
  27. x *= 0.25
  28. # clip to [0, 1]
  29. x += 0.5
  30. x = np.clip(x, 0, 1)
  31. # convert to RGB array
  32. x *= 255
  33. if K.image_data_format() == 'channels_first':
  34. x = x.transpose((1, 2, 0))
  35. x = np.clip(x, 0, 255).astype('uint8')
  36. return x
  37. def process_image(x, former):
  38. """utility function to convert a valid uint8 image back into a float array.
  39. Reverses `deprocess_image`.
  40. # Arguments
  41. x: A numpy-array, which could be used in e.g. imshow.
  42. former: The former numpy-array.
  43. Need to determine the former mean and variance.
  44. # Returns
  45. A processed numpy-array representing the generated image.
  46. """
  47. if K.image_data_format() == 'channels_first':
  48. x = x.transpose((2, 0, 1))
  49. return (x / 255 - 0.5) * 4 * former.std() + former.mean()
  50. def visualize_layer(model,
  51. layer_name,
  52. step=1.,
  53. epochs=15,
  54. upscaling_steps=9,
  55. upscaling_factor=1.2,
  56. output_dim=(412, 412),
  57. filter_range=(0, None)):
  58. """Visualizes the most relevant filters of one conv-layer in a certain model.
  59. # Arguments
  60. model: The model containing layer_name.
  61. layer_name: The name of the layer to be visualized.
  62. Has to be a part of model.
  63. step: step size for gradient ascent.
  64. epochs: Number of iterations for gradient ascent.
  65. upscaling_steps: Number of upscaling steps.
  66. Starting image is in this case (80, 80).
  67. upscaling_factor: Factor to which to slowly upgrade
  68. the image towards output_dim.
  69. output_dim: [img_width, img_height] The output image dimensions.
  70. filter_range: Tupel[lower, upper]
  71. Determines the to be computed filter numbers.
  72. If the second value is `None`,
  73. the last filter will be inferred as the upper boundary.
  74. """
  75. def _generate_filter_image(input_img,
  76. layer_output,
  77. filter_index):
  78. """Generates image for one particular filter.
  79. # Arguments
  80. input_img: The input-image Tensor.
  81. layer_output: The output-image Tensor.
  82. filter_index: The to be processed filter number.
  83. Assumed to be valid.
  84. #Returns
  85. Either None if no image could be generated.
  86. or a tuple of the image (array) itself and the last loss.
  87. """
  88. s_time = time.time()
  89. # we build a loss function that maximizes the activation
  90. # of the nth filter of the layer considered
  91. if K.image_data_format() == 'channels_first':
  92. loss = K.mean(layer_output[:, filter_index, :, :])
  93. else:
  94. loss = K.mean(layer_output[:, :, :, filter_index])
  95. # we compute the gradient of the input picture wrt this loss
  96. grads = K.gradients(loss, input_img)[0]
  97. # normalization trick: we normalize the gradient
  98. grads = normalize(grads)
  99. # this function returns the loss and grads given the input picture
  100. iterate = K.function([input_img], [loss, grads])
  101. # we start from a gray image with some random noise
  102. intermediate_dim = tuple(
  103. int(x / (upscaling_factor ** upscaling_steps)) for x in output_dim)
  104. if K.image_data_format() == 'channels_first':
  105. input_img_data = np.random.random(
  106. (1, 3, intermediate_dim[0], intermediate_dim[1]))
  107. else:
  108. input_img_data = np.random.random(
  109. (1, intermediate_dim[0], intermediate_dim[1], 3))
  110. input_img_data = (input_img_data - 0.5) * 20 + 128
  111. # Slowly upscaling towards the original size prevents
  112. # a dominating high-frequency of the to visualized structure
  113. # as it would occur if we directly compute the 412d-image.
  114. # Behaves as a better starting point for each following dimension
  115. # and therefore avoids poor local minima
  116. for up in reversed(range(upscaling_steps)):
  117. # we run gradient ascent for e.g. 20 steps
  118. for _ in range(epochs):
  119. loss_value, grads_value = iterate([input_img_data])
  120. input_img_data += grads_value * step
  121. # some filters get stuck to 0, we can skip them
  122. if loss_value <= K.epsilon():
  123. return None
  124. # Calculate upscaled dimension
  125. intermediate_dim = tuple(
  126. int(x / (upscaling_factor ** up)) for x in output_dim)
  127. # Upscale
  128. img = deprocess_image(input_img_data[0])
  129. img = np.array(pil_image.fromarray(img).resize(intermediate_dim,
  130. pil_image.BICUBIC))
  131. input_img_data = np.expand_dims(
  132. process_image(img, input_img_data[0]), 0)
  133. # decode the resulting input image
  134. img = deprocess_image(input_img_data[0])
  135. e_time = time.time()
  136. print('Costs of filter {:3}: {:5.0f} ( {:4.2f}s )'.format(filter_index,
  137. loss_value,
  138. e_time - s_time))
  139. return img, loss_value
  140. def _draw_filters(filters, n=None):
  141. """Draw the best filters in a nxn grid.
  142. # Arguments
  143. filters: A List of generated images and their corresponding losses
  144. for each processed filter.
  145. n: dimension of the grid.
  146. If none, the largest possible square will be used
  147. """
  148. if n is None:
  149. n = int(np.floor(np.sqrt(len(filters))))
  150. # the filters that have the highest loss are assumed to be better-looking.
  151. # we will only keep the top n*n filters.
  152. filters.sort(key=lambda x: x[1], reverse=True)
  153. filters = filters[:n * n]
  154. # build a black picture with enough space for
  155. # e.g. our 8 x 8 filters of size 412 x 412, with a 5px margin in between
  156. MARGIN = 5
  157. width = n * output_dim[0] + (n - 1) * MARGIN
  158. height = n * output_dim[1] + (n - 1) * MARGIN
  159. stitched_filters = np.zeros((width, height, 3), dtype='uint8')
  160. # fill the picture with our saved filters
  161. for i in range(n):
  162. for j in range(n):
  163. img, _ = filters[i * n + j]
  164. width_margin = (output_dim[0] + MARGIN) * i
  165. height_margin = (output_dim[1] + MARGIN) * j
  166. stitched_filters[
  167. width_margin: width_margin + output_dim[0],
  168. height_margin: height_margin + output_dim[1], :] = img
  169. # save the result to disk
  170. save_img('vgg_{0:}_{1:}x{1:}.png'.format(layer_name, n), stitched_filters)
  171. # this is the placeholder for the input images
  172. assert len(model.inputs) == 1
  173. input_img = model.inputs[0]
  174. # get the symbolic outputs of each "key" layer (we gave them unique names).
  175. layer_dict = dict([(layer.name, layer) for layer in model.layers[1:]])
  176. output_layer = layer_dict[layer_name]
  177. assert isinstance(output_layer, layers.Conv2D)
  178. # Compute to be processed filter range
  179. filter_lower = filter_range[0]
  180. filter_upper = (filter_range[1]
  181. if filter_range[1] is not None
  182. else len(output_layer.get_weights()[1]))
  183. assert(filter_lower >= 0
  184. and filter_upper <= len(output_layer.get_weights()[1])
  185. and filter_upper > filter_lower)
  186. print('Compute filters {:} to {:}'.format(filter_lower, filter_upper))
  187. # iterate through each filter and generate its corresponding image
  188. processed_filters = []
  189. for f in range(filter_lower, filter_upper):
  190. img_loss = _generate_filter_image(input_img, output_layer.output, f)
  191. if img_loss is not None:
  192. processed_filters.append(img_loss)
  193. print('{} filter processed.'.format(len(processed_filters)))
  194. # Finally draw and store the best filters to disk
  195. _draw_filters(processed_filters)
  196. if __name__ == '__main__':
  197. # the name of the layer we want to visualize
  198. # (see model definition at keras/applications/vgg16.py)
  199. LAYER_NAME = 'block5_conv1'
  200. # build the VGG16 network with ImageNet weights
  201. vgg = vgg16.VGG16(weights='imagenet', include_top=False)
  202. print('Model loaded.')
  203. vgg.summary()
  204. # example function call
  205. visualize_layer(vgg, LAYER_NAME)