模型设计的代码需要用到上一节数据处理的Python类,定义如下:

    1. import random
    2. import numpy as np
    3. from PIL import Image
    4. class MovieLen(object):
    5. def __init__(self, use_poster):
    6. self.use_poster = use_poster
    7. # 声明每个数据文件的路径
    8. usr_info_path = "./work/ml-1m/users.dat"
    9. if use_poster:
    10. rating_path = "./work/ml-1m/new_rating.txt"
    11. else:
    12. rating_path = "./work/ml-1m/ratings.dat"
    13. movie_info_path = "./work/ml-1m/movies.dat"
    14. self.poster_path = "./work/ml-1m/posters/"
    15. # 得到电影数据
    16. self.movie_info, self.movie_cat, self.movie_title = self.get_movie_info(movie_info_path)
    17. # 记录电影的最大ID
    18. self.max_mov_cat = np.max([self.movie_cat[k] for k in self.movie_cat])
    19. self.max_mov_tit = np.max([self.movie_title[k] for k in self.movie_title])
    20. self.max_mov_id = np.max(list(map(int, self.movie_info.keys())))
    21. # 记录用户数据的最大ID
    22. self.max_usr_id = 0
    23. self.max_usr_age = 0
    24. self.max_usr_job = 0
    25. # 得到用户数据
    26. self.usr_info = self.get_usr_info(usr_info_path)
    27. # 得到评分数据
    28. self.rating_info = self.get_rating_info(rating_path)
    29. # 构建数据集
    30. self.dataset = self.get_dataset(usr_info=self.usr_info,
    31. rating_info=self.rating_info,
    32. movie_info=self.movie_info)
    33. # 划分数据集,获得数据加载器
    34. self.train_dataset = self.dataset[:int(len(self.dataset)*0.9)]
    35. self.valid_dataset = self.dataset[int(len(self.dataset)*0.9):]
    36. print("##Total dataset instances: ", len(self.dataset))
    37. print("##MovieLens dataset information: \nusr num: {}\n"
    38. "movies num: {}".format(len(self.usr_info),len(self.movie_info)))
    39. # 得到电影数据
    40. def get_movie_info(self, path):
    41. # 打开文件,编码方式选择ISO-8859-1,读取所有数据到data中
    42. with open(path, 'r', encoding="ISO-8859-1") as f:
    43. data = f.readlines()
    44. # 建立三个字典,分别用户存放电影所有信息,电影的名字信息、类别信息
    45. movie_info, movie_titles, movie_cat = {}, {}, {}
    46. # 对电影名字、类别中不同的单词计数
    47. t_count, c_count = 1, 1
    48. count_tit = {}
    49. # 按行读取数据并处理
    50. for item in data:
    51. item = item.strip().split("::")
    52. v_id = item[0]
    53. v_title = item[1][:-7]
    54. cats = item[2].split('|')
    55. v_year = item[1][-5:-1]
    56. titles = v_title.split()
    57. # 统计电影名字的单词,并给每个单词一个序号,放在movie_titles中
    58. for t in titles:
    59. if t not in movie_titles:
    60. movie_titles[t] = t_count
    61. t_count += 1
    62. # 统计电影类别单词,并给每个单词一个序号,放在movie_cat中
    63. for cat in cats:
    64. if cat not in movie_cat:
    65. movie_cat[cat] = c_count
    66. c_count += 1
    67. # 补0使电影名称对应的列表长度为15
    68. v_tit = [movie_titles[k] for k in titles]
    69. while len(v_tit)<15:
    70. v_tit.append(0)
    71. # 补0使电影种类对应的列表长度为6
    72. v_cat = [movie_cat[k] for k in cats]
    73. while len(v_cat)<6:
    74. v_cat.append(0)
    75. # 保存电影数据到movie_info中
    76. movie_info[v_id] = {'mov_id': int(v_id),
    77. 'title': v_tit,
    78. 'category': v_cat,
    79. 'years': int(v_year)}
    80. return movie_info, movie_cat, movie_titles
    81. def get_usr_info(self, path):
    82. # 性别转换函数,M-0, F-1
    83. def gender2num(gender):
    84. return 1 if gender == 'F' else 0
    85. # 打开文件,读取所有行到data中
    86. with open(path, 'r') as f:
    87. data = f.readlines()
    88. # 建立用户信息的字典
    89. use_info = {}
    90. max_usr_id = 0
    91. #按行索引数据
    92. for item in data:
    93. # 去除每一行中和数据无关的部分
    94. item = item.strip().split("::")
    95. usr_id = item[0]
    96. # 将字符数据转成数字并保存在字典中
    97. use_info[usr_id] = {'usr_id': int(usr_id),
    98. 'gender': gender2num(item[1]),
    99. 'age': int(item[2]),
    100. 'job': int(item[3])}
    101. self.max_usr_id = max(self.max_usr_id, int(usr_id))
    102. self.max_usr_age = max(self.max_usr_age, int(item[2]))
    103. self.max_usr_job = max(self.max_usr_job, int(item[3]))
    104. return use_info
    105. # 得到评分数据
    106. def get_rating_info(self, path):
    107. # 读取文件里的数据
    108. with open(path, 'r') as f:
    109. data = f.readlines()
    110. # 将数据保存在字典中并返回
    111. rating_info = {}
    112. for item in data:
    113. item = item.strip().split("::")
    114. usr_id,movie_id,score = item[0],item[1],item[2]
    115. if usr_id not in rating_info.keys():
    116. rating_info[usr_id] = {movie_id:float(score)}
    117. else:
    118. rating_info[usr_id][movie_id] = float(score)
    119. return rating_info
    120. # 构建数据集
    121. def get_dataset(self, usr_info, rating_info, movie_info):
    122. trainset = []
    123. for usr_id in rating_info.keys():
    124. usr_ratings = rating_info[usr_id]
    125. for movie_id in usr_ratings:
    126. trainset.append({'usr_info': usr_info[usr_id],
    127. 'mov_info': movie_info[movie_id],
    128. 'scores': usr_ratings[movie_id]})
    129. return trainset
    130. def load_data(self, dataset=None, mode='train'):
    131. use_poster = False
    132. # 定义数据迭代Batch大小
    133. BATCHSIZE = 256
    134. data_length = len(dataset)
    135. index_list = list(range(data_length))
    136. # 定义数据迭代加载器
    137. def data_generator():
    138. # 训练模式下,打乱训练数据
    139. if mode == 'train':
    140. random.shuffle(index_list)
    141. # 声明每个特征的列表
    142. usr_id_list,usr_gender_list,usr_age_list,usr_job_list = [], [], [], []
    143. mov_id_list,mov_tit_list,mov_cat_list,mov_poster_list = [], [], [], []
    144. score_list = []
    145. # 索引遍历输入数据集
    146. for idx, i in enumerate(index_list):
    147. # 获得特征数据保存到对应特征列表中
    148. usr_id_list.append(dataset[i]['usr_info']['usr_id'])
    149. usr_gender_list.append(dataset[i]['usr_info']['gender'])
    150. usr_age_list.append(dataset[i]['usr_info']['age'])
    151. usr_job_list.append(dataset[i]['usr_info']['job'])
    152. mov_id_list.append(dataset[i]['mov_info']['mov_id'])
    153. mov_tit_list.append(dataset[i]['mov_info']['title'])
    154. mov_cat_list.append(dataset[i]['mov_info']['category'])
    155. mov_id = dataset[i]['mov_info']['mov_id']
    156. if use_poster:
    157. # 不使用图像特征时,不读取图像数据,加快数据读取速度
    158. poster = Image.open(self.poster_path+'mov_id{}.jpg'.format(str(mov_id[0])))
    159. poster = poster.resize([64, 64])
    160. if len(poster.size) <= 2:
    161. poster = poster.convert("RGB")
    162. mov_poster_list.append(np.array(poster))
    163. score_list.append(int(dataset[i]['scores']))
    164. # 如果读取的数据量达到当前的batch大小,就返回当前批次
    165. if len(usr_id_list)==BATCHSIZE:
    166. # 转换列表数据为数组形式,reshape到固定形状
    167. usr_id_arr = np.array(usr_id_list)
    168. usr_gender_arr = np.array(usr_gender_list)
    169. usr_age_arr = np.array(usr_age_list)
    170. usr_job_arr = np.array(usr_job_list)
    171. mov_id_arr = np.array(mov_id_list)
    172. mov_cat_arr = np.reshape(np.array(mov_cat_list), [BATCHSIZE, 6]).astype(np.int64)
    173. mov_tit_arr = np.reshape(np.array(mov_tit_list), [BATCHSIZE, 1, 15]).astype(np.int64)
    174. if use_poster:
    175. mov_poster_arr = np.reshape(np.array(mov_poster_list)/127.5 - 1, [BATCHSIZE, 3, 64, 64]).astype(np.float32)
    176. else:
    177. mov_poster_arr = np.array([0.])
    178. scores_arr = np.reshape(np.array(score_list), [-1, 1]).astype(np.float32)
    179. # 放回当前批次数据
    180. yield [usr_id_arr, usr_gender_arr, usr_age_arr, usr_job_arr], \
    181. [mov_id_arr, mov_cat_arr, mov_tit_arr, mov_poster_arr], scores_arr
    182. # 清空数据
    183. usr_id_list, usr_gender_list, usr_age_list, usr_job_list = [], [], [], []
    184. mov_id_list, mov_tit_list, mov_cat_list, score_list = [], [], [], []
    185. mov_poster_list = []
    186. return data_generator
    1. # 解压数据集
    2. ! cd work && unzip -o -q ml-1m.zip