Using Deeplearning4J in Android Applications

Contents

DL4JImageRecognitionDemo

This example application uses a neural network trained on the standard MNIST dataset of 28x28 greyscale 0..255 pixel value images of hand drawn numbers 0..9. The application user interace allows the user to draw a number on the device screen which is then tested against the trained network. The output displays the most probable numeric values and the probability score. This tutorial will cover the use of a trained neural network in an Android Application, the handling of user generated images, and the output of the results to the UI from a background thread. More information on general prerequisites for building DL4J Android Applications can be found here.

Android Image Classifier - 图1

Setting the Dependencies

Deeplearning4J applications requires application specific dependencies in the build.gradle file. The Deeplearning library in turn depends on the libraries of ND4J and OpenBLAS, thus these must also be added to the dependencies declaration. Starting with Android Studio 3.0, annotationProcessors need to be defined as well, thus dependencies for either -x86 or -arm processors should be included, depending on your device, if you are working in Android Studio 3.0 or later. Note that both can be include without conflict as is done in the example app.

  1. compile (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '1.0.0-beta4') {
  2. exclude group: 'org.bytedeco.javacpp-presets', module: 'opencv-platform'
  3. exclude group: 'org.bytedeco.javacpp-presets', module: 'leptonica-platform'
  4. exclude group: 'org.bytedeco.javacpp-presets', module: 'hdf5-platform'
  5. exclude group: 'org.nd4j', module: 'nd4j-base64'
  6. }
  7. compile group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta4'
  8. compile group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta4', classifier: "android-arm"
  9. compile group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta4', classifier: "android-arm64"
  10. compile group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta4', classifier: "android-x86"
  11. compile group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta4', classifier: "android-x86_64"
  12. compile group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3'
  13. compile group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm"
  14. compile group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm64"
  15. compile group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86"
  16. compile group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86_64"
  17. compile group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3'
  18. compile group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm"
  19. compile group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm64"
  20. compile group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86"
  21. compile group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86_64"
  22. compile group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3'
  23. compile group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm"
  24. compile group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm64"
  25. compile group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86"
  26. compile group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86_64"
  27. implementation 'com.google.code.gson:gson:2.8.2'
  28. annotationProcessor 'org.projectlombok:lombok:1.16.16'
  29. //This corrects for a junit version conflict.
  30. configurations.all {
  31. resolutionStrategy.force 'junit:junit:4.12'
  32. }

Compiling these dependencies involves a large number of files, thus it is necessary to set multiDexEnabled to true in defaultConfig.

  1. multiDexEnabled true

Finally, a conflict in the junit module versions will give the following error: > Conflict with dependency ‘junit:junit’ in project ‘:app’. Resolved versions for app (4.8.2) and test app (4.12) differ.This can be suppressed by forcing all of the junit modules to use the same version.

  1. configurations.all {
  2. resolutionStrategy.force 'junit:junit:4.12'
  3. }

Training and loading the Mnist model in the Android project resources

Using a neural network requires a significant amount of processor power, which is in limited supply on mobile devices. Therefore, a background thread must be used for loading of the trained neural network and the testing of the user drawn image by using AsyncTask. In this application we will run the canvas.draw code on the main thread and use an AsyncTask to load the drawn image from internal memory and test it against the trained model on a background thread. First, lets look at how to save the trained neural network we will be using in the application.

You will need to begin by following the DeepLearning4j quick start guide to set up, train, and save neural network models on a desktop computer. The DL4J example which trains and saves the Mnist model used in this application is MnistImagePipelineExampleSave.java and is included in the quick start guide referenced above. The code for the Mnist demo is also available here. Running this demo will train the Mnist neural network model and save it as “trained_mnist_model.zip” in the dl4j\target folder of the dl4j-examples directory. You can then copy the file and save it in the raw folder of your Android project.

Android Image Classifier - 图2

Accessing the trained model using an AsyncTask

