You’ve probably heard of a cool little application
called ChatGPT, OpenAI’s latest iteration of Generative Pre-trained Transformers. ChatGPT is an amazing chat-bot, it's capable of understanding and replying to text messages
better then some humans that I know. And the secret behind ChatGPT’s amazing
abilities is the transformer neural network architecture (after all, it’s in the name). But have you ever wondered what a transformer
actually is? Of course you have, which is why there are
dozens of videos explaining what a transformer is on Youtube already. But all of these videos merely describe the
transformer architecture. So in this video I want to take a deeper look
at the design decisions behind the transformer and the motivations for them. Starting with a simple convolutional neural
net, I will take you on a step-by-step journey to derive the transformer. The goal is for you to come away feeling like
you could have invented the transformer yourself. Now I will be assuming that you are already
familiar with the convolutional neural net, or CNN for short, if you aren’t then I’d
recommend you first watch my previous video which explains everything you need to know
about CNNs. With that out of the way, let’s get started. The transformer was actually invented all
the way back in 2017. To set the scene, this was just after CNNs
had achieved major breakthroughs in image processing, and were now solving all kinds
of tasks that were thought to be impossible just a decade prior. Inspired by these amazing successes in image
processing, researchers tried to apply CNNs to natural language processing. Natural language processing (NLP for short)
refers to any task that involves automatically processing text data, such as translating
between 2 languages, classifying the sentiment of a product review, or just chatting. Unfortunately, CNNs weren’t quite as good
at these NLP tasks. In fact, for almost all of them CNNs were
significantly worse than humans. For many tasks, CNNs were so bad as to be
completely unusable. So what’s going on? Why is it that CNNs can handle images fine,
but struggle so much with text? The most obvious difference between text and
images, is that text is not described by numbers like an image is. And neural nets need numbers as inputs, after
all you can’t apply a linear function to text! Fortunately, statisticians had already figured
out how to train models on non-numerical data a hundred years ago, using a method called
one-hot encoding. One-hot encoding is a technique that replaces
each possible variable value with a unique vector. To one-hot encode text we first make a list
of every unique word in the language, there are about 50,000 for English, and then associate
each word with its position in this list. We then convert these positions to one-hot
vectors, that is vectors with 0 In every component except for the given position, which has a
1. The exact arrangement of which word has which
position doesn’t matter, so long as each word is consistently associated with the same
unique vector. So, whenever I show a word being input to
a neural net, it is actually the one-hot vector representation of that word that is being
fed to the neural net. The true cause of the CNNs failure to learn
from text is a bit more subtle. In a convolutional layer, a neural net is
applied to each group of 3 consecutive words. This neural net does its neural magic and
spits out a new vector that contains information from all of the inputs. The vectors output by the first convolutional
layer will be representations of the particular sequences of 3 words that occur at each position. The next convolutional layer again combines
consecutive groups of these new vectors. So vectors output by the second layer each
contain information from 5 consecutive words in the original input. And so on. Each layer combines information from larger
and larger groups, until the final layer contains information from the entire input. The thing is though, it’s actually really
difficult to do this whole combine multiple vectors into a single vector representation
thing. Neural nets can learn to do it, but only if
all of their inputs are strongly related to each other. If there is only a weak relationship then
neural nets tend to get confused about which information to keep. When I say related here, what I specifically
mean is that knowing the value of one input should help you to better predict the value
of the others. For images, this would mean that knowing the
color of a pixel should allow you to better predict what color the neighbouring pixels
will be: this is generally true for natural images, where nearby pixels tend to have similar
colors, which is why CNNs have no problem learning from images. If you want to know more about why CNNs break
down when their inputs aren’t related, my previous video goes into it in more detail. But what about text? Does knowing what word occurs at a given position
help you to predict what the neighbouring words are? Well, it is true that neighbouring words in
a sentence are often related to each other. But this isn’t always the case. There are plenty of times when two words on
opposite ends of a sentence are the most strongly related. For example, consider this sentence “the
dog spun around so fast that it caught its own tail”. In this sentence it’s clear that the “tail”
belongs to the “dog” and so these 2 words are strongly related. But since it isn’t until the final layer
that information from these words can be combined, in the early layers the CNN will attempt to
combine them with other words in the sentence which are less related, and get itself confused. This is the problem that transformers seek
to resolve: CNNs can’t handle long-range relations. And the key idea behind the transformer is
actually very simple: just replace convolutional layers with… pairwise convolutional layers. In a pairwise convolutional layer, we apply
a neural net to each pair of words from the sentence. Now it doesn’t matter how far apart two
words are, information from every pair of words can immediately be combined in a single
layer. Even if most of these pairs are unrelated
to each other, it doesn’t matter so long as somewhere we have pairs of the related
words. And we can repeat this operation again to
combine pairs of pair-vectors together. Now each one of these resulting vectors represents
a group of 4 words from the original sentence. Each time we apply this
operation we get representations of larger and larger groups of words. At the end we can just average all of the
vectors together to get the final prediction. Each one of the resulting vectors is created
by combining the words from the original sentence in a different order. In fact, this whole operation
is exactly equivalent to rearranging the original input sentence into all possible permutations of the words,
then applying a regular CNN to each permutation, and then averaging the output from each. The idea is that, somewhere in this list there
must be a good ordering. One where all of the related words occur right
next to each other, and hence can be efficiently combined together by the CNN. Problem solved. Well, not quite. There’s one small problem with this model:
It ignores the order of words in the input sentence. Consider these 2 phrases: “the fat cat”
and “the cat fat”. Now clearly, these each have different meanings. But, since the set of permutations of each
phrase is the same, and the model output is an average over the permutations, the model
output has to be the same in each case. If we want our model to understand the meaning
of natural language, then it needs to be able take into account the order in which words
appear. Fortunately, there is a simple solution to
this: we just need to attach the position of each word to its vector representation. As before, each word is converted into a one-hot
vector, but now we also attach a one-hot encoding of the word’s position in the sentence. These position indexes go up to some pre-specified
cutoff, usually in the tens of thousands. Now each vector represents both the identity
of the word as well as where that word occurs in the sentence. The CNN can learn to change its output based
on the available position information. Perfect! There’s just one more little problem with
this model: it’s big. Very big. In this example the input sequence was only
3 words and after just 2 layers we already have 81 vectors to deal with. In every new layer, the number of vectors
is squared. And for these vectors to
learn useful representations, each one is going to need to use hundreds or, more likely, thousands of components. And this was with only 3 input words, we want
to be able to run this model on long passages of text. This is going to get too big to handle very
quickly. In order to make this model practical, we
need to somehow reduce the number of vectors used. Ideally, for an input with n words, the model
would never need to use more than n^2 vectors. For this, we would need to somehow reduce
the n^2 output vectors back down to just n vectors before applying the next layer. If we could do that, then the number of vectors
would never grow beyond n^2 no matter how many layers we apply. So the question is, given
n^2 vectors, representing all pairs of input vectors, how can we reduce this down to just n vectors,
while still maintaining as much useful information as possible. One way to do it, would be to simply sum the
vectors down each column. This neatly reduces each column of n vectors
into just a single vector. The downside is that we are going to lose
a lot of information, since a single vector simply cannot contain all of the information
from the n original vectors. To get an idea for why this is a problem,
imagine what happens when you take the average of a bunch of images: you just end up with
a blurry mess. The same thing happens with our vectors, when
you try to cram a bunch of different vectors together, you just get noise and you lose
all of the information contained in the original vectors. In order for any information to survive the
summation, we basically need 1 of the original vectors to have a relatively large magnitude,
and the rest of the vectors to be close to 0. This way when we do the sum, the information
contained in the large vector will be preserved. So ideally, the neural net will learn to output
a large vector for the most important pair, which contains the most strongly related words,
and output small vectors for all other pairs. And fortunately, this is exactly the kind
of thing that neural nets excel at! The only problem is that sometimes it’s
impossible to tell if a pair of words is important just by looking at the pair itself. Consider the following sentence:
“there was a tree on fire and it was roaring”. In this sentence, the words “fire” and
“roaring” are related because it is the fire which is roaring. So we would like our neural net to output
a large vector for the pair (“roaring”, “fire ”) and a zero-vector for the pair
(“roaring”, “tree”). Now consider this sentence:
“there was a lion on fire and it was roaring” Now it is the lion that is roaring. So we want the neural net to output a large
vector for the pair (“roaring”, “lion”) and a zero vector for the pair (“roaring”,
“fire”). The problem is that in one sentence the pair
(“roaring”, “fire”) needs to have a large vector, and in the other it needs
to have a zero vector. But the neural net sees the same input in
both cases, so it has to produce the same output. Somehow, we need to be able to tune the neural
net’s output based on the context of which other words are in the sentence. If we assign each pair a score, indicating
how important it is, then we can compare the scores of all of the pairs in a column to
pick the most important one. For example, we might assign the pair (“roaring”,
“fire”) an importance of 10 and the pair (“roaring”, “lion”) an importance
of 90, and all other pairs an importance of 0. Then we can get relative importance by dividing
each score by the sum of all of the scores. This way, in our first sentence, (“roaring”,
“fire”) has a relative importance of 1, but in the second sentence, it only has a
relative importance of 0.1. If we simply multiply each pair vector by
the relative importance, then the pair vector will be shrunk down to 0 only when it has
a low relative importance. This means that when we sum the weighted vectors
will always keep information mainly from the most important pair, while discarding information
from less important pairs. Now the only question is, how do we get these
importance scores? And the answer is of course, we will train
a neural net to produce them. So, we’ll have 2 neural nets in each layer, the first maps pairs to their
new vector representation, and the second maps pairs to their importance
score. We then normalize across each column to get
relative importance scores, and take a weighted sum of vectors in each column to produce our
final n vectors. Oh, but there’s one small problem though,
the output of the scoring neural net could be negative, which can introduce problems
when we try to normalize. We don’t want to accidentally divide by
0! This is easily fixed though, we can just run
these importance scores through some function which always outputs positive numbers, such
as the exponential function. Okay, to summarize the model so far: in each
layer, we grab all pairs of input vectors, run each pair through 2 different neural nets. One outputs the vector representation of the
pair, and the other outputs a scalar score which, when exponentiated and normalized,
gives the relative importance of this pair. Then we take a weighted sum down each column. This entire operation is known as a self-attention
layer. It’s called that because it is like the
model trying to decide which inputs it should pay attention to. Now the model that I’ve described thus far
is not quite a standard transformer. We still need to make a few more optimizations. However, these optimizations don’t fundamentally
change how the model works, they just make it more efficient. Which is important in practice, but if you
only want to know why the transformer architecture works, then you really only need to understand
this simplified version. And with that out of the way, let’s get
optimising. We’ve managed to cap our model so that it
only uses at most n^2 vectors in each layer, but n^2 is still a lot! Especially for large passages of text, containing
thousands of words. And remember, each vector contains thousands
of numbers. So the major bottleneck in this model is the
application of our representation neural nets to the n^2 pair vectors. The first thing we are going to do is replace
the representation neural nets with simple linear functions. These use a tiny amount of computation compared
to full neural nets. Much better. But now we have a problem! We’ve just lost the non-linear processing
power of the representation neural net, and we really do need that in order to be able
to combine pair vectors effectively throughout the layers. So, after we’ve done the self-attention
operation, we will apply the neural net to each of the output vectors. We would also like to do the same thing with
the scoring neural nets, but the scoring function itself needs some non-linearity, it can’t
wait until after the reduction, since we use the scores in the reduction process. Fortunately, the scoring function only needs
to output 1 number, not an entire vector. Since this job is so much simpler, we can
get away with using a really small neural net to produce scores. You can also replace the scoring neural nets
with simple non-linear functions, such as a bi-linear forms. You get pretty much the same performance using
small neural nets as you do with bi-linear forms, but bi-linear forms use a little bit
less computation, so they are preferred in practice. Each layer now consists of 2 steps: 1) do
self-attention with linear representation functions and bi-linear form scoring functions,
2) apply a neural net to each of the n resulting vectors. The crucial part is that now we only need
to apply the large neural net n times, down from n^2. Also, now that we are using linear representation
functions, some inputs are redundant. Since each column consists of pairs which
contain the same word, and we apply the same linear function to each pair, each item in
this column will have the same term added to it. Since the attention weights sum to 1, this
simplifies to just adding this term to the result. So we can just do that and save ourselves
a bit of computation. One last thing. Remember how when we sum n vectors together
the information from each gets blurred together, and the only way information can be preserved
is if one vector is large and the rest are 0? Well, sometimes we want to keep information
from more than 1 vector. And there’s a very simple way to do this:
Just apply the self-attention operation multiple times. This way, each of the self-attention operations
can, potentially, select a different pair vector to keep. We can then just concatenate the outputs from
each into 1 big vector, thereby retaining information from each. This is known as multi-head self-attention. In practice it is common to use a few dozen
heads in each self-attention layer. And there we have it. This is the complete transformer architecture
in all of its glory. With this you should now
have a complete understanding of the transformer as it is implemented in ChatGPT. Well, not quite. Standard transformer implementations use a
few more tricks that I haven’t covered here, such as layer normalization, residual connections,
byte-pair encoding. But, these things aren’t
unique to the transformer, they’ve been used in pretty much every neural net for the last 10 years, so I don’t feel
the need to go over them here. In any case, I hope you enjoyed this video,
and came away with some appreciation of the design decisions behind the transformer.