Callbacks

Sometimes you need to change how things work a little bit. In fact, we have already seen examples of this: Mixup, fp16 training, resetting the model after each epoch for training RNNs, and so forth. How do we go about making these kinds of tweaks to the training process?

We’ve seen the basic training loop, which, with the help of the Optimizer class, looks like this for a single epoch:

  1. for xb,yb in dl:
  2. loss = loss_func(model(xb), yb)
  3. loss.backward()
  4. opt.step()
  5. opt.zero_grad()

<> shows how to picture that.

Basic training loop

The usual way for deep learning practitioners to customize the training loop is to make a copy of an existing training loop, and then insert the code necessary for their particular changes into it. This is how nearly all code that you find online will look. But it has some very serious problems.

It’s not very likely that some particular tweaked training loop is going to meet your particular needs. There are hundreds of changes that can be made to a training loop, which means there are billions and billions of possible permutations. You can’t just copy one tweak from a training loop here, another from a training loop there, and expect them all to work together. Each will be based on different assumptions about the environment that it’s working in, use different naming conventions, and expect the data to be in different formats.

We need a way to allow users to insert their own code at any part of the training loop, but in a consistent and well-defined way. Computer scientists have already come up with an elegant solution: the callback. A callback is a piece of code that you write, and inject into another piece of code at some predefined point. In fact, callbacks have been used with deep learning training loops for years. The problem is that in previous libraries it was only possible to inject code in a small subset of places where this may have been required, and, more importantly, callbacks were not able to do all the things they needed to do.

In order to be just as flexible as manually copying and pasting a training loop and directly inserting code into it, a callback must be able to read every possible piece of information available in the training loop, modify all of it as needed, and fully control when a batch, epoch, or even the whole training loop should be terminated. fastai is the first library to provide all of this functionality. It modifies the training loop so it looks like <>.

Training loop with callbacks

The real effectiveness of this approach has been borne out over the last couple of years—it has turned out that, by using the fastai callback system, we were able to implement every single new paper we tried and fulfilled every user request for modifying the training loop. The training loop itself has not required modifications. <> shows just a few of the callbacks that have been added.

Some fastai callbacks

The reason that this is important is because it means that whatever idea we have in our head, we can implement it. We need never dig into the source code of PyTorch or fastai and hack together some one-off system to try out our ideas. And when we do implement our own callbacks to develop our own ideas, we know that they will work together with all of the other functionality provided by fastai–so we will get progress bars, mixed-precision training, hyperparameter annealing, and so forth.

Another advantage is that it makes it easy to gradually remove or add functionality and perform ablation studies. You just need to adjust the list of callbacks you pass along to your fit function.

As an example, here is the fastai source code that is run for each batch of the training loop:

  1. try:
  2. self._split(b); self('begin_batch')
  3. self.pred = self.model(*self.xb); self('after_pred')
  4. self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')
  5. if not self.training: return
  6. self.loss.backward(); self('after_backward')
  7. self.opt.step(); self('after_step')
  8. self.opt.zero_grad()
  9. except CancelBatchException: self('after_cancel_batch')
  10. finally: self('after_batch')

The calls of the form self('...') are where the callbacks are called. As you see, this happens after every step. The callback will receive the entire state of training, and can also modify it. For instance, the input data and target labels are in self.xb and self.yb, respectively; a callback can modify these to alter the data the training loop sees. It can also modify self.loss, or even the gradients.

Let’s see how this works in practice by writing a callback.

Creating a Callback

When you want to write your own callback, the full list of available events is:

  • begin_fit:: called before doing anything; ideal for initial setup.
  • begin_epoch:: called at the beginning of each epoch; useful for any behavior you need to reset at each epoch.
  • begin_train:: called at the beginning of the training part of an epoch.
  • begin_batch:: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyperparameter scheduling) or to change the input/target before it goes into the model (for instance, apply Mixup).
  • after_pred:: called after computing the output of the model on the batch. It can be used to change that output before it’s fed to the loss function.
  • after_loss:: called after the loss has been computed, but before the backward pass. It can be used to add penalty to the loss (AR or TAR in RNN training, for instance).
  • after_backward:: called after the backward pass, but before the update of the parameters. It can be used to make changes to the gradients before said update (via gradient clipping, for instance).
  • after_step:: called after the step and before the gradients are zeroed.
  • after_batch:: called at the end of a batch, to perform any required cleanup before the next one.
  • after_train:: called at the end of the training phase of an epoch.
  • begin_validate:: called at the beginning of the validation phase of an epoch; useful for any setup needed specifically for validation.
  • after_validate:: called at the end of the validation part of an epoch.
  • after_epoch:: called at the end of an epoch, for any cleanup before the next one.
  • after_fit:: called at the end of training, for final cleanup.

The elements of this list are available as attributes of the special variable event, so you can just type event. and hit Tab in your notebook to see a list of all the options.

Let’s take a look at an example. Do you recall how in <> we needed to ensure that our special reset method was called at the start of training and validation for each epoch? We used the ModelResetter callback provided by fastai to do this for us. But how does it work? Here’s the full source code for that class:

