Introduction to Federated Learning

Enabling on-device training, model personalization, and more

There are over 5 billion mobile device users all over the world. Such users generate massive amounts of data—via cameras, microphones, and other sensors like accelerometers—which can, in turn, be used for building intelligent applications. Such data is then collected in data centers for training machine/deep learning models in order to build intelligent applications.

However, due to data privacy concerns and bandwidth limitations, common centralized learning techniques aren’t appropriate—users are much less likely to share data, and thus the data will be only available on the devices.

This is where federated learning comes into play. According to Google’s research paper titled, Communication-Efficient Learning of Deep Networks from Decentralized Data [1], the researchers provide the following high-level definition of federated learning:

The outline of the article is as follows:

  • Data is Available Everywhere
  • What is Federated Learning?
  • Steps for Federated Learning
  • Properties of Problems Solved using Federated Learning
  • Federated Averaging Algorithm

Let’s get started.

Data is Available Everywhere

Living in the data era, data is a primary requirement for building intelligent applications. From where and how to get the data, then? The good news is that data is available everywhere—the bad news is that much of said data is inaccessible.

Mobile, embedded, and sensor-laden IoT devices are major sources of data nowadays. Being used frequently by its users and available by hand all time, mobile devices are the primary source of data.

According to a recent GSMA Mobile Economy report, the number of mobile users reached 5.2 billion in 2019 and was expected to increase to 5.8 billion by 2025. Out of the 5.2 billion mobile users, there are 3.8 users connected to the internet.

This means a couple of things. Internet connectivity is an indication of increased data generation, and that users will be in need of intelligent applications to have better experiences. Due to the existence of smart buildings, the report also shows that 12 billion IoT devices are available in 2019 and expected to increase to 24.6 billion devices by 2025.

According to a Pew Research Center report, the majority of such mobile devices are smart. The following figure shows a number of countries and the percentage of the adults who use smartphones:

The existence of such large numbers of data generators means data is indeed available everywhere. Each click by a mobile user adds more data about what interests the user and thus can be used to build intelligent applications with better UXs.

To make use of the users’ private data without revealing their privacy, federated learning comes into action.

What is Federated Learning?

Federated learning is a new type of learning introduced by Google in 2016 in a paper titled Communication-Efficient Learning of Deep Networks from Decentralized Data [1]. Besides the definition mentioned at the beginning of the article, let’s add more explanation of federated learning.

Federated learning is a relatively new type of learning that avoids centralized data collection and model training. In a traditional machine learning pipeline, data is collected from different sources (e.g. mobile devices) and stored in a central location (i.e. data center). Once all data is available at a center, a single machine learning model is trained by such data. Because the data must be moved from the users’ devices to a central device for building and training the model, this approach is called centralized learning.

On the other hand, federated learning is about training multiple machine learning models on mobile devices (which are referred to as clients) and then combining the results of all such models into a single model that resides at a server. Thus, a model is trained on devices themselves using ground-truth data and just the trained model is shared with a server. This way the user’s data is leveraged to build machine/deep learning models while keeping data private.

In this case, federated learning benefits from the users’ data without revealing their privacy. The raw data is available at the users’ devices and never moved to a data center—but a model out of this data is created, which in turn is sent to the server.

Using federated learning, the user’s data is not uploaded to the server and thus there is no DIRECT access to the data, but there is still the possibility of the data being accessed. Privacy breaking in federated learning will be discussed in a later post.

Steps for Federated Learning

Federated learning in theory is fairly simple and can be summarized in the following steps:

  1. A generic (shared) model is trained server-side.
  2. A number of clients are selected for training on top of the generic model.
  3. The selected clients download the model.
  4. The generic model is trained on the devices, leveraging the users’ private data, based on an optimization algorithm like the stochastic gradient descent.
  5. A summary of the changes made to the model (i.e. weights of the trained neural network) is sent to the server.
  6. The server aggregates the updates from all devices to improve the shared model. Update aggregation is done using a new algorithm called the federated averaging algorithm.
  7. The process of sending the generic model to mobile devices and updating them according to the received summary of updates is repeated.

The previous steps are summarized in the next figure, based on a blog post by Google research scientists:

A. Your phone personalizes the model locally, based on your usage.

B. Many users’ updates are aggregated.

C. A change to the shared model is made according to the aggregated updates, after which the procedure is repeated.

You can also watch this video from Google that summarizes the definition of federated learning.

Properties of Problems Solved using Federated Learning

According to the Google research paper [1], the ideal problems to be solved by federated learning have 3 properties. The first property is

When a single machine learning model is created at the server, it uses data from different users to create a single generic model. Because the users vary in how they use mobile devices, the model should be generic enough to cope with such variety.

Unfortunately, the user experience will not be enhanced by a generic model, but instead by a customized model that seems created specifically for the device. Such personalization is achieved using federated learning and can provide the feeling that the device is created just for the user.

The second property:

It isn’t practical to ask the user to upload large amounts of data to create a generic model at the server. This adds additional costs to the user. Also, the user is likely to reject uploading private data to help build a model, especially with applications that require sensitive user information. In cases where data is private or large in scale, federated learning is a good option compared to centralized learning.

With this largeness in size, a new challenge is introduced to federated learning. Mobile devices’ resources are limited. Working with large amounts of data will consume time and thus, more power.

To work around this issue, there is a miniature version of TensorFlow suitable for on-device training called TensorFlow Federated. Training will only take place when the device is IDLE, plugged into the charger, and have a free wireless connection.

The third property is:

According to Google research paper [1], here are 2 problems that fit the previous 3 properties:

In supervised learning, a model is trained using labeled data so that the model knows the labels of all training samples. For federated learning to work with supervised learning, the labels of the user’s private data must be available. Here’s the explanation from the Google research paper:

Federated Averaging Algorithm

According to the previous discussion, the server aggregates the changes (i.e. weights) received from all the devices. How is this aggregation applied? Using a new algorithm called the federated averaging algorithm.

The devices train the generic neural network model using the gradient descent algorithm, and the trained weights are sent back to the server. The server then takes the average of all such updates to return the final weights. The following pseudocode shows how the federated averaging algorithm works.

At the server, K clients are selected, which are indexed by the variable k. In parallel, all clients update the generic model weights according to the ClientUpdate() function, which returns the trained weights w back to the server. Finally, the server takes the average of all weights w received from the K clients. The average of the weights is regarded as the new set of weights for the generic model.

What’s Next

This article introduced federated learning, which is a new type of training method for machine learning models that leverages ground-truth data generated by an end device (i.e. a mobile phone) to update a generic or shared model that’s distributed to different devices. The article summarized the federated learning pipeline in 7 steps, starting with preparing a generic model through receiving the trained/updated versions from the mobile devices.

A primary motivation behind federated learning is to keep the data private and just share a model trained by such data. Unfortunately, this privacy could be broken, which I’ll discuss in my next article.

References

  1. McMahan, H. Brendan, et al. “Communication-efficient learning of deep networks from decentralized data.” arXiv preprint arXiv:1602.05629 (2016).
Avatar photo

Fritz

Our team has been at the forefront of Artificial Intelligence and Machine Learning research for more than 15 years and we're using our collective intelligence to help others learn, understand and grow using these new technologies in ethical and sustainable ways.

Comments 0 Responses

Leave a Reply

Your email address will not be published. Required fields are marked *