∞-former: Infinite Memory Transformer (aka Infty-Former / Infinity-Former, Research Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello there today we'll look at infinity former infinite memory transformer by pedro enrique martens zita mourinho and andre f t martens on a high level this paper proposes a transformer that can attend to unbounded memory in the past it does so by building up what it calls a long-term memory which is a continuous signal rather than a discrete signal as most of the other transformers do it uses continuous attention to do so and that enables it essentially to continuously compress the past into this continuous long-term memory and then attend to it as it predicts next tokens it also introduces the concept of sticky memories which essentially are events in the past that are of particular importance to the future so by keeping those sticky memories specifically around they increase performance yet again so we'll go through the paper what the model looks like how it works and what it does in the experimental results ha caught you you wouldn't have guessed it but this video is sponsored by weights and biases if you're in the ml space and you don't know about weights and biases what are you doing please if you track your experiments using a spreadsheet a piece of paper tensorboard weird folder names like i used to do stop that use weights and biases it's one line of code and you can log any of your experiments to the cloud not just metrics but models data sets output images little videos anything you want say hello to zurich believe me when i started the phd i was looking for something like weights and biases and i tried every single thing there is i tried every productivity tool every note taking tool and i just couldn't get anything to work for one part because the features were just lacking for the other part because i was just too lazy and weights and biases solves both of those problems it has all the things that i need to track my experiments collaborate with others and so on but also it's just a single line of code and everything else works automatically it even boosts my productivity because whenever i have logged a model i can just call a function to download that model from the weights of ics website i don't need to place it in a correct folder or keep track of it myself it's just there on top of that it relieves me from the stress of writing stupid overleaf reports because i can write a weights and pisces report and share that with the people that i want to show my work to the weights and biases report is so much more useful than a pdf it's essentially a website but you don't need to code any html or css or whatnot you can include dynamic content you can reference the runs you did you can pull out data from the runs you can present that in a neat fashion and it gets even more easy you don't even need and it gets even more simple you don't need to even set up anything in fact weights and biases runs in the cloud by default you can host it on premise but it really wants to live in the cloud all you have is an api key you log in and you're good to go so please check it out uh accounts are completely free for personal use i promise you will not be disappointed give it a try and now let's get into the video bye-bye [Music] cool so there are a couple of good things and a couple of questionable things about this paper also there are a lot of engineering choices in this paper which i don't necessarily want to go into there are a lot of things that one could do differently i feel which in influences the experimental results as well i guess but we'll just take it for what it is the other thing is that i believe this should be called not infinity former but inft former that's actually how you find it on uh if you google for this you have you can enter inft former inft being of course the um abbreviation in latex for this symbol right here and i think you know to make it more unique we should just call this the entity former all right so what does the inf deformer propose they say in the abstract right here that transformers struggle when attending to long context since the amount of computation grows with the context length and therefore cannot model long-term memories effectively so there are a number of things hidden right here they say the amount of computation grows with the context length now for classic transformers it's actually worse right the amount of computation grows quadratically with the context length but even for some of these let's say linear transformers the amount of computation still grows linearly with the context length so they they see even this as a problem they say they cannot model long-term memories effectively now they say several variations have been proposed to alleviate this problem but they all have a finite memory capacity being forced to drop old information in this paper we propose the inf deformer which extends the vanilla transformer with an unbounded long-term memory by making use of a continuous space attention mechanism to attend over the long-term memory the entity former's attention complexity becomes independent of the context length now already remember right here there is rarely a free lunch i don't want to say there is no free lunch because i've definitely eaten free lunches before but there is rarely a free lunch in these kinds of things if we have a finite computation we cannot pack infinite information in there so if we are attending to unbounded long-term memory that means something else will have to give and of course the thing that gives here is just the amount of information you can retain now this can be a good thing to trade off sort of boundedness in time for boundedness in information yet still you have to keep that in mind as i said they also introduce this thing called sticky memories that keep important things around now as we go through this this gets it in my mind at least this gets more and more into just like a classic lstm model so the classic lstm model of course takes in some sort of a a input then models a hidden state then propagates that hidden state when it inputs the next thing and so on and it sort of has to keep track of what's important in its own hidden state as to decide what it wants to remember what it doesn't want to remember so as with the transformer the lstm has in fact an unbounded memory right it can remember com things for arbitrarily long yet it only has finite capacity to do so it needs to overwrite some memory every now and then so this is a bit how you can think of this model is essentially the same principle as an lstm trading off unboundedness for finite representation space i'm not saying this is an lstm it is a little bit different it might be a smarter way to do unbounded computation uh it might not be but in concept it is the sim the the similar thing okay so what's up with this continuous attention that they keep talking about this is in if in in essence quite a simple concept namely if you have a sequence of let's say tokens right and every token has an embedding vector so every token is associated with a vector that is its embedding and this can be the first layer but this can be also the intermediate the intermediate values of the computation so from one layer to the next you always in the transformer have number of tokens of these embedding vectors that travel through the model they get transformed into by the next layer into new embedding vectors and so on and so on now the init deformer what it does is it takes this signal right here and and changes that from a discrete signal into a continuous signal so you would no longer have dimensions that you know the first the topmost dimension here the first dimension of all these vectors might be whatever four five nine point one three uh that's no longer the case what you would have is like a continuous signal okay now how do you do that pretty easily what the inf deformer does is it takes each of these dimensions separately okay each of these dimensions it plots these points up on a sort of continuous plane so this this here so this it labels it from zero to one so you divide this interval into i guess five different points because we have five tokens for the first one you label sorry about that you label with a four where is a four i suck at this so here is a four so dot here then here is a five i guess so dot here nine point one and three like here okay so here's three cool and then what it does is it it calculates an interpolation so the interpolation would be this approximately right so it calculates an interpolation of these points and then it simply stores that interpolation it forgets about the embedding vectors themselves and it simply stores that signal and that is its so-called long-term memory simply this signal now you might wonder why don't we just store the embedding vectors right instead of the signal and that is of course a good question the goal is of course that you can store the signal more efficiently than the embedding vectors so if we can describe this signal here with less than five numbers then we might be able to um then we might be able to save some space right like what like this is reasonable this could be a polynomial of degree three right if for example like if i draw this you know this is reasonably a polynomial of degree three ergo we'd have to store like three numbers maybe plus a by so four um but if we agree that we always store polynomials of degree 3 then no matter how many embedding vectors we have we're always going to store the signal as three numbers or four numbers right as a constant amount of numbers and that is essentially the trick right here on how we get away from the sequence length we simply commit to a representation a fixed representation of a signal and and then we interpolate the embedding vectors using this fixed representation now the fixed representation here isn't a degree polynomial but it is in fact a series of radial basis functions so we associate each point in time which is the the here the one the two the like the the interval from zero to one um we index this into a radial basis function and radial basis functions are nothing more than so this is one this is one this is one okay so these are these are three essentially these are three radial basis function spaced out right here and how could we represent the signal from up here uh using that maybe we can say okay that's plus you know if here is one like that's plus 4.5 of that of of let's call that psi one then minus you know it goes down like minus three of psi two and then it goes up again like plus four of psi three maybe some sort of a bias plus two okay so four numbers three radial basis functions all right so these things here are completely independent of the data they're not learned they're simply fixed once like this is going to be the our basis for representing all of the signals and then the way we transform the discrete signal into the continuous one is we run a regression so the regression you can run by solving this system right here by figuring out what is the matrix b here and that's a linear system what is the matrix b how do i have to mix the radial basis functions here in order to match my signal as closely as possible the way they do it is they run a ridge regression ridge regression is simply a um a regression with an l2 penalty i i think is that the case yes i think so so run y is equal to x times w so you're trying to find w like it's x times w you're trying to find that so your loss is going to be the distance of these things squared and then you have some sort of regularization constant and on the l2 norm of the weights so you solve this there's a closed form solution this is the closed form solution for ridge regression with f being the matrix containing these basis vectors this one right here and there you get your b matrix so you transform x which is dependent on the length of your sequence right into b which is only of the length of how many basis vectors you decide to have in this case three or three plus one if we want to bias again all right so and that's how you have a continuous signal you might already hear you might already say wait isn't this just a special case of a system that simply compresses a sequence into a variable length sequence into a fixed length sequence like isn't this just a way to embed like a continuous uh like an unbounded sequence and i'd say yes absolutely that's the first thing the second thing is is certainly the whole procedure is certainly not independent of length as this system right here is absolutely dependent on the length of your signal and you can also see that the longer your sequence gets the more mistakes you'll actually make in representing it because you only represent it using the same basis vector so here is where the trade-offs happen by going from length l to length i believe they call it n the length here of the number of basis vectors is n so that's the first thing here's where the trade-off happens the second thing which really kind of interests me and here you see this again right so by the way this then they consider their their memory right so you can technically do this with all of the past right you take all of the past you remember the vectors right here and then you interpolate or what you can do is you can what they call you know if you really go to unbounded memory you take the past you take the current sequence you can do what you can do is you can contract the past which means you can interpolate the interpolation so you can sample it in a more coarse grained fashion at than the you can sample it in a more coarse grained fashion than you originally produced it which leads to samples like here and then you concatenate with the new signal and then you simply interpolate again into the whole signal so you can see the more distant past is now compressed to that and the more recent past is appended to that and of course in the next step you'll contract this whole thing to a shorter sequence and append the more recent thing right here and interpolate again how ex this is conceptually no different from an lstm it brings about the same problems as an lstm namely more recent things are more likely to be in memory than way past things and so on um so calling this you know being able to attend to unbounded uh unbounded memory and so on is a like it's a bit shady like that just that's just my opinion you have to be aware of the trade-offs second of all second is the fact that in order for this to work right and we haven't even gotten to the attention part yet we're just representing our signal as a as a continuous signal in order for this to work you're counting on the fact that there is some kind of a regularity like here i've drawn these points specifically such that i could draw a neat line through them yet there is absolutely no reason why the embeddings of the continuous you know next to each other tokens should be in any way continuous such that you can interpolate it right you count on the fact that you can compress the signal because the signal like the samples go like right then you're like whoa i can i can represent this by one line right one radial basis function goes through all of them cool uh but there is no reason why this should be like the signal could be like completely completely random in terms of what the the real floating point numbers are in the individual dimensions um yeah they mitigate this a little bit by smoothing the signal first before they uh before they interpolate it but in my mind that kind of only makes it less accurate it doesn't make the problem go away it just makes it sort of less accurate because if there is an actual value to having a pattern like this if that's actually an important um an important pattern then neither interpolating it very coarsely with only a few basis functions nor first smoothing it will will necessarily help so you know i just from a principled standpoint i am skeptical that this is the case that signals that these signals here are necessarily such that they are easily interpolatable but of course i might be wrong so you know that that's it i might be wrong right okay so what do we do with it all right let's say we have the past in this long-term memory right this is all of the past we've interpolated it into this fixed um long-term memory this continuous signal that we represent as a superposition of a fixed set of basis functions we have our short term memory here which is simply whatever we would put anyway into the context of the transformer right and then we have our sequence that we actually want to deal with so the attention within the discrete part of the transformer is as you know it this is self-attention a training i guess masked self-attention for certain tasks this is as you know it the question is how do we make use of this long-term memory right here and here is how we do it so for each location in where we want some sort of a prediction right we produce a query as you know if in a transformer layer every single token produces to go from one layer to the next produces a query vector the query vectors tell what this token wants to know about the sequence in the last layer now every token also emits a key and a value vector so key and value key and value and so on i'm only drawing the keys and then this is routed by inner product now the query of course we can keep the query simply tells what does this token want to know so the query is also taken to go to the long-term memory right so the query vector of each discrete token now goes to the long-term memory down here and we have to find a way to ask the long-term memory something according to this query so how do we do it what we need is we need some sort of a notion of a key and a value for this long-term memory and here's how we compute it remember we have it's not the continuous signal is described by this matrix b right here so if the continuous signal is described by the matrix b then of course we can compute keys and values from b these w matrices right here are learned parameters that take b and make it into keys and values now the keys and the values are of different length there are sequences they're discrete sequences right they're of different length than the length of the sequence we're dealing with but that doesn't matter nothing in a transformer actually specifies that the next layer always have to it has to have the same length of sequence so what you can imagine the way you can imagine this is from the long-term memory essentially what we're doing is we're building another sequence it's not as long as the sequence that generated the long-term memory but essentially we're building another sequence of tokens they are you know not necessarily corresponding to individual tokens in the inputs they're corresponding to how the thing is constructed but nevertheless and from those we can certainly generate keys and values as we do regularly okay so we essentially compress the past into this pseudo sequence of fixed length via a continuous representation and then we just use attention again to map the keys here with the queries now when it comes to actually computing the thing it's not it's not as easy so this is in concept but when it comes to actually computing the thing what we want to do is we don't want to really abstract this into series we would like to use continuous attention so continuous attention essentially means that our attention doesn't go directly to one particular token so it's not like we know this token and this token and this token but since we have a continuous signal our attention should be something more like well i want to attend to this part of the sequence and we model that as a probability density over the sequence specifically we restrict ourselves to a gaussian so what i can say is i can my query the interactions between the queries and the keys will give me a gaussian where i say i would like to attend to this particular part of the sequence right this is where in the past i want to attend and this is how broadly let's say i want to attend you know how how many how much of the surrounding i want to consider so this this ultimately defines a gaussian like where it is and how how far the gaussian is spread right so i can attend to per per query per token per head i can attend to one location in the past and its surrounding and the width i can also specify and this is also learned so as i understand it these fine transformations right here are also learned transformations maybe i'm wrong in that it just says a fine um but yeah and then the sigmoid and the soft plus are just regular functions but you can see right here this is essentially um as you are used to multiplying keys and queries but then instead of attending to the tokens themselves because we don't have tokens right we we specify a gaussian to attend over the continuous signal and ultimately we can integrate essentially we can integrate the two things so we can integrate the values that we obtain from the from the sequence this these values we integrate them according to the probability distribution that we get and that's going to be our output values so these here are going to be our output values now once we have the output values from the long term memory we add them to the output values that we get from the short-term memory and the sequence itself add them together i think they go through another affine transformation after that and there is your output and the output is going to be one output per token in the sequence that you're interested in okay so i know this was fairly lengthy but to recap we take the past we do we do a regression a ridge regression in order to determine the coefficients to represent the past as a continuous signal with respect to a fixed set of radial basis functions this gives us a fixed size representation independent of how long the past is then the way we use the past is we take the queries that come from the attention mechanism we transform the representation of the past which is this b matrix right here into keys and values we take the inner product between the queries and the keys and this determines a gaussian window for us where in the past we want to attend to we integrate the values from that region according to the gaussian and that's going to be our output signal from the long-term memory this gets added to the output signal of the regular tension mechanism and that gives us the output signal as a whole okay this is essentially essentially it and if we do this one after another right we could simply always go to the past and compress it but we can also do this trick that i mentioned before this unbounded memory trick where you always take the signal from the past you compress it essentially by subsampling it you concatenate the new signal and then you interpolate again and on top of this they introduce these sticky memories and the sticky memories simply say look here the points that i have sampled the points that i have sampled this past signal on here i simply well don't believe my drawing but i simply did that uniformly i sample this uniformly that kind of gives me a good sampling of the of the signal right i can also sample this differently right i can over sample certain regions and under sample certain regions so here they say why don't we over sample according why don't we sample according to these gaussians that we've determined during the attention mechanism so the gaussians of course are uh summed up over all the attention heads and over all the sequences in so sorry all over all the tokens in the current sequence that you're looking at because all of these things attend to the same past if we sum up all these gaussians over these things then we should get an idea of where most of the attention went and where no attention went and the idea of sticky memories is simply let's over sample the regions where a lot of attention went so maybe a lot of attention went to this bump right here so we over sample that and maybe not much attention went to this region right here so we don't sample anything like this then once we have sampled we spread these things out i guess equally we could and then we interpolate again and that's how we keep the more important things in memory uh more accurately now again this is all heuristics and this is a bit what my criticism here is as well all of these things you know in an lstm it's at least learned like how to compress the past um and how to to read it how to use the past which memories to keep and so on all of all of this is learned right the lstm all the gates are learned and so on the the waiting functions now that's also the culprit in an lstm because you have to back propagate through time and that's just not possible for very long sequences so that's a bit of the lstms downfall as well whereas here we don't have to backprop through time because everything is a heuristic however everything being a heuristic it's you know like how do we know okay maybe it works but you know i'd rather i'd rather not use just heuristics for doing that kind of stuff um yeah but i guess there's room for improvement so here they detail that yeah they smooth the they smooth the signal with a cnn before they do the multivariate ridge regression and so on there is a regularization where they regularize the variance of the gaussian that they predict um yeah these are details so the ultimate loss has the training loss plus the kl divergence maybe they did that after they just saw the model simply wants to attend to everything um all the time i don't know but then they evaluate the model on various tasks such as this sorting task and i have to say they construct the tasks fairly cleverly by making sure the model can't like use simple strategies to solve it and what they see is that things like the transformer xl which tries to have some sort of a long-term memory but not doesn't do it really like doesn't i've made a paper on transformer excel sorry a video so if you're interested in that you can read it and also this this compressive transformer seems to be a little bit what the inf deformer is but without going via this continuous signal though the compressive transformer seems to be a transformer that always tries to sort of compress the past into fixed size um memory if i understand it correctly and generally they find that their model is relatively on par with the compressive transformer outperforming it a little bit now this being machine learning and so on i would not i would not be confident that there is a difference between the two model or which one is actually better just from these results in their results they are better and when they add the sticky memories they are even better which i guess makes sense but again take that with a grain of salt they do analyses on what which parts of the long-term memory this continuous attention goes to and in general this seems pretty reasonable if you look at um kind of you know these where in these long texts where the tension goes to like apparently here the the ground truth is um you too as i guess the answer of a question or oh no here i guess this is masked out maybe and the attention i'm not exactly sure where it's trying to predict you to maybe it's mask language modeling or some sort of question answering however it seems to be reasonable oh there's a helicopter it seems to be reasonable at least in this one example they show so they do ma sorry not mask language modeling actual language modeling or against something like gpt2 and they outperform that and they do some more analysis so again i don't want to go too deep into the experimental results right here because again with lots of engineering choices it seems to be um it seems to be you know like it's tricky to make sense of small differences between models what i would go for is the general trends and the general trends are are okay you know i don't know if the code's out i haven't seen any code if it is out give it a try i guess otherwise you know wait for about 30 minutes until lucid rains has an implementation available and with that i'll see you next time bye-bye
Info
Channel: Yannic Kilcher
Views: 31,098
Rating: undefined out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, inftyformer, infinityformer, infty former, infinity former, transformer, transformers, transformer linear, linear attention, unbounded memory transformer, continuous attention, attention mechanism, continuous attention mechanism, radial basis function, radial basis functions, ridge regression, long term memory, long term memory explained
Id: 0JlB9gufTw8
Channel Id: undefined
Length: 36min 37sec (2197 seconds)
Published: Mon Sep 06 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.