Transfer Learning for Mobile ML

Can transfer learning help enable AI on every smartphone?

Machine learning and deep learning are enabling amazing applications all around us. However, traditional models of deep neural networks are designed to take advantage of the vast memory and raw computing power of centralized servers.

In this article, we examine how the concept of ‘transfer learning’ may usher in deep learning efficiently on the edge, and more specifically, inside your smartphone.

What is ‘Edge Computing’ and how is it relevant to mobile machine learning?

Our modern world runs on the cloud. A vast infrastructure of liquid-cooled computing machinery humming away 24/7 inside a cold and robotic data center in a remote wilderness — that’s the engine powering everything from economics to entertainment.

While in a cloud computing architecture, the actual processing of data occurs far away from the source, edge computing combines the generation, collection, and analysis of data at the point of origin, rather than transmitting to a data center or cloud. Following the explosive growth of smartphones and IoT-enabled devices, the expansion of edge computing to the realm of machine learning is taking shape. Therefore, the natural question is: “What kind of novel applications can a smartphone enable through embedded machine learning?”

Potential applications of mobile machine learning

A modern smartphone has a dizzying array of sensors and actuators. As such, it’s easy to imagine some wonderful ‘intelligent functions’ that can be performed by this palm-sized pieces of hardware. Here are a few fun and interesting ones:

  • Dog breed classifier ( point your phone at a dog walking by and identify its breed instantly)
  • Game predictor (point your phone to a chessboard and delight the person with the prediction that his/her winning chance is high) — this is the sort of algorithm used by the now famous AlphaGo program of Google’s DeepMind.
  • On-spot preliminary medical diagnosis of tumors and scars before a full analysis can be performed — particularly useful for remote locations in developing nations where Internet connection/cloud infrastructure is hard to come by. Read this article for more discussion about this type of diagnosis.
  • Smart Poet — Generate simple poems and short stories based on spoken keywords or seed sentences/themes
  • Personal health adviser — Machine learning-based predictor of health issue based on the steps you took in a day, the exercise you did, and basic health information you punch in the software — it’s far more reassuring to know this kind of personal data is being analyzed right inside your phone and not being transmitted and stored on a remote server somewhere beyond your control.
  • Personalized finance adviser — Financial analysis tools employing machine learning and statistical modeling, based on your buying habits and bank/stock/bond transactions. Again, you don’t want to transmit this kind of sensitive data out of your phone if possible.

They exist everywhere and are often free to use. But for the most part, they reside on the cloud and leverage shared computing and storage resources to analyze data and build models for prediction.

However, it will be a grand success (and a matter of boosting consumer confidence) for mobile platforms when these kind of predictive and prescriptive analyses with sensitive data can be performed on the edge and not on the cloud.

Why is deep learning so expensive on mobile platforms?

At its core, deep learning is the technique of building neural networks with multiple hidden layers and using a vast amount of input (already labeled) data. Its goal is to optimize the weights associated with these complex networks to approximate a prediction function, which can be applied to new, unseen input cases to predict outcomes. Recent advances and breathtaking successes of deep learning are often attributed to three main factors:

  • Advances in algorithms
  • Availability of huge amount of raw data (or training data as we call it)
  • Quantum leap in computing power and specialized processors (read GPUs or Graphics Processing Engines) to efficiently handle vectorized matrix computations employed by deep learning networks

It isn’t difficult to see that clever algorithms can be easily codified and packaged inside smartphone software. So that’s not the bottleneck. But how do we harness the power of huge amounts of data and raw computing resources?

Everybody can take a huge amount of photos rather quickly with the advanced semi-automated phone cameras or record a lot of spoken words. In fact, for class identification problems (breed of dog, species of flower) a medium resolution, gray scale photo is often more than adequate (and more robust from an algorithmic point of view) for the machine learning task. So, raw data may not be the critical bottleneck for such mobile machine learning software after all.

Here is an example image of a convolutional neural network (CNN), the central ‘brain’ of such an intelligent model, which performs image recognition tasks like the one shown in the figure (e.g. recognizing the class of the input image i.e. whether it is a car, truck, airplane, ship, or a horse).

We don’t need to delve into the details of all the elements of this model. For enthusiastic readers who want to understand the nuts and bolts of such a CNN model, there’s a plethora of resources such as: (a) Stanford CS231 class notes on, (b) this excellent article on Medium, (c) the great (and free) course by Andrew Ng on Coursera, or the great video explanation by Brandon Rohrer of Facebook:

The point is there’s a requirement for a bunch of power-intensive matrix computations (and many passes/rounds of them, for that matter) on a number of layers (consisting of individual neurons or computing nodes) before a CNN becomes good and stable at recognizing correct classes from a set of input images.

This is a computationally intensive process for a mobile processor (a matrix inverse or multiplication has a complexity of O(n^3) i.e. in doubling the size, the computation requires 8 times as much processing power). One can see that the problem grows quickly to an intractable size when thinking about training on an average mobile hardware platform.

Furthermore, the latest deep learning algorithms may demand sizable amounts of memory (RAM) while doing these training passes through the hidden layers to speed up the whole compute process. If the cumulative size of these memory chunks grows beyond the physical limit of the smartphone RAM, the whole process may collapse. That’s not a pretty scenario for the user, who might have spent considerable time collecting and labeling the input images/data for training the deep learning model.

What is “Transfer Learning” and how can it help?

Transfer learning is nothing but a cute little idea to short-circuit a sizable portion of the compute-intensive training process for a CNN and use previously trained optimized models for your specific task. The idea is illustrated in the following figure:

Basically, all the complex convolutional and associated pooling/averaging layers’ details are abstracted in the new model in the sense that only the final layers’ architecture and weights are transferred to the new model.