Now let’s start by writing our AsyncTask<Params, Progress, Results> to load and use the neural network on a background thread. The AsyncTask will use the parameter types <String, Integer, INDArray>. The Params type is set to String, which will pass the Path for the saved image to the asyncTask as it is executed. This path will be used in the doInBackground() method to locate and load the trained Mnist model. The Results parameter is of type INDArray which will store the results from the neural network and pass it to the onPostExecute method that has access to the main thread for updating the UI. For more on NDArrays, see https://nd4j.org/userguide. Note that the AsyncTask requires that we override two more methods (the onProgressUpdate and onPostExecute methods) which we will get to later in the demo.

  1. private class AsyncTaskRunner extends AsyncTask<String, Integer, INDArray> {
  2. // Runs in UI before background thread is called.
  3. @Override
  4. protected void onPreExecute() {
  5. super.onPreExecute();
  6. }
  7. @Override
  8. protected INDArray doInBackground(String... params) {
  9. // Main background thread, this will load the model and test the input image
  10. // The dimensions of the images are set here
  11. int height = 28;
  12. int width = 28;
  13. int channels = 1;
  14. //Now we load the model from the raw folder with a try / catch block
  15. try {
  16. // Load the pretrained network.
  17. InputStream inputStream = getResources().openRawResource(R.raw.trained_mnist_model);
  18. MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(inputStream);
  19. //load the image file to test
  20. File f=new File(absolutePath, "drawn_image.jpg");
  21. //Use the nativeImageLoader to convert to numerical matrix
  22. NativeImageLoader loader = new NativeImageLoader(height, width, channels);
  23. //put image into INDArray
  24. INDArray image = loader.asMatrix(f);
  25. //values need to be scaled
  26. DataNormalization scalar = new ImagePreProcessingScaler(0, 1);
  27. //then call that scalar on the image dataset
  28. scalar.transform(image);
  29. //pass through neural net and store it in output array
  30. output = model.output(image);
  31. } catch (IOException e) {
  32. e.printStackTrace();
  33. }
  34. return output;
  35. }

Handling images from user input

Now lets add the code for the drawing canvas that will run on the main thread and allow the user to draw a number on the screen. This is a generic draw program written as an inner class within the MainActivity. It extends View and overrides a series of methods. The drawing is saved to internal memory and the AsyncTask is executed with the image Path passed to it in the onTouchEvent case statement for case MotionEvent.ACTION_UP. This has the streamline action of automatically returning results for an image after the user completes the drawing.

  1. //code for the drawing input
  2. public class DrawingView extends View {
  3. private Path mPath;
  4. private Paint mBitmapPaint;
  5. private Paint mPaint;
  6. private Bitmap mBitmap;
  7. private Canvas mCanvas;
  8. public DrawingView(Context c) {
  9. super(c);
  10. mPath = new Path();
  11. mBitmapPaint = new Paint(Paint.DITHER_FLAG);
  12. mPaint = new Paint();
  13. mPaint.setAntiAlias(true);
  14. mPaint.setStrokeJoin(Paint.Join.ROUND);
  15. mPaint.setStrokeCap(Paint.Cap.ROUND);
  16. mPaint.setStrokeWidth(60);
  17. mPaint.setDither(true);
  18. mPaint.setColor(Color.WHITE);
  19. mPaint.setStyle(Paint.Style.STROKE);
  20. }
  21. @Override
  22. protected void onSizeChanged(int W, int H, int oldW, int oldH) {
  23. super.onSizeChanged(W, H, oldW, oldH);
  24. mBitmap = Bitmap.createBitmap(W, H, Bitmap.Config.ARGB_4444);
  25. mCanvas = new Canvas(mBitmap);
  26. }
  27. @Override
  28. protected void onDraw(Canvas canvas) {
  29. canvas.drawBitmap(mBitmap, 0, 0, mBitmapPaint);
  30. canvas.drawPath(mPath, mPaint);
  31. }
  32. private float mX, mY;
  33. private static final float TOUCH_TOLERANCE = 4;
  34. private void touch_start(float x, float y) {
  35. mPath.reset();
  36. mPath.moveTo(x, y);
  37. mX = x;
  38. mY = y;
  39. }
  40. private void touch_move(float x, float y) {
  41. float dx = Math.abs(x - mX);
  42. float dy = Math.abs(y - mY);
  43. if (dx >= TOUCH_TOLERANCE || dy >= TOUCH_TOLERANCE) {
  44. mPath.quadTo(mX, mY, (x + mX)/2, (y + mY)/2);
  45. mX = x;
  46. mY = y;
  47. }
  48. }
  49. private void touch_up() {
  50. mPath.lineTo(mX, mY);
  51. mCanvas.drawPath(mPath, mPaint);
  52. mPath.reset();
  53. }
  54. @Override
  55. public boolean onTouchEvent(MotionEvent event) {
  56. float x = event.getX();
  57. float y = event.getY();
  58. switch (event.getAction()) {
  59. case MotionEvent.ACTION_DOWN:
  60. invalidate();
  61. clear();
  62. touch_start(x, y);
  63. invalidate();
  64. break;
  65. case MotionEvent.ACTION_MOVE:
  66. touch_move(x, y);
  67. invalidate();
  68. break;
  69. case MotionEvent.ACTION_UP:
  70. touch_up();
  71. absolutePath = saveDrawing();
  72. invalidate();
  73. clear();
  74. loadImageFromStorage(absolutePath);
  75. onProgressBar();
  76. //launch the asyncTask now that the image has been saved
  77. AsyncTaskRunner runner = new AsyncTaskRunner();
  78. runner.execute(absolutePath);
  79. break;
  80. }
  81. return true;
  82. }
  83. public void clear(){
  84. mBitmap.eraseColor(Color.TRANSPARENT);
  85. invalidate();
  86. System.gc();
  87. }
  88. }

Now we need to build a series of helper methods. First we will write the saveDrawing() method. It uses getDrawingCache() to retrieve the drawing from the drawingView and store it as a bitmap. We then create a file directory and file for the bitmap called “drawn_image.jpg”. Finally, FileOutputStream is used in a try / catch block to write the bitmap to the file location. The method returns the absolute Path to the file location which will be used by the loadImageFromStorage() method.

  1. public String saveDrawing(){
  2. drawingView.setDrawingCacheEnabled(true);
  3. Bitmap b = drawingView.getDrawingCache();
  4. ContextWrapper cw = new ContextWrapper(getApplicationContext());
  5. // set the path to storage
  6. File directory = cw.getDir("imageDir", Context.MODE_PRIVATE);
  7. // Create imageDir and store the file there. Each new drawing will overwrite the previous
  8. File mypath=new File(directory,"drawn_image.jpg");
  9. //use a fileOutputStream to write the file to the location in a try / catch block
  10. FileOutputStream fos = null;
  11. try {
  12. fos = new FileOutputStream(mypath);
  13. b.compress(Bitmap.CompressFormat.JPEG, 100, fos);
  14. } catch (Exception e) {
  15. e.printStackTrace();
  16. } finally {
  17. try {
  18. fos.close();
  19. } catch (IOException e) {
  20. e.printStackTrace();
  21. }
  22. }
  23. return directory.getAbsolutePath();
  24. }

Next we will write the loadImageFromStorage method which will use the absolute path returned from saveDrawing() to load the saved image and display it in the UI as part of the output display. It uses a try / catch block and a FileInputStream to set the image to the ImageView img in the UI layout.

  1. private void loadImageFromStorage(String path)
  2. {
  3. //use a fileInputStream to read the file in a try / catch block
  4. try {
  5. File f=new File(path, "drawn_image.jpg");
  6. Bitmap b = BitmapFactory.decodeStream(new FileInputStream(f));
  7. ImageView img=(ImageView)findViewById(R.id.outputView);
  8. img.setImageBitmap(b);
  9. }
  10. catch (FileNotFoundException e)
  11. {
  12. e.printStackTrace();
  13. }
  14. }

We also need to write two methods that extract the predicted number from the neural network output and the confidence score, which we will call later when we complete the AsyncTask.

  1. //helper class to return the largest value in the output array
  2. public static double arrayMaximum(double[] arr) {
  3. double max = Double.NEGATIVE_INFINITY;
  4. for(double cur: arr)
  5. max = Math.max(max, cur);
  6. return max;
  7. }
  8. // helper class to find the index (and therefore numerical value) of the largest confidence score
  9. public int getIndexOfLargestValue( double[] array )
  10. {
  11. if ( array == null || array.length == 0 ) return -1;
  12. int largest = 0;
  13. for ( int i = 1; i < array.length; i++ )
  14. {if ( array[i] > array[largest] ) largest = i; }
  15. return largest;
  16. }

Finally, we need a few methods we can call to control the visibility of an ‘In Progress…’ message while the background thread is running. These will be called when the AsyncTask is executed and in the onPostExecute method when the background thread completes.

  1. public void onProgressBar(){
  2. TextView bar = findViewById(R.id.processing);
  3. bar.setVisibility(View.VISIBLE);
  4. }
  5. public void offProgressBar(){
  6. TextView bar = findViewById(R.id.processing);
  7. bar.setVisibility(View.INVISIBLE);
  8. }

Now let’s go to the onCreate method to initialize the draw canvas and set some global variables.

  1. public class MainActivity extends AppCompatActivity {
  2. MainActivity.DrawingView drawingView;
  3. String absolutePath;
  4. public static INDArray output;
  5. @Override
  6. public void onCreate(Bundle savedInstanceState) {
  7. super.onCreate(savedInstanceState);
  8. setContentView(R.layout.activity_main);
  9. RelativeLayout parent = findViewById(R.id.layout2);
  10. drawingView = new MainActivity.DrawingView(this);
  11. parent.addView(drawingView);
  12. }

Updating the UI

Now we can complete our AsyncTask by overriding the onProgress and onPostExecute methods. Once the doInBackground method of AsyncTask completes, the classification results will be passed to the onPostExecute which has access to the main thread and UI allowing us to update the UI with the results. Since we will not be using the onProgress method, a call to its superclass will suffice.

  1. @Override
  2. protected void onProgressUpdate(Integer... values) {
  3. super.onProgressUpdate(values);
  4. }

The onPostExecute method will receive an INDArray which contains the neural network results as a 1x10 array of probability values that the input drawing is each possible digit (0..9). From this we need to determine which row of the array contains the largest value and what the size of that value is. These two values will determine which number the neural network has classified the drawing as and how confident the network score is. These values will be referred to in the UI as Prediction and the Confidence, respectively. In the code below, the individual values for each position of the INDArray are passed to an array of type double using the getDouble() method on the result INDArray. We then get references to the TextViews which will be updated in the UI and call our helper methods on the array to return the array maximum (confidence) and index of the largest value (prediction). Note we also need to limit the number of decimal places reported on the probabilities by setting a DecimalFormat pattern.

  1. @Override
  2. protected void onPostExecute(INDArray result) {
  3. super.onPostExecute(result);
  4. //used to control the number of decimals places for the output probability
  5. DecimalFormat df2 = new DecimalFormat(".##");
  6. //transfer the neural network output to an array
  7. double[] results = {result.getDouble(0,0),result.getDouble(0,1),result.getDouble(0,2),
  8. result.getDouble(0,3),result.getDouble(0,4),result.getDouble(0,5),result.getDouble(0,6),
  9. result.getDouble(0,7),result.getDouble(0,8),result.getDouble(0,9),};
  10. //find the UI tvs to display the prediction and confidence values
  11. TextView out1 = findViewById(R.id.prediction);
  12. TextView out2 = findViewById(R.id.confidence);
  13. //display the values using helper functions defined below
  14. out2.setText(String.valueOf(df2.format(arrayMaximum(results))));
  15. out1.setText(String.valueOf(getIndexOfLargestValue(results)));
  16. //helper function to turn off progress test
  17. offProgressBar();
  18. }

Conclusion

This tutorial provides a basic framework for image recognition in an Android Application using a DL4J neural network. It illustrates how to load a pre-trained DL4J model from the raw resources file and how to test user generate input images against the model. The AsyncTask then returns the output to the main thread and updates the UI.

The complete code for this example is available here.