Hey guys and welcome to this video! Today, I
will implement the Vision Transformer which was introduced in the paper: An Image is Worth
16x16 Words: Transformers for Image Recognition at Scale. What you see on the screen is the
official github repository for this paper, however, since I know basically nothing about
Jax I will use a different repository that has a PyTorch implementation. This repository is called
pytorch-image-models and it is also available on PyPI under the name timm. You can also see their
documentation to learn more about it. Anyway the reason why I decided to use their implementation
is that they have pre-trained weights. Note that there are other amazing implementations
on GitHub and I will link them in the description. What I will do in this video is that I
will just rewrite the following class called VisionTransformer from scratch. To verify
that my implementation is correct I will use the official implementation that you can see here
together with pre-trained weights. Before I start though let me point out that I will more or
less just copy-paste this implementation and add a couple of minor modifications/simplifications here
and there. All the credit for the implementation goes to the authors and the contributors of
this amazing library. So before we start let me give you a quick overview of the architecture.
First of all, the Vision Transformer is an image classifier. That means that it takes in an image
and then it outputs a class prediction. However, the reason why it is special is that it does
so without any convolutional layers. Instead, it uses the attention layers that are prevalent
in the natural language processing application. So the main question is ... How can one turn an
image to a sequence of one-dimensional tokens so that we can use the transformer architecture?
The idea is to split the input image into patches, flatten these patches into one-dimensional vectors
and then just use a linear mapping to create the initial embedding for each of the patches. And
honestly this diagram actually explains the entire idea amazingly well. Note hat the transformers
encoder shown on the right is virtually identical to the one proposed in the "Attention is all you
need" paper. Also let me stress that similarly to BERT we prepend the sequence with a so-called
CLS token and its goal is to capture the meaning of the entire image as a whole and it is exactly
the final CLS token's embedding after the last transformer block that we are going to be using
for classification. Finally, we will learn a position embedding that will allow the network to
determine what part of the image a specific patch came from. Again this exact idea was already used
in BERT. Let us start with the implementation. We'll start with arguably the most important
module and that is the patch embedding We will give this module an img_size
and we are kind of implicitly assuming that it's a square. However, in general it doesn't
have to be. Similarly, we provide the patch_size and again we assume that it's a square
so both the height and width are equal. This represents the number of channels of the
image. For example if you have a grayscale image it will be equal to 1, however, if we have an RGB
image it will be equal to 3 and finally we have the embedding dimension and it will determine how
big an embedding of our patch is going to be and note that this embedding dimension is going
to stay constant across the entire network. Internally, we will have two attributes. One of
them will be n_patches representing the number of patches we split our image into. The second
attribute is going to be a convolutional layer and we will use it to split the image into
patches. I know what you're thinking. I said that there wouldn't be any convolutional
layers inside of our architecture. However, this one is not your regular convolution.
I'll explain in a couple of moments. I save some of the parameters as attributes
and then I compute the number of patches and in the actual model that I'm
going to show you the img_size will be perfectly divisible by the patch_size
so we'll be covering the entire image. We define the attribute proj and you probably
noticed that it looks very suspicious. If we look at the kernel size and the stride compared to let's say regular convolutional neural networks.
What we are doing here is that we take the kernel size and we take the stride and we put both of
them equal to the patch size this way when we're sliding the kernel along the input tensor we'll
never slide it in an overlapping way and the kernel will exactly fall into these patches
that we're trying to divide our image into. The input tensor is nothing else than a
batch of images. Note that number of samples and batch size are synonymous and will be
using number of samples across the entire tutorial. Also since we are using
PyTorch the channels are actually the second or in Python terms the first dimension.
Finally image size is the height and the width of our images and it's the number that we
declared in the constructor. What we will get as an output is a three-dimensional tensor
and the second dimension represents different patches that we divided the image into and the
last dimension will be the embedding dimension. We run the input tensor through the convolutional
layer and we will get a 4 dimensional tensor. And we take the last two dimensions
that represent the grid of patches and we flatten them into a single dimension. Finally, we swapped two dimensions and arguably
at this point we implemented all the novelty of this paper and what comes next can be
more or less copied from NLP applications. Now we want to write the attention module. We provide the embed_dim and note that we
will set things up in a way that our input dimension of the tokens is going to be equal to
the output dimension. n_heads is another hyper parameter related to the attention mechanism.
This parameter will determine whether we want to include a bias in our linear projections
of the keys, queries and values. Finally, we have two different dropout probabilities note
that in this video I will only run the network in an inference mode so there is no need for a
dropout. However, it gets really important during training because it fights overfitting. Let me
give you a little bit more insight on the dropout layer so as you can see by default the dropout
module is going to be set to the training mode. Each forward pass will remove
approximately 50% of the elements and set them equal to zero.
However, to make up for this removal it will multiply the remaining elements
with the following constant: 1/(1 - p). Which in our case is 2. Let us now
set the module to the evaluation mode. We see that internally this
training boolean got set to False. As you can see in the evaluation mode the dropout
layer behaves exactly like an identity mapping. Internally, we save a scale factor and it
will be used to normalize the dot product. We will have a linear mapping that can be
actually split up into three separate ones. One for the keys, one for the values and one for
the queries. The projection is another linear layer and it is the last step of the attention
mechanism. Finally, we have dropout layers. Here we define the dimensionality for
each of the heads. The reason why we set it up in this way is that once we
concatenate all the attention heads we'll get a new tensor that will have
the same dimensionality as the input. The scale is coming from the "Attention is
all you need" paper and the idea behind it is not to feed extremely large values into the
softmax which could lead into small gradients. Here we create a linear mapping
that is going to take in a token embedding and generate a query, key
and a value. Note that you could also write three separate linear mappings that
are more or less doing the same thing. Here we define two dropout layers and we
create a linear mapping that takes the concatenated heads and maps them into a new space. What's really important about the forward pass
is that the input and the output tensors are going to have the same shape. Note
that the second dimension is going to have a size of number of patches plus
one and the reason why I include the plus one is that we will always have the class
token as the first token in the sequence. Here we just check whether the embedding dimension
of the input tensor is the same as the one we declared in the constructor. Note that I could be
probably writing way more of these sanity checks, however, yeah I just decided that this one is kind
of important. Here we take the input tensor and we turn it into the queries, keys and values. Before
I continue i will take a small side step and explain how the linear layer behaves when we have
three dimensional or more dimensional tensors. The most common way how to use the linear
layer is to give it a two-dimensional input. The first dimension represents the
samples or the batch and the second one is equal to the input features
we declared in the constructor. As you can see for each sample the linear layer
simply took the input features and mapped them into the output features. However, you can use
the linear layer on tensors of arbitrary dimension higher than two as well and in that case the
only thing you need to make sure of is that the input tensor's last dimension is equal to the
input features you declare them the constructor. As you can see, the output tensor was
just created by applying the linear layer across all the samples and across
the entire second dimension. I created a tensor with seven dimensions
and not surprisingly the behavior is going to stay the same. In our implementation
we actually applied the linear layer to a three-dimensional tensor and
this will be the final dimension. In the reshape step we create an extra
dimension for the heads and also we create an extra dimension for the key query and value
and in the permute step we change their order. Given the previous permutation now it's very
easy to extract the keys values and queries. Here we transpose our keys because we're
getting ready to compute the dot product. We compute the dot product and we use the scale
factor. Note that the matrix multiplication of the two respective tensors is going to work out
because the last two dimensions are compatible. Here we apply a softmax over the
last dimension and the reason is that we want to create a discrete
probability distribution that sums up to 1 and can be used as
weights in a weighted average. We compute a weighted average of all the values. Here we just swap two dimensions and I just
realized that Iforgot the attention dropout. Finally, we flatten the last two
dimensions. In other words the last two operations concatenated the
attention heads and note that we end up with a three-dimensional tensor that
has exactly the dimensions that we want. Finally, we do the last linear projection and
we follow it up by a dropout and we're done. Let us now implement the multi-layer perceptron. This multi-layer perceptron is
going to have one hidden layer and there's nothing special about it.
Maybe one interesting thing is that we are going to be using the Gaussian
error linear unit activation function. We just simply instantiate all the layers. Similarly to the attention block we are going to be applying the linear mapping
to a three-dimensional tensor. We apply the activation, then a dropout,
then the second linear layer and finally another dropout. Note that none of these
operations are changing the shape of the tensor. All right we have everything we need and now
it's time to start putting things together. We've seen most of these parameters
before, however, this mlp_ratio is a new one and it determines the hidden
layer size of the multi-layer perceptron. when it comes to the attributes we will have
two normalization layers one, attention module and one multi-layer perceptron module. We
instantiate the first layer normalization module and we set the epsilon equal to 10^(-6). If
you're wondering why it's because we just want to match the pre-trained model. Anyway, let me now
show you the basic properties of the LayerNorm. I created a tensor with three samples and
two features and I instantiate the LayerNorm. However, I set elementwise_affine equal to False
this way there will be no learnable parameters. Let us now compute the mean and the standard
deviation for each sample of our input. The layer norm will use these to normalize
the data and it will do this for each sample. As you can see the layer norm made sure
that the mean and the standard deviation is 0 and 1 respectively for each sample let me just point out that this process is
independent for different samples. In other words, the batch size doesn't really play any role.
Let me now reinstantiate the module, however, this time I will set the elementwise_affine
equal to True and that is actually the default. As you can see now we have 4 learnable parameters. They are actually contained in the bias and
the weight parameter of the module. They represent the new per feature mean and standard
deviation. That will be used to rescale the data.
if we run the forward pass it seems like
nothing changed if we compare it to the elementwise_affine=False, however, this
time around we actually did two things. First we applied the normalization as
before and then we used our learnable parameters to rescale the data. The
parameters are initialized in a way that make it seem as if the second step
never happened. The parameters would get learned during training or for the purposes of
this example I can just manually change them. As you can see after updating the parameters
the forward pass returns different tensors. So now it's clear that the second rescaling
step is actually taking place. Finally, let me just point out that our input
tensor can have an arbitrary number of dimensions as long as the last dimension
is equal to the number of features. However, the logic stays the same and it is always
the last dimension that is being normalized. We're back in the implementation and
we continue with the attention layer. We define the absolute value of the
hidden features and it's nothing else than the dimensionality times the mlp_ratio. As you can see the input features and output
features are going to be the same number so we probably could have simplified the multi-layer
perceptron class a little bit but whatever. This entire block has again the property that the input tensor and the output tensors
are going to have the same shape. Here we create a residual block so we take
the original input tensor and we add to it a new tensor and this new tensor is created
by applying the layer norm and the attention. And the second tensor is created by applying the
second layer normalization and the multi-layer perceptron. Note that we are using two
separate layer normalization modules and it's really important because both of
them will have their separate parameters. Now we're finally ready to put everything
together and write the Vision Transformer. We've seen most of these parameters
before one that's new is the depth and it will represent the
number of transformer blocks. We'll instantiate the patch embedding as the very
first layer of our network and then we will have two parameters. The first one is the CLS token and
it will represent the first token in the sequence and will be always prepended to the patch
embeddings. The second parameter is going to be the positional embedding and we include it
to encode the information about where exactly that given patch is located in the image.
We also create a module list that will hold all the Block modules. Patch embedding
is going to be the very first layer. We define the class token parameter
and we initialize it with zeros note that the first two dimensions
are there just for convenience. We add the positional embedding parameter and
its goal is to determine where exactly a given patch is in the image. We also want to learn
the positional encoding for the class token and that's why there is the plus one and again the
first dimension is there just for convenience. Here we iteratively create the transformer
encoder. Note that the hyper parameters of each block are the same, however, each of the
blocks will have its own learnable parameters. We add a normalization layer and we also create a linear mapping that is going to input
the final embedding of the CLS token and output logits over all the classes.
Let us now write the forward pass. The input tensor is nothing
else than a batch of images and the output tensor will represent
the logits for each sample. We take our input images and we
turn them into patch embeddings. We take the learnable class token and we
just replicate it over the sample dimension. Then we just take it and prepend
it to the patch embeddings. Then we add the positional embeddings
that we learned note that PyTorch will take care of the broadcasting. We apply a dropout. Here we iteratively run through all the blocks
of our transform encoder. We apply layer normalization. Out of all the patch embeddings
that we have we only select the CLS embedding. And it is exactly this embedding that continues
to the classifier. So in a way we're hoping that this embedding encodes the meaning of
the entire image because we threw away all the other patch embeddings.
And at this point we're done. Let us now write a script that will
verify that our implementation is correct. This helper function will count
the number of learnable parameters. Here we take two tensors and we
compare whether they are equal. Here we take the timm package that corresponds
to the GitHub repository I described at the beginning of the video and we load one of
the pre-trained vision transformer models. Here I defined hyper parameters that are
corresponding to the pre-trained model. We instantiate our custom model that we just
implemented and we set it to the evaluation mode. Here we iterate through all the parameters of
the official network and our custom network. First of all, for each parameter we check whether
the number of elements is equal. Note that we are making an assumption that the order in which the
modules were instantiated in the constructor is the same for the official and the custom
model however I made sure it is the case. We just take our custom parameter and we redefine it to be equal to the parameter
of the official implementation. Here we just double check
that the assignment worked. Here we create a random tensor that has the
right shape that our networks are expecting and then we just run the forward pass both for
the custom network and for the official one. First of all, we check whether the number
of trainable parameters is the same for both of the networks and then we take the two output
tensors and make sure they are identical. Finally, if we make it through all the checks we just
save our model that contains the correct weights. Let me first show you the
version of timm i'm using First of all, it downloads the checkpoint. So the script was run and there was no
assertion error and as you can see we have this model path checkpoint
lying in our working directory. Let me reiterate. The only reason why this worked
was that I was very careful about when and where exactly I define each of these layers in the
module constructors one way to break this is to just swap the order so here for example I will
just swap the class token and position embedding. As you can see we're getting an assertion
error which means that one of the checks failed. Now the only thing left to do is to
run the forward pass on a real image instead of a random tensor. Because I don't want you to
feel like I click baited you we will use the cat in the thumbnail. As you can see it has
exactly the right dimensions that we need. Also I have this text file that
contains all the ImageNet classes. I loaded the classes, I loaded the
model and I also loaded the image and pre-processed it. We run the forward
pass and we turned the logits into probabilities. It seems like the model is making the right
predictions. Anyway, that's it for the video. I hope you found it interesting and I also hope
i did not butcher the original implementation too much. As mentioned at the beginning of
this video all the credit goes to the authors and the contributors of the timm library
and I more than encourage you to check it out and give it a try. Please let me know in the
comments what you liked and what you didn't like also if you have any suggestions for future
topics I would be more than happy to hear them. Currently I'm just trying out different things
and then using youtube statistics to guess what interests you and what doesn't. So again do
not hesitate to directly share your feedback. I would more than appreciate it. Anyway,
have a nice day and see you next time!!!