使用飞桨实现Skip-gram

接下来我们将学习使用飞桨实现Skio-gram模型的方法。在飞桨中,不同深度学习模型的训练过程基本一致,流程如下:

  • 数据处理:选择需要使用的数据,并做好必要的预处理工作。

  • 网络定义:使用飞桨定义好网络结构,包括输入层,中间层,输出层,损失函数和优化算法。

  • 网络训练:将准备好的数据送入神经网络进行学习,并观察学习的过程是否正常,如损失函数值是否在降低,也可以打印一些中间步骤的结果出来等。

  • 网络评估:使用测试集合测试训练好的神经网络,看看训练效果如何。

在数据处理前,需要先加载飞桨平台(如果用户在本地使用,请确保已经安装飞桨)。

  1. #encoding=utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import io
  16. import os
  17. import sys
  18. import requests
  19. from collections import OrderedDict
  20. import math
  21. import random
  22. import numpy as np
  23. import paddle
  24. import paddle.fluid as fluid
  25. from paddle.fluid.dygraph.nn import Embedding

数据处理

首先,找到一个合适的语料用于训练word2vec模型。我们选择text8数据集,这个数据集里包含了大量从维基百科收集到的英文语料,我们可以通过如下代码下载数据集,下载后的文件被保存在当前目录的text8.txt文件内。

  1. #下载语料用来训练word2vec
  2. def download():
  3. #可以从百度云服务器下载一些开源数据集(dataset.bj.bcebos.com)
  4. corpus_url = "https://dataset.bj.bcebos.com/word2vec/text8.txt"
  5. #使用python的requests包下载数据集到本地
  6. web_request = requests.get(corpus_url)
  7. corpus = web_request.content
  8. #把下载后的文件存储在当前目录的text8.txt文件内
  9. with open("./text8.txt", "wb") as f:
  10. f.write(corpus)
  11. f.close()
  12. download()

接下来,把下载的语料读取到程序里,并打印前500个字符看看语料的样子,代码如下:

  1. #读取text8数据
  2. def load_text8():
  3. with open("./text8.txt", "r") as f:
  4. corpus = f.read().strip("\n")
  5. f.close()
  6. return corpus
  7. corpus = load_text8()
  8. #打印前500个字符,简要看一下这个语料的样子
  9. print(corpus[:500])
  1. anarchism originated as a term of abuse first used against early working class radicals including the diggers of the english revolution and the sans culottes of the french revolution whilst the term is still used in a pejorative way to describe any act that used violent means to destroy the organization of society it has also been taken up as a positive label by self defined anarchists the word anarchism is derived from the greek without archons ruler chief king anarchism as a political philoso

一般来说,在自然语言处理中,需要先对语料进行切词。对于英文来说,可以比较简单地直接使用空格进行切词,代码如下:

  1. #对语料进行预处理(分词)
  2. def data_preprocess(corpus):
  3. #由于英文单词出现在句首的时候经常要大写,所以我们把所有英文字符都转换为小写,
  4. #以便对语料进行归一化处理(Apple vs apple等)
  5. corpus = corpus.strip().lower()
  6. corpus = corpus.split(" ")
  7. return corpus
  8. corpus = data_preprocess(corpus)
  9. print(corpus[:50])
  1. ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against', 'early', 'working', 'class', 'radicals', 'including', 'the', 'diggers', 'of', 'the', 'english', 'revolution', 'and', 'the', 'sans', 'culottes', 'of', 'the', 'french', 'revolution', 'whilst', 'the', 'term', 'is', 'still', 'used', 'in', 'a', 'pejorative', 'way', 'to', 'describe', 'any', 'act', 'that', 'used', 'violent', 'means', 'to', 'destroy', 'the']

在经过切词后,需要对语料进行统计,为每个词构造ID。一般来说,可以根据每个词在语料中出现的频次构造ID,频次越高,ID越小,便于对词典进行管理。代码如下:

  1. #构造词典,统计每个词的频率,并根据频率将每个词转换为一个整数id
  2. def build_dict(corpus):
  3. #首先统计每个不同词的频率(出现的次数),使用一个词典记录
  4. word_freq_dict = dict()
  5. for word in corpus:
  6. if word not in word_freq_dict:
  7. word_freq_dict[word] = 0
  8. word_freq_dict[word] += 1
  9. #将这个词典中的词,按照出现次数排序,出现次数越高,排序越靠前
  10. #一般来说,出现频率高的高频词往往是:I,the,you这种代词,而出现频率低的词,往往是一些名词,如:nlp
  11. word_freq_dict = sorted(word_freq_dict.items(), key = lambda x:x[1], reverse = True)
  12. #构造3个不同的词典,分别存储,
  13. #每个词到id的映射关系:word2id_dict
  14. #每个id出现的频率:word2id_freq
  15. #每个id到词典映射关系:id2word_dict
  16. word2id_dict = dict()
  17. word2id_freq = dict()
  18. id2word_dict = dict()
  19. #按照频率,从高到低,开始遍历每个单词,并为这个单词构造一个独一无二的id
  20. for word, freq in word_freq_dict:
  21. curr_id = len(word2id_dict)
  22. word2id_dict[word] = curr_id
  23. word2id_freq[word2id_dict[word]] = freq
  24. id2word_dict[curr_id] = word
  25. return word2id_freq, word2id_dict, id2word_dict
  26. word2id_freq, word2id_dict, id2word_dict = build_dict(corpus)
  27. vocab_size = len(word2id_freq)
  28. print("there are totoally %d different words in the corpus" % vocab_size)
  29. for _, (word, word_id) in zip(range(50), word2id_dict.items()):
  30. print("word %s, its id %d, its word freq %d" % (word, word_id, word2id_freq[word_id]))
  1. there are totoally 253854 different words in the corpus
  2. word the, its id 0, its word freq 1061396
  3. word of, its id 1, its word freq 593677
  4. word and, its id 2, its word freq 416629
  5. word one, its id 3, its word freq 411764
  6. word in, its id 4, its word freq 372201
  7. word a, its id 5, its word freq 325873
  8. word to, its id 6, its word freq 316376
  9. word zero, its id 7, its word freq 264975
  10. word nine, its id 8, its word freq 250430
  11. word two, its id 9, its word freq 192644
  12. word is, its id 10, its word freq 183153
  13. word as, its id 11, its word freq 131815
  14. word eight, its id 12, its word freq 125285
  15. word for, its id 13, its word freq 118445
  16. word s, its id 14, its word freq 116710
  17. word five, its id 15, its word freq 115789
  18. word three, its id 16, its word freq 114775
  19. word was, its id 17, its word freq 112807
  20. word by, its id 18, its word freq 111831
  21. word that, its id 19, its word freq 109510
  22. word four, its id 20, its word freq 108182
  23. word six, its id 21, its word freq 102145
  24. word seven, its id 22, its word freq 99683
  25. word with, its id 23, its word freq 95603
  26. word on, its id 24, its word freq 91250
  27. word are, its id 25, its word freq 76527
  28. word it, its id 26, its word freq 73334
  29. word from, its id 27, its word freq 72871
  30. word or, its id 28, its word freq 68945
  31. word his, its id 29, its word freq 62603
  32. word an, its id 30, its word freq 61925
  33. word be, its id 31, its word freq 61281
  34. word this, its id 32, its word freq 58832
  35. word which, its id 33, its word freq 54788
  36. word at, its id 34, its word freq 54576
  37. word he, its id 35, its word freq 53573
  38. word also, its id 36, its word freq 44358
  39. word not, its id 37, its word freq 44033
  40. word have, its id 38, its word freq 39712
  41. word were, its id 39, its word freq 39086
  42. word has, its id 40, its word freq 37866
  43. word but, its id 41, its word freq 35358
  44. word other, its id 42, its word freq 32433
  45. word their, its id 43, its word freq 31523
  46. word its, its id 44, its word freq 29567
  47. word first, its id 45, its word freq 28810
  48. word they, its id 46, its word freq 28553
  49. word some, its id 47, its word freq 28161
  50. word had, its id 48, its word freq 28100
  51. word all, its id 49, its word freq 26229

