Tom Goldstein: "What do neural loss surfaces look like?"

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
so in this talk I to explore structures that we see inside of neural network loss functions and some of the implications that has for the design of optimization algorithms so just to start fairly basic and to be clear about we're talking about a neural net at least in the computer vision context it takes in some image if it's sampled from the internet it's probably a cat or something like that you pass it through a bunch of convolutional filters and it spits out a label so in this case it says are looking at grumpy cat and so you can think in some abstract sense of a neural network as a function that maps images X on the labels Y and it has some set of parameters W that are called the weights and the way that we train these neural networks is by minimizing a loss function so we collect a whole bunch of these X's and Y's and then we have a loss function that penalizes the difference between f of X and y so you might pick something like a least squares loss or or in the the more sophisticated context like we'll look at later we use like a cross entropy loss but the point is that the loss function it's not the the loss function is a function of the weight parameters oh so no so the loss function has a very high dimensionality it does not have the dimensionality of the the input features and also it's very expensive to evaluate the loss function because you have to average over all of the data points in your training set and I could potentially be very large so if F was a linear function then this was just simply be least squares problem you know linear least squares problem this would be a convex quadratic loss function everything will be fine but these neural networks they have nonlinearities layered on top of nonlinearities and so you would think that this would be a very very non convex loss function but the question is how convex we know that we can minimize these things fairly easily oftentimes you minimize these things we have global minimizer's and so maybe it's not so non convex but at the same time the the sophisticated structure of this function would lead you to think that it is so how non convex is it there's been a lot of theoretical work on this and I'll just mention the sort of categories of results that people have been able to prove so there are theoretical results that say the local minima are not so bad so if you get a local minima it's not too far an objective value from a global minimizer there are results that show there are no local minimizer's under various assumptions for example if your neurons are linear so there's no non-linearity where you have certain very wide layers or forming structure in the weights and there's also a variety of results that show that you won't have local minima for some types of shallow net so neural nets with one tenant layer for example but this is an area we mostly rely on theory and the reason we rely on theory is that evaluating the loss function is really expensive you have to loop over the entire training data set you have to do a full at Park of computation just to evaluate the loss function once and so it's really hard to see what these things look like I've been interested in this problem for a long time and I've had a lot of discussions with my grad students on you know solutions for this problem we can investigate the loss function you know with harmonic analysis or random matrix theory or statistical mechanics but you have this fundamental problem that even with a GPU it's it's too expensive to plot the loss function but grad students don't like to do harmonic analysis so I gave this point on my grad students and they came back with their own solution which is more GPUs and so we did a bunch of loss function computations on a GPU cluster and we're actually able to produce visualizations of these loss functions and they look like this so this is the loss function of a 56 layer neural network training on the SyFy 10 dataset okay so we're going to look at a bunch of these loss functions and we'll talk about how things like neural net architecture design affects the structure of these loss functions but first I want to talk about our motivations and a little bit of background on how you visualize these things and how to interpret the visualizations so one of the things that motivated us to look at this problem is there's this there's this debate going on I call it the the sharp versus flat dilemma so it is widely believed that flat minimizer's of neural networks generalize better than sharp minimizer's flat means it looks more like this so it has a really wide Basin sharp minimizer's look like this they have a really narrow Basin and there's a variety of theoretical works and empirical works that I've that suggests that flat minimizer's generalize better than sharp minimizer's but and in fact a product charge we talked earlier today about some of his entropy SGD bounds that provide a nice way of measuring the flatness of some of these minimizer's but this is sort of up for debate there was a recent paper at ICML 2017 that that exploits certain types of scaling and variances that we'll talk about in a bit to show that sharp minimizer's can also generalize and sharp versus flat is really important because the the kinds of algorithms that we use to Train affect what kind of minimizer's we get so it's believed that the implicit regularization of sed is important that means we need to have noise in sed and so small batch sizes tend to give us flat minimizer's the SGD optimizer tends to give us a flat minimizer's but then they're also bad optimization methods in some sense for example big batch optimization or using the atom optimizer tends to result in sharper minimizer's that don't generalize well okay so this is a simple experiment that's appeared in the literature I'm going to show some of the caveats of visualizing lost landscapes so I'll do an experiment that start with a weight vector so this is in weight space from a very high dimensional space we'll start with some random initialization and we're going to train our neural network twice we'll train it with a small batch and we'll go on some path and find a small batch minimizer and we'll train it with a large batch and we'll look for it to see the large batch minimizer and the thought is that the small batch minimizer is going to be flatter and generalize better and the large batch minimizer is going to be sharper and generalize worse and there's been a number of papers that have used visualization of this type there's Danny Mandy and good has a paper on this and he and good fellow as a paper where they do similar types of visualizations and so what we do is we take these two minimizer's and we pass a line through them so this is a line that interpolates between these two points and weight space and then we're going to walk along that line and we're going to plot the loss function at every one of those locations and we get something that looks like this so the red line at the top this is plotting accuracy you can just focus on these blue these plot the loss function solid line is the training loss that dotted lines a little bit higher that's the test loss so you look over here the small batch minimizer has this really wide flat minimizer so that's nice and then we look over here at the large batch minimizer and it kind of comes down here like a needle and sheets back up so it's really really sharp and so it's very apparent that there's a sharpness difference between these two things so one of the students in my lab Holly ran this experiment we replicated the results of newspapers but then he started turning some knobs and weird things start to happen so he turned on weight decay and when you turn on weight decay which is like an l2 penalty on the size of the weights you got a completely opposite results so with weight decay now the sharp the the small batch of no-mind is really sharp and the large batch minimizer is noticeably flatter but it doesn't really change the it's only projected on a one-dimensional vector no but that's one of the caveats of this method but these plots can be a little bit deceiving which is sort of the point of looking at these two different scenarios okay so still the small batch minimizer generalizes better but it's way sharper than this large batch minimizer so what's going on so it turns out the way you're really plotting is not the the sharpness or flatness of these minimizer's what you're plotting is just the size of the weights so we look at a histogram of the weights and here you'll see the this blue histogram shows the weights for the small batch minimizer and the weights are really large so this histograms really spread out and when we use a large batch the histogram concentrates more toward the origin and you get smaller weights so what happens is when your weights are small so that say that your weights are about size one if you add a perturbation of size 10 to them that makes a huge difference and it destroys the performance of your neural nets but if your weights live on the scale between 0 and 100 adding a permutation of size 10 might not be so disastrous and so when the weights are big they were the loss function doesn't depend so much on small permutations when the weights are small it depends a lot and so when the weights are small you get these really sharp looking minimizer when the weights are big it looks flat okay but this difference in size scale it doesn't correlate at all with generalization error we saw here that we get good generalization error from the Sharpe minimizer and in general we don't really care about how big the weights are because of the scale invariance property so if you have a neuron that with great little ya yes it did oh sorry let me go back it did well they are they are relatively sharper right I mean this flat minimize are still quite a bit sharper than this one sorry I'm not sure which one you're referring to this one that's interesting yeah it doesn't with the large batch okay well there's a reason why this this just this do you sort of flip the balance by turning on weights okay it's because weight to Kate doesn't really affect large batches very much because we train these donuts for the same number of epochs and when you have a really large batch that doesn't generate very many updates there's not very many updates for epoch you have a small batch you have lots and lots and lots and lots of updates so it doesn't really you would actually what you might expect this to look a little bit sharper than it does but actually the the weight decay doesn't really affect the size of these um awaits as much as you would think one that when the batch is large but these are two completely different sets of minimizer's we're not plotting the same minimizer between the two so they don't necessarily have to be the same they don't even have to be comparable we'll see later when we do sort of a more normalized visualization you comparisons across these different weights of case scenarios make sense like you can you can correlate sharpness with generalization error okay so there's a scale invariance property if you use a nonlinearities then I can multiply I can pull out the weights in one layer and multiply them by ten and I can compensate that by multiplying the weights and another layer by 1/10 so I can divide by ten and when I do that you get the exact same thing out of the network that you did before so you get an equivalent neural network and when you use batch normalization so a Bachelor realization takes the outputs from a layer and it subtracts the mean and divides by the standard deviation and renormalize is everything so in that case the weight scaling is completely irrelevant you can multiply the weights by 100 and that factor of a hundred would just get taken out by the Batchelor normalization layer and so these differences in scaling they're not really a meaningful thing to visualize they don't really affect the performance of the neural network and it's not really these these sorts of simple plots aren't really visualizing the endogenous sharpness or flatness of these curves so they can be very misleading right all or really visualizing in this case it's some combination of the true sharpness and flatness combined with the scaling of the weights so how can we get a more meaningful measure of sharpness so we proposed something called a filter normalization it's just a very simple normalization scheme to sort of normalize the different scales between minimizer's and the way that we're going to make these two dimensional plots is like this so we'll start by training a neural network so we starts with some random initialization we train down the contours of our neural network until we get a minimizer so we find some minimizer and then we take two random directions in weight space so two random vectors and we slice through that point and those should be roughly orthogonal and so these two random directions to find a plane and weight space and then what we do is we apply a filter a normalization to these random directions and I'll describe what that means in a sec and then we just raster over this plane in 2d space so we just scan over it and we compute many-many loss function evaluations and that produces a visualization of the loss function ok so what do I mean by this filter normalization so the filter normalization we look at our neural network and we select out a filter so we take a convolutional filter out and I look at this filter and I say okay this filter has magnitude ten so the two norm of the magnitude is ten but then when I produce my random directions the the random direction that I produce so just random Gaussian variables had to produce a set of random variables with the same dimensions as this filter and that had a different form that had norm 20 and I want those to be comparable I don't want to be adding a huge perturbation to small weights I don't want to be adding a small perturbation huge weights so I'm going to multiply that by one-half and that brings a norm down to ten and now I have a direction that is same so it lives on the same distance scale as the filter that it's modifying okay so as a sanity check does this actually identify the correct distance scale for these problems is this a reasonable way to these random directions in filter normalization to plot these loss functions and we find that with filter normalization we actually get curvatures you know sharpness versus flatness measures at least visually that correlate really well with the differences in generalization that you predict in theory so this is plotting through a random line so a random direction through a minimizer obtained via SGD with batch size 128 and with batch size 4096 4096 it gets sharper and there's an increase in generalization error so that that minimizer is a little bit worse we can do the same thing with weight decay so in this case I turned on weight decay if you look at these two minimizer's this one's a little bit sharper than this one so misers a little bit wider and it also does a little bit better in terms of generalization error and we can even compare across minimizer's so here's Adam versus SGD this atom minimizer is is visually sharper than this SGD minimizer and it also has worst generalization error so with this filter normalization process we can it sort of identifies the natural distance scale of these things so that we can actually compare different minimizer's to each other okay so now we have this a reasonable way to plot two different minimizer's and put them side-by-side and compare the two and so we can use this to explore what minimizer's look like when you do things like change the optimizer or change the structure of your neural network and there's different kinds of neural networks who want to look into the simplest neural networks are what I call vgg like nets vgg like networks are very simple convolutional networks we do include batch normalization so it's alternating between convolution spatula normalizations and Ray Luz and then we're also interested in skip connection networks things like res Nets so rez nets will have connections where information from shallower layers skips over all these convolutions and gets added into the outputs from deeper convolutional layers ok so this is a 56 layer vgg like net so it's a very simple neural network and you can see that so this is plotting with two random directions and at the center of the plot is the minimizer so there's a minimizer and you can see that in in two dimensions we can actually really effectively visualize the non convexity of these landscapes there is a convex attractor here but it's surrounded by a very hazardous landscape and one of the most surprising things about these experiments is if you just change the neural architecture so we're going to add skip connections and then the lost surface looks like this so these skip connections radically change the behavior of a lost surface we can do a side by side comparison here and we know that it's much easier to Train deep networks only have skipped connections all the modern state-of-the-art networks that managed to get good results on SCI far with really deep architectures really utilize the sort of skipped connections yeah right it's in some sense this is the whole landscape in the following sense so we we you generate a filter normalized direction okay so that should give you a natural distance scale for the problem and then we go one unit in each direction times that direction okay so when you get all the way out to the edge here you've added one full filter normalized perturbations for the weights so you've added a perturbation that's the same magnitude as the weights so that's pretty far it you'll see if you look at this loss surface it starts to curl up at the edges here and we've actually done experiments where we zoom out and look even further and this loss surface actually starts to shoot up and it just goes up and up and up monotonically until you just get Nan's and you can't even if I to it anymore so there's no more interesting behavior yeah yes we don't see any other really deep troughs these are actually on a logarithmic scale so that makes this appear a little bit artificially deeper than it is on a non logarithmic scale it's not as dramatic but then it's hard to sort of capture some of this stuff but if you zoom out this one's actually cropped down a little bit if you zoom out this one also it explodes up so these are you know we the problem is if you actually go out far enough that the edges of this thing come up you can't see into the loss function anymore because it blocks you so this is sort of trimmed a little bit yeah yep right so I don't make any claims about their being local minimizer's here these don't have to be minimizer's they look like minimizer's in 2d but there could be you know little wormholes in here that sort of escape to other attractors these these don't have to be minimizer's right so what I'm really more interested in is convexity versus non convexity right a convex function has convex contours and if you dimensionality reduce it severely it'll still have convex contours so I think one thing that you can take away from these dimensionality reductions is that there's substantial non convexity here that you can visualize that isn't present in these more well behaved networks we'll see but we'll see later on we'll see some examples where that non convexity creeps back in yeah yep yeah yep I think you need to be careful how you interpret them there are some other some interesting differences that I'll point out later but it has more to do with when I get take away from these is more that it's not that there aren't other minimizer's because there are is that this landscape is populated with very large convex like attractors that have very wide basins whereas this landscape is populated by a lot of more poorly conditioned I guess curvy parts I notice I want to say local minima right I think I think you need to be careful how you interpret these because like you said there could be a lot of dimensions and I probably are many dimensions in which this is flat but I still think it's pretty undeniable that there are major quantitative differences here between these two minimizer's and we'll see there's interesting implications for that too yeah no this is on everything no I'm showing everything so I take you take two random directions and wait space that perturb everything yeah you take everything yeah one of the things I like about the random Plains so later on rocks are going to not use random plains we use certain kinds of PCA directions that capture the since the most important directions the thing is that if you plot if you plot if you start somewhere on a lost surface people have done experiments where you plot the loss function over over time as well as you descend down the law surface and you see effectively no non convexity it's just a monotonic decrease to the minimizer the problem is that if you select if you if you don't select random directions if I pick special directions so if I stand on a lost surface and I pick two plot in a gradient direction it always goes down it's really steep right so things are tend to be really well-behaved in those gradient directions so if you if you make those kinds of plots you don't see the non convex behavior and there have been 1d visualizations in the literature that fail to capture any of the non convexity at all but that's because you don't if you if cherry-pick the directions too much you can't see the difference in behavior of these lost surfaces but it is there like I said you should take this with a grain of salt because it is a big dimensionality reduction and later on we'll see some philosophers that look quite a bit different ways PCA directions yeah oh I forget how many this is um it's it's on the order of ten million okay all right so it's the the behavior that we've seen these law officers surfaces it's a little bit more more complicated than chaotic versus non chaotic and in fact there's an interesting sort of chaos transition that happens that doesn't take place until you get deep enough okay so this is a function this is a law surface of a nine layer vgg like network and you'll see it's a pretty well behaved convex function and we can increase that all the way up to about 20 layers and it doesn't really change the qualitative behavior of this loss function at all it stays about the same in the original PGG paper they released they have two networks that were released by the visual graphics group and there's a ve 16 and vgg 19 and 4 dips below beyond 19 they found that there was a degradation and the performance one of the interesting things if we if we go beyond 20 then we get into this region where there's the sort of a phase transition between these sorts of nice convex behaviors and these sorts of chaotic behaviors so suddenly the lost surface becomes chaotic and if we keep going it only gets more chaotic although not dramatically so so this is a 110 layer Network and you'll see there's a very chaotic landscape around a very very sharp minimizer and so there's some sort of phase transition behavior that occurs here when you get deep enough there's this spontaneous transition from the sort of convex behavior to this sort of chaotic behavior that also seems to correspond to a an increase in generalization error and one of the interesting things in these plots is you can you could there's a contour in these plots that I like to call the shattering contour so there's this problem called the shatter and the shutter gradients problem if you try to train a network that's too deep and isn't well-behaved you'll run into a situation where the the gradients that you'd produce from different mini batches aren't correlated the gradients effectively become random and once that happens neural networks are basically untrainable you can't find good minimizer's anymore so for for cypher 10 which is what we're looking at here when you do a random initialization a good random initialization you initialize with a loss value of about 2.5 and there's a reason for that it's because if you put a good initializations the weights are scaled so that the outputs of each layer look like mean 0 standard deviation 1 Gaussian noise and if you put the kind of noise into a into a into a softmax layer you get a loss value out of about 2.5 and so when we have this nice convex like landscape the 2.5 contours here it's well in within this well behavior region and then as we move more in the chaotic direction this this chaos landscape starts to sort of drop down lower and lower and by the time you get to 110 the 2.5 contours right here it's right on the boundary between this convex like attractor and the chaotic landscape and then if you go beyond this the network becomes untrainable we can no longer find good minimizer's and this sort of visualization method fails so I suspect that what's happening is that as this becomes more chaotic this chaotic region so there's this partitioning between this well-behaved region in this chaos area when the chaotic region drops down below this shattering contour you you run into shadow gradients and you can no longer train so the initialization methods seem to fail in that case you end up initializing way out and this chaos landscape yep that's a great question yes so that's that skipped connections when we asked if connections that completely prevents this chaos transition from happening at least for deep enough networks that are fairly deep we've tried this up to about 300 layers we haven't gone deeper than that because we've sort of run on a memory on our GPUs when you add you skip connections you'll see that there isn't a substantial qualitative difference between these landscapes if you go from 20 256 to 110 qualitatively these these are not they're very similar and so if there is a chaos transition it's got to be somewhere way out here yeah I'm sorry what's the question I'm not sure how to answer that I mean it seems like it is a property of the model some neural networks are well behaved and we can train very deep you say ResNet right some neural networks are not well behaved and I believe pretty strongly that the difference between these networks has to have to do with the lost landscape because the training the optimization when you train takes place on the lost landscape whether it's a symptom or cause I don't know but you I think well-designed neural networks seem to be designed at least implicitly so that they have lost landscapes that are populated by very wide convex attractors instead of these sort of chaotic regions yeah yeah my my intuition is it's the latter it has to do with the conditioning and in fact if you look at this minimizer it's really interesting because this is not such a well conditioned minimizer it looks like it has a condition we're about three and there's theoretical reasons to believe when do these dimensionality reductions you actually compress the condition number by effective about a thousand and I can talk to you about why but this is a like a condition number that's probably on the order of about 10,000 but is that because it's paired with the structure of a continent right for an image I don't know I don't know we haven't tried that that's a really interesting thing to try though I don't really know I have no intuition for what you would get that's an interesting question yeah okay so if you're still not convinced that there are qualitative differences between these landscapes this is the lost landscape of that vgg like network with 110 layers and you'll see that there's basically a mountain range that passes through this region and there's like a little so the color is it's red is high blue is low and there's like a little blue lake sort of nestled in this valley with this big mountain range sweeping around it if you compare that to something like dense net this is 120 layer 20 one layer dense that it's so smooth at least in two dimensions it's basically a parabola it is no visible convexity at all so dense net has a very sophisticated set of skipped connections many more skip connections than you have in a residual Network and so it really seems like these skipped connections are sort of the secret sauce that makes these lost landscapes well-behaved another thing that we've looked at is how much how does the width of the network affect the lost landscape and so if we look at this 56 layer Network again so here's with residual skip connection stores without skip connections you can relax this considerably by making it wider so if we double the width you'll see that we go from a fairly chaotic landscape back to a fairly well-behaved landscape and if we go to 4x wider than we relax it even more and beyond this the structure of a loss function she doesn't change much on paper we go to 8 and it looks very similar to this this is interesting implications I think one of my students how they used to work on compressing neural networks these all work on weight pruning and one of the things as well known weight printing literature is not effective you want to train a really skinny network with very few filters it's not effective to build skinny networks and then train them because you get bad minimize minimizes the poor generalization error so what people do is you start with a wide network and you train that and then you prune the prune the weights so you prune the filters and you condense that down to a skinny network and then you get good generalization error and I think that's sort of a interesting because you'll see here that these wide networks really that the width really convex the Phi's this landscape makes it much easier to find these sorts of wide well-behaved minimizer's and then once you have them you can just sort of throw out the meaningless weights and get back something high-performance okay so what does this mean for optimization if we have time we'll try to return to this visualization issue a little later but what does this mean for optimization can we treat neural loss functions like they're convex I just said well ball behave loss functions at least locally they look convex can we can we use convex analysis to say useful things about them the answer is sometimes and I'm going to show two examples of things that we've studied in my lab relating to optimization of neural networks and one of them convex analysis and math that really comes from the convex properties of functions is really informative on how to design good algorithms and in another case it's just not okay so a bunch of people already talked about gans in this workshop I'll give you my quick three slide overview of gans so suppose that you have a probability distribution like a Gaussian curve and I want to learn a neural network I want to train a neural network that can represent this distribution but I'm only going to train it from samples from the distribution so it takes some samples from this Gaussian curve and I want to reconstruct the Gaussian in some sense or maybe you have a really complex distribution so I have samples from the distribution of natural images and I want to learn the distribution of those images and then we want to produce a neural network that can generate new images from that distribution so I want to create new samples from that distribution in this case we trained it on real images that are taken from a real image distribution and I want this thing to produce fake images that are synthetic so we cook up a neural network that takes in a random noise vector what we call Z and we pass it through a generator network that convolve this up into an image and then we could pass this through some sort of loss function that measures how imaging it is so you need some sort of measurement of how good an image is you could use something like total variation but that won't produce very good look at looking images that's not quite fine enough instrument well you would like to do it something some sort of measure of image goodness that is human-like characteristics and so we'll stick a neural network in there called the discriminator network and the discriminator network gives us a rating of how good-looking and image is how image eat is it and the goal of this generator network when I train the weights in the generator is to fool the discriminator into thinking that the images is producing they're from the real distribution so it wanted to think the discriminator should think that the images coming out look really good but if you if you train something like this what ends up happening is instead of producing good images the generator just takes advantage of weaknesses in the discriminator so there are various adversarial exploits you could make on this discriminator that the generator is going to learn to take advantage of and so to prevent that from happening we're going to keep updating the discriminator as we're training the generator and we do that by training the discriminator to tell the difference between real and fake images so the generator is trying to minimize a loss function that measures how well it fools the discriminator and the discriminator is fighting back it's trying to maximize the same loss function by not being fooled and telling difference between these real and synthetic images and we train these things using a gradient and alternating gradient descent process so I want to minimize this loss for the generator weights maximize it for the discriminator weights and I can do that by computing a gradient for the generator weights and then I marched down the gradient with some step size tau and then on the next iteration I compute the gradient of a loss function with respect to the discriminator weights and I march up that gradient that should be a plus not a minus and march up that gradient with some step size Sigma but this doesn't always work so well you get what are called collapse events so these are plots of the loss of the loss function for the discriminator and the loss function of the generator for DC gain and we'll see that as it rains these loss functions start to evolve in a reasonable way and then all of a sudden the loss function for the discriminator Falls to zero the discriminator is winning the game and the generator law shoots up and this is called a collapse event and when that happens images coming out of the Gann looks like this they're just sort of random noise so what's happening the discriminator is dominating the game the generator can't can't compete but we know from convex analysis that we can actually do better than this so it's known in the convex analysis literature that these kinds of alternating minimization maximization methods are not numerically stable so bad things happen when you're just solving a minimization problem gradient descent sort of gets stuck in these minimizer's and you can't go any further than that so that keeps things stable but if you're trying to find a saddle point that's a minimal in one direction and maximum and another then you have updates that are sort of competing against each other the the discriminator or the generator trying to minimize and the discriminator is trying to maximize and if the minimization updates are stronger than the maximization updates then you'll never settle into the subtle point you just fall off the side of the saddle and you get a collapse event so these sorts of behaviors are really well studied in the convex optimization literature so back when dinosaurs ruled the earth people studied convex optimization problems and there's a whole literature on saddle point optimization so there's this shamble pok method from 2010 and there's a similar work from s-sir shang and chen and then i even dabbled in this area for a while back in 2014 and here's what the literature on this says if you want to stabilize this kind of training you just have to make a very simple change so we have this descent step and then this ascent step I'm just gonna add in a trivial thing taught a prediction step the prediction step says after you update your generator weights look at the step you took between your old generator weights and your new generator weights and add that back so go further however far you just went go that much further call that G hat that's where I predict the generator weights are going to be in the future assuming the dynamics of the problem doesn't change too much and then when I update my discriminator I'm going to update it just like before except when I evaluate the gradient of the loss function the value for the generator weights I'm going to plug in is the predicted generator where weights okay so it's basically the same algorithm you just do this simple look ahead step and when you compute your generator update you just plug in these predicted weights and it's known in the convex optimization literature that this produces a stable algorithm that's attracted to saddle points but without the prediction method it doesn't converge at least not for weakly convex problems and so we can look at some experiments here so this is DC Gann images this is trained with no prediction and it's a training with prediction and this is using a github repository we got that had very finely tuned parameters to make this thing converge and when we look at these the loss functions for this thing we see that there are many collapse events as things seem to collapse multiple times and it just happens that if you stop it at a point that's far enough from a collapse event you get good results but if you train this with with a prediction method then it's numerically stable and you don't have any collapse events and this is not such a brittle thing so we could actually multiply the learning rate by 10 so this is training with 10x the learning rate in this case the method without prediction immediately collapses and method with prediction still works fine these are the lost curves so you see in this case the learning rate so high this just immediately explodes but this is still really well behaves it actually trains quite quickly you might think well that's not such a big deal just don't use a big learning rate but actually there's the learning rates you need to Train ganzar very fine-tuned and to demonstrate that we can actually cut the learning rate from the fine-tuned learning rate we'll cut it by 50% and when we do that it also diverges so in this case without prediction the method diverges with prediction it still converges and if you look at these lost curves you'll see that at the DCG n without prediction is stable for a while and then it collapse the collapse event and it can't recover from that we've also done experiments with things like weight tuning this is also very sensor sorryi momentum tuning these training algorithms are very sensitive to the momentum parameter but we can actually train with prediction and use a range of momentum parameters that's much wider than you can without using a prediction method so there's a few other test problems this is trying to learn a mixture of gaussians just learning a mixture of eight gaussians arranged in a circle if you try to learn this distribution with a prediction method it converges really quick and it means stable but without prediction it sort of oscillates around and it picks up certain modes and then it switches to other modes it doesn't really stabilize and we can actually sort of turn the knob up on this and try to fit a hundred gaussians that are organized in a circle in this case the prediction method it can't quite resolve the difference between these gaussians but at least gets a nice circular distribution that overlaps with the gaussians that we wanted but without prediction you get something very unstable and I actually want to make a video of this because there's some sort of interesting dynamics or these I'll point sort of wave around and they never really seem to converge it's not a very stable training method Arthur mentioned yesterday that everyone seems to ask if you can do the bedrooms I don't have nice curves to show you for this but in our paper we have some tables that you can look at but we do the bedrooms so here's stack and results on the else on data set my students said after our the first version our paper came out he said that people on reddit would think we're uncool if we didn't do Elson so we did that so hopefully read it now things that were cool all right so here's an example of a problem where convex analysis doesn't really inform us very much on how training proceeds so neural networks are very big everyone's seen this chart it's actually shown for a few times in this in this workshop it's just showing the size of neural networks versus the top one accuracy they can achieve on imagenet and all we see that many of the state-of-the-art networks need tens of Giga operations to process an image oh and tens of millions of weights there's done quick time check here quick tens of millions weights but the question is how do we put these things on low power and embedded systems you wanna put this on a drone or a cell phone how do you fit these huge networks on there and one solution is you binarize them so I'm gonna take the weights of my neural network and I'm going to try to force them all to be plus 1 or minus 1 now all the weights are just one bit and also there's no multiplications in the network because multiplying by a weight is just a sine flip and so if we can do this going from say 32-bit or arithmetic down to one bit is a 32x compression although in practice you could probably get away with less than that without using any less than 32 without any Trix weirdies is your storage costs it strips out all the multiplications and these there are really effective types of neural networks are things like Asics FPGA is other low-power devices so how do we train quantized Nets well you train a standard neural network just by computing the gradient of the loss and then marching down with some learning rate alpha and if you want to train say an integer valued neural net so I want all the weights to be integer valued I want to make sure that on every iteration the weights remain integer so I have some integer weights I might compute a gradient update and multiply it by alpha and when I add it to the weights I don't want to change the fact that these are integer so I put this quantization operator here that rounds this update to an integer but I need to use a smart update I don't want to just round to the nearest integer because when alpha the learning rate gets small these updates are all going to be less than 0.5 and they'll just get chopped off or we get round in to zero so we use a smarter method called stochastic rounding this is a very well-known method in the hardware literature it was studied in HTML 2015 and the idea is that if your update is 0.3 then you round it down with a 70% chance and you round it up with a 30% chance and the expected value of the integer you get back is still that 0.3 so that's a much more reasonable way to Train and the advantage of this is that these kind of training methods have requiring no floating-point weights you can train with no floating-point arithmetic at all if you want to train with you want to train better nets you have to use something like the binary Connect algorithm and binary connect rounds the weights before it computes the gradient of this loss so in this algorithm you use floating-point weights the weights are floating-point but before you compute the gradients I snap them to the nearest point on the integer grid and I compute the gradient at that location and this method is very popular a lot of people use it but the disadvantage is that still requires floating-point computations even though you're training a quantized net you have to train using high precision floating point and what we find so these are training curves so these dotted lines are show generalization error generalization error over time for so far no okay okay well the idea is that I'll just mention a few things here okay so stochastic rounding methods deterministic rounding is a disaster stochastic rounding is better if we use binary connect algorithm that uses floating-point arithmetic you can get really good and then this red line is full precision and full precision networks with full precision training you get about 0.8% better than you can do with with with a quantize network but you know too much better okay so I want to study the behavior of these kinds of algorithms and I can do that using convexity assumptions if I try to study behavior with convexity assumptions what I find is that I can put a bound on the error of these algorithms this is the optimality gap and it decays over time so there's this decay term there's an almost constant term on the top and there's a t on the bottom T's a number of variations so this term decays but there's also this noise floor term that doesn't decay and it scales with this parameter Delta which is the discretization width so if you ever really find a screw decision you get a better noise floor but if we look at the so this is stochastic round or stochastic grounding if we if we look at a had this backwards sorry this is stochastic grounding this is a fully floating point algorithm with the binary connect our Guerlain we can prove exactly the same results you converge with linear speed-up to a noise floor and that's pretty much it so we get the same results under convex the assumptions for both of these algorithms that doesn't explain why one is so much better than the other I mean to understand why one is much better than the other we have to look at the dynamics of these algorithms on non convex problems and so what we can do is we can run SGD on a non convex problem and the iterative bounce around and we can plot the equilibrium distribution of this thing so if you run it for a really long time here's what you get I can do it in three minutes so okay okay sorry okay we can plot the equilibrium distribution and if we drop the learning rate we see that this distribution starts to concentrate these minimizer's if we dropped the learning rate even more it concentrates further and if you drop it even more it becomes very concentrated on a global minimizer but if we use they say fully quantized training algorithm so this is an algorithm that doesn't use floating-point arithmetic we drop the learning rate and it starts to concentrate and then we drop it more it concentrates a little bit more and we drop it again and nothing really changes so we don't have these sorts of concentration behaviors and what we can actually prove is that these fully floating-point algorithms like binary connect they have an exploration exploitation trade-off so when the learning rate is large to explore the landscape and as a learning rate decays this sort of aggressively settle into local minimizer's and they concentrate on those minimizer's but you can look at stochastic rounding which is a fully quantized method and this method doesn't have these concentration properties I'll skip ahead a few slides but in a nutshell the lack of concentration properties explain the difference between these algorithms binary connect requires floating-point to train a quantized net and it gets great generalization error but if you don't have those concentration properties they sorts of simulated annealing type global optimization properties then you get really bad results and I have some results that I'm not going to have a chance to tell you about in each of the time but we can do some more visualizations on the descent paths that SGD takes through these loss functions so this is plotting the contours of a loss function and this is the descent path of stochastic gradient descent and we find that algorithms with good generalization error so here's s 2d with a small batch size they start with an exploration path that goes parallel to contours it's not strongly attracted to this minimizer and as soon as you drop the learning rate it becomes strongly attracted and it falls in to the center of this local minimizer but if you train with a large learning batch so this generally it's poor generalization error if you train with a large mini batch you don't get these kinds of trade-offs this this large batch algorithm is immediately attracted down into this minimizer and you don't see this sort of 90-degree kink where that the exploration exploitation trade-off happens right here okay okay so to wrap up neural loss functions are sometimes really well-behaved good well-designed neural networks have really nice loss functions with landscapes that are populated by large flat convex like minimizer's there are still things we don't understand I don't think we understand these sorts of chaos transitions that happen with deep networks or the effect of skip connections but I think there's this interesting issue is sharp versus flat even the right question maybe we should look at things like chaotic versus non chaotic and I'll leave you with that things [Applause]
Info
Channel: Institute for Pure & Applied Mathematics (IPAM)
Views: 10,430
Rating: 4.9572191 out of 5
Keywords: ipam, ucla, math, mathematics, deep learning, machine learning, neural nets, tom goldstein
Id: 78vq6kgsTa8
Channel Id: undefined
Length: 50min 25sec (3025 seconds)
Published: Fri Feb 16 2018
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.