TensorFlow 集成入门

MLeap Tensorflow 集成允许用户将 TensorFlow Graph 当做 Transformer 集成到 ML Pipeline 中。在未来我们会提升对 TensorFlow 的兼容性。目前来说,TensorFlow 与 MLeap 的集成还是一个实验性功能,我们仍在进一步稳定这个特性。

编译 MLeap-TensorFlow 模块

MLeap TensorFlow 模块未被托管在 Maven Central 上,用户必须借助 TensorFlow 提供的 JNI(Java Native Interface)支持,编译源码获得。参考相关教程从源码编译 TensorFlow 模块。

使用 MLeap-TensorFlow

编译工作就绪之后,你就能轻松将 TensorFlow 集成到 MLeap Pipeline 中。

首先,添加 MLeap-TensorFlow 作为项目依赖。

  1. libraryDependencies += "ml.combust.mleap" %% "mleap-tensorflow" % "0.13.0"

接下来就能在代码中使用 Tensor Graph。让我们构建一个包含两个 Tensor 的简单 Graph。

  1. import ml.combust.mleap.core.types._
  2. import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row}
  3. import ml.combust.mleap.tensor.Tensor
  4. import ml.combust.mleap.tensorflow.{TensorflowModel, TensorflowTransformer}
  5. import org.tensorflow
  6. // Initialize our Tensorflow demo graph
  7. val graph = new tensorflow.Graph
  8. // Build placeholders for our input values
  9. val inputA = graph.opBuilder("Placeholder", "InputA").
  10. setAttr("dtype", tensorflow.DataType.FLOAT).
  11. build()
  12. val inputB = graph.opBuilder("Placeholder", "InputB").
  13. setAttr("dtype", tensorflow.DataType.FLOAT).
  14. build()
  15. // Multiply the two placeholders and put the result in
  16. // The "MyResult" tensor
  17. graph.opBuilder("Mul", "MyResult").
  18. setAttr("T", tensorflow.DataType.FLOAT).
  19. addInput(inputA.output(0)).
  20. addInput(inputB.output(0)).
  21. build()
  22. // Build the MLeap model wrapper around the Tensorflow graph
  23. val model = TensorflowModel(graph,
  24. // Must specify inputs and input types for converting to TF tensors
  25. inputs = Seq(("InputA", TensorType.Float()), ("InputB", TensorType.Float())),
  26. // Likewise, specify the output values so we can convert back to MLeap
  27. // Types properly
  28. outputs = Seq(("MyResult", TensorType.Float())))
  29. // Connect our Leap Frame values to the Tensorflow graph
  30. // Inputs and outputs
  31. val shape = NodeShape().
  32. // Column "input_a" gets sent to the TF graph as the input "InputA"
  33. withInput("InputA", "input_a").
  34. // Column "input_b" gets sent to the TF graph as the input "InputB"
  35. withInput("InputB", "input_b").
  36. // TF graph output "MyResult" gets placed in the leap frame as col
  37. // "my_result"
  38. withOutput("MyResult", "my_result")
  39. // Create the MLeap transformer that executes the TF model against
  40. // A leap frame
  41. val transformer = TensorflowTransformer(shape = shape, model = model)
  42. // Create a sample leap frame to transform with the Tensorflow graph
  43. val schema = StructType(StructField("input_a", ScalarType.Float), StructField("input_b", ScalarType.Float)).get
  44. val dataset = Seq(Row(5.6f, 7.9f),
  45. Row(3.4f, 6.7f),
  46. Row(1.2f, 9.7f))
  47. val frame = DefaultLeapFrame(schema, dataset)
  48. // Transform the leap frame and make sure it behaves as expected
  49. val data = transformer.transform(frame).get.dataset
  50. assert(data(0)(2).asInstanceOf[Tensor[Float]].get(0).get == 5.6f * 7.9f)
  51. assert(data(1)(2).asInstanceOf[Tensor[Float]].get(0).get == 3.4f * 6.7f)
  52. assert(data(2)(2).asInstanceOf[Tensor[Float]].get(0).get == 1.2f * 9.7f)
  53. // Cleanup the transformer
  54. // This closes the TF session and graph resources
  55. transformer.close()

更多关于 TensorFlow 集成如何运作的细节:

  1. 数据集成与转换的相关细节参见本章节
  2. 序列化 TensorFlow Graph 为 MLeap Bundle 的相关细节参见本章节