得到word2id词典后,我们还需要进一步处理原始语料,把每个词替换成对应的ID,便于神经网络进行处理,代码如下:

  1. #把语料转换为id序列
  2. def convert_corpus_to_id(corpus, word2id_dict):
  3. #使用一个循环,将语料中的每个词替换成对应的id,以便于神经网络进行处理
  4. corpus = [word2id_dict[word] for word in corpus]
  5. return corpus
  6. corpus = convert_corpus_to_id(corpus, word2id_dict)
  7. print("%d tokens in the corpus" % len(corpus))
  8. print(corpus[:50])
  1. 17005207 tokens in the corpus
  2. [5233, 3080, 11, 5, 194, 1, 3133, 45, 58, 155, 127, 741, 476, 10571, 133, 0, 27349, 1, 0, 102, 854, 2, 0, 15067, 58112, 1, 0, 150, 854, 3580, 0, 194, 10, 190, 58, 4, 5, 10712, 214, 6, 1324, 104, 454, 19, 58, 2731, 362, 6, 3672, 0]

接下来,需要使用二次采样法处理原始文本。二次采样法的主要思想是降低高频词在语料中出现的频次,从而优化整个词表的词向量训练效果,代码如下:

  1. #使用二次采样算法(subsampling)处理语料,强化训练效果
  2. def subsampling(corpus, word2id_freq):
  3. #这个discard函数决定了一个词会不会被替换,这个函数是具有随机性的,每次调用结果不同
  4. #如果一个词的频率很大,那么它被遗弃的概率就很大
  5. def discard(word_id):
  6. return random.uniform(0, 1) < 1 - math.sqrt(
  7. 1e-4 / word2id_freq[word_id] * len(corpus))
  8. corpus = [word for word in corpus if not discard(word)]
  9. return corpus
  10. corpus = subsampling(corpus, word2id_freq)
  11. print("%d tokens in the corpus" % len(corpus))
  12. print(corpus[:50])
  1. 8743041 tokens in the corpus
  2. [5233, 3080, 194, 3133, 45, 58, 476, 10571, 27349, 102, 854, 0, 15067, 58112, 150, 854, 3580, 10712, 1324, 454, 2731, 3672, 708, 371, 539, 97, 1423, 2757, 567, 686, 7088, 247, 5233, 1052, 320, 44611, 2877, 792, 186, 5233, 602, 1134, 2621, 25, 8983, 4147, 59, 6437, 4186, 362]

在完成语料数据预处理之后,需要构造训练数据。根据上面的描述,我们需要使用一个滑动窗口对语料从左到右扫描,在每个窗口内,中心词需要预测它的上下文,并形成训练数据。

