Writing your own Keras layers

For simple, stateless custom operations, you are probably better off using layers.core.Lambda layers. But for any custom operation that has trainable weights, you should implement your own layer.

Here is the skeleton of a Keras layer, as of Keras 2.0 (if you have an older version, please upgrade). There are only three methods you need to implement:

  • build(input_shape): this is where you will define your weights. This method must set self.built = True at the end, which can be done by calling super([Layer], self).build().
  • call(x): this is where the layer's logic lives. Unless you want your layer to support masking, you only have to care about the first argument passed to call: the input tensor.
  • compute_output_shape(input_shape): in case your layer modifies the shape of its input, you should specify here the shape transformation logic. This allows Keras to do automatic shape inference.
  1. from keras import backend as K
  2. from keras.layers import Layer
  3. class MyLayer(Layer):
  4. def __init__(self, output_dim, **kwargs):
  5. self.output_dim = output_dim
  6. super(MyLayer, self).__init__(**kwargs)
  7. def build(self, input_shape):
  8. # Create a trainable weight variable for this layer.
  9. self.kernel = self.add_weight(name='kernel',
  10. shape=(input_shape[1], self.output_dim),
  11. initializer='uniform',
  12. trainable=True)
  13. super(MyLayer, self).build(input_shape) # Be sure to call this at the end
  14. def call(self, x):
  15. return K.dot(x, self.kernel)
  16. def compute_output_shape(self, input_shape):
  17. return (input_shape[0], self.output_dim)

It is also possible to define Keras layers which have multiple input tensors and multiple output tensors. To do this, you should assume that the inputs and outputs of the methods build(input_shape), call(x) and compute_output_shape(input_shape) are lists. Here is an example, similar to the one above:

  1. from keras import backend as K
  2. from keras.layers import Layer
  3. class MyLayer(Layer):
  4. def __init__(self, output_dim, **kwargs):
  5. self.output_dim = output_dim
  6. super(MyLayer, self).__init__(**kwargs)
  7. def build(self, input_shape):
  8. assert isinstance(input_shape, list)
  9. # Create a trainable weight variable for this layer.
  10. self.kernel = self.add_weight(name='kernel',
  11. shape=(input_shape[0][1], self.output_dim),
  12. initializer='uniform',
  13. trainable=True)
  14. super(MyLayer, self).build(input_shape) # Be sure to call this at the end
  15. def call(self, x):
  16. assert isinstance(x, list)
  17. a, b = x
  18. return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]
  19. def compute_output_shape(self, input_shape):
  20. assert isinstance(input_shape, list)
  21. shape_a, shape_b = input_shape
  22. return [(shape_a[0], self.output_dim), shape_b[:-1]]

The existing Keras layers provide examples of how to implement almost anything. Never hesitate to read the source code!