Grokking: Generalization beyond Overfitting on small algorithmic datasets (Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi there today we'll look at grocking generalization beyond overfitting on small algorithmic data sets by alethea power yuri bourda harry edwards igor babushkin and vedant misra of openai on a high level this paper presents a phenomenon that the researchers call grokking where a neural network will generalize all of a sudden after having after way the point of overfitting on a data set so you train the network it completely over fits on a data set training loss is comp is down training accuracy is 100 but it doesn't generalize at all to the validation set and then when you continue training the network at some point it will just snap into over uh into generalizing on these data sets that they're researching to a like a hundred percent generalization so 100 accuracy on the validation set and this is extremely interesting and as you can see the paper has been presented at a workshop at iclear 2021 which means that it is not yet it's sort of work in progress so there is still a lot of unclear things about this phenomenon it's a as i understand it a phenomenological paper that just presents look here is something interesting that we found and i think it's pretty cool so we'll dive into the paper we'll look at this phenomenon they do dig into it a little bit into what's happening here and try to come up with some explanation so the basic premise of grocking is the graph you see on the left right here now it is a little bit pixelish but i hope you can still see what's happening the red part is the training accuracy and on the x-axis you have number of optimization steps and this is a log scale so that's important to see this is a log scale for training steps in this direction now the training accuracy naturally after a few steps it shoots up to a hundred percent we're we'll get to what data sets these things are in a second but it's important to see the network can in fact fit the training data extremely well and it in it just overfits however the validation accuracy it if you can see it there is a little bump here but then it goes it goes down again almost um i don't know whether we should even regard this as a little bump that's actually happening however it it just stays it stays down it stays down and then after you can see orders of magnitude more steps this is 10 to the second 10 to the third 10 to the fourth 10 to the fifth steps it shoots up and it starts to generalize as well this is very interesting because um you know this essentially means you you keep on training uh for a long time and when all hope is lost still the network at some point will will generalize now why is this happening and as i understand it it's not the case often that the network like drops down again out of generalization though i haven't haven't actually seen this investigated like if they run for 10 to the i don't know how many steps but it seems like once the network is generalizing is uh has training accuracy of 100 it doesn't fall out of that again so the question is how does this happen like what's happening here why is this happening why is it all of a sudden and what makes it work and for that it's a bit important to understand a very related phenomenon in fact a connected probably phenomenon called the double descent phenomenon in deep learning the double descent phenomenon graph looks somewhat similar in that the premise is that on the x-axis you have the number of parameters in a network so the number of parameters in a neural network and then on the on the y-axis you have let's say loss okay or actually let's say let's say accuracy or i'm not sure lost most of these plots for the double descent phenomenon are actually loss so if you consider the training loss um as you increase the number of parameters in your neural network you will fit the data better and better the training data so you get a curve that goes something like this and then it just stays at zero right so there's zero training loss um as you increase the number of parameters these every point on this line is a neural network with a given number of parameters that has just been optimized to convergence okay that's important to remember on the left here we saw a graph during optimization on the right here is a graph of many different networks all of which have been trained to convergence now what you see with the validation loss in this case so if you look at the validation loss it might at some point it might come down with the training loss right and then in the classic fashion of machine learning you as the number of parameters go up you start to sort of overfit the validation loss goes up again uh because you start overfitting you start memorizing the training data set and then at the point where pretty much the number of parameters equal the number of training data points like the number of let's just call this n then you have again like a really crappy validation loss because you're just remembering the training data however if you increase your parameters beyond that point so if you scale up your neural networks even more the validation loss will come down again and actually end up at a lower point than if you were on this place over here if you had not enough parameters so there is a point beyond overfitting where you have more parameters than data points and interest interestingly for neural networks it is the case that it happens that they can achieve generalization in fact better generalization with over parameterization than comparable underparameterized models which flies in the face of of all statistics and whatnot but we know this phenomenon exists okay so uh we we knew that um things like this can happen like the training loss can be perfect and still we can have generalization right the grocking phenomenon is a phenomenon where i'm i'm gonna guess i'm gonna guess the the the creators of the double descent phenomenon haven't looked quite as far in order to i guess they simply ran training to convergence for a number of steps and then they they they looked at the validation loss and so i guess they would have stopped somewhere in between here between 10 to the third and 10 to the fourth steps this research here is simply what happens if we like let it run for a really long time then this shoots up as well and it seems like it seems like for a lot of conditions you you can you can do this so now it's worth looking at what kind of data sets we are we are interested in here the data sets are synthetic data sets in this paper the synthetic data sets are binary operation tables so here the data sets we consider are binary operation tables of the form a and then here this is like some sort of a binary operation a let's just call it multiplied a multiplied by b equals c where a b and c are discrete symbols with no internal structure and the circle is a binary operation examples of binary operations include addition composition of permutations bivariate polynomials and many many more in fact they have some examples i think down here so here you see some examples like addition and multiplication but also more complicated things like a polynomial that you then um that you then do modulo a prime number a division modulo a prime number and so on so the way you'd the way you create a data set is you construct a table and in the table you have a number of these symbols and then you define binary operations by simply filling in that table okay so if this were i don't know like a plus a plus b and a and b are are numbers then right a plus b is c if a is one bs2 cs3 and so on um but you can define this as many different things a lot of the experiments in this paper are of the group s5 which is the group of all permutations of five elements which i think has like so this is a group with 120 elements so your table would here be 120 by 120 and the operation would be the sort of um composition of permutation so every permutation of five elements composed with another permutation gives you yet another permutation of five elements so you can just construct this this table and then what you do is you just simply cross out a few things in the table so you say okay here i'm just going to cross out a few things and this is what the network should predict right i'm going to train the network on the data that i have and i'm going to predict the cells that i crossed out this way you can exactly measure how good the network is right there is no noise effectively in the data um it's all very well defined and a human goes about this with i guess with sort of a logical mind they try to figure out like ah what's the rule what's the rule a neural network can simply remember the training data but then it will not generalize to the hidden fields because it cannot memorize those so if a neural network generalizes here it also kind of means that it must have somehow learned the rule and this this is pretty interesting so there are a number of quantities to keep in mind um the the three quantities are first of all what's the operation uh because there are more and less complicated things for these networks to learn just from the kind of difficulty the complexity of the operation itself second of all is the data set size or the size of the binary table itself in this case it's 120 by 120. 20. and the third one is how many things are left away so how large is the training data fraction the fraction of the table that is filled in for the network to learn all of these three things are going to play a crucial role in this in this grocking phenomenon and when and how it appears for example here you see they they have trained neural networks on this s5 group right the permutations of groups of five elements until they reach generalization so they simply run it and they measure how long does it take a network to reach 99 validation accuracy or higher right that's that's the thing on the left is essentially um you know the answer would be something like between 10 to the 5 and 10 to the six right okay so and they measure this as a function of you might not be able to read this but it says training data fraction okay how much of the training data is filled in and you can pretty clearly see if i just give it like here 20 of training data there are even some runs that do not generalize in this number of steps now would they generalize if you were to optimize for even longer who knows honestly but you can see that as soon as you give like 30 of the training data the runs in general do generalize but they take something like um here yeah 10 to the 5 number of steps to do so and then as you increase the training date to fraction this snap to the generalization happens faster and faster you can see right here as you give more training data it goes faster and faster until it generalizes and the generalization happens as i understand it yeah fairly like quickly like it it doesn't generalize because it remembers the training data and this always happens as i understand it in a fairly similar number of steps but then at some later point it just kind of snaps and completely generalizes to the validation set and this is this is really interesting so we know that the more training data we have around the better right that's one recognition um then the other the other thing is they try to figure out okay um which parts of the optimization algorithm are are making this grocking phenomenon happen and here they figure out that uh weight decay in fact is one of the is one of the big drivers of this so if they add weight decay to the algorithm and they try a lot of different things they try full batch versus mini batch with dropout without dropout modulating the learning rate and so on but weight decay seems to be one of the biggest contributors to this grocking phenomenon to the fact or to how fast these networks generalize you can see that the network generalizes much sooner if you have weight decay turned up than not also they make the observation that uh if you have symmetric operations uh if your binary operation is symmetric then also the groking phenomenon happens much faster than if you have like non-symmetric operations this might just be a function of these networks which if you if you have like something like a transformer uh you know it it's it's sort of kind of invariant to to the symmetry so it might like essentially one data point is sort of two data points in disguise if it's symmetric or there's only half as much stuff to learn uh you choose whatever you you want to interpret this as but i think yeah this is not as important as the weight decay and why do i highlight this um i highlight this because oh so down here you can see they analyze then um they analyze the results of a network that has learned to generalize uh like this so on the right you see a t-sne projection of the output layer weights from a network trained on modular addition so this is x plus y modulo eight i think the lines show the result of adding eight to each element the colors show the residue of each element modulo eight so if you do the t snip projection you can see the lines are obviously drawn by the authors but you can see there are structures where if you go along the line right here they've colored essentially this is always adding eight adding eight adding eight so there are structures where um this the rule for generating the data is clearly present in the data itself oh sorry in the in the network's weights this gives you a strong indication that the network has not only just remembered the data somehow but has in fact discovered the rule behind the data and we have never incentivized the networks to learn these rules that's the wild point there are there are architectures where you try to specifically make tell the network look there there is a rule behind this i want you to figure out the rule you can maybe do symbolic regression or i don't know like like you can try to build an internal graph of and reason over it no no we just train neural networks right here and it turns out that these networks can learn these rules so why do i relate this to the double descent phenomenon in the double descent phenomenon um it is assumed or i've heard the authors of these papers uh speak about their their kind of hypothesis why this happens and this is a bit mixed with my my hypothesis as well uh they speak of for example weight decay being one possible explanation so they say if i have a bunch of of data points let's say i have a bunch of data points right here right and i want to do regression on them well if i just do linear regression i have one line right it's fairly robust right it's fairly flat it's fairly robust because it's just one parameter now if i start to add parameters right i get maybe i get to a point where i have a good number of parameters you know this this polynomial maybe kind of like this still fairly robust right you can see how it might generalize to to new data then right so this the blue one would be somewhere here the dark blue one would be somewhere here where the the validation loss actually goes down with the training loss but then when i add when i keep adding data points uh sorry parameters then you know classically i'll start you know my my overfitting right here and this it will not generalize to any point that might be in between like one here or so there will just go up so the green would correspond to the point where i just start to interpolate the training data but then what happens if i go on if i make even higher order polynomials or higher order neural networks well at that point at least these authors argue do i have another color this one they argue that you get like a polynomial that or or a curve that yes it has a lot of parameters but it uses these parameters such that it can be sort of smoothly interpolate the training data now this curve is quite complicated in terms of the number of numbers you need to describe it but it uses the fact that it has a lot of freedom you know it can choose to be however it wants as long as it interpolates the training data right yet it chooses to be smooth because of a combination of sgd training it and of weight decay so the weight decay would prevent any of these numbers from getting too big and therefore getting like super out of whack curve so the weight decay would in fact smoothen the curve and that makes the model generalize really well because the smoothness now is reasonably generalizes two training data points that are in between like this data point is still fairly well represented by the purple curve in fact it's better than the the dark blue curve in this particular case so you can see that the authors here argue that weight decay might be an important contributor to why over-parameterized networks generalize and it's interesting that the these groking the authors of the groking phenomenon paper here find the same thing they say okay if we use weight decay the groking appears to happen much faster is this i don't know what exactly they call grocking i'm just going to call grocking this whenever the validation loss snaps all of a sudden from 0 to 100 on these these data sets now again these are algorithmic data sets so you know we don't know what happens i think they they do make experiments when they they noise some of the data so um they they have some noise in there and i think they find that if they add noise then uh it's way more difficult i'm i'm not sure though maybe i'm confusing papers here um but what what might be happening right here right this is it's interesting because um what might be happening is that by imposing this smoothness um and the over parameterization we're sort of biasing these networks to find like simple solutions right so if if i have just very few training data points if most of the cells here are blacked out right the simplest solution is simply to remember the training data however as i get more and more training data points that give me more and more information about a potential underlying rule it becomes simpler for me to simply to understand the underlying rule than to remember the training data it's more it's more difficult to remember the training data than simply to learn the rule so what might be happening here is that as i train and this is always training here the training happens always on the same data right you simply uh sample the same things over and over again train on it i think what might be happening is that you kind of jump around in your optimization procedure you can see there there's some bumps in the training accuracy here to kind of jump around jump around that's a song no um so you jump around a bit and and in your in your lost landscape there there might be many of these local minima where you in fact uh remember the training data perfectly so you kind of jump around a bit between them right you remember the training data perfectly and then one of them is just you remember the training data as well now this is you remember the training data as well however the solution is just so much simpler that you stay there this is not a good way of visualizing it so it must be something like here are the minima where here are the minima where this is the training just the loss on the data however there is another loss and that's the loss on like the for example the weight decay loss and the weight decay loss is you know it's pretty good all of these things but then for one of them it's just like because that solution is so much simpler so you're going to choose you want to jump around between those minima jump around until you know once you reach this one this loss right here that comes on top of this is just so much lower that you're gonna you're gonna stay there it's like wow i found such an easy solution um i'm not gonna go out again so yeah now the big question is of course how and why does something like sgd plus weight decay plus potential other drivers of smoothness in these models how and why do they correspond to simplicity of solutions right because simplicity of solutions is something that kind of we humans have built in like okay what's the rule behind this what's the rule is this essentially assuming that there is a simple rule trying to find it because it would make our life much easier it's a simple explanation for what's happening the interesting part is that weight decay or something similar something that's happening in these neural networks is essentially doing the same thing even though we don't tell it to do it so understanding this i think is going to be uh quite an important um quite an important task for the near future and also maybe maybe we're not exactly right with the weight decay maybe there is some other constraint that we can impose that encourages simple solutions in in the way we care about simplicity even more and you know once we have that um the it's it's like you know there this age-old argument do these things actually understand anything well in this case i'm sorry but if you have found this solution with the rule uh essentially built into the networks of the into the weights of the neural network you can say well the network has in fact learned the rule behind this binary operations so you know who are we to say these networks don't understand anything at that point and also it gives us the opportunity to you know train these networks and then from the structures of their latent spaces we might in fact parse out the rules of data we don't know yet so we let the networks fit and we parse we parse the underlying maybe physical laws maybe social social phenomena we parse them out from the underlying uh data oh yeah here okay there is an appendix where they they list binary operations they have tried out models uh optimization so yeah they use a transformer with two layers four attention heads um so it's not it's not a big thing and also the data sets aren't aren't super complicated but it's pretty cool to see this phenomenon now again on if if we have real world data bigger networks noisy data um it's not going to it's not going to happen as drastically and also they say as you increase the size of the data set whereas that as you increase the size of the data set then this phenomenon is harder and harder so if the entire data set is bigger uh the the grogging phenomenon i guess it's it's more tough to see and also here is the experiment i mentioned where you have several outliers so noisy data points and as you um so this is the fraction of correctly labeled data points so as you increase the number of correctly labeled data points you can see the grocking happens in more often or to a better validation accuracy than not so well you can i don't know if you can read this but um yeah the these these down here they have too many outliers so with too many outliers either the validation accuracy just stays at zero or it just turns up like quite late okay that's it here is an example of one of these binary operation tables that is a little bit larger i don't know if it's one of the 120 sized ones but this is something that would be presented to the network and they say they say what we invite the reader to guess which operation is represented here well have fun dear dear reader um yeah all right so this was it from me for the rocking paper as i said this seems like it's work in progress i think it's pretty cool work in progress it uh raises a lot of questions and um i think yeah i think it's it's pretty cool i wonder how this happened like like how how did how did people find this they just forget to turn off their computer and in the morning they came back and they're like whoopsie-doopsie generalized though if you if you know if you build these kinds of data sets i guess you have something in mind already yeah in any case that was it for me tell me what what you think is going on in neural networks or is there like is there like a super easy occam's razor explanation that i'm missing i don't know tell me what you think i'll see you next time bye
Info
Channel: Yannic Kilcher
Views: 9,820
Rating: 4.9935274 out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, grokking, openai, double descent, belkin, overfitting, bias variance, steps, training, binary tables, binary operations, binary operation, multiplication table, algorithmic datasets, groups, s5 group, deep learning algorithmic, deep learning generalization, generalization research, why do neural networks generalize
Id: dND-7llwrpw
Channel Id: undefined
Length: 29min 47sec (1787 seconds)
Published: Wed Oct 06 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.