In [ ]:

  1. class ModelResetter(Callback):
  2. def begin_train(self): self.model.reset()
  3. def begin_validate(self): self.model.reset()

Yes, that’s actually it! It just does what we said in the preceding paragraph: after completing training or validation for an epoch, call a method named reset.

Callbacks are often “short and sweet” like this one. In fact, let’s look at one more. Here’s the fastai source for the callback that adds RNN regularization (AR and TAR):

In [ ]:

  1. class RNNRegularizer(Callback):
  2. def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta
  3. def after_pred(self):
  4. self.raw_out,self.out = self.pred[1],self.pred[2]
  5. self.learn.pred = self.pred[0]
  6. def after_loss(self):
  7. if not self.training: return
  8. if self.alpha != 0.:
  9. self.learn.loss += self.alpha * self.out[-1].float().pow(2).mean()
  10. if self.beta != 0.:
  11. h = self.raw_out[-1]
  12. if len(h)>1:
  13. self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]
  14. ).float().pow(2).mean()

note: Code It Yourself: Go back and reread “Activation Regularization and Temporal Activation Regularization” in <> then take another look at the code here. Make sure you understand what it’s doing, and why.

In both of these examples, notice how we can access attributes of the training loop by directly checking self.model or self.pred. That’s because a Callback will always try to get an attribute it doesn’t have inside the Learner associated with it. These are shortcuts for self.learn.model or self.learn.pred. Note that they work for reading attributes, but not for writing them, which is why when RNNRegularizer changes the loss or the predictions you see self.learn.loss = or self.learn.pred =.

When writing a callback, the following attributes of Learner are available:

  • model:: The model used for training/validation.
  • data:: The underlying DataLoaders.
  • loss_func:: The loss function used.
  • opt:: The optimizer used to update the model parameters.
  • opt_func:: The function used to create the optimizer.
  • cbs:: The list containing all the Callbacks.
  • dl:: The current DataLoader used for iteration.
  • x/xb:: The last input drawn from self.dl (potentially modified by callbacks). xb is always a tuple (potentially with one element) and x is detuplified. You can only assign to xb.
  • y/yb:: The last target drawn from self.dl (potentially modified by callbacks). yb is always a tuple (potentially with one element) and y is detuplified. You can only assign to yb.
  • pred:: The last predictions from self.model (potentially modified by callbacks).
  • loss:: The last computed loss (potentially modified by callbacks).
  • n_epoch:: The number of epochs in this training.
  • n_iter:: The number of iterations in the current self.dl.
  • epoch:: The current epoch index (from 0 to n_epoch-1).
  • iter:: The current iteration index in self.dl (from 0 to n_iter-1).

The following attributes are added by TrainEvalCallback and should be available unless you went out of your way to remove that callback:

  • train_iter:: The number of training iterations done since the beginning of this training
  • pct_train:: The percentage of training iterations completed (from 0. to 1.)
  • training:: A flag to indicate whether or not we’re in training mode

The following attribute is added by Recorder and should be available unless you went out of your way to remove that callback:

  • smooth_loss:: An exponentially averaged version of the training loss

Callbacks can also interrupt any part of the training loop by using a system of exceptions.

Callback Ordering and Exceptions

Sometimes, callbacks need to be able to tell fastai to skip over a batch, or an epoch, or stop training altogether. For instance, consider TerminateOnNaNCallback. This handy callback will automatically stop training any time the loss becomes infinite or NaN (not a number). Here’s the fastai source for this callback:

In [ ]:

  1. class TerminateOnNaNCallback(Callback):
  2. run_before=Recorder
  3. def after_batch(self):
  4. if torch.isinf(self.loss) or torch.isnan(self.loss):
  5. raise CancelFitException

The line raise CancelFitException tells the training loop to interrupt training at this point. The training loop catches this exception and does not run any further training or validation. The callback control flow exceptions available are:

  • CancelBatchException:: Skip the rest of this batch and go to after_batch.
  • CancelTrainException:: Skip the rest of the training part of the epoch and go to after_train.
  • CancelValidException:: Skip the rest of the validation part of the epoch and go to after_validate.
  • CancelEpochException:: Skip the rest of this epoch and go to after_epoch.
  • CancelFitException:: Interrupt training and go to after_fit.

You can detect if one of those exceptions has occurred and add code that executes right after with the following events:

  • after_cancel_batch:: Reached immediately after a CancelBatchException before proceeding to after_batch
  • after_cancel_train:: Reached immediately after a CancelTrainException before proceeding to after_train
  • after_cancel_valid:: Reached immediately after a CancelValidException before proceeding to after_valid
  • after_cancel_epoch:: Reached immediately after a CancelEpochException before proceeding to after_epoch
  • after_cancel_fit:: Reached immediately after a CancelFitException before proceeding to after_fit

Sometimes, callbacks need to be called in a particular order. For example, in the case of TerminateOnNaNCallback, it’s important that Recorder runs its after_batch after this callback, to avoid registering an NaN loss. You can specify run_before (this callback must run before …) or run_after (this callback must run after …) in your callback to ensure the ordering that you need.