LSTM

LSTM is an architecture that was introduced back in 1997 by Jürgen Schmidhuber and Sepp Hochreiter. In this architecture, there are not one but two hidden states. In our base RNN, the hidden state is the output of the RNN at the previous time step. That hidden state is then responsible for two things:

  • Having the right information for the output layer to predict the correct next token
  • Retaining memory of everything that happened in the sentence

Consider, for example, the sentences “Henry has a dog and he likes his dog very much” and “Sophie has a dog and she likes her dog very much.” It’s very clear that the RNN needs to remember the name at the beginning of the sentence to be able to predict he/she or his/her.

In practice, RNNs are really bad at retaining memory of what happened much earlier in the sentence, which is the motivation to have another hidden state (called cell state) in the LSTM. The cell state will be responsible for keeping long short-term memory, while the hidden state will focus on the next token to predict. Let’s take a closer look at how this is achieved and build an LSTM from scratch.

Building an LSTM from Scratch

In order to build an LSTM, we first have to understand its architecture. <> shows its inner structure.

A graph showing the inner architecture of an LSTM

In this picture, our input $x_{t}$ enters on the left with the previous hidden state ($h_{t-1}$) and cell state ($c_{t-1}$). The four orange boxes represent four layers (our neural nets) with the activation being either sigmoid ($\sigma$) or tanh. tanh is just a sigmoid function rescaled to the range -1 to 1. Its mathematical expression can be written like this:

\tanh(x) = \frac{e^{x} - e^{-x}}{e^{x}+e^{-x}} = 2 \sigma(2x) - 1

where $\sigma$ is the sigmoid function. The green circles are elementwise operations. What goes out on the right is the new hidden state ($h_{t}$) and new cell state ($c_{t}$), ready for our next input. The new hidden state is also used as output, which is why the arrow splits to go up.

Let’s go over the four neural nets (called gates) one by one and explain the diagram—but before this, notice how very little the cell state (at the top) is changed. It doesn’t even go directly through a neural net! This is exactly why it will carry on a longer-term state.

First, the arrows for input and old hidden state are joined together. In the RNN we wrote earlier in this chapter, we were adding them together. In the LSTM, we stack them in one big tensor. This means the dimension of our embeddings (which is the dimension of $x_{t}$) can be different than the dimension of our hidden state. If we call those n_in and n_hid, the arrow at the bottom is of size n_in + n_hid; thus all the neural nets (orange boxes) are linear layers with n_in + n_hid inputs and n_hid outputs.

The first gate (looking from left to right) is called the forget gate. Since it’s a linear layer followed by a sigmoid, its output will consist of scalars between 0 and 1. We multiply this result by the cell state to determine which information to keep and which to throw away: values closer to 0 are discarded and values closer to 1 are kept. This gives the LSTM the ability to forget things about its long-term state. For instance, when crossing a period or an xxbos token, we would expect to it to (have learned to) reset its cell state.

The second gate is called the input gate. It works with the third gate (which doesn’t really have a name but is sometimes called the cell gate) to update the cell state. For instance, we may see a new gender pronoun, in which case we’ll need to replace the information about gender that the forget gate removed. Similar to the forget gate, the input gate decides which elements of the cell state to update (values close to 1) or not (values close to 0). The third gate determines what those updated values are, in the range of –1 to 1 (thanks to the tanh function). The result is then added to the cell state.

The last gate is the output gate. It determines which information from the cell state to use to generate the output. The cell state goes through a tanh before being combined with the sigmoid output from the output gate, and the result is the new hidden state.

In terms of code, we can write the same steps like this:

