Trains two recurrent neural networks based upon a story and a question.

The resulting merged vector is then queried to answer a range of bAbI tasks.

The results are comparable to those for an LSTM model provided in Weston et al.:"Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks"http://arxiv.org/abs/1502.05698

Task NumberFB LSTM BaselineKeras QA
QA1 - Single Supporting Fact5052.1
QA2 - Two Supporting Facts2037.0
QA3 - Three Supporting Facts2020.5
QA4 - Two Arg. Relations6162.9
QA5 - Three Arg. Relations7061.9
QA6 - yes/No Questions4850.7
QA7 - Counting4978.9
QA8 - Lists/Sets4577.2
QA9 - Simple Negation6464.0
QA10 - Indefinite Knowledge4447.7
QA11 - Basic Coreference7274.9
QA12 - Conjunction7476.4
QA13 - Compound Coreference9494.4
QA14 - Time Reasoning2734.8
QA15 - Basic Deduction2132.4
QA16 - Basic Induction2350.6
QA17 - Positional Reasoning5149.1
QA18 - Size Reasoning5290.8
QA19 - Path Finding89.0
QA20 - Agent's Motivations9190.7

For the resources related to the bAbI project, refer to:https://research.facebook.com/researchers/1543934539189348

Notes

  • With default word, sentence, and query vector sizes, the GRU model achieves:
  • 52.1% test accuracy on QA1 in 20 epochs (2 seconds per epoch on CPU)
  • 37.0% test accuracy on QA2 in 20 epochs (16 seconds per epoch on CPU)In comparison, the Facebook paper achieves 50% and 20% for the LSTM baseline.

  • The task does not traditionally parse the question separately. This likelyimproves accuracy and is a good example of merging two RNNs.

  • The word vector embeddings are not shared between the story and question RNNs.

  • See how the accuracy changes given 10,000 training samples (en-10k) insteadof only 1000. 1000 was used in order to be comparable to the original paper.

  • Experiment with GRU, LSTM, and JZS1-3 as they give subtly different results.

  • The length and noise (i.e. 'useless' story components) impact the ability ofLSTMs / GRUs to provide the correct answer. Given only the supporting facts,these RNNs can achieve 100% accuracy on many tasks. Memory networks and neuralnetworks that use attentional processes can efficiently search through thisnoise to find the relevant statements, improving performance substantially.This becomes especially obvious on QA2 and QA3, both far longer than QA1.

  1. from __future__ import print_function
  2. from functools import reduce
  3. import re
  4. import tarfile
  5. import numpy as np
  6. from keras.utils.data_utils import get_file
  7. from keras.layers.embeddings import Embedding
  8. from keras import layers
  9. from keras.layers import recurrent
  10. from keras.models import Model
  11. from keras.preprocessing.sequence import pad_sequences
  12. def tokenize(sent):
  13. '''Return the tokens of a sentence including punctuation.
  14. >>> tokenize('Bob dropped the apple. Where is the apple?')
  15. ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
  16. '''
  17. return [x.strip() for x in re.split(r'(\W+)', sent) if x.strip()]
  18. def parse_stories(lines, only_supporting=False):
  19. '''Parse stories provided in the bAbi tasks format
  20. If only_supporting is true,
  21. only the sentences that support the answer are kept.
  22. '''
  23. data = []
  24. story = []
  25. for line in lines:
  26. line = line.decode('utf-8').strip()
  27. nid, line = line.split(' ', 1)
  28. nid = int(nid)
  29. if nid == 1:
  30. story = []
  31. if '\t' in line:
  32. q, a, supporting = line.split('\t')
  33. q = tokenize(q)
  34. if only_supporting:
  35. # Only select the related substory
  36. supporting = map(int, supporting.split())
  37. substory = [story[i - 1] for i in supporting]
  38. else:
  39. # Provide all the substories
  40. substory = [x for x in story if x]
  41. data.append((substory, q, a))
  42. story.append('')
  43. else:
  44. sent = tokenize(line)
  45. story.append(sent)
  46. return data
  47. def get_stories(f, only_supporting=False, max_length=None):
  48. '''Given a file name, read the file, retrieve the stories,
  49. and then convert the sentences into a single story.
  50. If max_length is supplied,
  51. any stories longer than max_length tokens will be discarded.
  52. '''
  53. data = parse_stories(f.readlines(), only_supporting=only_supporting)
  54. flatten = lambda data: reduce(lambda x, y: x + y, data)
  55. data = [(flatten(story), q, answer) for story, q, answer in data
  56. if not max_length or len(flatten(story)) < max_length]
  57. return data
  58. def vectorize_stories(data, word_idx, story_maxlen, query_maxlen):
  59. xs = []
  60. xqs = []
  61. ys = []
  62. for story, query, answer in data:
  63. x = [word_idx[w] for w in story]
  64. xq = [word_idx[w] for w in query]
  65. # let's not forget that index 0 is reserved
  66. y = np.zeros(len(word_idx) + 1)
  67. y[word_idx[answer]] = 1
  68. xs.append(x)
  69. xqs.append(xq)
  70. ys.append(y)
  71. return (pad_sequences(xs, maxlen=story_maxlen),
  72. pad_sequences(xqs, maxlen=query_maxlen), np.array(ys))
  73. RNN = recurrent.LSTM
  74. EMBED_HIDDEN_SIZE = 50
  75. SENT_HIDDEN_SIZE = 100
  76. QUERY_HIDDEN_SIZE = 100
  77. BATCH_SIZE = 32
  78. EPOCHS = 20
  79. print('RNN / Embed / Sent / Query = {}, {}, {}, {}'.format(RNN,
  80. EMBED_HIDDEN_SIZE,
  81. SENT_HIDDEN_SIZE,
  82. QUERY_HIDDEN_SIZE))
  83. try:
  84. path = get_file('babi-tasks-v1-2.tar.gz',
  85. origin='https://s3.amazonaws.com/text-datasets/'
  86. 'babi_tasks_1-20_v1-2.tar.gz')
  87. except:
  88. print('Error downloading dataset, please download it manually:\n'
  89. '$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2'
  90. '.tar.gz\n'
  91. '$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz')
  92. raise
  93. # Default QA1 with 1000 samples
  94. # challenge = 'tasks_1-20_v1-2/en/qa1_single-supporting-fact_{}.txt'
  95. # QA1 with 10,000 samples
  96. # challenge = 'tasks_1-20_v1-2/en-10k/qa1_single-supporting-fact_{}.txt'
  97. # QA2 with 1000 samples
  98. challenge = 'tasks_1-20_v1-2/en/qa2_two-supporting-facts_{}.txt'
  99. # QA2 with 10,000 samples
  100. # challenge = 'tasks_1-20_v1-2/en-10k/qa2_two-supporting-facts_{}.txt'
  101. with tarfile.open(path) as tar:
  102. train = get_stories(tar.extractfile(challenge.format('train')))
  103. test = get_stories(tar.extractfile(challenge.format('test')))
  104. vocab = set()
  105. for story, q, answer in train + test:
  106. vocab |= set(story + q + [answer])
  107. vocab = sorted(vocab)
  108. # Reserve 0 for masking via pad_sequences
  109. vocab_size = len(vocab) + 1
  110. word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
  111. story_maxlen = max(map(len, (x for x, _, _ in train + test)))
  112. query_maxlen = max(map(len, (x for _, x, _ in train + test)))
  113. x, xq, y = vectorize_stories(train, word_idx, story_maxlen, query_maxlen)
  114. tx, txq, ty = vectorize_stories(test, word_idx, story_maxlen, query_maxlen)
  115. print('vocab = {}'.format(vocab))
  116. print('x.shape = {}'.format(x.shape))
  117. print('xq.shape = {}'.format(xq.shape))
  118. print('y.shape = {}'.format(y.shape))
  119. print('story_maxlen, query_maxlen = {}, {}'.format(story_maxlen, query_maxlen))
  120. print('Build model...')
  121. sentence = layers.Input(shape=(story_maxlen,), dtype='int32')
  122. encoded_sentence = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(sentence)
  123. encoded_sentence = RNN(SENT_HIDDEN_SIZE)(encoded_sentence)
  124. question = layers.Input(shape=(query_maxlen,), dtype='int32')
  125. encoded_question = layers.Embedding(vocab_size, EMBED_HIDDEN_SIZE)(question)
  126. encoded_question = RNN(QUERY_HIDDEN_SIZE)(encoded_question)
  127. merged = layers.concatenate([encoded_sentence, encoded_question])
  128. preds = layers.Dense(vocab_size, activation='softmax')(merged)
  129. model = Model([sentence, question], preds)
  130. model.compile(optimizer='adam',
  131. loss='categorical_crossentropy',
  132. metrics=['accuracy'])
  133. print('Training')
  134. model.fit([x, xq], y,
  135. batch_size=BATCH_SIZE,
  136. epochs=EPOCHS,
  137. validation_split=0.05)
  138. print('Evaluation')
  139. loss, acc = model.evaluate([tx, txq], ty,
  140. batch_size=BATCH_SIZE)
  141. print('Test loss / test accuracy = {:.4f} / {:.4f}'.format(loss, acc))