Deep Learning: Run Pytorch+FastAI trained model on Android by using Pytorch Mobile (also applicable to iOS)
--
Introduction
In this brief blog entry I will train a simple classifier using the FastAI library and explain how to run it on a simple demo Android App by using Pytorch Mobile. Always follow the links for guides and more info. If you are non-beginer with deep learning and FastAI, please get straight to the point and go to [Let’s do it] section. If you don’t know anything about Android it’s ok (I didn’t when I start writing this).
Preconcepts
Totally avoid these sections if you know what the subtitle topic is.
Pytorch
A tensor library using GPUs and CPUs. Similar to TensorFlow but better designed because it’s much pythonic, much more flexible, and equally faster. Mainly designed for deep learning (it has a wide collection of ‘deep learning abstraction’) but you can also use it as a GPU optimized numpy.
FastAI
FastAI is an amazing library that makes incredible easy to train Deep Learning Pytorch models but super flexible and potent because it has a very well designed layered API. It was created by Jeremy Howard a distinguished research scientist (and very practical guy!) who wants to “make deep learning accessible for all”. I totally recommend FastAI courses as one of the best deep learning courses. They also recently lunched a fantastic book based on the library.
Pytorch Mobile
A runtime -still beta release when writing this post- that allows easily deploy models on mobile staying entirely within the PyTorch ecosystem. Here there is a Hello World (this post is strongly based on it) and here there is An Overview of the PyTorch Mobile Demo Apps.
Motivation
I created this guide to provide a step by step guide of how to run natively mobile FastAI trained models. I found many other very contributing blog entries but they provided solutions based on cloud web services:
- “FastAi and Render: Quick and easy way to create and deploy Computer Vision Models”: creates a Web App by using a cloud provider called Render that make super trivial to deploy webservices.
- “How to Make a Cross-Platform Image Classifying App With Flutter and Fastai”: uses Render to provide a webservice and show how to use Flutter, an SDK created by Google that allows you to easily make an App that take pictures from camera/gallery and query the webservice.
- “Deploying Deep Learning Models On Web And Mobile (uses Heroku)”: uses a well known option: Flask (for python webservice creation) + Heroku (hosting and deployment system). It provides an App (for Android AND iOS as it uses Expo, an open-source platform for making universal native apps for both) but it doesn’t dive on how the app is created.
- “How to deploy Machine Learning models on Android and IOS with Telegram Bots”: a wonderful idea to make our model accessible! To provide them through Telegram Bot API. It explains how (very easy!) and how to host/deploy it on Heroku.
But many times we want to run it directly on the device (like a cellphone) for many reasons: latency or connectivity problems, preserving privacy or interactive use cases (like a prediction strongly integrated with the camera in a real-time way). Some FastAI forum entries ask/talk about it:
- “Inference locally on device, state of fast.ai/pytorch” (Mar 19)
- “How to deploy your model in production on android platform” (Mar 19)
- “Pytorch came out with a mobile version. Can we get a Fastai Mobile too?” (Jan 20)
In the last post it’s mentioned that it is super easy to run a Pytorch model using Pytorch Mobile but I wanted to do a guide explaining how. Also it’s important to port (maybe with a non-trivial re-implementation) all the pre-processing used at test time in our native platform. Anyway, remember that in Deep Learning we usually don’t over pre-process the data (just image crop or resize for image + normalization) and leave all the responsibility of feature extraction to the classifier BUT sometimes (for instance in many FastAI application examples) the image it’s not a natural image but a product of other signal processing (see this or this).
Requirements
For running this tutorial you should have:
- FastAI + Pytorch installed (FastAI conda installation will install pytorch).
- Item above you would want to do in an environment with GPU (for training). If you don’t have one you can use Collab or Papperspace cloud solutions. If you feel lazy for see how to use them you can execute in CPU (In my i7 7th Gen it took around 30 minutes and consumed ~7GB RAM). But I strongly suggest that you take your time and see how these free sites work!
- Android SDK + Android NDK. I’m going to use Android Studio for opening the project (but it’s not strictly required).
Let’s do it
The code is available on github.com/mariano22/fastai-android-demo
Training the model
The training pipeline could be found on this notebook. Also explained step by step what we do. But let’s review the most important things:
- Data-loading and training is almost the same as Chapter 5 of fastbook. A one-epoch fine tuning cycle (just train the head) will be fine (the pre-trained model is trained in ImageNet so it would be trivial for it to extract features to classify cats vs dogs).
- Don’t forget any piece of the pipeline: on the mobile we will have a picture and we wonder about the prediction (probabilities of cat vs dog).
When Learner.predict we do: Preprocessing → Learn.model → Activation
We must add the Softmax activation when saving the model to export. After it, we can follow the same steps that HelloWorld Pytroch Mobile for trace and saving the model.
In the notebook we also dive in how to show the preprocessing pipeline and be sure we can mimic natively on the device (where we don’t have DataLoaders nor other FastAI things):
Preprocessing = dls.after_item → dls.before_batch → dls.after_batch
Integrating on Android App
The main file in Android project is on MainActivity.java. The code is very commented and easy to understand.
- onCreate function inits everything. Also it’s the place where we load my_model.pt (stored in the assets folder).
- The button triggers the function addFileOnClick (the assosciation is made on the activity_main.xml file). There we call the gallery and onActivityResult is where we take the image choosed by the user, call showPredictionCatOrDog and show it result.
- showPredictionCatOrDog is the function where (the very simple) magick occurs: loads the image as tensor, normalizes it with imagenet stats, runs the model on it and calcs the predicted class (with the maximum probability).
The inference (module.forward) could be slow. In a real world application should not be done in the main thread or it will block the app.
Validation on cellphone
I’m very mistrustful person as I said, so I decided to include a piece of code for validating the dataset on the cellphone. That’s why I included all the dataset on the asset folder. Feel free to change RUN_VALIDATION to true to see the accuracy on some of the dataset images (I limit the percentage of images used by using RUN_VALIDATION_PCT). This validation on cellphone is not conceptually valid as some of this images were used for training (a conceptually good implementation would use the validation images used by the valid DataLoader only).
Final comment
I always think that the best part of an article is their comments. I love them. So I would 100% appreciate that you comment what you think about the article, if it helped you, if you have improvements or suggestions. Thanks!