6.8. 长短期记忆(LSTM)

本节将介绍另一种常用的门控循环神经网络:长短期记忆(long short-termmemory,LSTM)[1]。它比门控循环单元的结构稍微复杂一点。

6.8.1. 长短期记忆

LSTM 中引入了3个门,即输入门(input gate)、遗忘门(forgetgate)和输出门(outputgate),以及与隐藏状态形状相同的记忆细胞(某些文献把记忆细胞当成一种特殊的隐藏状态),从而记录额外的信息。

6.8.1.1. 输入门、遗忘门和输出门

与门控循环单元中的重置门和更新门一样,如图6.7所示,长短期记忆的门的输入均为当前时间步输入

6.8. 长短期记忆(LSTM) - 图1 与上一时间步隐藏状态 6.8. 长短期记忆(LSTM) - 图2 ,输出由激活函数为sigmoid函数的全连接层计算得到。如此一来,这3个门元素的值域均为 6.8. 长短期记忆(LSTM) - 图3

长短期记忆中输入门、遗忘门和输出门的计算 图 6.7 长短期记忆中输入门、遗忘门和输出门的计算

具体来说,假设隐藏单元个数为

6.8. 长短期记忆(LSTM) - 图5 ,给定时间步 6.8. 长短期记忆(LSTM) - 图6 的小批量输入 6.8. 长短期记忆(LSTM) - 图7 (样本数为 6.8. 长短期记忆(LSTM) - 图8 ,输入个数为 6.8. 长短期记忆(LSTM) - 图9 )和上一时间步隐藏状态 6.8. 长短期记忆(LSTM) - 图10 。时间步 6.8. 长短期记忆(LSTM) - 图11 的输入门 6.8. 长短期记忆(LSTM) - 图12 、遗忘门 6.8. 长短期记忆(LSTM) - 图13 和输出门 6.8. 长短期记忆(LSTM) - 图14 分别计算如下:

6.8. 长短期记忆(LSTM) - 图15

其中的

6.8. 长短期记忆(LSTM) - 图166.8. 长短期记忆(LSTM) - 图17 是权重参数, 6.8. 长短期记忆(LSTM) - 图18 是偏差参数。

6.8.1.2. 候选记忆细胞

接下来,长短期记忆需要计算候选记忆细胞

6.8. 长短期记忆(LSTM) - 图19 。它的计算与上面介绍的3个门类似,但使用了值域在 6.8. 长短期记忆(LSTM) - 图20 的tanh函数作为激活函数,如图6.8所示。

长短期记忆中候选记忆细胞的计算 图 6.8 长短期记忆中候选记忆细胞的计算

具体来说,时间步

6.8. 长短期记忆(LSTM) - 图22 的候选记忆细胞 6.8. 长短期记忆(LSTM) - 图23 的计算为

6.8. 长短期记忆(LSTM) - 图24

其中

6.8. 长短期记忆(LSTM) - 图256.8. 长短期记忆(LSTM) - 图26 是权重参数, 6.8. 长短期记忆(LSTM) - 图27 是偏差参数。

6.8.1.3. 记忆细胞

我们可以通过元素值域在

6.8. 长短期记忆(LSTM) - 图28 的输入门、遗忘门和输出门来控制隐藏状态中信息的流动,这一般也是通过使用按元素乘法(符号为 6.8. 长短期记忆(LSTM) - 图29 )来实现的。当前时间步记忆细胞 6.8. 长短期记忆(LSTM) - 图30 的计算组合了上一时间步记忆细胞和当前时间步候选记忆细胞的信息,并通过遗忘门和输入门来控制信息的流动:

6.8. 长短期记忆(LSTM) - 图31

如图6.9所示,遗忘门控制上一时间步的记忆细胞

6.8. 长短期记忆(LSTM) - 图32 中的信息是否传递到当前时间步,而输入门则控制当前时间步的输入 6.8. 长短期记忆(LSTM) - 图33 通过候选记忆细胞 6.8. 长短期记忆(LSTM) - 图34 如何流入当前时间步的记忆细胞。如果遗忘门一直近似1且输入门一直近似0,过去的记忆细胞将一直通过时间保存并传递至当前时间步。这个设计可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

长短期记忆中记忆细胞的计算。这里的\ :math:`\odot`\ 是按元素乘法 图 6.9 长短期记忆中记忆细胞的计算。这里的