在实际操作中,由于词表往往很大(50000,100000等),对大词表的一些矩阵运算(如softmax)需要消耗巨大的资源,因此可以通过负采样的方式模拟softmax的结果,代码实现如下。

  • 给定一个中心词和一个需要预测的上下文词,把这个上下文词作为正样本。
  • 通过词表随机采样的方式,选择若干个负样本。
  • 把一个大规模分类问题转化为一个2分类问题,通过这种方式优化计算速度。
  1. #构造数据,准备模型训练
  2. #max_window_size代表了最大的window_size的大小,程序会根据max_window_size从左到右扫描整个语料
  3. #negative_sample_num代表了对于每个正样本,我们需要随机采样多少负样本用于训练,
  4. #一般来说,negative_sample_num的值越大,训练效果越稳定,但是训练速度越慢。
  5. def build_data(corpus, word2id_dict, word2id_freq, max_window_size = 3, negative_sample_num = 4):
  6. #使用一个list存储处理好的数据
  7. dataset = []
  8. #从左到右,开始枚举每个中心点的位置
  9. for center_word_idx in range(len(corpus)):
  10. #以max_window_size为上限,随机采样一个window_size,这样会使得训练更加稳定
  11. window_size = random.randint(1, max_window_size)
  12. #当前的中心词就是center_word_idx所指向的词
  13. center_word = corpus[center_word_idx]
  14. #以当前中心词为中心,左右两侧在window_size内的词都可以看成是正样本
  15. positive_word_range = (max(0, center_word_idx - window_size), min(len(corpus) - 1, center_word_idx + window_size))
  16. positive_word_candidates = [corpus[idx] for idx in range(positive_word_range[0], positive_word_range[1]+1) if idx != center_word_idx]
  17. #对于每个正样本来说,随机采样negative_sample_num个负样本,用于训练
  18. for positive_word in positive_word_candidates:
  19. #首先把(中心词,正样本,label=1)的三元组数据放入dataset中,
  20. #这里label=1表示这个样本是个正样本
  21. dataset.append((center_word, positive_word, 1))
  22. #开始负采样
  23. i = 0
  24. while i < negative_sample_num:
  25. negative_word_candidate = random.randint(0, vocab_size-1)
  26. if negative_word_candidate not in positive_word_candidates:
  27. #把(中心词,正样本,label=0)的三元组数据放入dataset中,
  28. #这里label=0表示这个样本是个负样本
  29. dataset.append((center_word, negative_word_candidate, 0))
  30. i += 1
  31. return dataset
  32. dataset = build_data(corpus, word2id_dict, word2id_freq)
  33. for _, (center_word, target_word, label) in zip(range(50), dataset):
  34. print("center_word %s, target %s, label %d" % (id2word_dict[center_word],
  35. id2word_dict[target_word], label))