In [ ]:

  1. class LSTMCell(Module):
  2. def __init__(self, ni, nh):
  3. self.forget_gate = nn.Linear(ni + nh, nh)
  4. self.input_gate = nn.Linear(ni + nh, nh)
  5. self.cell_gate = nn.Linear(ni + nh, nh)
  6. self.output_gate = nn.Linear(ni + nh, nh)
  7. def forward(self, input, state):
  8. h,c = state
  9. h = torch.cat([h, input], dim=1)
  10. forget = torch.sigmoid(self.forget_gate(h))
  11. c = c * forget
  12. inp = torch.sigmoid(self.input_gate(h))
  13. cell = torch.tanh(self.cell_gate(h))
  14. c = c + inp * cell
  15. out = torch.sigmoid(self.output_gate(h))
  16. h = out * torch.tanh(c)
  17. return h, (h,c)

In practice, we can then refactor the code. Also, in terms of performance, it’s better to do one big matrix multiplication than four smaller ones (that’s because we only launch the special fast kernel on the GPU once, and it gives the GPU more work to do in parallel). The stacking takes a bit of time (since we have to move one of the tensors around on the GPU to have it all in a contiguous array), so we use two separate layers for the input and the hidden state. The optimized and refactored code then looks like this:

In [ ]:

  1. class LSTMCell(Module):
  2. def __init__(self, ni, nh):
  3. self.ih = nn.Linear(ni,4*nh)
  4. self.hh = nn.Linear(nh,4*nh)
  5. def forward(self, input, state):
  6. h,c = state
  7. # One big multiplication for all the gates is better than 4 smaller ones
  8. gates = (self.ih(input) + self.hh(h)).chunk(4, 1)
  9. ingate,forgetgate,outgate = map(torch.sigmoid, gates[:3])
  10. cellgate = gates[3].tanh()
  11. c = (forgetgate*c) + (ingate*cellgate)
  12. h = outgate * c.tanh()
  13. return h, (h,c)

Here we use the PyTorch chunk method to split our tensor into four pieces. It works like this:

In [ ]:

  1. t = torch.arange(0,10); t

Out[ ]:

  1. tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [ ]:

  1. t.chunk(2)

Out[ ]:

  1. (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))

Let’s now use this architecture to train a language model!

Training a Language Model Using LSTMs

Here is the same network as LMModel5, using a two-layer LSTM. We can train it at a higher learning rate, for a shorter time, and get better accuracy:

In [ ]:

  1. class LMModel6(Module):
  2. def __init__(self, vocab_sz, n_hidden, n_layers):
  3. self.i_h = nn.Embedding(vocab_sz, n_hidden)
  4. self.rnn = nn.LSTM(n_hidden, n_hidden, n_layers, batch_first=True)
  5. self.h_o = nn.Linear(n_hidden, vocab_sz)
  6. self.h = [torch.zeros(n_layers, bs, n_hidden) for _ in range(2)]
  7. def forward(self, x):
  8. res,h = self.rnn(self.i_h(x), self.h)
  9. self.h = [h_.detach() for h_ in h]
  10. return self.h_o(res)
  11. def reset(self):
  12. for h in self.h: h.zero_()

In [ ]:

  1. learn = Learner(dls, LMModel6(len(vocab), 64, 2),
  2. loss_func=CrossEntropyLossFlat(),
  3. metrics=accuracy, cbs=ModelResetter)
  4. learn.fit_one_cycle(15, 1e-2)
epochtrain_lossvalid_lossaccuracytime
03.0008212.6639420.43831400:02
12.1396422.1847800.24047900:02
21.6072751.8126820.43977900:02
31.3477111.8309820.49747700:02
41.1231131.9377660.59440100:02
50.8520422.0121270.63159200:02
60.5654941.3127420.72574900:02
70.3474451.2979340.71126300:02
80.2081911.4412690.73120100:02
90.1263351.5699520.73730500:02
100.0797611.4271870.75415000:02
110.0529901.4949900.74511700:02
120.0390081.3937310.75789400:02
130.0315021.3732100.75846400:02
140.0280681.3680830.75846400:02

Now that’s better than a multilayer RNN! We can still see there is a bit of overfitting, however, which is a sign that a bit of regularization might help.