6.8. 长短期记忆(LSTM) - 图36 是按元素乘法

6.8.1.4. 隐藏状态

有了记忆细胞以后,接下来我们还可以通过输出门来控制从记忆细胞到隐藏状态

6.8. 长短期记忆(LSTM) - 图37 的信息的流动:

6.8. 长短期记忆(LSTM) - 图38

这里的tanh函数确保隐藏状态元素值在-1到1之间。需要注意的是,当输出门近似1时,记忆细胞信息将传递到隐藏状态供输出层使用;当输出门近似0时,记忆细胞信息只自己保留。图6.10展示了长短期记忆中隐藏状态的计算。

长短期记忆中隐藏状态的计算。这里的\ :math:`\odot`\ 是按元素乘法 图 6.10 长短期记忆中隐藏状态的计算。这里的

6.8. 长短期记忆(LSTM) - 图40 是按元素乘法

6.8.2. 读取数据集

下面我们开始实现并展示长短期记忆。和前几节中的实验一样,这里依然使用周杰伦歌词数据集来训练模型作词。

  1. In [1]:
  1. import d2lzh as d2l
  2. from mxnet import nd
  3. from mxnet.gluon import rnn
  4.  
  5. (corpus_indices, char_to_idx, idx_to_char,
  6. vocab_size) = d2l.load_data_jay_lyrics()

6.8.3. 从零开始实现

我们先介绍如何从零开始实现长短期记忆。

6.8.3.1. 初始化模型参数

下面的代码对模型参数进行初始化。超参数num_hiddens定义了隐藏单元的个数。

  1. In [2]:
  1. num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
  2. ctx = d2l.try_gpu()
  3.  
  4. def get_params():
  5. def _one(shape):
  6. return nd.random.normal(scale=0.01, shape=shape, ctx=ctx)
  7.  
  8. def _three():
  9. return (_one((num_inputs, num_hiddens)),
  10. _one((num_hiddens, num_hiddens)),
  11. nd.zeros(num_hiddens, ctx=ctx))
  12.  
  13. W_xi, W_hi, b_i = _three() # 输入门参数
  14. W_xf, W_hf, b_f = _three() # 遗忘门参数
  15. W_xo, W_ho, b_o = _three() # 输出门参数
  16. W_xc, W_hc, b_c = _three() # 候选记忆细胞参数
  17. # 输出层参数
  18. W_hq = _one((num_hiddens, num_outputs))
  19. b_q = nd.zeros(num_outputs, ctx=ctx)
  20. # 附上梯度
  21. params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,
  22. b_c, W_hq, b_q]
  23. for param in params:
  24. param.attach_grad()
  25. return params

6.8.4. 定义模型

在初始化函数中,长短期记忆的隐藏状态需要返回额外的形状为(批量大小,隐藏单元个数)的值为0的记忆细胞。

  1. In [3]:
  1. def init_lstm_state(batch_size, num_hiddens, ctx):
  2. return (nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx),
  3. nd.zeros(shape=(batch_size, num_hiddens), ctx=ctx))

下面根据长短期记忆的计算表达式定义模型。需要注意的是,只有隐藏状态会传递到输出层,而记忆细胞不参与输出层的计算。

  1. In [4]:
  1. def lstm(inputs, state, params):
  2. [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,
  3. W_hq, b_q] = params
  4. (H, C) = state
  5. outputs = []
  6. for X in inputs:
  7. I = nd.sigmoid(nd.dot(X, W_xi) + nd.dot(H, W_hi) + b_i)
  8. F = nd.sigmoid(nd.dot(X, W_xf) + nd.dot(H, W_hf) + b_f)
  9. O = nd.sigmoid(nd.dot(X, W_xo) + nd.dot(H, W_ho) + b_o)
  10. C_tilda = nd.tanh(nd.dot(X, W_xc) + nd.dot(H, W_hc) + b_c)
  11. C = F * C + I * C_tilda
  12. H = O * C.tanh()
  13. Y = nd.dot(H, W_hq) + b_q
  14. outputs.append(Y)
  15. return outputs, (H, C)

6.8.4.1. 训练模型并创作歌词

同上一节一样,我们在训练模型时只使用相邻采样。设置好超参数后,我们将训练模型并根据前缀“分开”和“不分开”分别创作长度为50个字符的一段歌词。

  1. In [5]:
  1. num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
  2. pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']

