TensorFlow Serving介绍

简介

TensorFlow的模型文件包含了深度学习模型的Graph和所有参数,其实就是checkpoint文件,用户可以加载模型文件继续训练或者对外提供Inference服务。

使用SavedModel导出模型

模型导出方式参考 https://tensorflow.github.io/serving/serving_basic

使用方法基本如下。

  1. from tensorflow.python.saved_model import builder as saved_model_builder
  2. export_path_base = sys.argv[-1]
  3. export_path = os.path.join(
  4. compat.as_bytes(export_path_base),
  5. compat.as_bytes(str(FLAGS.model_version)))
  6. print 'Exporting trained model to', export_path
  7. builder = saved_model_builder.SavedModelBuilder(export_path)
  8. builder.add_meta_graph_and_variables(
  9. sess, [tag_constants.SERVING],
  10. signature_def_map={
  11. 'predict_images':
  12. prediction_signature,
  13. signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
  14. classification_signature,
  15. },
  16. legacy_init_op=legacy_init_op)
  17. builder.save()

可以参考 https://github.com/tobegit3hub/deep_recommend_system/ 提供的可运行代码示例。

  1. ./dense_classifier.py --mode savedmodel

使用exporter导出模型

这里有导出TensorFlow serving支持的模型文件例子,可以参考使用 https://github.com/tobegit3hub/deep_recommend_system/blob/master/dense_classifier.py

导出的代码也比较简单,用户在inputs和output中填入模型Inference时的输入和输出即可。

  1. from tensorflow.contrib.session_bundle import exporter
  2. flags = tf.app.flags
  3. FLAGS = flags.FLAGS
  4. flags.DEFINE_string("model_path", "./model", "The path to export the model")
  5. flags.DEFINE_integer("export_version", 1, "Version number of the model")
  6. # Define the graph
  7. keys_placeholder = tf.placeholder(tf.int32, shape=[None, 1])
  8. keys = tf.identity(keys_placeholder)
  9. # Start the session
  10. # Export the model
  11. print("Exporting trained model to {}".format(FLAGS.model_path))
  12. model_exporter = exporter.Exporter(saver)
  13. model_exporter.init(
  14. sess.graph.as_graph_def(),
  15. named_graph_signatures={
  16. 'inputs': exporter.generic_signature({"keys": keys_placeholder, "features": inference_features}),
  17. 'outputs': exporter.generic_signature({"keys": keys, "softmax": inference_softmax, "prediction": inference_op})
  18. })
  19. model_exporter.export(FLAGS.model_path, tf.constant(FLAGS.export_version), sess)
  20. print 'Done exporting!'

与SavedModel方法相比,两者都可以直接用TensorFlow Serving加载,我们使用deep_recommend_system导出两种模型方式测试过预测结果一模一样,只是模型文件大小不同。

导入带assert的模型文件

在NLP等场景除了参数文件,还需要导入vocabulary等文件,可以在exporter中设置assets_collection,参考 https://github.com/tensorflow/serving/issues/264

原文: http://docs.api.xiaomi.com/cloud-ml/modelservice/02_tensorflow_serving.html