Sparse is Enough in Scaling Transformers (aka Terraformer) | ML Research Paper Explained

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello there today we'll look at sparse is enough in scaling transformers by researchers of the university of warsaw google research and open ai this paper on a high level proposes a set of building blocks to introduce sparsity into transformers and this results in an architecture called the scaling transformer in the second half of the paper they then introduce additional features to the scaling transformer to make it into the terra former both the scaling transformer and the terraformer they're really fast at what they call unbatched decoding decoding is essentially inference in such a transformer model and unbatched means that they can do this for a single sample of course they're also faster in batch decoding but i guess the effects are not as pronounced and we're going to see why because the sparsity really shines through if you have single examples and can only activate very small parts of the network at the same time so the effect of all of this at least for the scaling transformer is right here if you have a model with 800 million parameters i guess today that be called a small model uh the baseline transformer has a decoding time of about 0.16 seconds whereas if you add all the tricks to the scaling transformer you speed that up by a factor of about 2.6 x that's not that pronounced yet yet the effect really shines if you go to bigger models so if you go to a 17 billion parameter models the baseline transformer takes about 3.6 seconds on this particular hardware to decode the tera no sorry the scaling transformer with all the tricks activated takes about 0.18 seconds giving a speed up of 20x and uh so in different settings on different configurations these speedups can in fact get even higher i've seen up to like 37x or something like this which is quite quite fast and this all while the performance doesn't degrade and that is um surprising so they say surprisingly the sparse layers are enough to obtain the same perplexity as the standard transformer with the same number of parameters so they have the same number of parameters it it's just that they activate them sparsely when forward propagating which is much faster and needs much less memory and this results in the same perplexity when language modeling so essentially means that the performance is on par and also they say if they if they integrate with prior sparsity approaches that's where they achieve the terraformer uh they can do fast inference on long sequence even with limited memory this results in performance competitive to the state of the art on long text summarization which is another thing where their model is state of the art or equivalent to state of the art while being much more sparse much more memory efficient and much faster so yeah we'll dive into this the architecture it's quite it's quite a mess like there are engineering tricks engineering tricks engineering tricks um and you know the the you have to wonder a little bit you know what came first like which trick came first and which trick necessitated which other trick but we'll go through the architecture uh through all the different pieces and you'll see what this is all about and where the savings are done all right if you enjoy content like this you know don't hesitate to subscribe i don't want to do the other youtubers show the graph i'll i'll do like i'll do this here's the graph here's the graph ah so many of you are not subscribed i mean look at that excellent all right so the point with the these sparsity gains is that um if you implement them somewhere then that part is fine but then another part is still dense and is still the bottleneck so you kind of have to introduce them everywhere so if we look at a classic transformer model and they specifically i think refer to like the stack of uh attention is all you need and so on so what they have basically is they have two attention modules so there's attention one i think there's attention two and then there is this feed forward layer okay so we're going to take care of all of those right here attention one is called self-tension so if i have a sequence coming in here the self-attention would be essentially a tension in between the elements of the sequence the second attention block is i think encoder decoder attention or something like this the variants vary a little bit right here but i would have sort of a second stack of this right here i would have a input sequence right here so this would be the input this would be the target sequence that i'm about to decode um maybe this has some causal attention who knows the second layer of attention here is specifically attention that goes to the uh encoder sequence right here so it's it's a tension in between the encoder and the decoder and the feed forward uh so this essentially these two mix all the information of the different tokens together and the feed forward layer simply takes a single embedding of a single single token and feeds it through a feed-forward function so all the tokens are handled by the same feed-forward function the first thing this paper does is it essentially eliminates the distinguishing between the self attention and the attention between encoder and decoder and i think that makes sense that's also a lot of what a lot of other models do so famously bert is an encoder only model gpt is a decoder only model and if i understand them correctly there as well they're simply taking the encodings from the source and then just prepending them to the target or something like this you know safe to say there are lots of things that one could do right here but what i wanted to say is that we now need to replace each of those things with a sparse version so we need a sparse feet forward and we also need a sparse attention block so how we're going to achieve this first we're going to the sparse feed forward layer remember a feed forward layer is i have a sequence of embeddings so that's these are all vectors and these are all embedding vectors this is a sequence of embedding vectors that came out of the attention module right and the feed forward layer essentially is a matrix um and i simply pass each of these through a matrix in fact it's not one matrix i think it is usually two matrices one matrix that sort of well that's not how you draw a matrix um like this and then like this okay so you kind of blow up the dimension in the middle and then here there is a relu non-linearity in between and the point is what i already said you'd feed every single token by itself through this function so this becomes like a large token then there's a relu and then this would become sort of a token of the input dimension again and you feed this token through as well individually which give you this one and so on so in essence we have a vector right a a token all the tokens are independent we have a token and somehow we need to make this sparse right now it's a dense multiplication twice so there's two matrices right here and if dense multiplication right so what do we do the first thing they say is that well given that there's a relu non-linearity right here right there's a relu a lot of the things here essentially are going to end up being zero right so it makes sense it makes sense to do sparsity here now i don't i don't follow that entirely um you know i guess half of the stuff will end up being zero yet the sparsity goes much further so but maybe maybe they maybe they justify why they can set some things to zero not entirely sure but i found that reasoning a bit shaky but here is essentially you know you don't need any reason to introduce sparsity if it works it's good so here's how it works first and this is what i found a bit um confusing so it essentially starts on the right then it goes to the left but it i guess it's easier to start on the left so what we want to do we see here is that input vector right and here is that first matrix so the first matrix is of dimension d model which is the same as this dimension and dff which is the feed forward dimension and usually i just multiply that together which would give me a vector in the dimension of the feed forward layer right which i then send through my relu however however what i want to do i want to compartmentalize i want only certain columns here to be activated right so essentially say i already accept that a lot of my things in my result are going to be zero because you know they will go to a relu anyway so i'm gonna accept that some of the things will already be zero so let's say all of these i already accept they're gonna be zero i don't even need to calculate the matrix multiplication between the vector here and let's say uh this column right here don't need to do it because after that they will become zero anyway so who cares um so i'm simply going to decide that some of the things are just going to end up being zero and they justify this by saying well there's a relu so some of the things are going to be zero but more more here's like you know six out of eight are going to be zero and now i only need to calculate the remaining columns and that is the sparsity right here um effectively they subdivide all of the they subdivide the whole matrix into these compartments so we'd have two different compartments right here and of in each compartment only one column can be activated at the same time right i think yeah yeah there's one one of them it's decided on one of them one of them can be activated and only that one needs to be loaded from memory only that one needs to be calculated uh as an inner product with the vector and so the cells here where an actual value is going to be are sparse now the question is how do we decide which ones we're going to activate by the way if you can see then for the second matrix you know the same thing applies in fact i can use that same mask from here and i can again say well in the first module column number three was activated here right so row number three of this matrix needs to be activated the other ones don't matter because they're zero anyway so there's a zero coming in right here being multiplied with this row you know who cares what the result is the the input is zero actually well people care it's zero right but it means you don't even need to need to do it um you can simply just load the rows that you are that you know are potentially non-zero so yeah how do how do you decide how do you decide which ones uh you should load from memory essentially you're simulating you're already pre-committing to a relu pattern right so this is how you do it essentially you build you build you take your input vector right here and you're trying to somehow we're going to see how that works we somehow come up with a vector of with a binary vector with numbers between like 0 and 1. so everything right here is like a point one point five point three point eight so every single entry has a value every single entry will output like the probability that that particular element should be non-zero and then you simply sample from that distribution and use a straight through gumball soft max in order to back propagate so they also do a lot of tricks right here i think they mentioned that in the forward propagation they even sometimes need to do a actually to pass just the softmax output instead of the actual sampling so there's a lot of engineering tricks to actually get this to work but safe to say that's during training we are we care about inference during inference you sample exactly one per module that is non-zero okay so you have two different workflows the workflow one goes here decides what needs to be non-zero right and then given that information you can do this feed forward layer in a sparse way but that is all useless if this right here is um is not sparse so this is actually not sparse but it is low rank so they say well in order to figure out which things need to be non-zero we technically don't need as much information as you know actually propagating information so what we can do is we can have a low rank essentially it's another feed forward layer again doing this blowing up the dimension to the feet forward dimension but we make it low rank so instead of instead of wait yeah instead of blowing up the dimension in between we shrink it down right you can see right here we shrink it down to a low dimension and then we go to the dimension of the feet forward layer to decide which things are one and zero and that's a thing you're going to see often in this model is that they make use of low rank combined with sparsity and it's also a bit of a of a trouble that i have because for some things a low rank approximation is fine but you know there's a reason we have dense multiplications everywhere because sometimes it's not because with a low rank multiplication you essentially restrict your function space to a very very small subspace yeah but it seems to work so the trade-off here is that you get to do this sparse which means that the time it takes decreases and the memory but you have to this here over this this is new right you didn't have to do this before you could simply do the multiplication so this is going to add to your compute well this here is going to be faster and now it's about whether whether or not you can make this side sufficiently low rank such that the the gains over here uh are more than the time that you have to invest to compute this max this mask at the first place over here again for these particular problems that they look at it seems to be working right but these kinds of trade-offs it's not guaranteed like it's not so clear to me that it would you know just work um like it's not it's not straightforward that that trade-off would be positive right here there might very well be problems where this rank right here is just too small to carry meaningful information you need to make it bigger and that would sort of vanish all the savings you make over here because these savings are i mean essentially linear in the sparsity and this these gain sorry these these this right here is essentially linear in the in the low rank dimension so there's the trade-off right there so they here is how you how you can express this you can essentially express this as the original multiplication with the first matrix um relu through the relu then times the controller output and all of that then goes into the second multiplication that's how you can represent it mathematically that's not actually what you do right because here you still have the full multiplications with the weight matrices but it will result in the same thing as this formula all right so that is the sparse feed forward layer and they do show that it decreases the coding time quite a bit and interestingly it also doesn't degrade uh performance too much in fact you can see right here this blue line is the average of the baseline models and if you if you don't go too sparse um you still have quite good performance so this is quite close only if you go more sparse uh does your perplexity here start to suffer i think that that is one of the surprising things that there is a level of sparsity you can go at where you're actually considerably faster while your performance doesn't degrade yet again can very well be because for the problems we look at the sort of the they're not difficult enough to really make use of the capacities of the dense models okay so feed forward is done now we go to the attention layer and the attention layer again is split up into two parts in fact they don't even they don't even really deal with the attention mechanism itself what they actually care about is in order to do attention attention is something like i have my queries and my keys and i do an outer product and i normalize by something that i can't remember and then i multiply by my values this is the attention formula and what they care about is how do i get the queries the keys and the values they in order to make attention itself sparse or or long-range or efficient they rely on on different uh techniques that from other papers so for example they will later include the performer and the reformer architectures which make attention itself sparse or efficient or low dimensional however in this particular paper they care about how do we even get these matrices and usually you get q by multiplying your input by a weight matrix like wq you get key by multiplying your input by a key weight matrix and you get v by x so all of these are dense multiplications and obviously they now become the bottleneck once we have the sparse feed forward layers the dense layers in in the attention layers become the bottleneck the question is can we use the same trick here as we did before and the answer they say is no because the structure of the feed forward layer here was such that it had the relu in between right so and that's why they argue so naturally a lot of things are going to end up being 0 which we can exploit by making you know just just a few more things zero i guess but they don't they don't want to do this right here because here like none of the things necessarily are going to be zero in the output of these calculations so the q or the k or the v they don't have many zero entries so might not be justified to go sparse and just say well make stuff zero so what do we do instead instead we look at this diagram here so on the top you have what the current attention mechanism looks like as i said there is a there is a dense uh layer essentially in front of each of these three matrices which is that's how you that's exactly how you get the matrix in the first place all right we're going to look at a thing which they call a multiplicative layer so which this is this malt right here and the multiplicative layer potentially could replace the dense layer however they go a step further and they say they end up with this architecture right here where they have a multiplicative layer then it's a one multiplicative layer for all three matrices that is shared and then one convolutional layer for each of the different matrices which is going to make stuff even faster and then they also they drop kind of this uh this dense mechanism right here and they simply add right here again i like i'm pretty sure this works right now for these particular problems hope like maybe because the problems don't make use of of the parameters or the original models were just poorly engineered they didn't they never actually needed all of these you know parameters like this one and we're all fine this could also be the case so we have two things to look at inside of the attention model the multiplicative layer and the conv layers and these kind of go together and it also goes together with what's usually done in the attention mechanism which is multi-head attention so i'll draw a diagram of an attention mechanism for the about 500th time but you have some sort of a sequence right and every sequence i'll replicate the sequence over here so every sequence emits what's called a like a query which is a vector some vector which are the queries and also every element in the sequence emits a key so the keys are also some vectors and the keys are also some vectors and then routing is done via inner product overlap so probably these go would be routed together these two would be routed together this would probably be routed here it can also be routed to multiple stuff but you route essentially via inner product so that's how you construct the weight matrix or the query key matrix for then multiplying by the values the idea behind multi-headed attention which is what's usually on is that let's not only have one such block let's actually have many such blocks in parallel right and instead of using the entire vectors that are output right here by for example that are in queue q are these the queries right q or is a matrix and every row or column don't exactly remember is one of these vectors right here they say hey let's instead of so q is a matrix let's say every row but for for let's just say every row if i'm wrong then you know just reimagine um so instead of taking the entire vectors here like the entire vectors as queries we split the vectors into in this case into three parts and this first part right here that becomes the query for this attention mechanism the second part becomes the query for that detention mechanism and the third one becomes the query for yet another attention mechanism that's multi-headed attention same with the keys same with the values and yeah so now now we're prepared so what we want to do right here is we want to um take a token and remember we now need to make a query let's say we want to produce the queries right so from this token we need to produce a query vector um not only one but number of heads many query vectors from this token using some sort of some sort of a linear layer some sort of a linear function so that's how we do it they say we have this matrix right here the weight matrix d and what the weight matrix d the weight matrix d is it has the same dimension here as the input and has as many as many rows as we have different attention heads right so what we're going to do is we're going to element wise multiply and i would also add right here broadcast right broadcast so if you've used numpy or or tensorflow or pytorch you know the broadcasting operation so the broadcasting is done this is of dimension one right here the broadcasting is done between this one and this s right here this is going to be broadcast into this form right here and you can see now i mean it's just an element-wise multiplication so all that is is like differently scaled versions of x in each dimension right so each row is essentially x a little bit shaky so let's double shake x for the bottom row okay but this already is now a vector one vector for each of the attention heads um now since element wise multiply is probably not going to get us very far we also multiply this by an actual matrix but instead of multiplying it by a d model times the model matrix again we go into a low rank low rank regime and simply say okay we have this number m and that's going to be a reduction on reduction on our dimensionality so this isn't d modeled by a d model matrix which would probably be expensive it's a d model by m matrix and out comes this so this is going to be the query vector for the first attention mechanism sorry no this is going to be the query vector for the first attention mechanism and this is going to be the query vector for the second attention head head i meant to say head there is a thing like they don't just choose m arbitrarily they in fact choose i believe s times m equals to d model right that is that is their their formula so they if they split into s different heads like let's in this case you see s is two then m is 3. and that has a very particular reason namely they say with this particular construction of the element was multiply followed by the multiplication by this weight matrix e if if we do it like this then they can have a theorem where is the theorem there is the theorem the theorem essentially says that they can um they can represent an arbitrary permutation so they say the minimum thing the minimum thing that we have to be able to do is to take x and kind of permute it so to place every single element of x in the output wherever we want essentially they say every part of x should be able to be forward propagated to all the attention heads or to any of the attention heads and if a theorem that says that if they construct it like this any permutation is within the um the realm is within possibilities for some matrices for some weight matrices d and e so that's kind of their justification of well we can represent all permutations so it can't be too bad right uh yeah i found a little bit of another way of you know seeing this if you look at this with the element-wise multiple and so on it is easier to understand this as let me try to draw this up maybe over oopsie boops over here so if you think about it a little bit it is like so you have and you also look at the formula this formula right here you can clearly see that this is in fact a matrix multiplication again so you have i would say you have if you look at this as d times x times e where x here is a matrix that has zeros but x on so on the diagonal it's x right which would give you it would give you sort of a so d is kind of this shape then x is that shape but only the diagonal is filled with x and then e is like that shape so and d and e are fixed matrices so you can see that uh what the mult what this multiplicative layer is doing essentially is it it defines outputs it defines outputs so these are the number of outputs and this is the dimensionality of the output and what you're able to do is is in some higher dimensional space you're able to manipulate the coordinate system scaling a little bit well a little bit arbitrarily but you cannot mix the individual dimension freely you can simply in that high dimensional space for a given mixing of dimensions that's what these matrices here do for a given mixing of dimensions for given linear projections from the low dimensional to the high dimensional space you're able to manipulate the coordinate system so if if you learn you need to be able to find matrices d and e such that for arbitrary samples the manipulation of the coordinate systems there makes sense it's a little bit like you know like doing a pca or something on a on a data set right but it's just like during training right here so yeah i'm not sure again this is quite this is quite a loss this is quite a trade-off with an actual dense layer right here so but it's interesting to see that it works right and again this is only conceptual right here um if you were to actually do this you would lose all the benefits that you would lose all the benefits that you had and again you can see a little bit that the trick here isn't necessarily sparsity but mostly low rank this is mostly like a low rank um function uh yeah okay so we have the multiplicative layer we end up with the queries and the keys and the values for each attention head and now we're going to they essentially say okay we could do this for every one of the three things or or we simply do it once which would give us this property of would you give us this property of the permutation being able and then we can do something even cheaper if we want to get the individual matrices right and so the trade-off here as well here still every permutation was possible for the different matrices so the queue could have different permutations than k then v or different functions here we're simply going to resort to one function one mixing or shuffling around of the dimension and then we're going to do something even cheaper which is this convolutional module and this convolutional module is also fairly simple to see so this output y right here and draw it again over here you have two vectors right here and they say it somewhere they say the dimensionality somewhere so you have two vectors one per tension head this is the output of the multiplicative layer and presumably you would have those per token right we just looked at one token but the next token let me draw in this color the next token would also have them and then the next token would also have uh two of those all right let's do this so what you'd get is a tensor that has the sequence length l it has the number of heads what's s i guess or number of modules and it has m which is that that essentially that low rank dimensionality that the keys and queries and values live in and they simply treat this as an image and then they run a convolution across it so the convolution is going to be let me see if i can draw this properly the convolution is going to be um across these two so the filter is going to be like this and then in all the dimensions so like this i'm terrible at drawing but the filter essentially is going to be um f in the dimension of s f in the dimension of l and m uh deep and you have m filters of those so you you have an s by l by m tensor here and you transform it also to an s by l by m tensor essentially you can just think of this as a regular convolutional layer and what the again what does the convolution go over remember that the multiplicative layer simply works on a single token it uh mixes it kind of sh it is able to shuffle around the tokens dimensionalities a little bit uh to permute them a little bit in the best case and in all other cases it essentially manipulates the scaling in a high dimensional space and now with the convolutional layer what we can do is we can bridge a little bit of information already between the tokens even before we go into the attention module so given that the convolution is across the l and the s dimension it means that for the s dimension information is able to be passed between neighboring attention heads and for the l dimension it means information is being able to be passed between neighboring tokens in the sequence so that potentially gives some sort of a positionality to tokens because now that there's a notion of being close together and also it gives maybe a little bit of a meaning to different attention heads because the attention heads up until this point they've just been kind of unordered independent things and now they hang together a little bit this all of this is sort of one of the things why the the the exact um conclusions of this paper are going to be hard to assess even if they do ablations right they at the same time where they introduce efficiency they also introduce entirely new ways of of sort of doing things they introduce new paths when it where information can be passed from between things and so it's very hard to point down exactly where things go right and wrong so this was the sparse or rather low dimensional um attention module again this is first one of these multiplicative layers which is element-wise multiply followed by matrix multiplication to a lower dimension and then that is followed by these um by these convolutions by these convolutional layers right here so they call this whole thing a malt conv right if they combine all of this together you can see right here the blue with the shade is the average of the bass lines this is perplexity so lower is presumably better and you can see up to some noise all of these things are fairly consistent right they they follow the trajectory of the baselines quite neatly uh some are even kind of a bit lower this one right here though i'm not sure if there's a there is exactly confusion because so the f right here is the filter size right and the s is the the sparsity in the multiplicative layer so essentially how many attention heads it splits stuff into and you can see right here there is a conv is just the conf and there's just the mold but the f is with the malt which confuses me because the f is the filter size so technically that should be with the conv i guess um if the authors are watching please please leave a comment um if i'm wrong right here other i'm confused in any case uh they show that the baseline transformer don't particularly do that much better in these nlp tasks or even do worse sometimes as you can see right here though everything is pretty much within like a standard deviation um than these scaling transformers so this architecture that we've discussed right now is this scaling transformer the last thing to do would be to add a sparse loss layer so they can replace the dense layer with a multiplicative layer similar to previous sections this speeds up the coding time say sorry they say but may degrade perplexity results are in the appendix so the the lost layer might not might be the last refuge of of really dense uh things to do but remember due to the fact that in the feed forward layers we sample from this distribution uh to really be sparse or in fact we might do argmax right during inference um that's where the speed up comes from during training we actually have to forward propagate the soft max from time to time so that the training works and that means that the benefits of sparsity are lost because if we don't hard sample ones and zeros if we soft sample them then all the rows are still activated and we need to track everything and the same goes i think a little bit for batch inference so if i have batch inference even if i hard sample right different samples are going to have different activation patterns and therefore you know with enough samples all the things are going to be one somewhere and therefore i probably need to load the entire matrix right here from memory i need to do the multiplication with the entire matrix possibly not for all the vectors but also possibly something like a gpu probably wouldn't care that some stuff is zero it's gonna be as fast just to do all the things at the same time but that might be a hardware limitation okay so that was the scaling transformer and now we're going to supercharge the scaling transformer which makes it into a terraformer i don't think there's any relation to the tool terraform but you know we're running out of names of formers so yeah this was the last refuge i guess so what they do is they use essentially they use essentially the architecture from the attention from reformer so yes we focus on the l locality sensitive hashing attention from reformer was that reformer i thought that was perform i am confused by my by my own stuff reformer yes so they do two things right they um have an architecture for long sequences uh while integrating sparse attention layer into a scaling transformer we notice the architecture is sub-optimal that's what i said at the beginning um separating decoder self-attention and encoder decoder retention is not necessary anymore from the perspective of efficiency we remove the encoder decoder attention that i said that at the very beginning but just concatenate the encoder representation before the decoder tokens so they replace the encoder decoder attention by essentially two attention blocks that is that okay i guess there's no performer in here just the reformer so the lsh i've done a video on this locality sensitive hashing instead of full attention so if you have really long sequences you as i said you need to compute inner products between all pairs between all pairs of uh of nodes right here of tokens and this is cumbersome there are various techniques to speed that up one is lsh locality sensitive hashing where you essentially create hash buckets and then you hash all the vectors also all the vectors inside of it or all the inner products become hashes and you look for essentially hash collisions that indicate where you want to calculate and check and a whole everything that's not a hash collision you don't need to check so locality sensitive hashing has been long standing technique to make inner product search in high dimensions or inner product computations and looking for the most close inner product in in among very many elements how very fast so they borrow that from there and then also they include the recurrent blocks so recurrent blocks is um no that's later first it's the reversibility all of this is just so similar reversibility is also apparently in reformer and what reversibility means it's kind of this architecture right here so again we have two attention and then one feet forward right the second attention replaces the encoder decoder attention and reversible means that instead of having one strand like one flow of forward propagating information right one flow of information we have two so there's i1 and i2 input 1 and input 2. we have 2 information flows forward and then every function that's applied is applied to one flow and added to the other flow right this gives you this and this one right here is simply forward propagated as a residual connection essentially and then x2 is taken so this the flow of the actual function would be this right here right you can see this is the flow of hitting all the functions and you can also see that we always have a signal for each of the functions we always have a signal that travels without being touched by the function right here okay so that signal right here and this is the signal right here and that makes the blocks reversible and that means that um i can i don't have to keep activations in mind this limits this limits the capabilities a lot so non-rever an example for non-reversible would be well this here is non-reversible because because unless i do like a linear function that goes from exactly the same dimension to the same dimension that is non-degenerate unless i do that i cannot possibly reconstruct the input right here like the the signal right here x from the output y not even for a single one of those blocks right it's not possible for me um essentially to do this or uh yeah so the reversibility changes that essentially means i can always reconstruct from the from these signals i can reconstruct the intermediate activations and therefore i don't need to store them because in a normal neural network as i forward propagate i need to store a lot of intermediate stuff like right here and right here in order to then during back propagation i need those things um because otherwise i couldn't calculate the gradient so i need to store the activation somewhere reversible networks reversible blocks do not have this property they do not need to store because they're reversible and they're made reversible not by changing the individual modules like this or this but by simply having this construction of the two strands of information and the modules simply apply between the two that's it's a pretty smart architecture but one has to say it has very often significant trade-offs because these things being reversible also bring some some properties like there are a lot of functions you cannot express anymore because you need to keep everything reversible so again i think for the problems they particularly look at here it might work it might not work for all problems i think that's a bit of a general thing in this um this paper right here it's more like we're gonna have to test for every new task we tackle or new challenges new modalities whether these things still hold the last thing they build in is recurrence and they say it's for generalization and that is if i understand it correctly it is they use uh simple recurrent units not like an lstm because they say that would be too slow so simple recurrent units they're still fairly complicated like i've looked them up they're i didn't know what they were they're still oh they're still okay complicated so it's not just like a recurrent layer it's actually you know it has gates and so on like a bit like gru's or um lstm cells and if i understand correctly this goes between so as i said before in the feed forward layer that every single token goes independently through that if i understand this correctly if i understand this correctly this introduces a recurrent connection in between these did i well did i understand it correctly okay um we also add a recurrence to the feed forward block of terraformer recurrent layers allow information to propagate in time even even in a single decoder block okay i think i understood that correctly so within the feed forward block right here there is a recurrent connection between the different tokens every token goes independently through that but now we introduce actually a sort of dependence or a function that goes from the first token to the second to the third and so on a recurrent uh small recurrent neural network and again they one can only speculate why they have this in here i mean they say that this the results on c4 are minimal which is their language modeling uh task and they say the biggest benefits are when they do like these uh these toy tasks where you need to copy a decimal digit and then you can train at on 128 digits but then you can test on 256 so it's over two times longer than seen in training so they really make this point that it's for generalization though it is very very odd like this is a very odd addition i can i could get them until like you know here it says yeah okay you go for long sequences you know that that's cool long sequences are cool it's cool if your model can you know also do long sequences fine then memory efficiency okay you know so given that this is all sparse and low rank and so on you also might want uh to use uh less memory cool but then recurrence for this is this is quite an odd choice i feel and it could be that it simply didn't work like so they also say that the terraformer here um in sort of these tasks like summarization that it sort of beats or matches state of the art matches much much larger models and so on it could i can imagine that their numbers were slightly smaller like slightly worse than kind of the baselines and they were just looking for something to add to pump up those numbers and this worked if this is the case if that's a big if again it's very dangerous because it might work for these particular problems and not for others if not if this was really just like an idea they had and said well it'd be cool if that's in there then you know good like i'm willing to i'm willing to accept that as well all right so that was the terra former and here you see so the terraformer now has over a 37 x speed up on it's a considerably large model but for this large model it requires less than 100 milliseconds per token of decoding time while not degrading in performance too much so that is that is i think quite an achievement even if it's only for particular types of tasks like these here it is quite an achievement and um it's a bit of a shame that the speed ups are only for like they're only so huge for the really huge models i guess it makes sense uh because these effects are often compounding uh you know so it for you and me with like uh our regular old computers laptops it maybe won't make that much a difference uh in terms of speed it might make a difference in terms of memory because of the reversibility but other than that yeah but it's it's good for like if you work if you want to work with larger models but you don't necessarily have to compute and you do inference this might be something for you they specifically say that not everything has been tried yet they still don't do quantization which could yet deliver another speed up and there's also lots of things to do to sexually speed up training um maybe there's a way to get around this gumball soft max need to forward propagate uh the true soft max from time to time and so on so lots of engineering lots of kind of choices that are interleaved very hard to say where gain comes from but undeniable gain has been made in huge form and that's cool alright tell me what you think i'll see you next time bye
Info
Channel: Yannic Kilcher
Views: 23,055
Rating: undefined out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, terraformer, scaling transformers, nli, nlp, natural language processing, transformers memory, deep learning memory, fast transformer, fast transformers, attention, attention mechanism, attention is all you need, bert, gpt-3, google research, reversible layers, reformer, sparse attention, sparse feedforward, low-rank
Id: hgSGHusDx7M
Channel Id: undefined
Length: 57min 6sec (3426 seconds)
Published: Thu Dec 02 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.