训练数据准备好后,把训练数据都组装成mini-batch,并准备输入到网络中进行训练,代码如下:

  1. #构造mini-batch,准备对模型进行训练
  2. #我们将不同类型的数据放到不同的tensor里,便于神经网络进行处理
  3. #并通过numpy的array函数,构造出不同的tensor来,并把这些tensor送入神经网络中进行训练
  4. def build_batch(dataset, batch_size, epoch_num):
  5. #center_word_batch缓存batch_size个中心词
  6. center_word_batch = []
  7. #target_word_batch缓存batch_size个目标词(可以是正样本或者负样本)
  8. target_word_batch = []
  9. #label_batch缓存了batch_size个0或1的标签,用于模型训练
  10. label_batch = []
  11. for epoch in range(epoch_num):
  12. #每次开启一个新epoch之前,都对数据进行一次随机打乱,提高训练效果
  13. random.shuffle(dataset)
  14. for center_word, target_word, label in dataset:
  15. #遍历dataset中的每个样本,并将这些数据送到不同的tensor里
  16. center_word_batch.append([center_word])
  17. target_word_batch.append([target_word])
  18. label_batch.append(label)
  19. #当样本积攒到一个batch_size后,我们把数据都返回回来
  20. #在这里我们使用numpy的array函数把list封装成tensor
  21. #并使用python的迭代器机制,将数据yield出来
  22. #使用迭代器的好处是可以节省内存
  23. if len(center_word_batch) == batch_size:
  24. yield np.array(center_word_batch).astype("int64"), \
  25. np.array(target_word_batch).astype("int64"), \
  26. np.array(label_batch).astype("float32")
  27. center_word_batch = []
  28. target_word_batch = []
  29. label_batch = []
  30. if len(center_word_batch) > 0:
  31. yield np.array(center_word_batch).astype("int64"), \
  32. np.array(target_word_batch).astype("int64"), \
  33. np.array(label_batch).astype("float32")
  34. for _, batch in zip(range(10), build_batch(dataset, 128, 3)):
  35. print(batch)
  1. [ 1265],
  2. [ 3],
  3. [ 65],
  4. [ 1244],
  5. [ 9598],
  6. [ 953],
  7. [130727],
  8. [ 2855],
  9. [ 94577],
  10. [ 6823],
  11. [ 38536],
  12. [213517],
  13. [ 515],
  14. [ 3621],
  15. [ 44],
  16. [ 4],
  17. [ 5263],
  18. [244680],
  19. [ 1463],
  20. [ 1360],
  21. [ 13489],
  22. [ 1120],
  23. [ 369],
  24. [ 9670],
  25. [ 1143],
  26. [ 2282],
  27. [ 2800],
  28. [ 1311],
  29. [ 109],
  30. [ 1840],
  31. [ 300],
  32. [ 10293],
  33. [ 5],
  34. [ 43],
  35. [ 54],
  36. [ 1760],
  37. [ 82],
  38. [ 14908],
  39. [ 10663],
  40. [ 58],
  41. [ 2296],
  42. [ 1965],
  43. [ 5199],
  44. [ 7810],
  45. [ 1439],
  46. [ 4882],
  47. [ 2],
  48. [123768],
  49. [ 230],
  50. [ 428],
  51. [ 415],
  52. [ 7468],
  53. [ 32],
  54. [ 13],
  55. [ 16],
  56. [ 1927],
  57. [ 1554],
  58. [ 2478],
  59. [ 10998],
  60. [ 1421],
  61. [ 6733],
  62. [ 68],
  63. [ 5244],
  64. [ 1495],
  65. [ 4581],
  66. [ 13],
  67. [ 44],
  68. [ 196],
  69. [ 2104],
  70. [ 15],
  71. [ 7030],
  72. [ 1894],
  73. [ 184],
  74. [ 1855],
  75. [ 3448],
  76. [ 679],
  77. [ 11555],
  78. [ 1634],
  79. [ 1278]]), array([[115149],
  80. [227413],
  81. [114660],
  82. [135450],
  83. [ 51389],
  84. [ 57271],
  85. [ 536],
  86. [228469],
  87. [241920],
  88. [175377],
  89. [ 385],
  90. [109229],
  91. [ 34997],
  92. [ 10101],
  93. [165518],
  94. [ 6388],
  95. [ 29172],
  96. [ 4193],
  97. [212822],
  98. [ 90],
  99. [ 19506],
  100. [ 11286],
  101. [208224],
  102. [ 3322],
  103. [159231],
  104. [ 93980],
  105. [ 13407],
  106. [149501],
  107. [201439],
  108. [ 86815],
  109. [ 22627],
  110. [ 1004],
  111. [ 48841],
  112. [220094],
  113. [227446],
  114. [ 83887],
  115. [199594],
  116. [119689],
  117. [ 55337],
  118. [173126],
  119. [172462],
  120. [104876],
  121. [ 32],
  122. [ 54893],
  123. [121031],
  124. [ 10926],
  125. [ 78656],
  126. [ 20295],
  127. [209221],
  128. [204856],
  129. [115120],
  130. [153562],
  131. [189289],
  132. [131977],
  133. [ 3250],
  134. [ 330],
  135. [ 37386],
  136. [171896],
  137. [133547],
  138. [141805],
  139. [ 88031],
  140. [193686],
  141. [ 8467],
  142. [147281],
  143. [124915],
  144. [162296],
  145. [131701],
  146. [ 26096],
  147. [ 31606],
  148. [ 84569],
  149. [ 88545],
  150. [210167],
  151. [130172],
  152. [ 62037],
  153. [ 2282],
  154. [ 65160],
  155. [ 82106],
  156. [ 5836],
  157. [ 67230],
  158. [180747],
  159. [ 20330],
  160. [ 241],
  161. [178782],
  162. [ 9447],
  163. [ 50076],
  164. [ 69035],
  165. [165865],
  166. [ 2043],
  167. [180836],
  168. [ 17345],
  169. [ 4653],
  170. [115248],
  171. [228045],
  172. [101045],
  173. [252862],
  174. [ 68642],
  175. [ 22668],
  176. [230309],
  177. [ 26582],
  178. [137038],
  179. [ 78219],
  180. [ 18370],
  181. [147975],
  182. [113080],
  183. [162421],
  184. [251731],
  185. [ 14666],
  186. [137719],
  187. [ 57],
  188. [121412],
  189. [ 13281],
  190. [ 9252],
  191. [ 59957],
  192. [188529],
  193. [ 66243],
  194. [ 90980],
  195. [ 86217],
  196. [162392],
  197. [ 586],
  198. [251889],
  199. [ 6018],
  200. [123037],
  201. [ 8],
  202. [ 36632],
  203. [206473],
  204. [166437],
  205. [ 16612],
  206. [ 21334]]), array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0.,
  207. 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0.,
  208. 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
  209. 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
  210. 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
  211. 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
  212. 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
  213. 0., 0., 0., 1., 0., 0., 0., 1., 0.], dtype=float32))
  214. (array([[ 97],
  215. [ 9313],
  216. [ 6173],
  217. [230581],
  218. [ 1855],
  219. [ 10223],
  220. [ 16021],
  221. [ 2347],
  222. [ 878],
  223. [ 18486],
  224. [ 7],
  225. [ 49],
  226. [ 4077],
  227. [ 8264],
  228. [ 153],
  229. [ 1468],
  230. [ 202],
  231. [ 3166],
  232. [ 278],
  233. [ 34330],
  234. [ 21349],
  235. [ 101],
  236. [ 755],
  237. [ 2],
  238. [ 8289],
  239. [ 259],
  240. [ 3535],
  241. [ 532],
  242. [ 19977],
  243. [ 12752],
  244. [ 93],
  245. [ 8635],
  246. [ 310],
  247. [ 24024],
  248. [ 186],
  249. [ 7051],
  250. [ 1385],
  251. [ 25567],
  252. [ 97880],
  253. [ 1011],
  254. [ 1250],
  255. [ 1268],
  256. [ 21282],
  257. [ 265],
  258. [ 2947],
  259. [ 6034],
  260. [ 8134],
  261. [ 15349],
  262. [ 1283],
  263. [ 14395],
  264. [ 34],
  265. [ 10268],
  266. [ 4856],
  267. [ 150],
  268. [ 8205],
  269. [ 914],
  270. [ 66],
  271. [ 1],
  272. [ 138],
  273. [ 15628],
  274. [ 1289],
  275. [ 4377],
  276. [ 15254],
  277. [ 47492],
  278. [ 5606],
  279. [ 90163],
  280. [ 1516],
  281. [ 283],
  282. [ 1609],
  283. [ 676],
  284. [ 2213],
  285. [ 2052],
  286. [ 3702],
  287. [ 1893],
  288. [ 3399],
  289. [ 4187],
  290. [ 1820],
  291. [ 434],
  292. [ 7597],
  293. [ 461],
  294. [ 3025],
  295. [ 9721],
  296. [ 4721],
  297. [ 1243],
  298. [ 959],
  299. [ 495],
  300. [ 38],
  301. [ 386],
  302. [ 7685],
  303. [ 49754],
  304. [ 811],
  305. [ 7893],
  306. [ 4782],
  307. [ 500],
  308. [ 93],
  309. [ 72],
  310. [ 52],
  311. [ 1183],
  312. [ 1977],
  313. [ 1769],
  314. [ 1336],
  315. [ 1762],
  316. [ 2107],
  317. [ 7293],
  318. [ 707],
  319. [ 87],
  320. [ 1218],
  321. [ 58268],
  322. [ 1406],
  323. [ 7855],
  324. [ 4419],
  325. [ 3974],
  326. [ 19668],
  327. [ 9496],
  328. [ 14167],
  329. [ 126],
  330. [ 42939],
  331. [ 1197],
  332. [ 8757],
  333. [ 1844],
  334. [ 6099],
  335. [ 613],
  336. [ 89],
  337. [ 220],
  338. [ 719],
  339. [ 33031],
  340. [ 512],
  341. [ 2275]]), array([[139722],
  342. [117611],
  343. [ 71213],
  344. [137778],
  345. [181620],
  346. [246897],
  347. [245809],
  348. [ 69878],
  349. [203753],
  350. [140350],
  351. [151153],
  352. [ 61792],
  353. [175281],
  354. [182660],
  355. [103808],
  356. [ 53618],
  357. [ 4565],
  358. [172698],
  359. [ 10613],
  360. [141260],
  361. [147810],
  362. [ 9122],
  363. [151272],
  364. [109948],
  365. [186090],
  366. [142546],
  367. [ 68926],
  368. [ 33715],
  369. [183932],
  370. [ 24213],
  371. [ 223],
  372. [177061],
  373. [157981],
  374. [108024],
  375. [ 14482],
  376. [127396],
  377. [132514],
  378. [ 6154],
  379. [ 6268],
  380. [ 17657],
  381. [233814],
  382. [ 46499],
  383. [165408],
  384. [ 21361],
  385. [ 60],
  386. [218114],
  387. [219449],
  388. [134553],
  389. [195442],
  390. [ 82144],
  391. [177449],
  392. [134277],
  393. [ 21885],
  394. [252401],
  395. [ 53781],
  396. [201727],
  397. [ 1213],
  398. [119397],
  399. [ 84729],
  400. [ 21069],
  401. [ 46868],
  402. [ 12511],
  403. [ 75372],
  404. [ 66137],
  405. [249827],
  406. [154956],
  407. [251185],
  408. [174851],
  409. [ 10662],
  410. [ 10855],
  411. [ 46892],
  412. [ 3855],
  413. [ 685],
  414. [195300],
  415. [ 94768],
  416. [ 41961],
  417. [220710],
  418. [ 92609],
  419. [ 20530],
  420. [193166],
  421. [203658],
  422. [ 638],
  423. [ 42027],
  424. [201522],
  425. [ 571],
  426. [ 37403],
  427. [ 22850],
  428. [232513],
  429. [157891],
  430. [ 96068],
  431. [ 62911],
  432. [ 97486],
  433. [ 61068],
  434. [116719],
  435. [ 16304],
  436. [222864],
  437. [ 502],
  438. [ 401],
  439. [173313],
  440. [ 34471],
  441. [ 1792],
  442. [146981],
  443. [215698],
  444. [191866],
  445. [ 11224],
  446. [203336],
  447. [146134],
  448. [218555],
  449. [125028],
  450. [134726],
  451. [152016],
  452. [112412],
  453. [252208],
  454. [231831],
  455. [ 6483],
  456. [ 92447],
  457. [ 24749],
  458. [ 79838],
  459. [140294],
  460. [187425],
  461. [171602],
  462. [ 2612],
  463. [207986],
  464. [223307],
  465. [ 11],
  466. [129691],
  467. [ 1141],
  468. [ 67218]]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
  469. 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
  470. 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
  471. 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
  472. 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
  473. 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0.,
  474. 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.,
  475. 0., 0., 1., 0., 0., 1., 0., 1., 0.], dtype=float32))
  476. (array([[ 1947],
  477. [ 682],
  478. [ 8161],
  479. [ 411],
  480. [ 18],
  481. [ 63],
  482. [ 70],
  483. [ 3306],
  484. [ 682],
  485. [ 67748],
  486. [ 55906],
  487. [ 32],
  488. [ 214],
  489. [ 145],
  490. [ 3671],
  491. [ 9095],
  492. [ 2],
  493. [ 493],
  494. [ 420],
  495. [ 1354],
  496. [ 8961],
  497. [ 14303],
  498. [ 415],
  499. [ 15926],
  500. [ 17],
  501. [ 349],
  502. [ 2630],
  503. [ 22314],
  504. [ 16501],
  505. [ 62953],
  506. [ 42],
  507. [ 22323],
  508. [ 1971],
  509. [ 11897],
  510. [ 676],
  511. [ 331],
  512. [ 5370],
  513. [ 8495],
  514. [ 97945],
  515. [202893],
  516. [ 1269],
  517. [ 1408],
  518. [ 7126],
  519. [ 667],
  520. [ 848],
  521. [ 2270],
  522. [ 9923],
  523. [ 1570],
  524. [ 766],
  525. [ 81448],
  526. [ 441],
  527. [ 2573],
  528. [ 3848],
  529. [ 17],
  530. [ 17356],
  531. [ 1104],
  532. [ 680],
  533. [ 13898],
  534. [ 291],
  535. [ 6910],
  536. [ 192],
  537. [ 465],
  538. [ 14],
  539. [ 47230],
  540. [ 70948],
  541. [ 25],
  542. [ 601],
  543. [ 1597],
  544. [ 17998],
  545. [ 41060],
  546. [ 64031],
  547. [ 9050],
  548. [ 20561],
  549. [ 1643],
  550. [ 1249],
  551. [ 6884],
  552. [ 5027],
  553. [ 3243],
  554. [ 435],
  555. [ 7433],
  556. [ 154],
  557. [125449],
  558. [ 7280],
  559. [ 250],
  560. [ 5331],
  561. [ 1531],
  562. [ 6366],
  563. [ 7196],
  564. [ 7541],
  565. [ 2688],
  566. [ 8177],
  567. [ 218],
  568. [ 23766],
  569. [ 839],
  570. [ 17025],
  571. [ 481],
  572. [ 2094],
  573. [ 9013],
  574. [ 11217],
  575. [ 45],
  576. [ 1902],
  577. [ 1118],
  578. [ 5940],
  579. [ 1203],
  580. [ 5910],
  581. [ 61184],
  582. [ 2033],
  583. [ 1596],
  584. [ 1836],
  585. [ 13],
  586. [ 3412],
  587. [ 24],
  588. [ 656],
  589. [ 8913],
  590. [ 2386],
  591. [ 2073],
  592. [ 26514],
  593. [ 2],
  594. [ 48],
  595. [ 76355],
  596. [ 8790],
  597. [ 9909],
  598. [ 132],
  599. [ 18374],
  600. [143106],
  601. [ 408],
  602. [ 220],
  603. [ 3538]]), array([[139296],
  604. [ 4880],
  605. [ 1583],
  606. [152999],
  607. [ 7458],
  608. [186204],
  609. [227328],
  610. [ 12555],
  611. [220131],
  612. [ 1572],
  613. [140465],
  614. [ 1711],
  615. [ 68521],
  616. [ 84679],
  617. [ 1278],
  618. [152164],
  619. [ 30258],
  620. [175972],
  621. [246667],
  622. [148560],
  623. [ 4338],
  624. [ 82572],
  625. [182056],
  626. [ 6849],
  627. [ 21765],
  628. [ 721],
  629. [146832],
  630. [ 16958],
  631. [228760],
  632. [144261],
  633. [ 3670],
  634. [247237],
  635. [ 26],
  636. [105189],
  637. [ 96],
  638. [156600],
  639. [159880],
  640. [163536],
  641. [189112],
  642. [217793],
  643. [ 25662],
  644. [ 2612],
  645. [ 921],
  646. [ 95435],
  647. [187418],
  648. [102884],
  649. [ 1519],
  650. [ 93926],
  651. [ 81387],
  652. [ 41618],
  653. [ 13819],
  654. [ 1329],
  655. [ 10597],
  656. [105710],
  657. [ 25],
  658. [ 75211],
  659. [232527],
  660. [ 50488],
  661. [205148],
  662. [169814],
  663. [205857],
  664. [ 4131],
  665. [162602],
  666. [ 5645],
  667. [232430],
  668. [ 37982],
  669. [192939],
  670. [ 452],
  671. [116465],
  672. [ 93774],
  673. [ 26099],
  674. [ 11962],
  675. [ 14143],
  676. [177118],
  677. [169795],
  678. [ 35447],
  679. [133547],
  680. [195258],
  681. [ 2469],
  682. [ 9862],
  683. [ 26991],
  684. [215255],
  685. [ 1998],
  686. [220857],
  687. [ 64919],
  688. [216418],
  689. [ 54092],
  690. [ 66917],
  691. [ 66814],
  692. [207736],
  693. [253680],
  694. [ 48401],
  695. [ 7593],
  696. [228541],
  697. [107926],
  698. [ 60510],
  699. [182870],
  700. [116335],
  701. [ 60402],
  702. [198790],
  703. [ 71236],
  704. [ 6240],
  705. [ 14794],
  706. [225419],
  707. [ 3723],
  708. [ 72610],
  709. [149417],
  710. [211664],
  711. [225128],
  712. [221956],
  713. [246062],
  714. [134126],
  715. [ 16404],
  716. [ 2],
  717. [ 97037],
  718. [ 94503],
  719. [139927],
  720. [ 24884],
  721. [125569],
  722. [170146],
  723. [158652],
  724. [ 386],
  725. [150282],
  726. [174588],
  727. [ 7842],
  728. [ 95128],
  729. [ 21300],
  730. [135342]]), array([0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 1.,
  731. 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
  732. 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 0.,
  733. 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1.,
  734. 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0.,
  735. 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  736. 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
  737. 0., 0., 1., 0., 0., 1., 0., 0., 0.], dtype=float32))
  738. (array([[ 15],
  739. [ 2232],
  740. [ 6767],
  741. [ 3783],
  742. [ 6292],
  743. [ 837],
  744. [ 31127],
  745. [ 379],
  746. [ 13571],
  747. [ 788],
  748. [ 1185],
  749. [ 30],
  750. [ 2833],
  751. [ 743],
  752. [ 1025],
  753. [ 189],
  754. [ 11324],
  755. [ 10769],
  756. [ 52],
  757. [ 2487],
  758. [ 55118],
  759. [ 3382],
  760. [ 15371],
  761. [ 2363],
  762. [ 113],
  763. [ 212],
  764. [181366],
  765. [ 12],
  766. [ 4],
  767. [ 965],
  768. [ 990],
  769. [ 2454],
  770. [ 608],
  771. [ 2994],
  772. [ 1878],
  773. [ 4605],
  774. [ 117],
  775. [ 2876],
  776. [ 1833],
  777. [ 6922],
  778. [ 1186],
  779. [ 351],
  780. [ 1660],
  781. [ 71],
  782. [ 13],
  783. [ 1745],
  784. [ 5385],
  785. [ 2011],
  786. [ 749],
  787. [ 975],
  788. [ 565],
  789. [ 3407],
  790. [ 1088],
  791. [ 765],
  792. [ 10828],
  793. [ 4355],
  794. [ 752],
  795. [ 1221],
  796. [ 685],
  797. [ 55],
  798. [ 6973],
  799. [ 431],
  800. [ 1272],
  801. [ 4],
  802. [ 4550],
  803. [125872],
  804. [ 4294],
  805. [ 9],
  806. [ 167],
  807. [ 17],
  808. [ 484],
  809. [ 150],
  810. [113235],
  811. [ 14437],
  812. [ 18026],
  813. [ 1475],
  814. [ 249],
  815. [ 246],
  816. [ 64],
  817. [ 96],
  818. [ 898],
  819. [ 6666],
  820. [ 18386],
  821. [ 192],
  822. [ 13195],
  823. [ 1386],
  824. [ 1460],
  825. [ 6199],
  826. [ 1908],
  827. [ 12516],
  828. [ 9133],
  829. [ 6],
  830. [ 3045],
  831. [ 546],
  832. [ 13105],
  833. [ 233],
  834. [ 145],
  835. [ 7],
  836. [ 599],
  837. [ 101],
  838. [ 210],
  839. [ 1482],
  840. [ 609],
  841. [ 6791],
  842. [ 505],
  843. [ 9947],
  844. [ 32617],
  845. [ 244],
  846. [ 26780],
  847. [ 14995],
  848. [ 17006],
  849. [ 480],
  850. [ 33061],
  851. [ 9874],
  852. [ 41],
  853. [ 8843],
  854. [ 24314],
  855. [ 2536],
  856. [ 4115],
  857. [ 3276],
  858. [ 23391],
  859. [ 1100],
  860. [ 57609],
  861. [ 62795],
  862. [ 3518],
  863. [ 12520],
  864. [ 485],
  865. [ 4]]), array([[ 72035],
  866. [244894],
  867. [ 50962],
  868. [175834],
  869. [ 74710],
  870. [200406],
  871. [139250],
  872. [182131],
  873. [129015],
  874. [ 96053],
  875. [ 52805],
  876. [202195],
  877. [ 833],
  878. [ 89438],
  879. [ 28316],
  880. [ 3428],
  881. [204154],
  882. [196151],
  883. [ 91264],
  884. [179355],
  885. [248265],
  886. [176000],
  887. [ 4501],
  888. [155139],
  889. [160741],
  890. [206366],
  891. [215399],
  892. [ 76430],
  893. [157923],
  894. [198098],
  895. [ 1899],
  896. [ 589],
  897. [ 752],
  898. [126527],
  899. [ 521],
  900. [222232],
  901. [ 5082],
  902. [213949],
  903. [148254],
  904. [ 21493],
  905. [ 2306],
  906. [140278],
  907. [150602],
  908. [203313],
  909. [ 2755],
  910. [248285],
  911. [236541],
  912. [ 49277],
  913. [226719],
  914. [208999],
  915. [145374],
  916. [196864],
  917. [167906],
  918. [ 22199],
  919. [ 57607],
  920. [ 56],
  921. [179402],
  922. [162834],
  923. [ 76853],
  924. [ 9909],
  925. [163437],
  926. [113844],
  927. [ 15286],
  928. [222847],
  929. [ 7638],
  930. [ 22422],
  931. [158440],
  932. [ 489],
  933. [189023],
  934. [ 65049],
  935. [ 41135],
  936. [215161],
  937. [ 78584],
  938. [ 20391],
  939. [192482],
  940. [ 34],
  941. [ 45896],
  942. [248027],
  943. [225976],
  944. [ 1146],
  945. [ 43379],
  946. [211909],
  947. [249712],
  948. [ 36],
  949. [123084],
  950. [132495],
  951. [ 6385],
  952. [ 80],
  953. [140223],
  954. [207572],
  955. [ 3398],
  956. [115641],
  957. [ 47690],
  958. [228621],
  959. [128907],
  960. [142325],
  961. [157140],
  962. [169023],
  963. [ 10191],
  964. [145893],
  965. [178573],
  966. [178342],
  967. [ 1388],
  968. [ 7439],
  969. [ 791],
  970. [139124],
  971. [144230],
  972. [ 2216],
  973. [106770],
  974. [161260],
  975. [166117],
  976. [179646],
  977. [122992],
  978. [ 3362],
  979. [ 68495],
  980. [ 83020],
  981. [ 73023],
  982. [ 804],
  983. [249352],
  984. [203801],
  985. [244814],
  986. [ 15],
  987. [131522],
  988. [ 17738],
  989. [ 857],
  990. [107369],
  991. [107460],
  992. [113488]]), array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
  993. 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.,
  994. 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  995. 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1.,
  996. 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1., 0.,
  997. 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
  998. 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
  999. 0., 0., 1., 0., 0., 1., 0., 0., 0.], dtype=float32))