我们每过40个迭代周期便根据当前训练的模型创作一段歌词。

  1. In [6]:
  1. d2l.train_and_predict_rnn(lstm, get_params, init_lstm_state, num_hiddens,
  2. vocab_size, ctx, corpus_indices, idx_to_char,
  3. char_to_idx, False, num_epochs, num_steps, lr,
  4. clipping_theta, batch_size, pred_period, pred_len,
  5. prefixes)
  1. epoch 40, perplexity 210.201758, time 0.73 sec
  2. - 分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我
  3. - 不分开 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我 我不的我
  4. epoch 80, perplexity 65.295333, time 0.73 sec
  5. - 分开 我想你这你 我想要这 我不要这 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要
  6. - 不分开 我想你这 我想要这样 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不要 我不
  7. epoch 120, perplexity 16.060283, time 0.73 sec
  8. - 分开 我想你这已经着 一想 你来了看着我 别发抖 快给我抬起 说有你对我有多 难是我 你给我 是你怎
  9. - 不分开 我想你这已经著 想想你 你想是看看 我想你这样活 你知你 说你的让我 想想好觉 你不了看个着 我知
  10. epoch 160, perplexity 3.975131, time 0.74 sec
  11. - 分开 我想了这样嵩堡 想想是没不活 唱天歌 一直走 我想就这样牵着你的手不放开 爱可不可以简简单单没有伤
  12. - 不分开你的爱面 温过过美过 想想要这想相信运运 我不能这远远让知知道 我已定会呵护著你 不要你这样打我妈妈

6.8.5. 简洁实现

在Gluon中我们可以直接调用rnn模块中的LSTM类。

  1. In [7]:
  1. lstm_layer = rnn.LSTM(num_hiddens)
  2. model = d2l.RNNModel(lstm_layer, vocab_size)
  3. d2l.train_and_predict_rnn_gluon(model, num_hiddens, vocab_size, ctx,
  4. corpus_indices, idx_to_char, char_to_idx,
  5. num_epochs, num_steps, lr, clipping_theta,
  6. batch_size, pred_period, pred_len, prefixes)
  1. epoch 40, perplexity 222.361727, time 0.05 sec
  2. - 分开 我不的 我你的 我你的 我你的 我你的 我你的 我你的 我你的 我你的 我你的 我你的 我你的
  3. - 不分开 我想你的 我不你 我你你的 我不你的 我不你的 我不你的 我不你的 我不你的 我不你的 我不你的
  4. epoch 80, perplexity 66.490616, time 0.05 sec
  5. - 分开 我想你这你 我不要这不 我不要这不 我不要这不 我不要这不 我不要这不 我不要这不 我不要这不
  6. - 不分开 我想想你的你 我不要你不 我不要 我不要 我不要这生 我不要这不 我不要这不 我不要这生 我不要这
  7. epoch 120, perplexity 14.220221, time 0.05 sec
  8. - 分开 一直在 在什么 一九我 一九我 一场我 印诉安 一诉我 印一安 一片段 有一段 装片么 装片么
  9. - 不分开 我想好这我 你着 一直走 我想想这生活 我知不觉 我该好好节奏 后知后觉 我该好好节活 后知后觉
  10. epoch 160, perplexity 3.566475, time 0.05 sec
  11. - 分开 一小我 是子 什么 一诉我的见袋 干真是 干什么 什么我的功袋 干真用 干什么 什么我的爱袋 干真
  12. - 不分开 我不了假不经 我想就 我不了 我不要 说我么么的久有 是是是看医着我 说说是我满腔的怒火 我想揍

6.8.6. 小结

  • 长短期记忆的隐藏层输出包括隐藏状态和记忆细胞。只有隐藏状态会传递到输出层。
  • 长短期记忆的输入门、遗忘门和输出门可以控制信息的流动。
  • 长短期记忆可以应对循环神经网络中的梯度衰减问题,并更好地捕捉时间序列中时间步距离较大的依赖关系。

6.8.7. 练习

  • 调节超参数,观察并分析对运行时间、困惑度以及创作歌词的结果造成的影响。
  • 在相同条件下,比较长短期记忆、门控循环单元和不带门控的循环神经网络的运行时间。
  • 既然候选记忆细胞已通过使用tanh函数确保值域在-1到1之间,为什么隐藏状态还需要再次使用tanh函数来确保输出值域在-1到1之间?

6.8.8. 参考文献

[1] Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory.Neural computation, 9(8), 1735-1780.