Federated Learning: Machine Learning on Decentralized Data (Google I/O'19)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
  • Original Title: Federated Learning: Machine Learning on Decentralized Data (Google I/O'19)
  • Author: TensorFlow
  • Description: Meet federated learning: a technology for training and evaluating machine learning models across a fleet of devices (e.g. Android phones), orchestrated by a ...
  • Youtube URL: https://www.youtube.com/watch?v=89BGjQYA0uE
👍︎︎ 1 👤︎︎ u/aivideos 📅︎︎ May 12 2019 🗫︎ replies
Captions
[MUSIC PLAYING] EMILY GLANZ: Hi, everyone. Thanks for joining us today. I'm Emily, a software engineer on Google's federated learning team. DANIEL RAMAGE: And I'm Dan. I'm a research scientist and the team lead. We'll be talking to day about Federated Learning-- machine learning on decentralized data. The goal of federated learning is to enable edge devices to do state-of-the-art machine learning without centralizing data and with privacy by default. And, with privacy, what we mean is that we have an aspiration that app developers, centralized servers, and models themselves learn common patterns only. That's really what we mean by privacy. In today's talk, we'll talk about decentralized data, what it means to work with decentralized data in a centralized fashion. That's what we call federated computation. We'll talk a bit about learning on decentralized data. And then we'll give you an introduction to TensorFlow Federated, which is a way that you can experiment with federated computations in simulation today. Along the way, we'll introduce a few privacy principles, like ephemeral reports, and privacy technologies, like federated model averaging that embody those principles. All right, let's start with decentralized data. A lot of data is born at the edge, with billions of phones and IoT devices that generate data. That data can enable better products and smarter models. You saw in yesterday's keynote a lot of ways that that data can be used locally at the edge, with on-device inference, such as the automatic captioning and next generation assistant. On-device inference offers improvements to latency, lets things work offline, often has battery life advantages, and can also have some substantial privacy advantages because a server doesn't need to be in the loop for every interaction you have with that locally-generated data. But if you don't have a server in the loop, how do you answer analytics questions? How do you continue to improve models based on the data that those edge devices have? That's really what we'll be looking at in the context of federated learning. And the app we'll be focusing on today is Gboard, which is Google's mobile keyboard. People don't think much about their keyboards, but they spend hours on it each day. And typing on a mobile keyboard is 40% slower than on a physical one. It is easier to share cute stickers, though. Gboard uses machine-learned models for almost every aspect of the typing experience. Tap typing, gesture typing both depend on models because fingers are a little bit wider than the key targets, and you can't just rely on people hitting exactly the right keystrokes. Similarly, auto-corrections and predictions are powered by learned models, as well as voice to text and other aspects of the experience. All these models run on device, of course, because your keyboard needs to be able to work offline and quickly. For the last few years, our team has been working with the Gboard team to experiment with decentralized data. Gboard aims to be the best and most privacy forward keyboard available. And one of the ways that we're aiming to do that is by making use of an on-device cache of local interactions. This would be things like touch points, type text, context, and more. This data is used exclusively for federated learning and computation. EMILY GLANZ: Cool. Let's jump in to federated computation. Federated computation is basically a MapReduce for decentralized data with privacy-preserving aggregation built in. Let's introduce some of the key concepts of federated computations using a simpler example than Gboard. So here we have our clients. This is a set of devices-- some things like cell phones, or sensors, et cetera. Each device has its own data. In this case, let's imagine it's the maximum temperature that that device saw that day, which gets us to our first privacy technology-- on-device data sets. Each device keeps the raw data local, and this comes with some obligations. Each device is responsible for data asset management locally, with things like expiring old data and ensuring that the data is encrypted when it's not in use. So how do we get the average maximum temperature experienced by our devices? Let's imagine we had a way to only communicate the average of all client data items to the server. Conceptually, we'd like to compute an aggregate over the distributed data in a secure and private way, which we'll build up to throughout this talk. So now let's walk through an example where the engineer wants to answer a specific question of the decentralized data, like what fraction of users saw a daily high over 70 degrees Fahrenheit. The first step would be for the engineer to input this threshold to the server. Next, this threshold would then be broadcast to the subset of available devices the server has chosen to participate in this round of federated computation. This threshold is then compared to the local temperature data to compute a value. And this is going to be a 1 or a 0, depending on whether the temperature was greater than that threshold. Cool. So these values would then be aggregated using an aggregation operator. In this case, it's a federated mean, which encodes a protocol for computing the average value over the participating devices. The server is responsible for collating device reports throughout the round and emitting this aggregate, which contains the answer to the engineer's question. So this demonstrates our second privacy technology of federated aggregation. The server is combining reports from multiple devices and only persisting the aggregate, which now leads into our first privacy principle of only an aggregate. Performing that federated aggregation only makes the final aggregate data, those sums and averages over the device reports, available to the engineer, without giving them access to an individual report itself. So this now ties into our second privacy principle of ephemeral reports. We don't need to keep those per-device messages after they've been aggregated, so what we collect only stays around for as long as we need it and can be immediately discarded. In practice, what we've just shown is a round of computation. This server will repeat this process multiple times to get a better estimate to the engineer's question. It repeats this multiple times because some devices may not be available at the time of computation or some of the devices may have dropped out during this round. DANIEL RAMAGE: So what's different between federated computation and decentralized computation in the data center with things like MapReduce? Federal computation has challenges that go beyond what we usually experience in distributed computation. Edge devices like phones tend to have limited communication bandwidth, even when they're connected to a home Wi-Fi network. They're also intermittently available because the devices will generally participate only if they are idle, charging, and on an unmetered network. And because each compute node keeps the only copy of its data, the data itself has intermittent availability. Finally, devices participate only with the user's permission, depending on an app's policies. Another difference is that in a federated setting, it is much more distributed than a traditional data center distributed computation. So to give you a sense of orders of magnitude, usually in a data center, you might be looking at thousands or maybe tens of thousands of compute nodes, where this federated setting might have something like a billion compute nodes. Maybe something like 10 million are available at any given time. Something like 1,000 are selected for a given round of computation, and maybe 50 drop out. That's just kind of a rough sense of the scales that we're interested in supporting. And, of course, as Emily mentioned, privacy preserving aggregation is kind of fundamental to the way that we think about federated computation. So when you posed this set of differences, what does it actually look like when you run a computation in practice? This is a graph of the round completion rate by hour over the course of three days for a Gboard model that was trained in the United States. You see this periodic structure of peaks and troughs, which represent day versus night. Because devices are only participating when they're otherwise idle and charging, this represents that the peaks of down completion rate are when more devices are plugged in, which is usually when they're charging on someone's nightstand as they sleep. Rounds complete faster when more devices are available. And the device availability can change over the course of the day. That, in turn, implies a dynamic data availability because the data itself might be slightly different from the users who plug in phones at night versus the day, which is something that we'll get back to when we talk about federated learning in particular. Let's take a more in-depth example of what a federated computation looks like-- the relative typing frequencies of common words in Gboard. Typing frequencies are actually useful for improving the Gboard experience in a few ways. If someone has typed the letters H-I, "hi" is much, much more likely than "hieroglyphic." And so knowing those relative word frequencies allows the Gboard team to make the product better. How would we be compute these relative typing frequencies as a federated computation? Instead of the engineers specifying a single threshold. Now, what they would be specifying is something like a snippet of code that's going to be running on each edge device. And in practice, that will often be something that's actually in TensorFlow, but for here, I've written it as Python X pseudocode. So think of that device data as each device's record of what was typed in recent sessions on the phone. So for each word in that device data, if the word is in one of the common words we're trying to count, we'll increase its count when the local device updates. That little program is what would be shipped to the edge and run locally to compute a little map that says that perhaps this phone typed the word "hello" 18 times and "world" 0 times. That update would then be encoded as a vector. Here, the first element of the vector would represent the count for "hello" and the second one for the count for "world," which would then be combined and summed using the federated aggregation operators that Emily mentioned before. At the server, the engineer would see the counts from all the devices that have participated in that round, not from any single device, which brings up a third privacy principle of focused collection. Devices report only what is needed for this specific computation. There's a lot more richness in the on-device data set that's not being shared. And if the analyst wanted to ask a different question, for example, counting a different set of words, they would run a different computation. This would then repeat over multiple rounds, getting the aggregate counts higher and higher, which in turn would give us better and better estimates of the relative frequencies of the words typed across the population. EMILY GLANZ: Awesome. Let's talk about our third privacy technology of secure aggregation. In the previous example, we saw how this server only needs to emit the sum of vectors reported by the devices. The server could compute this sum from the device reports directly, but we've been researching ways to provide even stronger guarantees. Can we make it so the server itself cannot inspect individual reports? That is, how do we enforce that only in aggregate privacy principle we saw from before in our technical implementation? Secure aggregation is an optional extension to the client/server protocol that embodies this privacy principle. Here's how it works. So this is a simplified overview that demonstrates the key idea of how a server can compute a sum without being able to decrypt the individual messages. In practice, handling phones that have dropped partway is also required by this protocol. See the paper for details. Awesome. So let's jump into this. Through coordination by the server, two devices are going to agree upon a pair of large masks that when summed add to 0. Each device will add these masks to their vectors before reporting. All devices that are participating in this round of computation will exchange these zero-sum pairs. Reports will be completely masked by these values, such that we see that these added pairs now make each individual report themselves look randomized. But when aggregated together, the pairs cancel out, and we're left with only the sum we were looking for. In practice, again, this protocol is more complicated to handle dropout. So we showed you what you can do with federated computation. But what about the much more complex workflows associated with federated learning? Before we jump into federated learning, let's look at the typical workflow a model engineer who's performing machine learning would go through. Typically, they'll have some data in the cloud where they start training and evaluation jobs, potentially in grids to experiment with different hyperparameters, and they'll monitor how well these different jobs are performing. They'll end up with a model that will be a good fit for the distribution of cloud data that's available. So how does this workflow translate into a federated learning workflow? Well, the model engineer might still have some data in the cloud, but now this is proxy data that's similar to the on-device data. This proxy data might be useful for training and evaluating in advance, but our main training loop is now going to take place on our decentralized data. The model engineer will still do things that are typical of a machine learning workflow, like starting and stopping tasks, trying out different learning rates or different hyperparameters, and monitoring their performance as training is occurring. If the model performs well on that decentralized data set, the model engineer now has a good release candidate. They'll evaluate this release candidate using whatever validation techniques they typically use before deploying to users. These are things you can do with ModelValidator and TFX. They'll distribute this final model for on-device inference with TensorFlow Lite after validation, perhaps with a rollout or A/B testing. This deployment workflow is a step that comes after federated learning once they have a model that works well. Note that the model does not continue to train after it's been deployed for inference on device unless the model engineer is doing something more advanced, like on-device personalization. So how does this federated learning part work itself? If a device is idle and charging, it will check into the server. And most of the time, it's going to be told to go away and come back later. But some of the time, the server will have work to do. The initial model as dictated by the model engineer is going to be sent to the phone. For the initial model, usually 0s or a random initialization is sufficient. Or if they have some of that relevant proxy data in the cloud, they can also use a pre-trained model. The client computes an update to the model using their own local training data. Only this update is then sent to the server to be aggregated, not the raw data. Other devices are participating in this round, as well, performing their own local updates to the model. Some of the clients may drop out before reporting their update, but this is OK. The server will aggregate user updates into a new model by averaging the model updates, optionally using secure aggregation. The updates are ephemeral and will be discarded after use. The engineer will be monitoring the performance of federated training through metrics that are themselves aggregated along with the model. Training rounds will continue if the engineer is happy with model performance. A different subset of devices is chosen by the server and given the new model parameters. This is an iterative process and will continue through many training rounds. So what we've just described is our fourth privacy technology of federated model averaging. Our diagram showed federated averaging as the flavor of aggregation performed by the server for distributed machine learning. Federated averaging works by computing a data-weighted average of the model updates from many steps of gradient descent on the device. Other federization optimization techniques could be used. DANIEL RAMAGE: So what's different between federated learning and traditional distributed learning inside a data center? Well, it's all the differences that we saw with federated computation plus some additional ones that are learning specific. For example, the data sets in a data center are usually balanced in size. Most compute nodes will have a roughly equal size slice of the data. In the federated setting, each device has one users' data, and some users might use Gboard much more than others, and therefore those data set sizes might be very different. Similarly, the data in federated computation is very self-correlated. It's not a representative sample of all users' typing. Each device has only one user's data in it. And many distributed training algorithms in the data center make an assumption that every compute node gets a representative sample of the full data set. And, third, that variable data availability that I mentioned earlier-- because the people whose phones are plugged in at night versus plugged in during the day might actually be different, for example, night shift workers versus day shift workers, we might actually have different kinds of data available at different times of day, which is a potential source of bias when we're training federated models and an active area of research. What's exciting is the fact that federated model averaging actually works well for a variety of state-of-the-art models despite these differences. That's an empirical result. When we started this line of research, we didn't know if that would be true or if it would apply widely to the kinds of state-of-the-art models that teams like Gboard are interested in pursuing. The fact that it does work well in practice is great news. So when does federated learning apply? When is it most applicable? It's when the on-device data is more relevant than the server-side proxy data or its privacy sensitive or large in ways that would make it not make sense to upload. And, importantly, it works best when the labels for your machine-learned algorithm can be inferred naturally from user interaction. So what does that naturally inferred label look like? Let's take a look at some examples from Gboard. Language modeling is one of the most essential models that powers a bunch of Gboard experiences. The key idea in language modeling is to predict the next word based on typed text so far. And this, of course, powers the prediction strip, but it also powers other aspects of the typing experience. Gboard uses the language model also to help understand as you're tap typing or gesture typing which words are more likely. The model input in this case is the type in sequence so far, and the output is whatever word the user had typed next. That's what we mean by self-labeling. If you take a sequence of text, you can use every prefix of that text to predict the next word. And so that gives a series of training examples as result of people's natural use of the keyboard itself. The Gboard team ran dozens of experiments in order to replace their prediction strip language model with a new one based on a more modern recurrent neural network architecture, described in the paper linked below. On the left, we see a server-trained recurrent neural network compared to the old Gboard model, and on the right, a federated model compared to that same baseline. Now, these two model architectures are identical. The only difference is that one is trained in the data center using the best available server-side proxy data and the other was trained with federated learning. Note that the newer architecture is better in both cases, but the federated model actually does even better than the server model, and that's because the decentralized data better represents what people actually type. On the x-axis here for the federated model, we see the training round, which is how many rounds of computation did it take to hit a given accuracy on the y-axis? And the model tends to converge after about 1,000 rounds, which is something like a week on wall clock time. That's longer than in the data center, where the x-axis measures the step of SGD, where we get to a similar quality in about a day or two. But that week long time frame is still practical for machine learning engineers to do their job because they can start many models in parallel and work productively in this setting, even though it takes a little bit longer. What's the impact of that relatively small difference? It's actually pretty big. The next word prediction accuracy improves by 25% relative. And it actually makes the prediction strip itself more useful. Users click it about 10% more. Another example that the Gboard team has been working with is emoji prediction. Software keyboards have a nice emoji interface that you can find, but many users don't know to look there or find it inconvenient. And so Gboard has introduced the ability to predict emoji right in line on the prediction strip, just like next words. And the federated model was able to learn that the fire emoji is an appropriate completion for this party is lit. Now, on the bottom, you can see a histogram of just the overall frequency of emojis that people tend to type, which has the laugh/cry emoji much more represented. So this is how you know that the context really matters for emoji. We wouldn't want to make that laugh cry emoji just the one that we suggest all the time. And this model ends up with 7% more accurate emoji predictions. And Gboard users actually click the prediction strip 4% more. And I think, most importantly, there are 11% more users who've discovered the joy of including emoji in their texts, and untold numbers of users who are receiving those wonderfully emojiful texts. So far, we've focused on the text entry aspects, but there are other components to where federated learning can apply, such as action prediction in the UI itself. Gboard isn't really just used for typing. A key feature is enabling communication. So much of what people type is in messaging apps, and those apps can become more lively when you share the perfect GIF. So just helping people discover great GIFs to search for and share from the keyboard at the right times without getting in the way is one of Gboard's differentiating product features. This model was trained to predict from the context so far, a query suggestion for a GIF or a sticker, a search or emoji, and whether that suggestion is actually worth showing to the user at this time. An earlier iteration of this model is described at the paper linked below. This model actually resulted in a 47% reduction in unhelpful suggestions, while simultaneously increasing the overall rate of emoji, GIF and sticker shares by being able to better indicate when a GIF search would be appropriate, and that's what you can see in that animation. As someone types "good night," that little "g" turns into a little GIF icon, which indicates that a good GIF is ready to share. One final example that I'd like to give from Gboard is the problem of discovering new words. So what words are people typing that Gboard doesn't know? It can be really hard to type a word that the keyboard doesn't know because it will often auto-correct to something that it does know. And Gboard engineers can use the top typed unknown words to improve the typing experience. They might add new common words to the dictionary in the next model release after manual review or they might find out what kinds of typos are common, suggesting possible fixes to other aspects of the typing experience. Here is a sample of words that people tend to type that Gboard doesn't know. How did we get this list of words if we're not sharing the raw data? We actually trained a recurrent network to predict the sequence of characters people type when they're typing words that the keyboard doesn't know. And that model, just like the next word prediction model, is able to be used to sample out letter by letter words. We then take that model in the data center, and we ask it. We just generate from it. We generate millions and millions of samples from that model that are representative of words that people are typing out in the wild. And if we break these down a little bit, there is a mix of things. There's abbreviations, like "really" and "sorry" missing their vowels. There's extra letters added to "hahah" and "ewwww," often for emphasis. There are typos that are common enough that they show up even though Gboard likes to auto-correct away from those. There are new names. And we also see examples of non-English words being typed in an English language keyboard, which is what this was-- English in the US was what this was trained against. Those non-English words actually indicate another way that Gboard might improve. Gboard has, of course, an experience for typing in multiple languages. And perhaps there's ways that that multilingual experience or switching language more easily could be improved. This also brings us to our fourth privacy principle, which is don't memorize individuals' data. We're careful in this case to use only models aggregated over lots of users and trained only on out of vocabulary words that have a particular flavor, such as not having a sequence of digits. We definitely don't want the model we've trained in federated learning to be able to memorize someone's credit card number. And we're looking further at techniques that can provide other kinds of even stronger and more provable privacy properties. One of those is differential privacy. This is the statistical science of learning common patterns in the data set without memorizing individual examples. This is a field that's been around for a number of years and it is very complementary to federated learning. The main idea is that when you're training a model with federated learning or in the data center, you're going to use appropriately calibrated noise that can obscure an individual's impact on the model that you've learned. This is something that you can experiment with a little bit today in the TensorFlow privacy project, which I've linked here, for more traditional data center settings, where you might have all the data available and you'd like to be able to use an optimizer that adds the right kind of noise to be able to guarantee this property, that individual examples aren't memorized. The combination of differential privacy and federated learning is still very fresh. Google is working to bring this to production, and so I'm giving you kind of a preview of some of these early results. Let me give you a flavor of how this works with privacy technology number five-- differentially private model averaging, which is described in the ICLR paper linked here. The main idea is that in every round of federated learning, just like what Emily described for a normal round, an initial model will be sent to the device, and that model will be trained on that device's data. But here's where the first difference comes in. Rather than sending that model update back to the server for aggregation, the device first clips the update, which is to say it makes sure that the model update is limited to a maximum size. And by maximum size, we actually mean in a technical sense the L2 ball of in parameter space. Then the server will add noise when combining the device updates for that round. How much noise? It's noise that's roughly on the same order of magnitude as the maximum size that any one user is going to send. With those two properties combined and properly tuned, it means that any particular aspect of the updated model from that round might be because some user's contribution suggested that the model go that direction or it might be because of the random noise. That gives kind of an intuitive notion of plausible deniability about whether or not any change was due to a user versus the noise, but it actually provides even a more stronger formal property that the model that you learn with differentially private model averaging will be approximately the same model whether or not any one user was actually participating in training. And a consequence of that is that if there is something only one user has typed, this model can't learn it. We've created a production system for federated computation here at Google, which is what has been used by the Gboard team in the examples that I've talked about today. You can learn more about this in the paper we published at SysML this year, "Towards Federated Learning at Scale-- System Design." Now, this system is still being used internally. It's not yet a system that we expect external developers to be able to use, but that's something that we're certainly very interested in supporting. EMILY GLANZ: Awesome. We're excited to share our community projects that allows all to develop the building blocks of federated computations. And this is TensorFlow Federated. TFF offers two APIs, the Federated Learning or FL API, and the Federated Core, or FC API. The FL API comes with implementations of federated training and evaluation that can be applied to your existing Keras models so you can experiment with federated learning in simulation. The FC API allows you to build your own federated computations. And TFF also comes with a local runtime for simulations. So, earlier, we showed you how federated computation works conceptually. Here's what this looks like in TFF. So we're going to refer to these sensor readings collectively as a federated value. And each federated value has a type, both the placement-- so this is at clients-- and the actual type of the data items themselves, or a float32. The server also has a federated type. And, this time, we've dropped the curly braces to indicate that this is one value and not many, which gets us into our next concept is distributed aggregation protocol that runs between the clients and this server. So in this case, it's the TFF federated mean. So this is a federated operator that you can think of as a function, even though its inputs and its outputs live in different places. A federated op represents an abstract specification of a distributed communication protocol. So TFF provides a library of these federated operators that represent the common building types of federated protocols. So now I'm going to run through a brief code example using TFF. I'm not going to go too in-depth, so it might look a little confusing. But at the end, I'm going to put up a link to a site that provides more tutorials, and more walkthrough is of the code. So this section of code that I have highlighted right now declares our federated type that represents our input. So you can see we're defining both the placement, so this is at the TFF clients, and that each data item is a tf.float32. Next, we're passing this as an argument to this special function decorator that declares this a federated computation. And here we're invoking our federated operator. In this case, it's that tff.federated_mean on those sensor readings. So now let's jump back to that example where the model engineer had that specific question of what fraction of sensors saw readings that were greater than that certain threshold. So this is what that looks like in TFF. Our first federated operator in this case is the tff.federated_broadcast that's responsible for broadcasting that threshold to the devices. Our next federated operator is the tff.federated_map that you can think of as the map step in MapReduce. That gets those 1s and 0s representing whether their local values are greater than that threshold. And, finally, we perform a federated aggregation so that tff.federated_mean, to get the result back at the server. So let's look at this, again, in code. We're, again, declaring our inputs. Let's pretend we've already declared our readings type and now we're also defining our threshold type. This time, it has a placement at the server, and we're indicating that there is only one value with that all_equal=True, and it's a tf.float32. So we're again passing that into that function decorator to declare this a federated computation. We're invoking all those federated operators in the appropriate order. So we have that tff.federated_broadcast that's working on the threshold. We're performing our mapping step that's taking a computation I'll talk about in a second and applying it to the readings in that threshold that we just broadcast. And this chunk of code represents the local computation each device will be performing, where they're comparing their own data item to the threshold that they received. So I know that was a fast brief introduction to coding with TFF. Please visit this site, tensorflow.org/federated, to get more hands-on with the code. And if you like links, we have one more link to look at all the ideas we've introduced today about federated learning. Please check out our comic book at federated.withgoogle.com. We were fortunate enough to work with two incredibly talented comic book artists to illustrate these comics as graphic art. And it even has corgis. That's pretty cool. DANIEL RAMAGE: All right, so in today's talk, we covered decentralized data, federated computation, how we can use federated computation building blocks to do learning, and gave you a quick introduction to the TensorFlow Federated project, which you can use to experiment with how federated learning might work on data sets that you have already in the server in simulation today. We expect that you might have seen, the TF Lite team has also announced that training is a big part of their roadmap, and that's something that we are also really excited about for being able to enable external developers to run the kinds of things that we're running internally sometime soon. We also introduced privacy technologies, on-device data sets, federated aggregation, secure aggregation, federated model averaging, and the differentially private version of that, which embodies some privacy principles of only an aggregate, ephemeral reports, focused collection, and not memorizing individuals' data. So we hope we've given you a flavor of the kinds of things that federated learning and computation can do. To learn more, check out the comic book and play a little bit with TensorFlow Federated for a preview of how you can write your own kinds of federated computations. Thank you very much. [APPLAUSE] [MUSIC PLAYING]
Info
Channel: TensorFlow
Views: 38,200
Rating: 4.9166665 out of 5
Keywords: type: Conference Talk (Full production);, pr_pr: Google I/O, purpose: Educate
Id: 89BGjQYA0uE
Channel Id: undefined
Length: 41min 11sec (2471 seconds)
Published: Thu May 09 2019
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.