网络定义

定义skip-gram的网络结构,用于模型训练。在飞桨动态图中,对于任意网络,都需要定义一个继承自fluid.dygraph.Layer的类来搭建网络结构、参数等数据的声明。同时需要在forward函数中定义网络的计算逻辑。值得注意的是,我们仅需要定义网络的前向计算逻辑,飞桨会自动完成神经网络的反向计算,代码如下:

  1. #定义skip-gram训练网络结构
  2. #这里我们使用的是paddlepaddle的1.7.0版本
  3. #一般来说,在使用fluid训练的时候,我们需要通过一个类来定义网络结构,这个类继承了fluid.dygraph.Layer
  4. class SkipGram(fluid.dygraph.Layer):
  5. def __init__(self, vocab_size, embedding_size, init_scale=0.1):
  6. #vocab_size定义了这个skipgram这个模型的词表大小
  7. #embedding_size定义了词向量的维度是多少
  8. #init_scale定义了词向量初始化的范围,一般来说,比较小的初始化范围有助于模型训练
  9. super(SkipGram, self).__init__()
  10. self.vocab_size = vocab_size
  11. self.embedding_size = embedding_size
  12. #使用paddle.fluid.dygraph提供的Embedding函数,构造一个词向量参数
  13. #这个参数的大小为:[self.vocab_size, self.embedding_size]
  14. #数据类型为:float32
  15. #这个参数的名称为:embedding_para
  16. #这个参数的初始化方式为在[-init_scale, init_scale]区间进行均匀采样
  17. self.embedding = Embedding(
  18. size=[self.vocab_size, self.embedding_size],
  19. dtype='float32',
  20. param_attr=fluid.ParamAttr(
  21. name='embedding_para',
  22. initializer=fluid.initializer.UniformInitializer(
  23. low=-0.5/embedding_size, high=0.5/embedding_size)))
  24. #使用paddle.fluid.dygraph提供的Embedding函数,构造另外一个词向量参数
  25. #这个参数的大小为:[self.vocab_size, self.embedding_size]
  26. #数据类型为:float32
  27. #这个参数的名称为:embedding_para_out
  28. #这个参数的初始化方式为在[-init_scale, init_scale]区间进行均匀采样
  29. #跟上面不同的是,这个参数的名称跟上面不同,因此,
  30. #embedding_para_out和embedding_para虽然有相同的shape,但是权重不共享
  31. self.embedding_out = Embedding(
  32. size=[self.vocab_size, self.embedding_size],
  33. dtype='float32',
  34. param_attr=fluid.ParamAttr(
  35. name='embedding_out_para',
  36. initializer=fluid.initializer.UniformInitializer(
  37. low=-0.5/embedding_size, high=0.5/embedding_size)))
  38. #定义网络的前向计算逻辑
  39. #center_words是一个tensor(mini-batch),表示中心词
  40. #target_words是一个tensor(mini-batch),表示目标词
  41. #label是一个tensor(mini-batch),表示这个词是正样本还是负样本(用0或1表示)
  42. #用于在训练中计算这个tensor中对应词的同义词,用于观察模型的训练效果
  43. def forward(self, center_words, target_words, label):
  44. #首先,通过embedding_para(self.embedding)参数,将mini-batch中的词转换为词向量
  45. #这里center_words和eval_words_emb查询的是一个相同的参数
  46. #而target_words_emb查询的是另一个参数
  47. center_words_emb = self.embedding(center_words)
  48. target_words_emb = self.embedding_out(target_words)
  49. #center_words_emb = [batch_size, embedding_size]
  50. #target_words_emb = [batch_size, embedding_size]
  51. #我们通过点乘的方式计算中心词到目标词的输出概率,并通过sigmoid函数估计这个词是正样本还是负样本的概率。
  52. word_sim = fluid.layers.elementwise_mul(center_words_emb, target_words_emb)
  53. word_sim = fluid.layers.reduce_sum(word_sim, dim = -1)
  54. word_sim = fluid.layers.reshape(word_sim, shape=[-1])
  55. pred = fluid.layers.sigmoid(word_sim)
  56. #通过估计的输出概率定义损失函数,注意我们使用的是sigmoid_cross_entropy_with_logits函数
  57. #将sigmoid计算和cross entropy合并成一步计算可以更好的优化,所以输入的是word_sim,而不是pred
  58. loss = fluid.layers.sigmoid_cross_entropy_with_logits(word_sim, label)
  59. loss = fluid.layers.reduce_mean(loss)
  60. #返回前向计算的结果,飞桨会通过backward函数自动计算出反向结果。
  61. return pred, loss

