Multihead Attention's Impossible Efficiency Explained

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
in the last video I said that the multi-hit attention layer was just like the linear layer and that each output depended on the value of every input but that it did this with dramatically less computation and parameters in this video I thought I'd give you the big picture on how that's possible and it all starts with the backbone of modern machine learning weighted sums and nothing exemplifies this better than the linear layer because each output is literally just just a weighted sum of the inputs using different weights for each output on one hand this is fantastic because if there's one thing we've learned from all this work in neural networks it's that if we get data in front of a large weighted sum we can make intelligence happen see the universal approximation theorem on the other hand with large data these giant weighted sums are just too much computation and they require way too many weights now the multi-ad attention layer is still effectively doing these giant weight sums which means it's still able to give us that intelligence that we so crave but it does two very clever things to make this computationally feasible the first ingenious thing the multi-headed tension layer does is that it doesn't store the weights for these giant weighted sums directly instead it stores a much smaller amount of Weights that it then combines with the input to generate all the individual weights that it needs for the weighted sums this dramatically reduces the number of Weights the layer needs as we mentioned in the last video in a real life situation like chat gbt this can mean something like 4 million times less parameters than a linear layer would need also please note that the visualizations in this video are simplifying the heck out of how this actually works right now I'm not trying to be accurate I just want to give you the big picture of how the attention layer can be so efficient in a future video after we've talked about the internal implementation I'll revisit this idea and we'll go through and full detail and find those giant weighted sums in that implementation and yes we'll actually have to do a little bit of work to find them because spoiler alert the implementation looks nothing like the visualizations here okay so the first big idea here is that the attention there stores small weights that it uses to build up the full weights for that weighted sum and this drastically reduces the number of Weights that it needs second brilliant idea is how it reduces the computation needed because at the present moment we've actually added a lot of computation to build up those weights and here it does something that's going to look very familiar if you've watched my videos on depthwise separable convolution it performs the weighted sum in two parts first it does a row-wise computation and then it does a columnwise computation unfortunately it's still just as expensive to create the weights but now using them is incredibly cheap so when all the computation is said and done the attention layer still comes out well ahead of the linear layer in our real world chat GPT use case the attention layer is roughly a thousand times more computationally efficient so let's recap what we've done here the first really really important point is that the attention layer is effectively doing a weighted sum of all of its inputs for each output just like the linear layer does and that this makes it incredibly powerful the next main idea is the attention layer saves on its parameter count by constructing the weights for the weighted sum dynamically and finally it cuts down on the computation it needs by performing the weighted sum rowwise and then columnwise instead of all at once and now I just want to take a moment and appreciate the beauty in all of this as far as I can tell this is actually a really novel efficiency Improvement in neural networks and efficiency improvements on this scale are few and far between the only other comparable examp examples that I can think of are the convolutional layer and the recurrent layers and the convolutional layer was invented back in 1980 and the lstm was invented back in 1995 so it's been quite a while since we've had a breakthrough and efficiency of the same magnetude as the attention layer also both the convolutional layer and the lstm approached efficiency in kind of the same way which you can see from them having the same big trade-offs for the performance that they provide by the way don't judge the animation of of this recurrent layer please the first obvious drawback in both cases is that a particular output is no longer based on the entire input in convolution it's only based on a small spatial area of the input and in the recurrent cell it's only based on a single point in the sequence with a very limited aggregation of the past and a second more subtle trade-off is the outputs no longer get their own set of Weights in convolution they share the exact same weights with the other outputs in their feature this isn't necessarily detrimental in fact it works quite well for images but there's no denying that it does absolutely reduce the expressive power of the convolutional layer and in the recurrent layer the weights are shared with every other point in the sequence the attention layer approaches efficiency in a completely new way each output still depends on the entire input and still effectively gets its own unique set of weights the Brilliance lies in the dynamic weight creation and in the decomposition of the computation and ultimately it's Paving the way for a whole new class of layers and a brand new way of thinking about neural network efficiency and I really hope this is just the first step I want to believe that this has opened the minds of researchers to a whole new area of study and will find even better ways to execute on this new idea with all that being said I think it's finally time to dive into the implementation I've already started playing around with the animations and I promise they don't look anything like this so if you're interested in a fresh perspective on understanding the implementation then please subscribe hit the Bell icon to be notified and consider supporting me on patreon thanks for watching
Info
Channel: Animated AI
Views: 4,016
Rating: undefined out of 5
Keywords:
Id: TZxRJOLrhLA
Channel Id: undefined
Length: 6min 27sec (387 seconds)
Published: Fri May 10 2024
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.