However, this begs the obvious question: “What if the output classes of the original model were quite different than the ones I want to predict in my task?”. Although not obvious, the answer turns out to be a hopeful one. It’s been shown that, in many cases, image classification tasks achieve remarkably high accuracy with this transfer learning approach (i.e. using a pre-trained model) even when the final classes (to be predicted) are completely different from the original classes that the model was trained on.

Of course, to achieve this, one may have to add one fully-connected layer followed by a task-specific softmax layer (a special layer at the output of the neural network which squashes the output numbers from the previous layers into probabilities corresponding to the output classes) and train with her own data set for a small number of epochs.

Still, this operation is supposed to be much less compute intensive than fully training the deep network.

In addition, this also saves the mobile machine learning practitioner countless hours in hyperparameter tuning for all those hidden layers (filter size, stride, pooling layer strategy, dropout probability, learning rate, etc.) and uses sound engineering knowledge and established best practices.

Although there’s no rigorous theoretical proof of why this approach works, it’s well understood that the reason for the success of transfer learning is the hierarchical nature of feature extraction at various stages of hidden layers in a CNN.

This means that it’s highly likely that the first few layers of a CNN are extracting simple features from the input images like straight or curved edges, while the deep layers are building more intricate features such as body parts of an animal or a complicated automobile shapes.

At the end, most of these feature sets turn out to be pretty generic, i.e. they’re likely to appear similar in many common images.

Consequently, even if the original model was trained on images of cats and dogs, the final weights are so adjusted that they can readily be applied to tell the difference between a tiger and a lion, even if the model is never directly trained on those animals.

Watch this video to understand more about this hierarchical feature extraction concept:

Related resources: Here is a Medium article on an example of transfer learning using Keras. And here is a more in-depth tutorial on transfer learning from Or you can check this wonderful blog from Sebastian Ruder for an honest discussion about the possibilities and limitations of this approach.

Pre-built models for transfer learning

It turns out there are plenty of famous models that we can use for transfer learning, particularly for image classification tasks:

AlexNet: This network, proposed by Alex Krizhevsky (which won the 2012 ImageNet LSVRC-2012 competition), pioneered the use of ReLu (Rectified Linear Unit) for the non-linear part, instead of a tanh(x) or sigmoid function which was the earlier standard for traditional neural networks. It is composed of 5 convolutional layers followed by 3 fully connected layers.

In total, AlexNet has 62.3 million parameters and needs 1.1 billion computation units in a forward pass. The convolution layers, which account for 6% of all the parameters, consume 95% of the computation. It’s surely a good thing that we can download the network weights and potentially start optimizing a mobile machine learning model on top of it by adding a softmax layer only.

VGG16/VGG32: These are very deep CNNs used by researchers to win the 2014 ImageNet competition. The idea is to use a small 3×3 convolutional filter window along with 16–18 layers of depth. The Keras model for VGG16 is available here. Here is the original paper describing the model architecture in detail.

GoogLeNet/Inception: While VGG achieves phenomenal accuracy on the ImageNet dataset, its deployment on even the most modestly sized GPUs is a problem because of huge computational requirements, both in terms of memory and time.

It becomes inefficient because of the large width of convolutional layers. So GoogLeNet devised a module called inception module that approximates a sparse CNN with a normal dense construction (shown in the figure). Since only a small number of neurons are effective, the width/number of the convolutional filters of a particular kernel size is kept small.

Also, it uses convolutions of different sizes to capture details at varied scales (5X5, 3X3, 1X1). Another change that GoogLeNet made was to replace the fully-connected layers at the end with a simple global average pooling, which averages out the channel values across the 2D feature map after the last convolutional layer.

This drastically reduces the total number of parameters. Here’s an article showing how a simple Python code can be written to download the latest checkpoints/weights from the Inception model and then used for image classification with minimal computational load.

ResNet: At the ILSVRC 2015, the so-called Residual Neural Network (ResNet) by Kaiming He et al introduced a novel architecture with “skip connections” that features heavy batch normalization. Such skip connections are also known as gated units or gated recurrent units and are similar to recent successful elements applied in recurrent neural networks (RNNs).

Thanks to this technique, they were able to train a network with 152 layers while still having lower complexity than VGGNet.

Here is a good summative discussion about these architectures.

Summary and Conclusion

The potential for applying deep learning methods to mobile machine learning can unleash a torrent of smart apps and disrupt a wide range of industries. App developers and data scientists will be free of the burden of maintaining a dedicated high-speed data connection for transmitting their data to a remote cloud resource to do any meaningful analysis.

The full power of advanced machine learning will come to the ‘edge’. This will be particularly beneficial to the billions of mobile phone users in developing nations where high-speed data connection is patchy at best. Therefore, this kind of edge computing can bring revolutions to transportation, medicine and healthcare, personal finance, communication, and travel.

However, without finding a solution for reducing the burden of the compute-intensive model training process, this dream may not be realized to its full extent.

In this article, we discussed one such ‘short-circuiting’ idea which aims to reuse the already optimized models for the specific task at hand while significantly reducing the compute and storage requirements on mobile hardware. We also looked at a few well-known examples of pre-trained models for image classification tasks.

I hope that some enthusiastic readers will try these models for transfer learning on their next mobile deep learning app and solve an amazing problem or two to bring the full power of artificial intelligence to the realm of edge computing.

If you have any questions or ideas to share, please contact the author at tirthajyoti[AT] Also you can check author’s GitHub repositories for other fun code snippets in Python, R, or MATLAB and machine learning resources. If you are, like me, passionate about machine learning/data science, please feel free to add me on LinkedIn or follow me on Twitter.

Discuss this post on Hacker News.


Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *

wix banner square