PonderNet: Learning to Ponder (Machine Learning Research Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello there today we'll look at pondernet learning to ponder by andrea bonino jan balaguer and charles blundell this paper on a high level introduces a recurrent architecture or a principle of recurrent computation for deep networks that essentially says the network recurrently computes its output at each step and at each step it can decide to stop now because it is satisfied with the answer that it has the idea is that at a complex task you can compute for many steps because it requires many steps of thinking and then give the output and for an easy task the network can decide to output right away because it already has computed the solution this decision can be done on a per sample basis so for each sample the network can decide when it's time to give the final output and this is this is not necessarily a paper that just you know makes something bigger and then pushes state-of-the-art on some benchmark and that's why it piqued my interest is that it tries to re rephrase a little bit how we think about the connection of deep learning and algorithms like classic algorithms by themselves essentially this is a dynamic if condition in this algorithm that decides when it's when it's time to stop and i appreciate that you know it not everything has to be state of the art pushing here this is simply a cool method to do something that's relatively new of course uh things like this have been done before and they are discussed at length in this paper how this paper is different from other papers that do similar things and it does push state-of-the-art just not on benchmarks that you might be super duper familiar with um but yeah it's it's a cool paper it's a short paper the idea is pretty simple and it appears to work and um yeah that's exciting stuff so we're gonna dive into this paper have a look have a look at what's new in this particular model how it works and yeah as always if you have feedback leave a comment subscribe i'd be happy for that and yeah thanks for being here okay so in the abstract here they say that in a standard neural network the amount of computation used grows with the size of the inputs but not with the complexity of the problem being learned so which is true right in a standard neural network you have a forward pass be that in a fully connected neural network where you have you know you have your input and then you go layer layer layer layer layer and then you have your output this computation here is always the same no matter the input even in a recurrent neural network right you have kind of an input right here at the beginning you have a layer then you have an input again and then you have this that goes into the same layer and then you have the next input that goes into the same layer even a recurrent neural network usually usually um just does the same forward pass this is a little bit different if you have something like a language model that can emit at some point a you know a stop token or an end of sentence token at which point the computation essentially stops but it's a little bit of a different thing than we consider right here right here we consider a neural network that has to find the answer to a particular problem and we're going to see the problems um down but one problem that they present is the parity problem so the parity problem is you get a string of zeros and ones i think there's also negative ones in there but i think they're bit for a distraction and the answer you're looking for is um as a whole is the parity so the amount of ones in this string odd or even right so this requires a let's say an integrated view of computation this is essentially a classic algorithm that you have to perform over this string and neural networks as good as they are in computer vision and speech recognition they are having trouble with a simple algorithmic tasks like this so the idea of of this paper here is that well it doesn't make sense to apply a neural network that always does the same amount of compute right i shove this sequence just like in here it doesn't make sense because you know if there is just a single one in the string and i see that right away i can give the answer right away however if there's if it's a long string and there's a bunch of ones i might need to think about this problem for a while and thus adapt the number of computation steps i do in my head i might you know first if i look at this string i might first connect these two you know and then ah that's two and then i might connect these two that's two again and then i might connect these two that's four there's nothing here there's nothing here right okay four so that's kind of like one two three steps of computation so that's the the rough idea whereas this if the string was shorter and and more regular i might need less computation so they say to overcome this limitation we introduce pondernet a new algorithm that learns to adapt the amount of computation based on the complexity of the problem at hand um pondernet learned end to end the number of computational steps to achieve an effective compromise between training prediction accuracy computational cost and generalization so we are going to see how they do this um yeah exactly so they then they go into the the task their experimental tasks in this paper are are sort of these um constructed tasks where people know you need this dynamic computation they're not gonna they're not gonna compete on like imagenet or something like this so the major majority of the paper is in um in contra posing their model against this act model the adaptive computation time i believe so there have been previous attempts at doing dynamic computation time yet either they have um so it turns out they're kind of finicky and um this model here this pondernet model has a bunch of advantages they say they present pondernet that builds on the previous ideas it's fully differentiable which allows for low variance gradient estimates unlike reinforce so a couple of previous attempts have been with reinforcement learning so let's just learn the number of steps or when to stop using reinforcement learning and that as you might know is very very noisy it has unbiased gradient estimates which is also unlike other models in the past and yeah so they say this has consequences in all three in all aspects of the model in pondernet the halting node predicts the probability of halting conditional or not having halted before this kind of seems obvious but apparently that no one has done this so far so what do we need for an architecture for pondernet they say this down here um essentially that's the architecture it's an inline formula which you know but that's the architecture so what you need is you need an input okay you need an input which is x your input and x is transformed into a hidden state this is let's say the hidden state at step one those two or you can also reformulate this as just a hidden state the hidden state is going into s the so-called step function and that's the recurrent function right here so into this step function you can put anything you want you can put like a cnn inside you can treat this as an lstm since we're going to apply it recursively sorry recurrently and anything you want can be the step function as long as it can be applied recurrently so this step function is going to give you the next hidden state right so you can see it's a recurrent neural network however it is also going to give you um the output at that particular point in time so uh y one i guess that'll be here and it's also going to give you this number lambda n now what are these so from here you could apply the step function again you'd get h3 you get the output two and you'd get lambda sorry that's that's a one that's a two okay so it seems like it's a just a recurrent neural network and if i were to put push this to the end right i go give my hhh and then at the end i get my yn and i treat that as the output of the computation then it's just a recurrent neural network however as we said the network can in this case decide to stop anywhere in between for example if it decides to stop at this particular step then that would be the output of the computation so every computation step the network computes and a potential output a suggestion for an output and then it also thinks about whether or not it really wants to answer with that output or whether it wants to continue and to do another step essentially take another shot at answering the question because it doesn't yet have the correct um answer and that's where this lambda thing comes in so the lambda is a probability of stopping essentially so here you can see the output lambda is a number between 0 and 1 and that is the probability of halting this is the output considered that the network halts so whenever this is one the network will halt conditioned on the fact that it hasn't previously halted yeah it seemed as i said it seems obvious to formulate it like this because you can you know you can only halt if you haven't previously halted but apparently previous models have simply output a number that is sort of the probability of halting in general which doesn't give you a bias sorry an unbiased gradient if you try to back propagate through it so if you consider the lambdas to be like this if you unroll for an entire training run then you get um you get the probability of halting at any particular step this one so this is what um this is what the previous networks would have estimated directly however this network estimates these lambdas these ones here so you can see how you can compute the probability that for example the network halts after three steps by multiplying up the probability that network has not halted which is this one at step one has not halted at step two and then the probability that network halts at step three that it given that it hasn't halted at the previous steps so that is a valid probability distribution it's a generalization of the geometric uh distribution and essentially it encapsulates a decision tree right so you're at the beginning you can halt sorry let's go a halt or not or continue if you continue then again you can halt or you can continue if again you can halt or continue and so on and all of this so if you want the probability that the network halts after you know this the third step then you would consider this node which means that you'd multiply you that you multiply up these paths right here and that's the probability that it holds after three steps okay so the network can output this lambda at every step if the lambda is high then the network halts of course at inference this is done probabilistically now at training time this is done a little bit differently so you i hope you can see at inference time you simply go forward and you get a lambda maybe the lambda in the first step is uh 0.1 and then you flip a coin a biased coin right if if it comes up pads you stop with the probability of 0.1 if it comes up tails which is a 0.9 probability you continue then maybe at the second step it's it's um 0.05 so maybe maybe you stop but probably you won't stop and then at the third step it like comes up 0.9 the network thinks yeah i should probably stop here and you sample from that and yes you you might indeed in nine out of ten cases you actually stop there so that's inference um how about training how about we train this thing um during training what we do is again we input x our input into an encoder for a hidden state and as i said you can also input x all the time into your step function as you see right here but what you do is you unroll the network for a number of steps right independent of these output nodes independent of this or if the halting probability let's say we we unroll it four four five steps right here and at every point we get a output and a value y three y four this is lambda 2 lambda 3 lambda 4. so at training we simply unroll until a a given step now there are some technical difficulties with doing with unrolling for a finite amount of step like how do you normalize the probability distribution because essentially this tree can go on until infinity they find okay we we can simply unroll until kind of the rest probability the probability we haven't used yet is is really small and then just load that all into the last step but these are technical uh difficulties that you really only care when you then go and implement however so we unroll for a number of steps and then um our we consider all the outputs at the same time now this is one big difference i believe to one of the previous networks to this act so what act does it it always unrolls and then the the output of the network so for ac the output of the network would simply be a weighted output of the lambda i y i okay so the output of the network is always a weighting between the different steps okay and the network can decide okay how do i want to weight the individual outputs whereas here it's different here the output is really either y1 or y2 or y3 or y4 and to in order to pack this into a single loss function um what we can do sorry i should probably leave this uh in order to pack this into a single loss function we simply take okay what's the loss what would be the loss if we answered y one right um what would be the loss and we weigh that by the probability and we say okay what would be the loss of y2 we weigh it by the probability that the network output right so now if we and so on so plus essentially we compute the expected uh loss given the probabilities that the network has output so now if we back prop this we back prop through these losses we have of course two paths of back dropping so we backdrop through the y's which means um let's at some so there's a loss right and both these things and these things go into the loss right so the loss is oh how bad is this um times how probably it was and so on so the back propagation path would actually attack at two different paths you can see so the back prop goes into y because you want the network to compute a a better output but the proc propagation also goes into the lambda because you want the network to get better at uh estimating when its output is good and when not this i see a little bit as a tricky situation because usually um this this seems a little bit unstable just from experience from other papers and so on if you have a back prop through two different things especially that are appear to be multiplied together and that you know the network can now trade off one versus the other which might you might think is desirable right it can either choose to make its output better if it wants to keep the probability high of outputting this thing or it can just reduce the probability that it's going to output whatever it wants to output and you know then it it doesn't have to necessarily make the output itself correct because the loss uh the loss won't be as high for that particular thing because the probability of outputting it is low the network essentially has a choice as i said this might be desirable but usually that's kind of unstable and um i think this is just my personal opinion i think a lot of um why this might work might rest on whether or not or let's say the complexity itself of assessing of making y better versus adjusting these probabilities of course yeah so you see if the output y is very complex right then um this you know the same gradient signal for that might mean much less than simply reducing the probability okay so if the output is very very complex right not the problem but just the output itself like how to arrive at an output if the output is an entire pixel map or something like this then and and and that has dependencies and so on the network might just choose to always reduce the probability because it's like well how am i gonna how am i gonna make this better at all i don't know i can just reduce the probability i'm going to output this crap right and it it will probably do this then for every you know single step which you know if a com if it's complex problem makes sense but still that's that would be a bit my my fear here and this is not really discussed in the paper itself um so i think the the fact that this works might rely on sort of a balance of the of the complexity or information content that you get from the loss at the output node versus the loss at the probability node so okay enough about that so in yeah during training you simply compute the expected loss weighted by the probabilities and then you can back prop through that and i hope you can see the difference between these two um one is a they both seem to sum up somehow the outputs uh weighted by these these factors however one considers the actual output of the network to be a weighted combination of outputs of the individual steps where the other one says no no the network output is actually one of them we don't know which one ergo for the loss we need to compute the expectation of the loss that seems to be a bit of a let's just say uh yeah it seems to be a more reasonable formulation though in hindsight you can say many things um are reasonable if they work better right uh yeah so they discuss things like maximum number of pondering steps and so on again which i think is a um technical detail and this is interesting so there you have the training loss as we just discussed now we've discussed this part right here which they call the reconstruction loss because you have some kind of desired y and you have a y that um comes from this and i was a little bit wrong here in my formulation of course the expectation you don't have you don't want to take the lambdas you actually want to take the probabilities that each thing happens which means that you you need to compute this p number you know going along this tree as we did because the p is the actual probability that you reach that node whereas the lambda is only the conditional probability that you reach a node given you were at the previous node so yeah consider consider that if you if you are crazy enough to implement things straight as i speak in the videos lucid drains shout out um the second part of the loss here and you can see this is a hyper parameter so you you're going to trade off two of two losses right here because right now um we saw okay you can either continue or not continue and for the network you know it might actually be easier as i said if the loss of the output comes reasonably complex right here it might be easier to simply say well in this case i'm just always going to reduce my probabilities you might counteract this with having this number of steps not like maximum number of steps but essentially this term here is what counteracts that really there is a regularization term on these probabilities as you can see right here so we regularize with the kl divergence which is sort of a distance measure don't tell this to a mathematician it's a it's a a divergence it's a sort of a distance measure between the distribution that the network outputs for the steps and this thing right here which is a geometric distribution with this parameter and this parameter lambda p is another hyperparameter so what does that mean essentially uh if you consider here the number of steps that the network thinks right think thinks for what you regularize for this distribution right here is a geometric uh distribution i'll go something like maybe you know something like this so essentially a geometric distribution is set it exactly computes this tree that we computed right um so at each step uh you can essentially stop and the question is after you know this distribution gives you a indication uh after what's the probability that you stop after one step two steps three steps four steps considering the fact that in order to stop after four steps you already have to have made three non-stopping steps except in the geometric distribution the probability of continuing is always the same whereas in our network our network for each node and the tree can output a different uh probability otherwise you know there'd be no point we can simply put in the fixed distribution now what that probability is of stopping at each point that's exactly this lambda p hyper parameter right here so you regularize for a kl for this which means that you tell the network look here is a a reasonable reasonable distribution of when you should stop so you should stop um so it should be you know somewhat probable that you stop after one step and somewhat probable if you've already done one step that you stop after two steps and so on so you give it sort of a default probability of stopping after each step so if this is 0.1 for example you tell the network essentially look at any given step there is like a default 10 chance that you should stop i as a designer of the algorithm think that's a reasonable uh prior to have now the network can decide differently the network can decide no no no i actually want to stop way earlier right like like this it puts much more emphasis on the first steps which of course in turn because you need to normalize put less emphasis on the latter steps so the network can still decide to violate this prior if the if effect may reduce the loss for enough so this is as i said a trade-off there are two hyper parameters the geometric distribution shape and the amount that you regularize by this kl divergence and yeah so now we come into the experimental results and these are pretty pretty neat because yeah they i think these are um straightforward experimental results they're not super big large scale results or anything like this but they show that look on tasks where we sort of know that this dynamic computation has an advantage uh our model will outperform both previous attempts at dynamic computation and especially networks that have no dynamic computation built in whatsoever so this is the parity task which we're going to look at as you can see here the orange is this act which is the previous work that they compare most with that is most similar to them you can see in terms of accuracy pondernet beats this network by quite a bit also appreciate the error bars in this one they almost overlap but they don't so you can say that you're definitely better um and interestingly the number of compute steps even though yeah the error bars overlap as well here but pondernet itself needs less compute steps which might be you know i don't i don't know why why exactly that happens but you can speculate that it is because um pondernet sort of fixes on a single like it outputs a single answer whereas the act it outputs this weighing of things and therefore when it when it outputs let's say the first step answer it always needs to consider that this needs to be compatible with potential future steps so um just formulating just formulating how act outputs stuff it seems like it becomes a lot less dynamic because the output is always a weighting of different outputs and therefore the first steps they have to they can't just output what they think is the correct solution but they sort of already have to incorporate the future and estimate well if i'm going to continue computing um then you know there's going to be stuff added to my output right here and they have to take this into account so it can be ironically less dynamic of a network and that's why i think pondernet might need less steps here i might be totally wrong though so this is the parity task and specifically they train with string lengths between you know so this is a string length of one and then string length of we've before we had like eight right something like this so they train up from one until 49 lengths one until 49 and this is a little bit important i think because their training set contains all of them which you know this is a little bit of an experimental trick right so in order for your network what you wanted to learn is kind of the general principle of parity independent of string length so you construct the training data set to be sort of a distribution of lengths of string um rather than just strings of a of a fixed length and then you assess their parity so yeah that that's maybe a bit of a lesson uh for if you do experiments construct your tasks themselves already such that uh they help find the correct solution right so they train with strings of length one up up until 49 right and then they try to extrapolate which is this b right here so this is extrapolation where then they test so first here they test they train on small strings they test on small strings here in b they train on the same small strings up till length 49 but then as i understand it they give it length 50 to what 99 or so uh in 2 or 96 it says it somewhere just longer strings that it has been trained with right and now that the setup is you know clear it's clear why they did the different length strings in the training set and not just fixed length strings because uh there's a reasonable chance the network does not learn to extrapolate just from one particular or two particular lengths of string nevertheless they test how does the network extrapolate to longer strings and you can see right here that act even though it also has been trained on the dynamic length uh strings it is that's 50 right that's pure chance so it's a parity task right it's the output is either odd or even so act um just gets a pure random chance as a result whereas the pondernet as you can see has like an accuracy of 0.9 which i guess is pretty good especially on strings that are so long you've never seen them so what can we read from this i'm not exactly sure there's always the possibility that you know they've just trained a ct wrong or something like this but it's also it's also reasonable to say that just how the previous models were constructed either they didn't learn the concept or their their output is just weird in the way ict is or since acts bias gradients estimates and pondernet doesn't yada yada we don't know what we do know is that in their experiments this pondernet was actually able to solve the extrapolation task right here the interesting thing is that if you look at the number of compute steps done you can see that pondernet in contrast to what it was trained with during inference sorry that's an alarm in in contrast to what it was trained with during inference during inference it has like two point between two point five and three steps let's say three steps computes for about three steps uh during inference time that's what it decides on for the smaller strings yet the same model right train on the same strings this is the same model during inference time on the longer strings all of a sudden it raises its compute to five steps right whereas act okay act doesn't work in the in this one it it just decides to stick around uh two or three steps as it does in training right so the authors sort of claim that this is good evidence that pondernet learns uh to solve the actual task right here and as the task gets more complex pondernet needs more steps to think about the task and this might be exactly you know what we saw that you have some sort of a string of zeros and ones and you learn during training you learn a how to take one of these maybe in multiple steps and get an output but now you also you have a longer string right well so now what you can do is you can also learn an output for this one and now you have two outputs right and now you can learn a series of steps uh to transform the two outputs here into a single output and that might just need one or two more computation steps which is exactly what we see right here happening so it's a good it's a good indication that something like this is happening i would be wondering pondering one might say haha if you know how this actually happens like like what do the individual computation steps represent is it in fact a for example in this parity task is the network going about this task in a hierarchical fashion you know like like i've shown here is it something different is it going about it in sort of a purely recurrent fashion where even though we as i understand it we input the entire string at the beginning uh does it only look at the string position by position or you know how does this work how does the scaling behave in general if you know they only show small strings large strings but how does it behave in general as you go up uh the length and so on it would be really interesting to introspect this model a little bit more than simply showing kind of um you know end results here of the individual tasks okay what they also find is that the hyper parameter how you regularize the shape we've seen this up here how you regularize this shape um is you know that is a hyperparameter but it doesn't seem to be terribly important again they compare to act which has another hyperparameter that does the similar thing that regularizes the shape of the um of the desired halting distribution which they call tau now tau doesn't mean a particular thing in so they say it does not have any straightforward interpretation though i guess the authors of act might disagree um but as you can see here so if i draw the the means there is a region where the tau per where a selection of tau performs high though you have to say see that is all around sort of the same value of like five e minus four or something like this and then for the other values that you might set it for it simply doesn't work at all so you the authors claim you have to hit this tau uh pretty correctly in order to even get the network to do anything whereas they claim in pondernet this variable right here first of all it's between zero and one and not just an arbitrary value right because it's a probability um and they claim that you know it kind of works for for most things except this one right here where essentially you bias the network to just output everything after one step so the trick is for the geometric distribution you have to take the inverse so 1 over this lambda p and that will give you the expected number of steps that the network would compute according to this prior so when you put in 0.9 that would essentially be a single step that you ask the network to do but for all the other things well you you judge for yourself uh whether whether this here is really good but what you can say is that look it goes from zero to one so you have a clear range and for most of that range the the thing seems to work okay-ish and what they highlight is even down here so even if they do this even if they set lambda p to 1 or sorry 2.1 which would essentially bias the network towards 10 steps so that the prior is please do 10 steps of computation in this parity task as i understand it um even for that point one you can see the network it doesn't do ten steps it actually uh also goes towards three four or five steps most of the time so the network learns to be somewhat robust to this prior distribution i mean i guess that's also a function largely of the uh hyper parameter here where you trade it off we don't know the effect of that just from the paper but even you know even if they set that to really low it's it it of course then the network is kind of robust to the choice of the lambda p yet it's still good news because that means it would mean you wouldn't have to regularize the the model super heavily in order to get it to work okay they go into two other tasks right here again these aren't tasks that you might necessarily know they are tasks where this type of computation shines uh particularly and yeah as i said i see the paper more as sort of an interesting and interesting task an interesting niche task subtask you might say of of connecting deep learning and classic algorithms there are a number of things that i think you can do right here to extend this so it's completely thinkable that you know the loss might be a bit different that you don't ask the network to output the direct answer at each point but you know you might you might want to attach memories and so on at at these output nodes um you might want it want them to output intermediate results or something like this another thing you could do is you could work with sort of adversarial losses instead of of um you know kind of reconstruction losses or whatnot so you could you could have some sort of a gan going on inside of this um in order to decide on the on the stopping probability uh there's lots of stuff um one can fiddle around with this uh type of of network um you can even think of of crazier architectures i don't know hop field like uh structures where you decide you know how far you iterate because you don't you may not always want to trade until fixed points i don't know i'm just i'm just talking crap right now okay uh one last shout out to the broader impact statement of this paper what a beautiful beautiful piece of um of writing so essentially they say well um this enables neurals neural networks to adapt their computational complexity to the tasks they are trying to solve you know neural networks are good but currently they require much time expensive hardware they often fail pondernet expands the capabilities they say look it you know it can do this it can do that makes it particularly well suited for platforms with limited resources such as mobile phones which is a good thing right it can also generalize better uh that means it's better for real world problems and they say we encourage other researchers to pursue the questions we have considered on this work we believe that biasing neural network architectures to behave more like algorithms and less like flat mappings will help developing deep learning methods to their full potential and that is indeed the broader impact of this work like that is that's the impact it had on me and that's the impact that it it should have um yeah i'm not like at today's conferences that most might be kicked out because of course it doesn't say technology good technology bad technology bias but you know respect for that and that was it for me let me know what you think and bye bye
Info
Channel: Yannic Kilcher
Views: 18,637
Rating: 4.9884224 out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, pondernet, deepmind, pondernet learning to ponder, deepmind pondernet, pondernet explained, dynamic computation, deep learning classic algorithms, halting probability, deep learning recurrent computation, dynamic recurrent network, broader impact, deep network learning to stop
Id: nQDZmf2Yb9k
Channel Id: undefined
Length: 44min 19sec (2659 seconds)
Published: Mon Aug 23 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.