网络训练

完成网络定义后,就可以启动模型训练。我们定义每隔100步打印一次loss,以确保当前的网络是正常收敛的。同时,我们每隔1000步观察一下skip-gram计算出来的同义词(使用 embedding的乘积),可视化网络训练效果,代码如下:

  1. #开始训练,定义一些训练过程中需要使用的超参数
  2. batch_size = 512
  3. epoch_num = 3
  4. embedding_size = 200
  5. step = 0
  6. learning_rate = 0.001
  7. #定义一个使用word-embedding查询同义词的函数
  8. #这个函数query_token是要查询的词,k表示要返回多少个最相似的词,embed是我们学习到的word-embedding参数
  9. #我们通过计算不同词之间的cosine距离,来衡量词和词的相似度
  10. #具体实现如下,x代表要查询词的Embedding,Embedding参数矩阵W代表所有词的Embedding
  11. #两者计算Cos得出所有词对查询词的相似度得分向量,排序取top_k放入indices列表
  12. def get_similar_tokens(query_token, k, embed):
  13. W = embed.numpy()
  14. x = W[word2id_dict[query_token]]
  15. cos = np.dot(W, x) / np.sqrt(np.sum(W * W, axis=1) * np.sum(x * x) + 1e-9)
  16. flat = cos.flatten()
  17. indices = np.argpartition(flat, -k)[-k:]
  18. indices = indices[np.argsort(-flat[indices])]
  19. for i in indices:
  20. print('for word %s, the similar word is %s' % (query_token, str(id2word_dict[i])))
  21. #将模型放到GPU上训练(fluid.CUDAPlace(0)),如果需要指定CPU,则需要改为fluid.CPUPlace()
  22. with fluid.dygraph.guard(fluid.CUDAPlace(0)):
  23. #通过我们定义的SkipGram类,来构造一个Skip-gram模型网络
  24. skip_gram_model = SkipGram(vocab_size, embedding_size)
  25. #构造训练这个网络的优化器
  26. adam = fluid.optimizer.AdamOptimizer(learning_rate=learning_rate, parameter_list = skip_gram_model.parameters())
  27. #使用build_batch函数,以mini-batch为单位,遍历训练数据,并训练网络
  28. for center_words, target_words, label in build_batch(
  29. dataset, batch_size, epoch_num):
  30. #使用fluid.dygraph.to_variable函数,将一个numpy的tensor,转换为飞桨可计算的tensor
  31. center_words_var = fluid.dygraph.to_variable(center_words)
  32. target_words_var = fluid.dygraph.to_variable(target_words)
  33. label_var = fluid.dygraph.to_variable(label)
  34. #将转换后的tensor送入飞桨中,进行一次前向计算,并得到计算结果
  35. pred, loss = skip_gram_model(
  36. center_words_var, target_words_var, label_var)
  37. #通过backward函数,让程序自动完成反向计算
  38. loss.backward()
  39. #通过minimize函数,让程序根据loss,完成一步对参数的优化更新
  40. adam.minimize(loss)
  41. #使用clear_gradients函数清空模型中的梯度,以便于下一个mini-batch进行更新
  42. skip_gram_model.clear_gradients()
  43. #每经过100个mini-batch,打印一次当前的loss,看看loss是否在稳定下降
  44. step += 1
  45. if step % 100 == 0:
  46. print("step %d, loss %.3f" % (step, loss.numpy()[0]))
  47. #经过10000个mini-batch,打印一次模型对eval_words中的10个词计算的同义词
  48. #这里我们使用词和词之间的向量点积作为衡量相似度的方法
  49. #我们只打印了5个最相似的词
  50. if step % 10000 == 0:
  51. get_similar_tokens('one', 5, skip_gram_model.embedding.weight)
  52. get_similar_tokens('she', 5, skip_gram_model.embedding.weight)
  53. get_similar_tokens('chip', 5, skip_gram_model.embedding.weight)

从打印结果可以看到,经过一定步骤的训练,loss逐渐下降并趋于稳定。同时也可以发现skip-gram模型可以学习到一些有趣的语言现象,比如:跟who比较接近的词是"who, he, she, him, himself"。

Skip-gram的有趣使用

在使用word2vec模型的过程中,研究人员发现了一些有趣的现象。比如当得到整个词表的word embedding之后,对任意词都可以基于向量乘法计算跟这个词最接近的词。我们会发现,word2vec模型可以自动学习出一些同义词关系,如:

  1. Top 5 words closest to "beijing" are:
  2. 1. newyork
  3. 2. paris
  4. 3. tokyo
  5. 4. berlin
  6. 5. soul
  7. ...
  8. Top 5 words closest to "apple" are:
  9. 1. banana
  10. 2. pineapple
  11. 3. huawei
  12. 4. peach
  13. 5. orange

除此以外,研究人员还发现可以使用加减法完成一些基于语言的逻辑推理,如:

  1. Top 1 words closest to "king - man + women" are
  2. 1. queen
  3. ...
  4. Top 1 words closest to "captial - china + american" are
  5. 1. newyork

还有更多有趣的例子,赶快使用飞桨尝试实现一下吧。