在 bAbI 数据集上训练一个记忆网络。

参考文献:

120 轮迭代后,在 'single_supporting_fact_10k' 任务上达到了 98.6% 的准确率。每轮迭代时间: 3s on CPU (core i7).

  1. from __future__ import print_function
  2. from keras.models import Sequential, Model
  3. from keras.layers.embeddings import Embedding
  4. from keras.layers import Input, Activation, Dense, Permute, Dropout
  5. from keras.layers import add, dot, concatenate
  6. from keras.layers import LSTM
  7. from keras.utils.data_utils import get_file
  8. from keras.preprocessing.sequence import pad_sequences
  9. from functools import reduce
  10. import tarfile
  11. import numpy as np
  12. import re
  13. def tokenize(sent):
  14. '''返回包含标点符号的句子的标记。
  15. >>> tokenize('Bob dropped the apple. Where is the apple?')
  16. ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
  17. '''
  18. return [x.strip() for x in re.split(r'(\W+)?', sent) if x.strip()]
  19. def parse_stories(lines, only_supporting=False):
  20. '''解析 bAbi 任务格式中提供的故事
  21. 如果 only_supporting 为 true,
  22. 则只保留支持答案的句子。
  23. '''
  24. data = []
  25. story = []
  26. for line in lines:
  27. line = line.decode('utf-8').strip()
  28. nid, line = line.split(' ', 1)
  29. nid = int(nid)
  30. if nid == 1:
  31. story = []
  32. if '\t' in line:
  33. q, a, supporting = line.split('\t')
  34. q = tokenize(q)
  35. if only_supporting:
  36. # 只选择相关的子故事
  37. supporting = map(int, supporting.split())
  38. substory = [story[i - 1] for i in supporting]
  39. else:
  40. # 提供所有子故事
  41. substory = [x for x in story if x]
  42. data.append((substory, q, a))
  43. story.append('')
  44. else:
  45. sent = tokenize(line)
  46. story.append(sent)
  47. return data
  48. def get_stories(f, only_supporting=False, max_length=None):
  49. '''给定文件名,读取文件,检索故事,
  50. 然后将句子转换为一个独立故事。
  51. 如果提供了 max_length,
  52. 任何长于 max_length 的故事都将被丢弃。
  53. '''
  54. data = parse_stories(f.readlines(), only_supporting=only_supporting)
  55. flatten = lambda data: reduce(lambda x, y: x + y, data)
  56. data = [(flatten(story), q, answer) for story, q, answer in data
  57. if not max_length or len(flatten(story)) < max_length]
  58. return data
  59. def vectorize_stories(data):
  60. inputs, queries, answers = [], [], []
  61. for story, query, answer in data:
  62. inputs.append([word_idx[w] for w in story])
  63. queries.append([word_idx[w] for w in query])
  64. answers.append(word_idx[answer])
  65. return (pad_sequences(inputs, maxlen=story_maxlen),
  66. pad_sequences(queries, maxlen=query_maxlen),
  67. np.array(answers))
  68. try:
  69. path = get_file('babi-tasks-v1-2.tar.gz',
  70. origin='https://s3.amazonaws.com/text-datasets/'
  71. 'babi_tasks_1-20_v1-2.tar.gz')
  72. except:
  73. print('Error downloading dataset, please download it manually:\n'
  74. '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
  75. '.tar.gz\n'
  76. '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
  77. raise
  78. challenges = {
  79. # QA1 任务,10,000 样本
  80. 'single_supporting_fact_10k': 'tasks_1-20_v1-2/en-10k/qa1_'
  81. 'single-supporting-fact_{}.txt',
  82. # QA2 任务,1000 样本
  83. 'two_supporting_facts_10k': 'tasks_1-20_v1-2/en-10k/qa2_'
  84. 'two-supporting-facts_{}.txt',
  85. }
  86. challenge_type = 'single_supporting_fact_10k'
  87. challenge = challenges[challenge_type]
  88. print('Extracting stories for the challenge:', challenge_type)
  89. with tarfile.open(path) as tar:
  90. train_stories = get_stories(tar.extractfile(challenge.format('train')))
  91. test_stories = get_stories(tar.extractfile(challenge.format('test')))
  92. vocab = set()
  93. for story, q, answer in train_stories + test_stories:
  94. vocab |= set(story + q + [answer])
  95. vocab = sorted(vocab)
  96. # 保留 0 以留作 pad_sequences 进行 masking
  97. vocab_size = len(vocab) + 1
  98. story_maxlen = max(map(len, (x for x, _, _ in train_stories + test_stories)))
  99. query_maxlen = max(map(len, (x for _, x, _ in train_stories + test_stories)))
  100. print('-')
  101. print('Vocab size:', vocab_size, 'unique words')
  102. print('Story max length:', story_maxlen, 'words')
  103. print('Query max length:', query_maxlen, 'words')
  104. print('Number of training stories:', len(train_stories))
  105. print('Number of test stories:', len(test_stories))
  106. print('-')
  107. print('Here\'s what a "story" tuple looks like (input, query, answer):')
  108. print(train_stories[0])
  109. print('-')
  110. print('Vectorizing the word sequences...')
  111. word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
  112. inputs_train, queries_train, answers_train = vectorize_stories(train_stories)
  113. inputs_test, queries_test, answers_test = vectorize_stories(test_stories)
  114. print('-')
  115. print('inputs: integer tensor of shape (samples, max_length)')
  116. print('inputs_train shape:', inputs_train.shape)
  117. print('inputs_test shape:', inputs_test.shape)
  118. print('-')
  119. print('queries: integer tensor of shape (samples, max_length)')
  120. print('queries_train shape:', queries_train.shape)
  121. print('queries_test shape:', queries_test.shape)
  122. print('-')
  123. print('answers: binary (1 or 0) tensor of shape (samples, vocab_size)')
  124. print('answers_train shape:', answers_train.shape)
  125. print('answers_test shape:', answers_test.shape)
  126. print('-')
  127. print('Compiling...')
  128. # 占位符
  129. input_sequence = Input((story_maxlen,))
  130. question = Input((query_maxlen,))
  131. # 编码器
  132. # 将输入序列编码为向量的序列
  133. input_encoder_m = Sequential()
  134. input_encoder_m.add(Embedding(input_dim=vocab_size,
  135. output_dim=64))
  136. input_encoder_m.add(Dropout(0.3))
  137. # 输出: (samples, story_maxlen, embedding_dim)
  138. # 将输入编码为的向量的序列(向量尺寸为 query_maxlen)
  139. input_encoder_c = Sequential()
  140. input_encoder_c.add(Embedding(input_dim=vocab_size,
  141. output_dim=query_maxlen))
  142. input_encoder_c.add(Dropout(0.3))
  143. # 输出: (samples, story_maxlen, query_maxlen)
  144. # 将问题编码为向量的序列
  145. question_encoder = Sequential()
  146. question_encoder.add(Embedding(input_dim=vocab_size,
  147. output_dim=64,
  148. input_length=query_maxlen))
  149. question_encoder.add(Dropout(0.3))
  150. # 输出: (samples, query_maxlen, embedding_dim)
  151. # 编码输入序列和问题(均已索引化)为密集向量的序列
  152. input_encoded_m = input_encoder_m(input_sequence)
  153. input_encoded_c = input_encoder_c(input_sequence)
  154. question_encoded = question_encoder(question)
  155. # 计算第一个输入向量和问题向量序列的『匹配』('match')
  156. # 尺寸: `(samples, story_maxlen, query_maxlen)`
  157. match = dot([input_encoded_m, question_encoded], axes=(2, 2))
  158. match = Activation('softmax')(match)
  159. # 将匹配矩阵与第二个输入向量序列相加
  160. response = add([match, input_encoded_c]) # (samples, story_maxlen, query_maxlen)
  161. response = Permute((2, 1))(response) # (samples, query_maxlen, story_maxlen)
  162. # 拼接匹配矩阵和问题向量序列
  163. answer = concatenate([response, question_encoded])
  164. # 原始论文使用一个矩阵乘法来进行归约操作。
  165. # 我们在此选择使用 RNN。
  166. answer = LSTM(32)(answer) # (samples, 32)
  167. # 一个正则化层 - 可能还需要更多层
  168. answer = Dropout(0.3)(answer)
  169. answer = Dense(vocab_size)(answer) # (samples, vocab_size)
  170. # 输出词汇表的一个概率分布
  171. answer = Activation('softmax')(answer)
  172. # 构建最终模型
  173. model = Model([input_sequence, question], answer)
  174. model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy',
  175. metrics=['accuracy'])
  176. # 训练
  177. model.fit([inputs_train, queries_train], answers_train,
  178. batch_size=32,
  179. epochs=120,
  180. validation_data=([inputs_test, queries_test], answers_test))