执行阶段

这部分要短得多,更简单。 首先,我们加载 MNIST。 我们可以像之前的章节那样使用 ScikitLearn,但是 TensorFlow 提供了自己的助手来获取数据,将其缩放(0 到 1 之间),将它洗牌,并提供一个简单的功能来一次加载一个小批量:

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. mnist = input_data.read_data_sets("/tmp/data/")

现在我们定义我们要运行的迭代数,以及小批量的大小:

  1. n_epochs = 10001
  2. batch_size = 50

现在我们去训练模型:

  1. with tf.Session() as sess:
  2. init.run()
  3. for epoch in range(n_epochs):
  4. for iteration in range(mnist.train.num_examples // batch_size):
  5. X_batch, y_batch = mnist.train.next_batch(batch_size)
  6. sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
  7. acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
  8. acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})
  9. print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)
  10. save_path = saver.save(sess, "./my_model_final.ckpt")

该代码打开一个 TensorFlow 会话,并运行初始化所有变量的init节点。 然后它运行的主要训练循环:在每个时期,通过一些小批次的对应于训练集的大小的代码进行迭代。 每个小批量通过next_batch()方法获取,然后代码简单地运行训练操作,为当前的小批量输入数据和目标提供。 接下来,在每个时期结束时,代码评估最后一个小批量和完整训练集上的模型,并打印出结果。 最后,模型参数保存到磁盘。