Learner

We have data, a model, and a loss function; we only need one more thing we can fit a model, and that’s an optimizer! Here’s SGD:

In [ ]:

  1. class SGD:
  2. def __init__(self, params, lr, wd=0.): store_attr()
  3. def step(self):
  4. for p in self.params:
  5. p.data -= (p.grad.data + p.data*self.wd) * self.lr
  6. p.grad.data.zero_()

As we’ve seen in this book, life is easier with a Learner. The Learner class needs to know our training and validation sets, which means we need DataLoaders to store them. We don’t need any other functionality, just a place to store them and access them:

In [ ]:

  1. class DataLoaders:
  2. def __init__(self, *dls): self.train,self.valid = dls
  3. dls = DataLoaders(train_dl,valid_dl)

Now we’re ready to create our Learner class:

In [ ]:

  1. class Learner:
  2. def __init__(self, model, dls, loss_func, lr, cbs, opt_func=SGD):
  3. store_attr()
  4. for cb in cbs: cb.learner = self
  5. def one_batch(self):
  6. self('before_batch')
  7. xb,yb = self.batch
  8. self.preds = self.model(xb)
  9. self.loss = self.loss_func(self.preds, yb)
  10. if self.model.training:
  11. self.loss.backward()
  12. self.opt.step()
  13. self('after_batch')
  14. def one_epoch(self, train):
  15. self.model.training = train
  16. self('before_epoch')
  17. dl = self.dls.train if train else self.dls.valid
  18. for self.num,self.batch in enumerate(progress_bar(dl, leave=False)):
  19. self.one_batch()
  20. self('after_epoch')
  21. def fit(self, n_epochs):
  22. self('before_fit')
  23. self.opt = self.opt_func(self.model.parameters(), self.lr)
  24. self.n_epochs = n_epochs
  25. try:
  26. for self.epoch in range(n_epochs):
  27. self.one_epoch(True)
  28. self.one_epoch(False)
  29. except CancelFitException: pass
  30. self('after_fit')
  31. def __call__(self,name):
  32. for cb in self.cbs: getattr(cb,name,noop)()

This is the largest class we’ve created in the book, but each method is quite small, so by looking at each in turn you should be able to follow what’s going on.

The main method we’ll be calling is fit. This loops with:

  1. for self.epoch in range(n_epochs)

and at each epoch calls self.one_epoch for each of train=True and then train=False. Then self.one_epoch calls self.one_batch for each batch in dls.train or dls.valid, as appropriate (after wrapping the DataLoader in fastprogress.progress_bar. Finally, self.one_batch follows the usual set of steps to fit one mini-batch that we’ve seen throughout this book.

Before and after each step, Learner calls self, which calls __call__ (which is standard Python functionality). __call__ uses getattr(cb,name) on each callback in self.cbs, which is a Python built-in function that returns the attribute (a method, in this case) with the requested name. So, for instance, self('before_fit') will call cb.before_fit() for each callback where that method is defined.

As you can see, Learner is really just using our standard training loop, except that it’s also calling callbacks at appropriate times. So let’s define some callbacks!

Callbacks

In Learner.__init__ we have:

  1. for cb in cbs: cb.learner = self

In other words, every callback knows what learner it is used in. This is critical, since otherwise a callback can’t get information from the learner, or change things in the learner. Because getting information from the learner is so common, we make that easier by defining Callback as a subclass of GetAttr, with a default attribute of learner:

In [ ]:

  1. class Callback(GetAttr): _default='learner'

GetAttr is a fastai class that implements Python’s standard __getattr__ and __dir__ methods for you, such that any time you try to access an attribute that doesn’t exist, it passes the request along to whatever you have defined as _default.

For instance, we want to move all model parameters to the GPU automatically at the start of fit. We could do this by defining before_fit as self.learner.model.cuda(); however, because learner is the default attribute, and we have SetupLearnerCB inherit from Callback (which inherits from GetAttr), we can remove the .learner and just call self.model.cuda():

In [ ]:

  1. class SetupLearnerCB(Callback):
  2. def before_batch(self):
  3. xb,yb = to_device(self.batch)
  4. self.learner.batch = tfm_x(xb),yb
  5. def before_fit(self): self.model.cuda()

In SetupLearnerCB we also move each mini-batch to the GPU, by calling to_device(self.batch) (we could also have used the longer to_device(self.learner.batch). Note however that in the line self.learner.batch = tfm_x(xb),yb we can’t remove .learner, because here we’re setting the attribute, not getting it.

Before we try our Learner out, let’s create a callback to track and print progress. Otherwise we won’t really know if it’s working properly:

In [ ]:

  1. class TrackResults(Callback):
  2. def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]
  3. def after_epoch(self):
  4. n = sum(self.ns)
  5. print(self.epoch, self.model.training,
  6. sum(self.losses).item()/n, sum(self.accs).item()/n)
  7. def after_batch(self):
  8. xb,yb = self.batch
  9. acc = (self.preds.argmax(dim=1)==yb).float().sum()
  10. self.accs.append(acc)
  11. n = len(xb)
  12. self.losses.append(self.loss*n)
  13. self.ns.append(n)

Now we’re ready to use our Learner for the first time!

In [ ]:

  1. cbs = [SetupLearnerCB(),TrackResults()]
  2. learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs)
  3. learn.fit(1)
  1. 0 True 2.1275552130636814 0.2314922378287042
  1. 0 False 1.9942575636942674 0.2991082802547771

