Transformer Neural Networks Derived from Scratch

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
You’ve probably heard of a cool little application called ChatGPT, OpenAI’s latest iteration of Generative Pre-trained Transformers. ChatGPT is an amazing chat-bot, it's capable of  understanding and replying to text messages  better then some humans that I know. And the secret behind ChatGPT’s amazing abilities is the transformer neural network architecture (after all, it’s in the name). But have you ever wondered what a transformer actually is? Of course you have, which is why there are dozens of videos explaining what a transformer is on Youtube already. But all of these videos merely describe the transformer architecture. So in this video I want to take a deeper look at the design decisions behind the transformer and the motivations for them. Starting with a simple convolutional neural net, I will take you on a step-by-step journey to derive the transformer. The goal is for you to come away feeling like you could have invented the transformer yourself. Now I will be assuming that you are already familiar with the convolutional neural net, or CNN for short, if you aren’t then I’d recommend you first watch my previous video which explains everything you need to know about CNNs. With that out of the way, let’s get started. The transformer was actually invented all the way back in 2017. To set the scene, this was just after CNNs had achieved major breakthroughs in image processing, and were now solving all kinds of tasks that were thought to be impossible just a decade prior. Inspired by these amazing successes in image processing, researchers tried to apply CNNs to natural language processing. Natural language processing (NLP for short) refers to any task that involves automatically processing text data, such as translating between 2 languages, classifying the sentiment of a product review, or just chatting. Unfortunately, CNNs weren’t quite as good at these NLP tasks. In fact, for almost all of them CNNs were significantly worse than humans. For many tasks, CNNs were so bad as to be completely unusable. So what’s going on? Why is it that CNNs can handle images fine, but struggle so much with text? The most obvious difference between text and images, is that text is not described by numbers like an image is. And neural nets need numbers as inputs, after all you can’t apply a linear function to text! Fortunately, statisticians had already figured out how to train models on non-numerical data a hundred years ago, using a method called one-hot encoding. One-hot encoding is a technique that replaces each possible variable value with a unique vector. To one-hot encode text we first make a list of every unique word in the language, there are about 50,000 for English, and then associate each word with its position in this list. We then convert these positions to one-hot vectors, that is vectors with 0 In every component except for the given position, which has a 1. The exact arrangement of which word has which position doesn’t matter, so long as each word is consistently associated with the same unique vector. So, whenever I show a word being input to a neural net, it is actually the one-hot vector representation of that word that is being fed to the neural net. The true cause of the CNNs failure to learn from text is a bit more subtle. In a convolutional layer, a neural net is applied to each group of 3 consecutive words. This neural net does its neural magic and spits out a new vector that contains information from all of the inputs. The vectors output by the first convolutional layer will be representations of the particular sequences of 3 words that occur at each position. The next convolutional layer again combines consecutive groups of these new vectors. So vectors output by the second layer each contain information from 5 consecutive words in the original input. And so on. Each layer combines information from larger and larger groups, until the final layer contains information from the entire input. The thing is though, it’s actually really difficult to do this whole combine multiple vectors into a single vector representation thing. Neural nets can learn to do it, but only if all of their inputs are strongly related to each other. If there is only a weak relationship then neural nets tend to get confused about which information to keep. When I say related here, what I specifically mean is that knowing the value of one input should help you to better predict the value of the others. For images, this would mean that knowing the color of a pixel should allow you to better predict what color the neighbouring pixels will be: this is generally true for natural images, where nearby pixels tend to have similar colors, which is why CNNs have no problem learning from images. If you want to know more about why CNNs break down when their inputs aren’t related, my previous video goes into it in more detail. But what about text? Does knowing what word occurs at a given position help you to predict what the neighbouring words are? Well, it is true that neighbouring words in a sentence are often related to each other. But this isn’t always the case. There are plenty of times when two words on opposite ends of a sentence are the most strongly related. For example, consider this sentence “the dog spun around so fast that it caught its own tail”. In this sentence it’s clear that the “tail” belongs to the “dog” and so these 2 words are strongly related. But since it isn’t until the final layer that information from these words can be combined, in the early layers the CNN will attempt to combine them with other words in the sentence which are less related, and get itself confused. This is the problem that transformers seek to resolve: CNNs can’t handle long-range relations. And the key idea behind the transformer is actually very simple: just replace convolutional layers with… pairwise convolutional layers. In a pairwise convolutional layer, we apply a neural net to each pair of words from the sentence. Now it doesn’t matter how far apart two words are, information from every pair of words can immediately be combined in a single layer. Even if most of these pairs are unrelated to each other, it doesn’t matter so long as somewhere we have pairs of the related words. And we can repeat this operation again to combine pairs of pair-vectors together. Now each one of these resulting vectors represents a group of 4 words from the original sentence. Each time we apply this  operation we get representations  of larger and larger groups of words. At the end we can just average all of the vectors together to get the final prediction. Each one of the resulting vectors is created by combining the words from the original sentence in a different order. In fact, this whole operation  is exactly equivalent  to rearranging the original input sentence into all possible permutations of the words, then applying a regular CNN to each permutation, and then averaging the output from each. The idea is that, somewhere in this list there must be a good ordering. One where all of the related words occur right next to each other, and hence can be efficiently combined together by the CNN. Problem solved. Well, not quite. There’s one small problem with this model: It ignores the order of words in the input sentence. Consider these 2 phrases: “the fat cat” and “the cat fat”. Now clearly, these each have different meanings. But, since the set of permutations of each phrase is the same, and the model output is an average over the permutations, the model output has to be the same in each case. If we want our model to understand the meaning of natural language, then it needs to be able take into account the order in which words appear. Fortunately, there is a simple solution to this: we just need to attach the position of each word to its vector representation. As before, each word is converted into a one-hot vector, but now we also attach a one-hot encoding of the word’s position in the sentence. These position indexes go up to some pre-specified cutoff, usually in the tens of thousands. Now each vector represents both the identity of the word as well as where that word occurs in the sentence. The CNN can learn to change its output based on the available position information. Perfect! There’s just one more little problem with this model: it’s big. Very big. In this example the input sequence was only 3 words and after just 2 layers we already have 81 vectors to deal with. In every new layer, the number of vectors is squared. And for these vectors to  learn useful representations,  each one is going to need to use hundreds or, more likely, thousands of components. And this was with only 3 input words, we want to be able to run this model on long passages of text. This is going to get too big to handle very quickly. In order to make this model practical, we need to somehow reduce the number of vectors used. Ideally, for an input with n words, the model would never need to use more than n^2 vectors. For this, we would need to somehow reduce the n^2 output vectors back down to just n vectors before applying the next layer. If we could do that, then the number of vectors would never grow beyond n^2 no matter how many layers we apply. So the question is, given  n^2 vectors, representing  all pairs of input vectors, how can we reduce this down to just n vectors,  while still maintaining  as much useful information as possible. One way to do it, would be to simply sum the vectors down each column. This neatly reduces each column of n vectors into just a single vector. The downside is that we are going to lose a lot of information, since a single vector simply cannot contain all of the information from the n original vectors. To get an idea for why this is a problem, imagine what happens when you take the average of a bunch of images: you just end up with a blurry mess. The same thing happens with our vectors, when you try to cram a bunch of different vectors together, you just get noise and you lose all of the information contained in the original vectors. In order for any information to survive the summation, we basically need 1 of the original vectors to have a relatively large magnitude, and the rest of the vectors to be close to 0. This way when we do the sum, the information contained in the large vector will be preserved. So ideally, the neural net will learn to output a large vector for the most important pair, which contains the most strongly related words, and output small vectors for all other pairs. And fortunately, this is exactly the kind of thing that neural nets excel at! The only problem is that sometimes it’s impossible to tell if a pair of words is important just by looking at the pair itself. Consider the following sentence: “there was a tree on fire and it was roaring”. In this sentence, the words “fire” and “roaring” are related because it is the fire which is roaring. So we would like our neural net to output a large vector for the pair (“roaring”, “fire ”) and a zero-vector for the pair (“roaring”, “tree”). Now consider this sentence: “there was a lion on fire and it was roaring” Now it is the lion that is roaring. So we want the neural net to output a large vector for the pair (“roaring”, “lion”) and a zero vector for the pair (“roaring”, “fire”). The problem is that in one sentence the pair (“roaring”, “fire”) needs to have a large vector, and in the other it needs to have a zero vector. But the neural net sees the same input in both cases, so it has to produce the same output. Somehow, we need to be able to tune the neural net’s output based on the context of which other words are in the sentence. If we assign each pair a score, indicating how important it is, then we can compare the scores of all of the pairs in a column to pick the most important one. For example, we might assign the pair (“roaring”, “fire”) an importance of 10 and the pair (“roaring”, “lion”) an importance of 90, and all other pairs an importance of 0. Then we can get relative importance by dividing each score by the sum of all of the scores. This way, in our first sentence, (“roaring”, “fire”) has a relative importance of 1, but in the second sentence, it only has a relative importance of 0.1. If we simply multiply each pair vector by the relative importance, then the pair vector will be shrunk down to 0 only when it has a low relative importance. This means that when we sum the weighted vectors will always keep information mainly from the most important pair, while discarding information from less important pairs. Now the only question is, how do we get these importance scores? And the answer is of course, we will train a neural net to produce them. So, we’ll have 2 neural nets in each layer,  the first maps pairs to their  new vector representation, and the second maps pairs to their importance score. We then normalize across each column to get relative importance scores, and take a weighted sum of vectors in each column to produce our final n vectors. Oh, but there’s one small problem though, the output of the scoring neural net could be negative, which can introduce problems when we try to normalize. We don’t want to accidentally divide by 0! This is easily fixed though, we can just run these importance scores through some function which always outputs positive numbers, such as the exponential function. Okay, to summarize the model so far: in each layer, we grab all pairs of input vectors, run each pair through 2 different neural nets. One outputs the vector representation of the pair, and the other outputs a scalar score which, when exponentiated and normalized, gives the relative importance of this pair. Then we take a weighted sum down each column. This entire operation is known as a self-attention layer. It’s called that because it is like the model trying to decide which inputs it should pay attention to. Now the model that I’ve described thus far is not quite a standard transformer. We still need to make a few more optimizations. However, these optimizations don’t fundamentally change how the model works, they just make it more efficient. Which is important in practice, but if you only want to know why the transformer architecture works, then you really only need to understand this simplified version. And with that out of the way, let’s get optimising. We’ve managed to cap our model so that it only uses at most n^2 vectors in each layer, but n^2 is still a lot! Especially for large passages of text, containing thousands of words. And remember, each vector contains thousands of numbers. So the major bottleneck in this model is the application of our representation neural nets to the n^2 pair vectors. The first thing we are going to do is replace the representation neural nets with simple linear functions. These use a tiny amount of computation compared to full neural nets. Much better. But now we have a problem! We’ve just lost the non-linear processing power of the representation neural net, and we really do need that in order to be able to combine pair vectors effectively throughout the layers. So, after we’ve done the self-attention operation, we will apply the neural net to each of the output vectors. We would also like to do the same thing with the scoring neural nets, but the scoring function itself needs some non-linearity, it can’t wait until after the reduction, since we use the scores in the reduction process. Fortunately, the scoring function only needs to output 1 number, not an entire vector. Since this job is so much simpler, we can get away with using a really small neural net to produce scores. You can also replace the scoring neural nets with simple non-linear functions, such as a bi-linear forms. You get pretty much the same performance using small neural nets as you do with bi-linear forms, but bi-linear forms use a little bit less computation, so they are preferred in practice. Each layer now consists of 2 steps: 1) do self-attention with linear representation functions and bi-linear form scoring functions, 2) apply a neural net to each of the n resulting vectors. The crucial part is that now we only need to apply the large neural net n times, down from n^2. Also, now that we are using linear representation functions, some inputs are redundant. Since each column consists of pairs which contain the same word, and we apply the same linear function to each pair, each item in this column will have the same term added to it. Since the attention weights sum to 1, this simplifies to just adding this term to the result. So we can just do that and save ourselves a bit of computation. One last thing. Remember how when we sum n vectors together the information from each gets blurred together, and the only way information can be preserved is if one vector is large and the rest are 0? Well, sometimes we want to keep information from more than 1 vector. And there’s a very simple way to do this: Just apply the self-attention operation multiple times. This way, each of the self-attention operations can, potentially, select a different pair vector to keep. We can then just concatenate the outputs from each into 1 big vector, thereby retaining information from each. This is known as multi-head self-attention. In practice it is common to use a few dozen heads in each self-attention layer. And there we have it. This is the complete transformer architecture in all of its glory. With this you should now  have a complete understanding  of the transformer as it is implemented in ChatGPT. Well, not quite. Standard transformer implementations use a few more tricks that I haven’t covered here, such as layer normalization, residual connections, byte-pair encoding. But, these things aren’t  unique to the transformer,  they’ve been used in pretty much every neural net for the last 10 years, so I don’t feel the need to go over them here. In any case, I hope you enjoyed this video, and came away with some appreciation of the design decisions behind the transformer.
Info
Channel: Algorithmic Simplicity
Views: 128,959
Rating: undefined out of 5
Keywords: neural networks, neural nets, convolutional neural nets, CNN, transformer, convnet, nlp, text processing, transformers, Transformer Architecture, deep learning, transformers explained, attention, self-attention, attention explained
Id: kWLed8o5M2Y
Channel Id: undefined
Length: 18min 7sec (1087 seconds)
Published: Fri Aug 18 2023
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.