MLP-Mixer: An all-MLP Architecture for Vision (Machine Learning Research Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi there i'm sure you've seen this paper make the rounds it's called mlp mixer an all mlp architecture for vision it's by ilya tolstekin neil halsby alexander koleznikov and lucas buyer of google research this is not going to be a long video because the concept is pretty simple um these people did i say others or just the four names i don't remember there are a lot of authors here all of them deserve credit this paper presents a neural network that is just mlp so just feed forward multi-layer perceptrons no convolutions no attention mechanism is just matrix multiplications non-linearities normalization and i think skip connections but that's not really a layer is it um so it appears we've come full circle in computer vision going from mlps originally to convolutional neural networks some pixel rnns then vision transformers and by the way this paper is going to be much more understandable if you've read the paper on vision transformers because it's from largely the same people and does the same kind of experiments and methodologies and now we've come back to mlps turns out the thing you've tried at the very beginning you know it works after all no i'm i'm kidding so it's not just as simple as slap an mlp on to the problem and that works there is a still a very specific architecture involved right here and also i think the paper is mostly a lesson in what you can do with scale and that good architectures might be good for a particular scale and not just good by themselves so the end result here is going to be that this new architecture that the mlp mixer architecture performs adequately not state of the art not the best but adequately at large scales and it appears to benefit much more from scaling up than previous architectures which raises the question you know what happens if we go to even larger scales but i guess that's for another day or year or decade so let's just dive in uh this is the the architecture the computer vision architecture that is proposed it's a classification architecture you see this right here at the end there is like a fully connected layer and a class label and also there is a global average pooling so at the end you just collect everything you've done and you put it into a classifier and that gives you a class label so that means it's amenable to fine-tuning where you freeze the the representations that come out of the model and all of this this kind of stuff that you might already know at the beginning of the model you have a picture and like in vision transformer you're going to divide that picture up into patches so in this case you take something like 16 by 16 pixels as a patch and those become your patches down here and now you simply operate on those patches as you propagate through the network so unlike a convolutional neural network where you sort of shrink the resolution but increase the channels here we're just going to have one layer after another one layer as big as the last one stack stack stack and until the end so it is much like a transformer of course the difference between this and the transformer is in how the individual layer looks so like in the transformer first of all every patch is fed through a fully connected layer to bring it into a latent representation so this right here these right here are the latent representations there of a size that you choose as a model builder and that's going to be kind of the latent size that propagates through the network so this is done on a per patch basis and this per patch operations and you know uh in general these these sort of repeated operations are going to be the key to this architecture right here so every patch is projected using the same uh function into the latent space okay then we this is followed by n of these mixer layers now what does a mixer layer do and here is where the core comes in so in every layer you start out with you know you've just seen here we we had patches but now we have these latent embeddings like this stuff right here this essentially is one vector for every patch so every patch you unroll the patches like so and every patch gets you one vector right every patch in the image corresponds to one vector so technically this here you can interpret this as a table so that's what they do here it's just the other way around right so this this here is the lower left corner this one is the patch right next to it this one is the patch right next to that patch and so on and each patch has one two three four and so on channels each patch is described by a vector of whatever how many dimensions i guess something like 512. okay and now if you traditionally if you solve this problem and you said well i have an all mlp on all mlp architecture for vision what you would do is you would take that table and completely unroll it into one vector right so the the top patch would then be here and then the the blue patch would be next to it right this this blue patch right here and so on so you would completely unroll that that's the yellow patch into one single vector and then you would put a fully connected layer on top of that that's not what we do here we're doing much more like what we would do in a convolution except that we only have filters of size one by one so there are two different um two different in this mixer layer there are two different hash i say this modes of operation first we do the following we flip this table we transpose this table and so that means every row here is the same channel from all the patches so it's always channel one from all the patches in the image right so from all the patches i want channel one and i'm going to feed that through a fully connected layer i also take all the patches but channel two so channel two from all the patches i'm going to feed that through the same fully connected layer in fact you can see these weights are all shared right here so this is weight sharing across different channels so sorry across always across the same channel of the different patches this is much like you know one by one convolution so uh actually this one here is more like a one by one convolution but it is weight sharing okay and that means we have a picture we put it into patches and in this layer what we care about is connecting the same channel not even sure how to represent the same channel i guess you can say you you want the same type of information since this this all builds on the weight sharing of the last layer right so this fully connected layer right here it's the same for every patch so that fully collect connected layer might look at the patch and if there is something like a sharp corner in the top left uh corner of that patch it might put that into channel one so now all of the patches that have that in the top left corner like some sharp corner here will have that in their first channel okay so now if i aggregate among the same channels if i do this then if the first channel here reacts across the patches you know i can aggregate all the patches that have that feature because the feature producing map was shared okay so all of this builds on the fact that in the last layer features were shared too so here we share the projection which means that the channels in the individual patches mean similar things okay because they come from the same function and since they mean similar things we now group by those channels and aggregate or or compute over all the patches in that particular channel and since that particular channel has the same information you know that sort of lets us compute on a on a feature by feature basis now also of course these weights are shared so um since these weights are shared that means sort of on a meta level that now i'm going to perform the same computation in all of those channels which means that now i can i can do the the reverse trick again and flip the table back into patches and then do this shared computation for all the patches so ultimately i just have number one one weight matrix where i forward propagate all of the channels individually but in the same way and here i have another one so that's number two i have one forward propagation matrix where i propagate all of the patches individually but in the same way right and again since i now have done the same computation over here that means that the result here is going to be sort of distributed in the same way across patches now i aggregate this into the patch location and i forward propagate this this is much more like a one by one convolution right so we simply take a patch and we apply a computation across all of the channels of that patch and we apply the same computation and that prepares the exact same thing for the next layer i hope that makes a little bit of sense i have trouble articulating this but it it does make sense when you think about it so it there's two phases you repeat um you look you repeat two steps in this step you look at your patch and you say what kind of features are there right and you put the features into pre-defined categories so channel one is you know feature one channel two for feature two and so on and then in this step you take a look across all of the image so step or step two is here within the patch and step one is actually you look at all of the image but only in that channel that means only for that particular feature and then you look okay where in all the picture is that particular feature you do some computation across where that feature appears and how and then you go back to step number one or two however i labeled it here i hope that helps a bit the mlp is not really i didn't really say this correctly you don't have one matrix in fact it's two fully connected layers that are separated by a non-linearity um however this yeah it it's not one weight matrix it's it's two-way matrices they are shared though across channels or across patches depending on the step and that's it um that's the architecture there is as you can see layer norm you also saw this here in the diagram there is always the layer norm layer involved here is this yep and here and there are skip connections as you can see at the top but largely that's the architecture so what does this give us if again if you've seen the vision transformer paper this is or the big transfer paper all of this is extremely similar in terms of architectures what they do is they build a bunch of different sized models with different patch resolutions so this see the resolution is always the number after the slash right so here this would be 16 by 16. so obviously the lower this number the higher the the resolution where the the higher the resolution in which the model looks at the picture right now one advantage here is that uh compared to for example vision transformers is that vision transformers of course due to the attention mechanism they have a quadratic requirement of compute and memory as they go as they increase the sequence length which means as they lower this number right here their number of patches in the image increases and therefore they suffer quadratically while this model only suffers linearly from this and that is the point they make here in the experiments so the experiments is it's sort of a repeating pattern and the repeating pattern is you know if you look at the best models and let's say imagenet top one or or very good models we're not quite as good right if you know depending on so they pre-train they pre-train on large data sets and then they transfer learn or they linearly classify the frozen features and the story is always the same it's yeah you look at us we are sometimes you know even better than this but we're not we're not quite as good as this uh however we are competitive right that's the the the the core message here is that we are competitive you know competitive you know if if this had been on the market a couple of years ago this would have been state of the art by far but now the this model is it's competitive it achieves okay performance and since that's not what we like to hear in machine learning publishing uh i think that the big lesson if you want to publish something here is that find a metric where you win okay so they say you know we might not be the best ones in classification accuracy however we're okay and we have a better trade-off so there are a number of trade-offs they look at right here for example throughput you see this right here throughput images per second per core during inference this is something that's really important to practitioners to people that actually have to deploy these models right and you can see that the throughput of mixer here is way above these other models of course uh because you know convolutions here they're you know they're a difficult operation and also this this big transfer model it has a lot more layers uh i think than the the mixer or vision transformer and of course the vision transformer itself has that attention mechanism so not only does it have that quadratic requirement it also has the sort of computation of the soft max itself and so on and also uh if you look at how much you had to put into training um in this case vision transformer is actually outperforming mixer uh but in all of these tables you always have at least one metric where mixer is better you just have to select the metric so for example um [Music] you you can see that well this i like this more so here it's linear five shot imagenet top one so if i understand this correctly this is you train a linear classifier on the frozen representation of what the model gives you you evaluate it on top one accuracy but you get um it's a it's a five shot classifier okay so it's a very particular task and they look at what happens if we modify the training set size so the size that we train on and you can see that in this framing this model scales much more favorably than other models so big transfer which is good at you know low data set size all of a sudden plateaus and doesn't increase any more or much more when you scale up the data set by a significant factor however the mixer model scales really well and in fact at the end is on par almost sometimes with the vision transformer even here it's even a bit higher right and specifically it's also higher than the big transfer model what you can also see is that there is a significant gap at small training data sets however that gap also here that gap always appears to close as you go up so the gap here and here and here is way smaller and as we already said at the end very often they are on top of one another now this raises a bunch of interesting questions this is by the way it's not only this task right they show this on a bunch of tasks um that it's the this model benefits from scale a lot more um it is it has a higher throughput is a simpler architecture yeah it scales in terms of what you need to put in as compute into pre-training and so here you can see the imagenet transfer accuracy compared to how many core days on a tpu v3 you put in and you can see that the mixer and the transformer models they lie on very much similar curves leading actually leading the big transfer model so they are computationally more efficient and also here in terms of throughput you can see that for a given accuracy right mixer and transformer have higher throughputs than big transfer and for a given size of model uh mixer has a higher throughput than vision transformer though vision transformer makes up for that by being more accurate they have very very extensive evaluations to show that they are you know that this model is something i believe this model is something that if you really care about deploying it to large scales you might want to take that performance hit right uh in you know to trade off for better throughput i think that's that's fairly clear from these evaluations now it remains to be seen how this model performs in different settings for different data for different tasks and so on and when this is imagenet and imagenet after pre-training with particular data sets so here they pre-train on imagenet itself and you know if you pre-train on a small data set the the model sucks right it really trails it really trails other models you can see right here if you pre-train on a slightly larger data set it still sucks but it doesn't suck as much compared to others if you pre-train on a really big data set you can see that it only sucks a little a little bit um so you you're hard-pressed to find a number here that's higher and that's i think the point they make now the interesting question for me is is this like how does this go on as we go higher like as we go one order of magnitude higher in our data set and compute and so on is it the case that the mixer continues rising while the vision transformer sort of plateaus out which would be really interesting because you could you could then make the case that the vision transformer actually has more inductive biases than the um the mixer because both seem very general right and i would personally argue that the vision transformer is more general and has less inductive biases because here the mixer first of all the weights are fixed and second of all there's this very particular chessboard pattern to how you interact with the input data right it almost seems like um there are lots of biases here now these things these this inductive bias might be just super duper duper correct for the particular modality we're dealing with like in natural image classification or it might actually be that the mixer transfers to other domains and works really well in which case i might be wrong it also might be the case of course that both plateau in which case that would just mean uh with enough scale you can get pretty much anything to work right um so you know if you're cynic you can say well even a crap architecture like mixture you can get to work by just scaling it up and using sgd and yeah which might also be true ultimately in the limit of scale as you have the entire possibility of all images as your data set you can of course just perform a k nearest neighbor classification and you'd be correct 100 of the time i don't think we're there yet with the scale but the this sort of trend is relatively clear but it will be really interesting to see how that goes on after you know after our current limits the last thing they show here is the the weights and so they make a couple of interesting let's say um interesting observations here these are the token mixing weights so every point here corresponds to sort of one patch in the original image so this is how do you aggregate information within the same channel across different patches right and they make some observations namely for example that the weights here appear for example in pairs of negative positive so blue and red here are high and low values also in the lower layer so if i'm correct this is the the first the second and the third uh block so this this is the lower layer down here and the high layer is here you can see that in the lower layer you have rather large scale general features that are learned while as as you go higher you have much more specific interaction specific weights that you learn and this all is very reminiscent let's say of how we think or how we observe convolutional neural networks work so it's a good case here that the model learns something that it is that is sensible you can watch all of these weights i think they have it in the appendix they have um the full weights right here also pre-trained on different data sets and and this is really interesting too so if you pre-train on imagenet it looks qualitatively different than if you pre-train on imagenet 21k which is just it's a it's it's larger with more classes and that's also significantly different than if you pre-train on this jft 300m which is a super huge data set that's proprietary held by google and it's still i think that it's still unclear whether these differences are an effect of scale or an effect of how how how accurate the downstream model is so like let's say an effect of how well how much signal there is to learn independent of scale or whether it is actually just a property of the data sets being of a different nature and that would also explain why imagenet and imagenet 21k are seem to be a bit closer together visually than jft 300m no don't forget that jft is a huge data set the code is open source in fact it's right here uh you just take it also i've seen or already a bunch of people implement this so this was it for me for this paper again this is not it's not very complicated uh it's it's a very simple architecture which is exactly its selling point its selling point is it's simple and that means it can scale up really well its trade-off between compute and accuracy is really good and you should consider it if that's something that's of importance to you from a research perspective it raises a lot of questions about inductive biases how scale behaves and whether you can get anything and everything to work with sgd and a lot of tpus that's it thanks for listening i'll see you next time bye
Info
Channel: Yannic Kilcher
Views: 30,616
Rating: 4.9668508 out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, what is deep learning, deep learning tutorial, introduction to deep learning, google mixer, google ai mixer, vit, bit, mlp mixer, mlpmixer, imagenet mixer, imagenet only feedforward, no convolutions, imagenet without convolutions, image patches, attention mechanism, multilayer perceptron, transfer learning, linear classifier, state of the art, tradeoff
Id: 7K4Z8RqjWIk
Channel Id: undefined
Length: 28min 11sec (1691 seconds)
Published: Thu May 06 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.