It’s quite amazing to realize that we can implement all the key ideas from fastai’s Learner in so little code! Let’s now add some learning rate scheduling.

Scheduling the Learning Rate

If we’re going to get good results, we’ll want an LR finder and 1cycle training. These are both annealing callbacks—that is, they are gradually changing hyperparameters as we train. Here’s LRFinder:

In [ ]:

  1. class LRFinder(Callback):
  2. def before_fit(self):
  3. self.losses,self.lrs = [],[]
  4. self.learner.lr = 1e-6
  5. def before_batch(self):
  6. if not self.model.training: return
  7. self.opt.lr *= 1.2
  8. def after_batch(self):
  9. if not self.model.training: return
  10. if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException
  11. self.losses.append(self.loss.item())
  12. self.lrs.append(self.opt.lr)

This shows how we’re using CancelFitException, which is itself an empty class, only used to signify the type of exception. You can see in Learner that this exception is caught. (You should add and test CancelBatchException, CancelEpochException, etc. yourself.) Let’s try it out, by adding it to our list of callbacks:

In [ ]:

  1. lrfind = LRFinder()
  2. learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[lrfind])
  3. learn.fit(2)
  1. 0 True 2.6336045582954903 0.11014890695955222
  1. 0 False 2.230653363853503 0.18318471337579617

16.22% [12/74 00:02<00:12]

And take a look at the results:

In [ ]:

  1. plt.plot(lrfind.lrs[:-2],lrfind.losses[:-2])
  2. plt.xscale('log')

Learner - 图1

Now we can define our OneCycle training callback:

In [ ]:

  1. class OneCycle(Callback):
  2. def __init__(self, base_lr): self.base_lr = base_lr
  3. def before_fit(self): self.lrs = []
  4. def before_batch(self):
  5. if not self.model.training: return
  6. n = len(self.dls.train)
  7. bn = self.epoch*n + self.num
  8. mn = self.n_epochs*n
  9. pct = bn/mn
  10. pct_start,div_start = 0.25,10
  11. if pct<pct_start:
  12. pct /= pct_start
  13. lr = (1-pct)*self.base_lr/div_start + pct*self.base_lr
  14. else:
  15. pct = (pct-pct_start)/(1-pct_start)
  16. lr = (1-pct)*self.base_lr
  17. self.opt.lr = lr
  18. self.lrs.append(lr)

We’ll try an LR of 0.1:

In [ ]:

  1. onecyc = OneCycle(0.1)
  2. learn = Learner(simple_cnn(), dls, cross_entropy, lr=0.1, cbs=cbs+[onecyc])

Let’s fit for a while and see how it looks (we won’t show all the output in the book—try it in the notebook to see the results):

In [ ]:

  1. #hide_output
  2. learn.fit(8)

Finally, we’ll check that the learning rate followed the schedule we defined (as you see, we’re not using cosine annealing here):

In [ ]:

  1. plt.plot(onecyc.lrs);

Learner - 图2