使用Java客户端

简介

Java应用可以直接访问TensorFlow serving加载模型提供的服务,我们需要编写Java的gRPC客户端代码。

完整例子

这里有一个导出模型使用Java来访问模型的例子 https://github.com/tobegit3hub/deep_recommend_system/tree/master/java_predict_client

使用时通过Maven编译即可,不同模型只需要修改一个Java文件,其他外部依赖已经管理好,建议在此项目中修改使用。

Java客户端实现原理

Java无论是服务端还是客户端都是在独立于grpc的项目中实现,代码在 https://github.com/grpc/grpc-java 。使用时需要引入grpc实现的类,建议使用maven管理依赖,在pom.xml中加入下面的依赖。

  1. <dependency>
  2. <groupId>io.grpc</groupId>
  3. <artifactId>grpc-netty</artifactId>
  4. <version>1.0.0</version>
  5. </dependency>
  6. <dependency>
  7. <groupId>io.grpc</groupId>
  8. <artifactId>grpc-protobuf</artifactId>
  9. <version>1.0.0</version>
  10. </dependency>
  11. <dependency>
  12. <groupId>io.grpc</groupId>
  13. <artifactId>grpc-stub</artifactId>
  14. <version>1.0.0</version>
  15. </dependency>

由于使用grpc还需要用到protobuf生成的Java代码,如果通过命令生成再拷贝jar文件不好管理,可以使用maven插件,把proto文件拷贝到指定目录,在编译时就会自动生成java文件放到target目录。

  1. <build>
  2. <extensions>
  3. <extension>
  4. <groupId>kr.motd.maven</groupId>
  5. <artifactId>os-maven-plugin</artifactId>
  6. <version>1.4.1.Final</version>
  7. </extension>
  8. </extensions>
  9. <plugins>
  10. <plugin>
  11. <groupId>org.xolstice.maven.plugins</groupId>
  12. <artifactId>protobuf-maven-plugin</artifactId>
  13. <version>0.5.0</version>
  14. <configuration>
  15. <!--
  16. The version of protoc must match protobuf-java. If you don't depend on
  17. protobuf-java directly, you will be transitively depending on the
  18. protobuf-java version that grpc depends on.
  19. -->
  20. <protocArtifact>com.google.protobuf:protoc:3.0.0:exe:${os.detected.classifier}</protocArtifact>
  21. <pluginId>grpc-java</pluginId>
  22. <pluginArtifact>io.grpc:protoc-gen-grpc-java:1.0.0:exe:${os.detected.classifier}</pluginArtifact>
  23. </configuration>
  24. <executions>
  25. <execution>
  26. <goals>
  27. <goal>compile</goal>
  28. <goal>compile-custom</goal>
  29. </goals>
  30. </execution>
  31. </executions>
  32. </plugin>
  33. </plugins>
  34. </build>

注意我们需要加入TensorFlow serving和TensorFlow项目的proto文件,由于我们不使用bazel编译,因此proto文件的依赖路径需要修改,建议参考上面的完整项目。

构造TensorProto对象

使用protobuf定义了请求的接口,但我们还需要构建protobuf生成代码中的TensorProto对象,本质上是一个多维数据,在C++和Python中都有函数可以直接生成。

Java可以定义多维数据,然后参考这个Stackoverflow答案来构建 http://stackoverflow.com/questions/39443019/how-can-i-create-tensorproto-for-tensorflow-in-java ,下面是一个构建二位TensorProto的代码。

  1. // Generate features TensorProto
  2. float[][] featuresTensorData = new float[][]{
  3. {10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
  4. {10f, 10f, 10f, 8f, 6f, 1f, 8f, 9f, 1f},
  5. };
  6. TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
  7. featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT);
  8. for (int i = 0; i < featuresTensorData.length; ++i) {
  9. for (int j = 0; j < featuresTensorData[i].length; ++j) {
  10. featuresTensorBuilder.addFloatVal(featuresTensorData[i][j]);
  11. }
  12. }
  13. TensorShapeProto.Dim dim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
  14. TensorShapeProto.Dim dim2 = TensorShapeProto.Dim.newBuilder().setSize(9).build();
  15. TensorShapeProto shape = TensorShapeProto.newBuilder().addDim(dim1).addDim(dim2).build();
  16. featuresTensorBuilder.setTensorShape(shape);
  17. TensorProto featuresTensorProto = featuresTensorBuilder.build();

注意除了设置data,shape和dtype都需要我们手动设置,否则服务端无法解析TensorProto成tensor对象。

读取图片文件生成TensorProto

在图像分类等场景中,我们需要读取图片文件生成TensorProto对象,才可以通过gRPC请求TensorFlow serving服务,这里提供一个Java例子,测试支持jpg和png图片格式。

这里有完整的使用CNN训练模型和inference的例子,Java客户端可以直接读取本地文件来请求服务进行预测和分类 https://github.com/tobegit3hub/deep_cnn/tree/master/java_predict_client

  1. // Generate image file to array
  2. int[][][][] featuresTensorData = new int[2][32][32][3];
  3. String[] imageFilenames = new String[]{"../data/inference/Mew.png", "../data/inference/Pikachu.png"};
  4. for (int i = 0; i < imageFilenames.length; i++) {
  5. // Convert image file to multi-dimension array
  6. File imageFile = new File(imageFilenames[i]);
  7. try {
  8. BufferedImage image = ImageIO.read(imageFile);
  9. logger.info("Start to convert the image: " + imageFile.getPath());
  10. int imageWidth = 32;
  11. int imageHeight = 32;
  12. int[][] imageArray = new int[imageHeight][imageWidth];
  13. for (int row = 0; row < imageHeight; row++) {
  14. for (int column = 0; column < imageWidth; column++) {
  15. imageArray[row][column] = image.getRGB(column, row);
  16. int pixel = image.getRGB(column, row);
  17. int red = (pixel >> 16) & 0xff;
  18. int green = (pixel >> 8) & 0xff;
  19. int blue = pixel & 0xff;
  20. featuresTensorData[i][row][column][0] = red;
  21. featuresTensorData[i][row][column][1] = green;
  22. featuresTensorData[i][row][column][2] = blue;
  23. }
  24. }
  25. } catch (IOException e) {
  26. logger.log(Level.WARNING, e.getMessage());
  27. System.exit(1);
  28. }
  29. }
  30. // Generate features TensorProto
  31. TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();
  32. for (int i = 0; i < featuresTensorData.length; ++i) {
  33. for (int j = 0; j < featuresTensorData[i].length; ++j) {
  34. for (int k = 0; k < featuresTensorData[i][j].length; ++k) {
  35. for (int l = 0; l < featuresTensorData[i][j][k].length; ++l) {
  36. featuresTensorBuilder.addFloatVal(featuresTensorData[i][j][k][l]);
  37. }
  38. }
  39. }
  40. }
  41. TensorShapeProto.Dim featuresDim1 = TensorShapeProto.Dim.newBuilder().setSize(2).build();
  42. TensorShapeProto.Dim featuresDim2 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
  43. TensorShapeProto.Dim featuresDim3 = TensorShapeProto.Dim.newBuilder().setSize(32).build();
  44. TensorShapeProto.Dim featuresDim4 = TensorShapeProto.Dim.newBuilder().setSize(3).build();
  45. TensorShapeProto featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).addDim(featuresDim2).addDim(featuresDim3).addDim(featuresDim4).build();
  46. featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType.DT_FLOAT).setTensorShape(featuresShape);
  47. TensorProto featuresTensorProto = featuresTensorBuilder.build();

原文: http://docs.api.xiaomi.com/cloud-ml/modelservice/0903_use_java_client.html