A Generic Optimizer

To build up our accelerated SGD tricks, we’ll need to start with a nice flexible optimizer foundation. No library prior to fastai provided such a foundation, but during fastai’s development we realized that all the optimizer improvements we’d seen in the academic literature could be handled using optimizer callbacks. These are small pieces of code that we can compose, mix and match in an optimizer to build the optimizer step. They are called by fastai’s lightweight Optimizer class. These are the definitions in Optimizer of the two key methods that we’ve been using in this book:

  1. def zero_grad(self):
  2. for p,*_ in self.all_params():
  3. p.grad.detach_()
  4. p.grad.zero_()
  5. def step(self):
  6. for p,pg,state,hyper in self.all_params():
  7. for cb in self.cbs:
  8. state = _update(state, cb(p, **{**state, **hyper}))
  9. self.state[p] = state

As we saw when training an MNIST model from scratch, zero_grad just loops through the parameters of the model and sets the gradients to zero. It also calls detach_, which removes any history of gradient computation, since it won’t be needed after zero_grad.

The more interesting method is step, which loops through the callbacks (cbs) and calls them to update the parameters (the _update function just calls state.update if there’s anything returned by cb). As you can see, Optimizer doesn’t actually do any SGD steps itself. Let’s see how we can add SGD to Optimizer.

Here’s an optimizer callback that does a single SGD step, by multiplying -lr by the gradients and adding that to the parameter (when Tensor.add_ in PyTorch is passed two parameters, they are multiplied together before the addition):

In [ ]:

  1. def sgd_cb(p, lr, **kwargs): p.data.add_(-lr, p.grad.data)

We can pass this to Optimizer using the cbs parameter; we’ll need to use partial since Learner will call this function to create our optimizer later:

In [ ]:

  1. opt_func = partial(Optimizer, cbs=[sgd_cb])

Let’s see if this trains:

In [ ]:

  1. learn = get_learner(opt_func=opt_func)
  2. learn.fit(3, 0.03)
epochtrain_lossvalid_lossaccuracytime
02.7309182.0099710.33273900:09
12.2048931.7472020.44152900:09
21.8756211.6845150.44535000:09

It’s working! So that’s how we create SGD from scratch in fastai. Now let’s see what “momentum” is.