在 TrainJob 中使用 FDS FUSE

上一节介绍了两种在Cloud-ML中使用FDS的方法,这两种方法还是不够灵活。本节介绍在Cloud-ML中使用FDS FUSE。FDS FUSE是一种基于FUSE的文件系统,允许挂载FDS的bucket到本地文件系统。用户读写本地文件,FUSE会自动同步文件修改到远端FDS上。我们还是使用上节的MNIST例子,使用Tensorflow框架。

步骤

在FDS创建目录

假设我们要将训练的Checkpoints和最终训练的模型保存到下面目录:

test-bucket-xg/tf-mnist/mnist-fuse

如下图所示:tf_fuse_folders

在代码中指定路径

在代码中指定FDS的路径,如下图示:

  1. dataset_path = "/fds/tf-mnist/dataset"
  2. checkpoint_path = "/fds/tf-mnist/mnist-fuse"
  3. export_path = "/fds/tf-mnist/mnist-fuse"
其中,/fds 这个目录是Cloud-ML在Docker镜像中创建的FDS的挂载地址,可以理解为容器里面FDS的根目录; dataset目录存放数据文件(见上节介绍,是转换成的TFRecord格式); mnist-fuse存放训练Checkpoints和训练结果模型; 完整的训练代码请参考[附录1,在代码中使用FUSE](#附录1,在代码中使用fuse)。 接下来将代码打包上传到FDS,可以开始提交任务进行训练。 #### 提交任务时指定FDS地址 提交任务时候需要告诉Cloud-ML平台需要使用FUSE。命令如下:
  1. cloudml jobs submit -n tf-fuse -m trainer.task -u fds://test-bucket-xg/tf-mnist/tf_fuse_test-1.0.tar.gz -c 4 -M 8G -g 1 -fe cnbj2.fds.api.xiaomi.com -fb test-bucket-xg -fc "ls -al /fds/tf-mnist/mnist-fuse"

这里面有三个新参数:

-fe: 指定FDS的Endpoint, 缺省情况下会使用c3的fds;

-fb: 指定FDS的Bucket;

-fc: 是一个后置命令,表示在训练结束后执行“”中的命令。这儿,我们执行完成之后检查一下“/fds”目录下是否有我们指定的内容。更多后置命令的介绍可参考后面高级功能介绍。

训练结果

训练完成后,我们检查结果文件是否都存在,下面是我们示例程序运行的结果:

tf_fuse_train_result

Cloud-ML提供了ModelService和Tensorboard的功能,可以对这些结果进行下一步操作,请移步相关文档。

同时我们看一下Log的输出:tf_fuse_train_log

其中,

  • 使用fdsfuse命令将test-bucket-xg 挂载到/fds上;
  • -g参数表明该训练使用了一个Tesla P40的GPU;
  • 结果写入到我们指定的目录中;
  • 后置命令输出当前ls -al /fds/tf-mnist/mnist-fuse的内容。

附录1,在代码中使用FUSE

  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
    #

  2. Licensed under the Apache License, Version 2.0 (the "License");

    you may not use this file except in compliance with the License.

    You may obtain a copy of the License at

    #

  3. http://www.apache.org/licenses/LICENSE-2.0

    #

  4. Unless required by applicable law or agreed to in writing, software

    distributed under the License is distributed on an "AS IS" BASIS,

    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

    See the License for the specific language governing permissions and

    limitations under the License.

    ==============================================================================

    """Train and Eval the MNIST network.

  5. This version is like fully_connected_feed.py but uses data converted
    to a TFRecords file containing tf.train.Example protocol buffers.
    See tensorflow/g3doc/how_tos/reading_data.md#reading-from-files
    for context.

  6. YOU MUST run converttorecords before running this (but you only need to
    run it once).
    """
    from future import absoluteimport
    from _future
    import division
    from __future
    import print_function

  7. import os.path
    import time

  8. import numpy
    import tensorflow as tf
    from tensorflow.python.platform import gfile
    from tensorflow.contrib.session_bundle import exporter

  9. from tensorflow.examples.tutorials.mnist import mnist

  10. from tensorflow.contrib.session_bundle import exporter

    dataset_path = "/fds/tf-mnist/dataset"
    checkpoint_path = "/fds/tf-mnist/mnist-fuse"
    export_path = "/fds/tf-mnist/mnist-fuse"

  11. Basic model parameters as external flags.

    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
    flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
    flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
    flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
    flags.DEFINE_integer('batch_size', 100, 'Batch size.')
    flags.DEFINE_string('train_dir', dataset_path,
    'Directory with the training data.')
    flags.DEFINE_string('checkpoint_dir', checkpoint_path,
    'Directory for periodic checkpoints.')
    flags.DEFINE_string('export_dir', export_path,
    'Directory to export the final trained model.')
    flags.DEFINE_integer('export_version', 1, 'Export version')

  12. Constants used for dealing with the files, matches convert_to_records.

    TRAIN_FILE = 'train.tfrecords'
    VALIDATION_FILE = 'validation.tfrecords'

  13. def readand_decode(filename_queue):
    reader = tf.TFRecordReader()

  14. , serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
    serialized_example,

  15.   # Defaults are not specified since both keys are required.
  16.   features={
  17.       'image_raw': tf.FixedLenFeature([], tf.string),
  18.       'label': tf.FixedLenFeature([], tf.int64),
  19.   })
  20. Convert from a scalar string tensor (whose single string has

    length mnist.IMAGE_PIXELS) to a uint8 tensor with shape

    [mnist.IMAGE_PIXELS].

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image.set_shape([mnist.IMAGE_PIXELS])

  21. OPTIONAL: Could reshape into a 28x28 image and apply distortions

    here. Since we are not applying any distortions in this

    example, and the next step expects the image to be flattened

    into a vector, we don't bother.

    Convert from [0, 255] -> [-0.5, 0.5] floats.

    image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

  22. Convert label from a scalar uint8 tensor to an int32 scalar.

    label = tf.cast(features['label'], tf.int32)

  23. return image, label

  24. def inputs(train, batch_size, num_epochs):
    """Reads input data num_epochs times.

  25. Args:
    train: Selects between the training (True) and validation (False) data.
    batch_size: Number of examples per returned batch.
    num_epochs: Number of times to read the input data, or 0/None to
    train forever.

  26. Returns:
    A tuple (images, labels), where:

  27. * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
  28.   in the range [-0.5, 0.5].
  29. * labels is an int32 tensor with shape [batch_size] with the true label,
  30.   a number in the range [0, mnist.NUM_CLASSES).
  31. Note that an tf.train.QueueRunner is added to the graph, which
  32. must be run using e.g. tf.train.start_queue_runners().
  33. """
    if not num_epochs: num_epochs = None
    filename = os.path.join(FLAGS.train_dir,
    TRAIN_FILE if train else VALIDATION_FILE)

  34. with tf.name_scope('input'):
    filename_queue = tf.train.string_input_producer(
    [filename], num_epochs=num_epochs)

  35. # Even when reading in multiple threads, share the filename
  36. # queue.
  37. image, label = read_and_decode(filename_queue)
  38. # Shuffle the examples and collect them into batch_size batches.
  39. # (Internally uses a RandomShuffleQueue.)
  40. # We run this in two threads to avoid being a bottleneck.
  41. images, sparse_labels = tf.train.shuffle_batch(
  42.     [image, label], batch_size=batch_size, num_threads=2,
  43.     capacity=1000 + 3 * batch_size,
  44.     # Ensures a minimum amount of shuffling of examples.
  45.     min_after_dequeue=1000)
  46. return images, sparse_labels
  47. def run_training():
    """Train MNIST for a number of steps."""
    gfile.MkDir(FLAGS.checkpoint_dir)

  48. Tell TensorFlow that the model will be built into the default Graph.

    with tf.Graph().as_default():

  49. # Input images and labels.
  50. images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
  51.                         num_epochs=FLAGS.num_epochs)
  52. # Build a Graph that computes predictions from the inference model.
  53. logits = mnist.inference(images,
  54.                          FLAGS.hidden1,
  55.                          FLAGS.hidden2)
  56. # Add to the Graph the loss calculation.
  57. loss = mnist.loss(logits, labels)
  58. # Add to the Graph the predict
  59. # Add to the Graph operations that train the model.
  60. train_op = mnist.training(loss, FLAGS.learning_rate)
  61. # The op for initializing the variables.
  62. #init_op = tf.initialize_all_variables()
  63. init_op = tf.group(tf.initialize_all_variables(), tf.initialize_local_variables())
  64. # Create a session for running operations in the Graph.
  65. sess = tf.Session()
  66. # Create checkpoint saver
  67. saver = tf.train.Saver()
  68. # Initialize the variables (the trained variables and the
  69. # epoch counter).
  70. sess.run(init_op)
  71. # Start input enqueue threads.
  72. coord = tf.train.Coordinator()
  73. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  74. try:
  75.   step = 0
  76.   while not coord.should_stop():
  77.     start_time = time.time()
  78.     # Run one step of the model.  The return values are
  79.     # the activations from the `train_op` (which is
  80.     # discarded) and the `loss` op.  To inspect the values
  81.     # of your ops or variables, you may include them in
  82.     # the list passed to sess.run() and the value tensors
  83.     # will be returned in the tuple from the call.
  84.     _, loss_value = sess.run([train_op, loss])
  85.     duration = time.time() - start_time
  86.     # Print an overview fairly often.
  87.     if step % 100 == 0:
  88.       saver.save(sess, FLAGS.checkpoint_dir + '/model.ckpt',
  89.                  global_step=step)
  90.       print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
  91.                                                  duration))
  92.     step += 1
  93. except tf.errors.OutOfRangeError:
  94.   print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
  95. finally:
  96.   # When done, ask the threads to stop.
  97.   coord.request_stop()
  98. # Wait for threads to finish.
  99. coord.join(threads)
  100. print('Exporting trained model to ' + FLAGS.export_dir)
  101. # NOTE this format is depreceted, please refer to tensorflow_serving for
  102. # more examples
  103. saver = tf.train.Saver(sharded=True)
  104. model_exporter = exporter.Exporter(saver)
  105. signature = exporter.classification_signature(input_tensor=images,
  106.                                               scores_tensor=logits)
  107. model_exporter.init(sess.graph.as_graph_def(), 
  108.                     default_graph_signature=signature)
  109. model_exporter.export(FLAGS.export_dir, tf.constant(FLAGS.export_version),
  110.                       sess)
  111. print('Done exporting!')
  112. sess.close()
  113. def main(_):
    run_training()

  114. if name == 'main':
    tf.app.run()


原文: http://docs.api.xiaomi.com/cloud-ml/trainjob/04_trainjob_fuse.html