Deep Learning: Run Pytorch+FastAI trained model on Android by using Pytorch Mobile (also applicable to iOS)

Mariano Crosetti
6 min readJul 30, 2021

Introduction

In this brief blog entry I will train a simple classifier using the FastAI library and explain how to run it on a simple demo Android App by using Pytorch Mobile. Always follow the links for guides and more info. If you are non-beginer with deep learning and FastAI, please get straight to the point and go to [Let’s do it] section. If you don’t know anything about Android it’s ok (I didn’t when I start writing this).

The app will allow the user to upload an image from the gallery and classify between cat and dog. It will do it natively, processing it on the device.

Preconcepts

Totally avoid these sections if you know what the subtitle topic is.

Pytorch

A tensor library using GPUs and CPUs. Similar to TensorFlow but better designed because it’s much pythonic, much more flexible, and equally faster. Mainly designed for deep learning (it has a wide collection of ‘deep learning abstraction’) but you can also use it as a GPU optimized numpy.

FastAI

FastAI is an amazing library that makes incredible easy to train Deep Learning Pytorch models but super flexible and potent because it has a very well designed layered API. It was created by Jeremy Howard a distinguished research scientist (and very practical guy!) who wants to “make deep learning accessible for all”. I totally recommend FastAI courses as one of the best deep learning courses. They also recently lunched a fantastic book based on the library.

Pytorch Mobile

A runtime -still beta release when writing this post- that allows easily deploy models on mobile staying entirely within the PyTorch ecosystem. Here there is a Hello World (this post is strongly based on it) and here there is An Overview of the PyTorch Mobile Demo Apps.

Motivation

I created this guide to provide a step by step guide of how to run natively mobile FastAI trained models. I found many other very contributing blog entries but they provided solutions based on cloud web services:

But many times we want to run it directly on the device (like a cellphone) for many reasons: latency or connectivity problems, preserving privacy or interactive use cases (like a prediction strongly integrated with the camera in a real-time way). Some FastAI forum entries ask/talk about it:

In the last post it’s mentioned that it is super easy to run a Pytorch model using Pytorch Mobile but I wanted to do a guide explaining how. Also it’s important to port (maybe with a non-trivial re-implementation) all the pre-processing used at test time in our native platform. Anyway, remember that in Deep Learning we usually don’t over pre-process the data (just image crop or resize for image + normalization) and leave all the responsibility of feature extraction to the classifier BUT sometimes (for instance in many FastAI application examples) the image it’s not a natural image but a product of other signal processing (see this or this).

Requirements

For running this tutorial you should have:

  • FastAI + Pytorch installed (FastAI conda installation will install pytorch).
  • Item above you would want to do in an environment with GPU (for training). If you don’t have one you can use Collab or Papperspace cloud solutions. If you feel lazy for see how to use them you can execute in CPU (In my i7 7th Gen it took around 30 minutes and consumed ~7GB RAM). But I strongly suggest that you take your time and see how these free sites work!
  • Android SDK + Android NDK. I’m going to use Android Studio for opening the project (but it’s not strictly required).

Let’s do it

The code is available on github.com/mariano22/fastai-android-demo

Training the model

The training pipeline could be found on this notebook. Also explained step by step what we do. But let’s review the most important things:

  • Data-loading and training is almost the same as Chapter 5 of fastbook. A one-epoch fine tuning cycle (just train the head) will be fine (the pre-trained model is trained in ImageNet so it would be trivial for it to extract features to classify cats vs dogs).
  • Don’t forget any piece of the pipeline: on the mobile we will have a picture and we wonder about the prediction (probabilities of cat vs dog).

When Learner.predict we do: Preprocessing → Learn.model → Activation

We must add the Softmax activation when saving the model to export. After it, we can follow the same steps that HelloWorld Pytroch Mobile for trace and saving the model.

In the notebook it is explained with more detail why we add Softmax. Don’t skip reading it unless you understand why I’m adding Softmax. It will improve your pytorch and Deep Learning insight.

In the notebook we also dive in how to show the preprocessing pipeline and be sure we can mimic natively on the device (where we don’t have DataLoaders nor other FastAI things):

Preprocessing = dls.after_item → dls.before_batch → dls.after_batch

Integrating on Android App

The main file in Android project is on MainActivity.java. The code is very commented and easy to understand.

  • onCreate function inits everything. Also it’s the place where we load my_model.pt (stored in the assets folder).
  • The button triggers the function addFileOnClick (the assosciation is made on the activity_main.xml file). There we call the gallery and onActivityResult is where we take the image choosed by the user, call showPredictionCatOrDog and show it result.
  • showPredictionCatOrDog is the function where (the very simple) magick occurs: loads the image as tensor, normalizes it with imagenet stats, runs the model on it and calcs the predicted class (with the maximum probability).

The inference (module.forward) could be slow. In a real world application should not be done in the main thread or it will block the app.

For people as mistrustful as me, I left a code (in the last part of the notebook + in onCreate function before model loading) for checking that image loading is doing the same in Android than in python.

Validation on cellphone

I’m very mistrustful person as I said, so I decided to include a piece of code for validating the dataset on the cellphone. That’s why I included all the dataset on the asset folder. Feel free to change RUN_VALIDATION to true to see the accuracy on some of the dataset images (I limit the percentage of images used by using RUN_VALIDATION_PCT). This validation on cellphone is not conceptually valid as some of this images were used for training (a conceptually good implementation would use the validation images used by the valid DataLoader only).

Final comment

I always think that the best part of an article is their comments. I love them. So I would 100% appreciate that you comment what you think about the article, if it helped you, if you have improvements or suggestions. Thanks!

--

--

Mariano Crosetti

NLP & Computer Vision Engineer 🧠 Only posts that add VALUE to GROWTH as SWE AI / ML 📈 ICPC LATAM Champion 🏆 (defeated Berkeley, Stanford & ETH) Ex-@Google