FNet: Mixing Tokens with Fourier Transforms (Machine Learning Research Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello there today we're looking at fnet mixing tokens with fourier transforms by james lee thorpe joshua ainsley ilya eckstein and santiago antonion of google research i know i'm a bit late with this one but it's sort of a not only this paper but it's a really interesting direction that's happening right now in machine learning in general in deep learning in sequence models in image models and so on and that is the sort of giving up of attention mechanisms so for the longest time we've been focusing on transformers and in a transformer you technically you have some sort of a sequence as an input and then you push that through these attention layers the layers are actually always made up of attention sub layers and then feed forward layers so every layer would have an attention sub layer and a feed forward sub layer or multiple ones of them now the feed forward sub layers they would be sort of acting individually on the elements so the weights are shared there is one feet forward layer and the tokens every token goes through that feed forward layer so this can be efficiently parallelized or shorted or you can make things like mixture of experts where tokens go to different ones there's a lot of stuff possible however here in the attention part this was always a bit of a thorn in the eye of most people because while the attention mechanism is definitely a cool mechanism it needs a lot of memory and compute in fact the attention mechanism needs to decide which information in this layer's uh sequence goes to which information in the next layer sequence so where does the information go into the next thing from this token and then from this token does it go here or here who knows the tension mechanism's job is to figure out what information goes where it's a it's a routing problem and as such it has a complexity of o of n squared is if n is your sequence length and also it has memory requirements of n squared and that prohibits it from scaling to larger sequence lengths so we would always be sort of limited in the length of the sequences in which we could input or which we could input which prevented it for example from being applied to computer vision for a long time until people figured out actually we don't need to put pixel by pixel here we can just sort of subdivide our image into patches and do that and then we can use the transformers but still this limitation of the sequence length is a result from the attention mechanism having this complexity uh right here and people have been chipping away at that complexity for a while now so we've had a a about one or two years now of constant invention of linearizing this attention mechanism so to get that from o of n squared to sum o of n or maybe n log n or something like this or something manageable maybe a constant maybe n times k anything but n squared so we had lint former and long former and reformer and synthesizer and uh i don't even know if synthesizers in the same area but performer and um linear transformer there are so many uh what what would be called linear or or non-quadratic attention mechanisms trying to approximate basically this attention routing problem now we've entered into a new era now people are questioning do we even need the attention layer at all and i think the or one of this this comes all comes at very very similar times right now so even after this paper there there has been like at least three papers since then um trying to actually just actively get rid of the attention layer in the sequence models which is super super interesting so we're going to have a look at how do you get rid of the attention layer that has apparently given sequence models such a boost and what do you replace it with and in this particular paper the answer is very much fourier transforms now we're going to get into why fourier transforms but essentially they present a model that looks like this so it looks very much if you've seen my video on on attention or anything since then this should look very very familiar to you namely there is an input down here then uh the input is split into words sequences of words or word pieces maybe and then each of these word pieces gets a word embedding so this is a table where you look it up it gets a position embedding and maybe gets a type embedding so if you want the most direct reference maybe go watch the video on on bert right okay so the the next step then is n times this layer right here and this is where usually the attention would be so but instead of the attention this would be here now you have this what's called the fourier uh layer or whatever we're going to look at is in quite a bit the output is a dense layer and an output projection and then an output prediction so as you can see this is very much like a transformer except it says fourier instead of attention so just so you're aware of what's going on this is the this is the thing they change they they don't change any other thing except this sub part and what is this subpart this subpart is characterized in this formula right here but essentially what you do is you have your input to the layer right so x x would be whatever goes into the layer right here and then of course this would be like x0 and then x1 would be go back in n times all right so x what is done this is a fourier transform so you apply a fourier transform to x now you might ask how can you do that x is not a uh a like a continuous signal like a sound wave or something like this remember that the way we view sequences here is as a series of vectors so every input element at the bottom will get mapped to some sort of a vector uh as many vectors as you have tokens and as many dimensions that's that's something you decide by yourself so you you're going to have a bunch of vectors right here and you do a fourier transform first over the well let's see first over the hidden domain and then over the sequence domain so you do a fourier transform over this domain and then you do a four a one so a one d fourier transform over over this domain right um each individually and then a one d fourier transform in each dimension but across the time domain right here and that's it there is no parameters involved in this thing this is simply a fourier domain in the time domain and the fourier domain in the hidden dimension domain and that's all and the only learned parameter in this whole setup are i guess the normalization might have some fine parameters but these feed forward parameters are then the only learned parameters okay this is quite a departure now if you if you are a bit confused let me go a bit more into this fourier transform you might first of all see right here that we are only interested at the end in the real part of the output of the fourier domain what does the fourier transform do the fourier transform what it usually does is it takes some some sort of a signal and it transforms that in a reversible linear fashion into a let's say a superposition of um of these basis functions so these basis functions in in the case of fourier transform they're uh these how do you call them in english these these uh like sine and cosine waves of different frequencies right very much what you're might be used to from the position encoding so the fourier transform would give you that the top signal is like three times this plus five times this plus nine times the the bottom one okay so the this signal right here would be transformed into this signal right here and you can do an inverse fourier transform as well the formula for the fourier transform is is pretty simple this is it you decide how many components you want you can represent any signal exactly if you have infinite components uh but you know as we deal with real numbers we just cut off somewhere and then you have the fourier transform and the inverse transform is simply uh if you don't do the negative sign right here so you can in fact do this by simply constructing this matrix here ahead of time and then multiplying by this matrix and there you really see this is just a linear transformation of your data okay and you you do it once column wise and once row wise uh to your signal and there you have it that that's your that's your your layer no learned parameters at all now why might this work um and the the second part of the paper right here that we are have we we didn't really look at yet is what they call mixing tokens and they make an emphasis on this and i think i think it's really smart so this paper isn't about the fourier transform it is not advocating that the fourier transform as such is in any way special rather i think what they advocate for is that the mixing of tokens is special so the mixing of information between the tokens now what do we mean so if you have a sequence any sort of sequence and you want to do computation with that sequence if you want to understand the whole sequence at some point information needs to flow between the elements of the sequence right now if you look at an image for example it is it's quite natural to or let's let's go a different way how does a convolutional neural network flow information well a convolutional neural network sort of restricts information flow to a neighborhood so what it would do is it would let information flow in this neighborhood and let's do non-overlapping kernels maybe in this neighborhood and then this neighborhood and then in the next layer right now there's only three elements in the next layer it would sort of let information flow in this neighborhood and also let's include that twice in this neighborhood now there's two elements and then it would let information flow like in this neighborhood and then you this node right here has sort of a global overview over the whole sequence whereas this node here only had an overview over a local subsequence we accept this and for images it makes a lot of sense this is exactly our prior for images is that what's first and foremost relevant to like a pixel here is probably the surrounding pixels and then the objects if the image contains objects they're probably sort of in the neighborhood-ish of of that broader area and so on and then on the highest level we want to com you know the relationship of objects to each other we want to understand that so that seems like a natural prior to have however in text it's a little bit different right um in text it might very well be that here at the end if anyone has ever tried to learn german that here at the end is a word that just kind of references in like intrinsically as a first layer of information the second word in the sentence or something like this uh like a verb helper verb construction this this is very common in language so there is not at all this locality of of of information given and therefore routing information fast between elements of the sequence is very important uh especially when it comes to language but it also is important in images because as we've seen the vision transformers they also work quite well um so routing information between stuff is is very uh very helpful in in language and this locality might not be as helpful might actually be damaging if you only get to learn about your distant uh distant away tokens you know three four or five layers down that just limits your ability to do computation now the attention mechanism is exactly right what facilitated these connections between elements of the different uh across the whole sequence right because it an analyzed every single possible connection between two things and then it decided okay these are you know the important connections what this paper is saying and i guess other papers that have come out since like the mlp mixer and and the uh pay attention to mlps and and also this is you know it might be it might not be so important to decide exactly how information should flow between far away elements it might just be enough for most tasks if information flows at all right if if we just somehow get information from one side to all the other or from one token to all the other tokens then um then we we uh we facilitate this transfer of information and that might be enough the exact routing might not be as important as the fact that information is flowing and that's what the fourier transform ultimately does right here because if you um if you transform your time domain right this is step one step two step three step four if you transform this then a little bit of of of the one token is is in is influencing this number a little bit is influencing this number a little bit is influencing this number and for two three and four as well so the time domain is completely destroyed right but the the frequency domain is split up and then in the next step when you do a fourier transform again you do very much the reverse you sort of go back into the time domain even though i'm i'm not convinced that applying this twice like in the next layer again will bring you back is that is that the exact reverse i don't know someone someone with uh more knowledge of this should probably evaluate if i normalize correctly is applying this twice and taking the real part after each one equivalent to performing the fourier transform and then it's inverse i'm i'm not sure what i'm sure of is that um this this the fourier transform will absolutely stack the time domain on top of one another while splitting up the frequency domain and if you apply it again it will do the the opposite it will stack all the frequencies on top of one another and split up the time domain the signal is the same but the feet forward layer are applied differently remember the feet forward layer is applied individually right to so there's one feet forward layer one box and it's individually applied to each of the elements of the sequence so the same transformation now what happens if you do the fourier transform and then apply the feed forward to each element well now the elements each element is no longer corresponding to a token but each element is corresponding to one frequency across all the tokens in the entire sequence so now the alternatingly the feed forward the feed forward layers can work on the individual tokens or on the individual frequencies across all tokens right and i think ah this is the same this is a bit like you remember we i don't even remember what it was but we had we had a tension so if you look at an attention matrix axial attention that was it right where you if you like if these are like two pixels uh the attention matrix between all the pixels would be too expensive but you calculate sort of the attention in the columns and the and the rows and then it takes two layers because first uh that pixel can attend to this one and then in the next layer that pixel can attend to this one it's a bit like this right where um you get anywhere like you can route information from anything to anything in two steps instead of one the reason so that that's what the fourier transformation does now you might ask why the fourier transformation and to be honest and i think that's also the opinion of this paper right here uh and i think they say this in the conclusion i'm gonna i'm just gonna skip a bunch of stuff right here they i think they say they've looked at other transformations so we found the fourier transform to be a particularly effective mixing mechanism in part two the highly efficient fft that's the fast fourier transform it is quite remarkable that an unparameterized mixing mechanism can yield a relatively very accurate model on a practical note we only perform the cursory survey of other linear transformations therefore we believe there may be value in exploring other fast transformations so the fourier transform was chosen because it was readily available in libraries but it is it is just a mixing technique and i'm even i'm even open to the idea that to fourier transform is like the optimal mixing technique here of all the linear mixing techniques you could come up with but what seems to be important is just the fact that you do somehow get information um around between the tokens and that you operate sometimes on the individual tokens and you operate sometimes across the tokens with your transformations and for a lot of tasks it might not be that crucial exactly how that information is routed right so i think that's the the sort of takeaway message uh from here now with with respect to experiments um it is not better than transformers so just say this from we've we've quit the era of i want like here's a new state of the art and we've gone into the era of it works almost as well but it is faster and also in a very particular plot with very particular axes it is better you're going to see that not that it is bad right but essentially what they claim is look we have something that's way faster you're going to sacrifice a bunch of accuracy for that and depending on your task that might be worth it or not worth it so here's the stuff they compare birth base which is uh the transformer model they compare with the f net which is we replace every self a tension sub layer with a fourier sub layer as described in section three two that's what we just looked at then a linear encoder this is interesting right let's actually first let's go like there's a random encoder we replace each self attention sublater with two constant random matrices one applied to the hidden dimension one applied to the sequence dimension so this is just like a constant scrambling um this is this is like the fourier transform except it's less structured like it's just kind of a random thing and that's why i say the fourier transform might be the most effective non-parametric mixing method here because it kind of makes sense and i do think it outperforms this random encoder quite a bit um and then there's the feed forward only that only does feed forward that doesn't do any mixing at all um yeah there is no token mixing as you can see here the linear encoder we replace each self attention sub layer with two with a two learnable dense linear sub layers one applied to the hidden dimension and one applied to the sequence dimension this i mean this is the this is the mlp mixer now i get it mlp mixer was specifically for vision and you know people might have tried this before not saying they invented this particular thing they might have i don't know but this is exactly like it's it's funny that this appears again right here in fact when you look at the results this linear encoder performs quite well um it of course has more parameters right because this one has no parameters instead of attention whilst the linear encoder actually does have parameters it's just not as compute and memory intensive as attention um so what works well is this linear encoder works quite well which gives you know gives credit to mlp mixer as well and also what works well is what they claim later a hybrid version so when they use the f net but at the end they like in the last few layers they actually use attention so again this is it's not better it's a trade-off and the trade-off is speed and longer context size for accuracy so if yeah here you have the here you have the number of parameters and there you go with the first losses so this is pre-training loss right so pre-training loss in uh in masked language modeling and next sentence prediction and also uh accuracy on the right hand side you see bert is bert is just winning here uh the other ones aren't like not even close right i guess a bit close so you can also see that the linear here outperforms the f net interestingly the f net outperforms random way so it's not like it's not like any mixing is fine right yeah that's the interesting part here because the random one is whatever like just mixed information um so that that is interesting to see and that gives hope that we might come up with even better transformations than the fourier transformation um yeah we i guess didn't the synthesizer also try to learn the attention matrix at that point i said that doesn't make sense but maybe you know we find some sort of universal or what not attention matrix that is just better i have no idea i'm just talking crap now and then you can see that the hybrid here also performs fairly well uh but this is just pre-training for now if you then okay the speed up is i mean speed up is of course a lot um there's a you know decent speed up on tpu and a massive speed up on gpus so you know that's that's where these models shine they're very fast um in terms of evaluating these things this is the glue benchmark it's a bit you know i think it's debated of how useful these benchmarks really are but it's at least a number you can measure and you can see that bert is very much winning in most of them though there are some where it is not like okay i like i don't even know what these what these tasks are but i they the authors here say especially for example in the bert large case um the this is quite unstable so this is fine tuning by the way they pre-train on the uh on the big corpus and then they fine-tune right here this can be unstable for example for example look here like the bird large is actually worse than the bird base in this one which i guess is only due to training training instability but they did say they they tried a bunch of times i guess i i guess it's also a factor if a model is unstable right if you really want to go into production with it that's an issue so you might opt for something more stable so you can see that in most of these things bert wins there are some times where something else wins like f net or f net hybrid though keep in mind these these benchmarks um sometimes they are they are rather just like a benchmark like a number uh in overall bert wins by quite a bit though it is followed by the the hybrid model and then the linear model and the f net model aren't too far behind um also if yeah if you look at the large one though i think the large one is simply kind of bad because it's unstable so this might be more of a training instability issue than the fact that this model is somehow um exceptionally good yeah it's it's it's it's quite interesting because i also compare these numbers to to jacob devlin's original paper and they were quite different uh the glue numbers and so i'm i'm a little bit wary about just these numbers and just sort of thinking of know how much variance do they actually have between different implementations between different runs and so on and that sort of um makes me a bit cautious with these uh things they do as i said so here they plot masked language model accuracy versus time per training steps for 64 examples in the log scale and in one region of this plot they are uh the ethnic and linear net are better which is i i hope you agree with me it's a rather specific plot to plot and even in the conclusions they say something like you know for a given time and for a given time and accuracy budget here we demonstrated that for a fixed speed and accuracy budget small ethnic models outperform transformer models which is okay there's like a measure where you have where you're better which is cool right but at the same time i think the the message is really that here's a trade-off that you can do lastly they evaluate on the long-range arena so the long-range arena is sort of a textual task where it's somehow important that you remember things for a long time or that you can address uh sequence elements over large distances there's like list ops these are not necessarily natural language tasks but more like constructed tasks with the explicit goal of testing the long range capabilities of these models and of course the transformers see still seem to be best but of course the question here is very often if you have long sequences you can't use a transformer and therefore you have these other models that you can see are not too far behind but they do use considerably less memory and compute and they don't yeah they don't run into fail as often they train way faster so i'm also a bit skeptical of this long range arena results because it sort of it sort of seems like as as soon as you can remember whatever it is you need to remember you you sort of solve the tasks um so there's not there's not like it it's more a bit of a binary thing you either get there or you don't rather than there being um rather than there being some sort of nuance to it right now uh we might get once i guess once we get more robust models that work on longer sequences that might change in any case yeah it's cool to see that you know you see in the average numbers these models are not too far behind the transformers and they train way faster as i said okay so that was it um for this particular paper as i said this is it is a paper about fourier transform instead of attention but it's much more a paper about the importance of mixing information between tokens that is an important concept and the available trade-offs that there are tasks there are situations where you don't need the attention mechanism you don't need this full power this full analysis and in those cases it might be enough to just somehow mix the information the fourier transform being one attractive option because it doesn't have parameters and it has very very fast implementations and it sort of makes sense on a conceptual level so that was it from me do check out the paper that they provide and i think they have code too if i'm not mistaken and if not it's it should be relatively easy to implement this all right that was it from me bye-bye
Info
Channel: Yannic Kilcher
Views: 28,436
Rating: undefined out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, what is deep learning, deep learning tutorial, introduction to deep learning, fnet, fnets, fourier nets, fourier neural networks, attention fourier, fourier attention, deep learning fft, machine learning fft, deep learning fourier transform, attention mechanism fourier transform, fourier transform in deep learning, attention networks, do we need attention in deep learning
Id: JJR3pBl78zw
Channel Id: undefined
Length: 34min 22sec (2062 seconds)
Published: Fri May 21 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.