Autoregressive Diffusion Models (Machine Learning Research Paper Explained)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi there today we'll look at other aggressive diffusion models by emil hoggerboom and others of google research this paper on a high level proposes a new type of autoregressive model specifically one where variables can be decoded in arbitrary orders this is akin to the new types of diffusion models that have been used for generative models and it essentially amounts to something like bert in sequence the training objective is made such that we can decode variables as we like and i can show you the results the results are going to be that we can for example sample pictures pixel by pixel in order to make a generative model so rather than gans which produce pictures all at once or what we had so far auto regressive models but with a fixed order from for example from left to right now we can do it in any order in addition to this they introduce techniques where you don't have to go pixel by pixel but you can do multiple pixels at the same time and speed up by a lot so this is a paper which is also community informed so this is a community informed pay-per-view which means that in on our discord server we have regular paper discussions this was one of them i tried to pay attention i i don't i can't say yet whether that has worked but i'm trying to try to recount here a little bit also so my opinions are influenced a lot by what was said at the paper discussion if you want to influence my opinion uh feel free to join our paper discussions okay so there we go they say they introduced these auto aggressive diffusion models which is a model class encompassing and generalizing order agnostic auto aggressive models um and absorbing discrete diffusion models which they show are special cases yada yada they say they're simple to implement and easy to train unlike standard other aggressive models which you might know as um lstms or standard or regressive models or gpt type transformers these are all autoregressive models they do not require causal masking of model representations and can be trained using an effective objective similar to modern probabilistic diffusion models that scales favorably to high dimensional data at test time the ardms support parallel generation which can be adapted to fit any given generation budget so you can trade off uh how long you need to produce a given sample with how with the quality so you can say i want it faster and you'll still get a sample you'll just get a like a lower quality sample we find that they require significantly fewer steps than discrete diffusion models to attain the same performance yada yada they also do lossless compression with it okay so what's the deal with autoregressive models right if i want to if i have a bunch of variables let's say i have a piece of text or something like this uh what i'd have to do is i'd you know what you'd usually do in gpt you give a prefix and then you decode a token by token from left to right right a cat and then the model has to predict sat on the and so on so you predict from left to right one by one that's also how you train right you train from left to right you predict from left to right and with text that makes kind of sense because we also read from left to right right however it would also make sense to do this in a different order so if you have a cat and you first decode let's say matt right here then if you first do that then it becomes pretty clear what's in here so in order to give the model sort of the the biggest freedom you could let it decode in other places first and then it could decode the matte here first which would sort of determine the rest of the sentence whereas on the top uh the model already sort of has to have in mind what it wants to say later like the fact that that there's matt here in order to produce all of these things here uh but in in this way uh the model could predict that first and then the rest is sort of determined so it could impute that a little bit and this all of this is just to show you that it's not the only way to decode left to right and even more so in something like image gpt so you have an image and in again i produce the whole picture as one at once but in something like image gtp uh what i do is i start at the top left and i simply start producing the pixels left to right top to bottom right that's it and there is not really a reason why this is the best order to produce things at it's simply that we train in this way and that means we have to predict in this way what the other aggressive diffusion models do is they say we're going to train a model that can produce a sample in any order it doesn't matter which one so we could start off with like this pixel then go to this and ask for this then ask for this we can even ask the model something like which one do you feel best about like which one are you most sure about and the model can tell us and then that's the one that we could we could decode further we can also tell the model to decode like three pixels at a time and then these three pixels and so on so that's the trade-off i mentioned so this is how it looks in practice what you're going to have is you're going to have a um a neural so here the vector is your sample right and usually you would decode top to bottom that's sort of the analogous to left to right that's what you usually would do however in this model you can see first it's empty so nothing is decoded yet you have your neural network you have your predictor let's say that predicts a distribution so for every single item in the sample it predicts um a distribution so these here are categorical variables uh so it's going to be predicting a distribution and so all of these for example if the ears are pixels all of them predict a color so prediction is made for the whole image and not just for the thing you want to decode and after that you decide on one of them that you actually want to decode you sample that or you take the maximum class or whatever and then you continue right then the next step so in the next step you have the same sample except that one of the values is now already decoded the other ones are still empty again you use a neural network to predict a distribution for the entire image you'll see that you know for technical reasons even this here is actually predicted it doesn't need to be but the important part is that you're going to predict the entire image at once and then you decide to again decode one of them that's you're choosing so this one and you can see that you know this how this goes on specifically which ones you deco decode is given by a by this thing right here this sigma is a variable that stands for a given permutation so what you do is if before before you sample you can select a permutation you can say here is the the order in which i want to decode and then you decode according to that but it in my mind it doesn't matter even if you decide on the fly so you can decide on the fly you know here is here's my desired order i want to decode in that way now if this is seems familiar to you if you have seen a model something like this already before then if you're thinking of bert you would be sort of correct so even the paper says that this is kind of like you take the bert model and you just kind of stack it um or you just repeat it notice the this here these are always the same neural network so the same neural network will predict every single step right here um that's why it's an autoregressive model right because you input the output into the same neural network again so what do you do in bird you have a bunch you have a sentence right a cat sat on if you do masked language modeling you put that through the neural network right that's burt and out comes one sort of output per token now what you do when you train bert you mask some of the tokens right for example this one and this one and then bert predicts these verb predicts these at once this one and this one and what you want to do sorry bird predicts these tokens at once and that's a categorical distribution that's a classification into your vocabulary right which word was masked right here so what needs to do is part needs to infer from the words that exist uh what other words could be here notice one interesting property about bird the question is of course you know why do we even have to do this in a particular order can't we just if we are already predicting all pixels at once right the network already for each pixel that's not yet there predicts a categorical distribution why can't we just sample that right and the answer is because these things are not independent so if i um if i simply if i have a bunch of variables right here let me use this one if every single one of these nodes gives me a distribution or let's say just the ones that are not just the ones that are not filled out yet right here i have two pixels or two elements that are not filled yet now i'm going to take my input vector and i'm going to use that to predict for every of one of these two pixels what's the distribution of values that could be there right so the distribution of values could be well the first the number one is really popular to not so much number three a little bit and here it could be let's say number one also popular number two a little bit number three not that much right now if if those two are independent then we could totally fill these in at the same time but they might not be right pixels typically aren't independent if they're in the same image for example right if the entire if the pixel here is blue that makes it makes it's not independent of the fact of whether the pixel you know right next to it is blue and that doesn't only count for pixels next to one another uh that counts for pixels further away of course the further they are the less dependent they probably are but still i can't just sample both independently i need to in order to sample one i need to know what the other is so i need to sample this one first and not just have the distribution i need to commit to one of the outcomes before i even try to sample the other one and by committing to one that will actually change the distribution of the other one because this here assumes that the other pixel will be according to this distribution however once it's sampled it's no longer this distribution it's actually one of these things for sure like it's maybe this one for sure if that has been sampled and that will change in turn the distribution so what i want to do is i want to put the whole thing through the neural network again in order to really get the true distribution of this node right here so maybe it's maybe it was really likely that number class number one was it but now that it sees well this other node really has chosen number one so i'm probably not number one so i am class number two maybe i hope this is re this is a bit clear that even though we can train in bert style so we can predict all the things that are missing at once what we cannot do is we cannot decode all the things at once because what some of the elements or all of the elements are dependent uh on all of the other elements and being dependent means that we they need to know what the other elements are before they themselves commit to one of the classes of their distribution and that's the whole the whole point of it the point is these models they train like bird but they decode like like auto regressive models except that the order isn't fixed the order can be any order you want and they do actually apply this to text so just so you can see that uh this how this looks so here's how it looks this is a character level language model right so the uh it starts off with a relatively empty uh empty sentence let's say so the underscores are just empty these are variables that are not chosen yet and then it's gonna fill in a bunch uh at the beginning you can see that right here and it's gonna fill in some more right so here it's gonna fill in some more you'll notice that all of the ones that existed they should still exist do they do they i'm not even sure like here the x still exists the i still exists this i still exist yeah okay so all of the ones that were there they are still there but they're just more now and then more are imputed more are imputed uh until you finally come to the fully imputed sentence and um you can see that these are actual samples from their model so on text on character level text it's not yet like super good um the sentence doesn't really make sense uh i don't think that's actually an english word it sounds english but it may not exactly be an english word a potentially unsucked proof or inject operational weapons in the game car us individual model so yeah this is it's unclear because these are the sort of the beginnings of these types of models of whether that's the case or whether it's just much much much more a much better objective to just train other aggressive from left to right because there's also trade-offs right if you predict every single thing at once in your loss function has to split between all the things that there are to predict however if you just train left to right then your loss function can focus fully on what the next token is right in the given order so you gain the ability to decode in any order you want but that has a trade-off namely a performance trade-off because the model that specializes in one particular um in one particular order will always beat you so let's go back and i think that's you know that's the the entire point i've sort of found you can simplify this relatively much by essentially saying you know this is bird training but you decode one after another and you can i'm pretty sure the way this this is you can you could take you could take the pre-trained bird checkpoints and sort of decode like this however the problem is of course these birds checkpoints they have been trained with like a fixed percentage of um tokens masked out so they usually say it's like 10 to 20 percent of tokens masked out however in order to really get these models to produce samples they also have had to have seen cases where like this case where zero sorry not zero one hundred percent of the tokens are masked right so the way you train this is you mask tokens like bert and then you predict all of them at once um so the model would have to have seen every single proportion of masked tokens so that's not what exactly what um what bird is trained for but in essence you could do it so what's the background the background is essentially that these models what they usually do is they say look uh the whole sample has a given probability i can decompose that probability due to the multiplicative rule into products or in the log space sums of probabilities and this here this part here is what the order aggressive models take they say look if i have a bunch of nodes then the probability of for example this node is conditioned on everything that's before so i can factorize this into products where every probability is conditioned on the ones before and these models they essentially go and they say well there there's no reason no particular reason why you have to factorize in this way you can in fact factorize in any order that you want and if you do that if you recognize that you can factorize in any order you want you can also say that you can also say that the you can essentially not only train in the order that you decode in you can already train for all the orders at once right so if if my chosen order is i go from here to here to here to here right once i'm at the purple node right in this particular order i would go here next but in many other orders right where i came from from here in other order i would go here next and in yet another order i could choose i would go here next and these orders i sample uniformly okay so i can reasonably assume that the next time i see the sample i'm in one of those other orderings and therefore the expectation of my loss function is just the average if i were to predict this one or this one or this one at this time and therefore if why do i have to wait for the next samples uh i can simply say right now well i'm simply going to predict all of them at the same time and then take the mean as my loss function so the mean classification error as my loss function rather than just predict the one in the order where i happen to be left to right models don't need to do that because they are always left to right so the next time they see the sample they will have to only decode the exact same next variable however these models we train them to work in arbitrary orders and therefore we might as well predict all of the orders at once and take the mean of the loss function as a loss function and there again you see the trade-off this allows us then to decode in any order we want however also there's a trade-off now only one over the number of of remaining nodes uh is the portion of the loss function that is really trained on the order that we're eventually going to have and all the others are essentially superfluous well they might help for generalization a bit but you know the you you significantly reduce loss mass on the order that you actually then care about at the end when you sample here's how you sample it's pretty simple it's what i said so you initialize x empty you sample one order as i said you don't have to commit to one at the beginning but that's how you specified you sample an order uniformly um then you go through the through the ordering through the permutation here sigma is the the permutation of the node's decode this is he's very complicated written so the they built these masks right here you can see they built these masks and essentially m is just whatever has been decoded so far n is whatever is whatever one node is to pre predict it right now uh so what you do is you build a categorical distribution um you put the masked x into your neural network built a categorical distribution so this here means you'd predict all of the nodes at once given what you've predicted so far so m times x is what you've predicted so far that goes into a neural network that's essentially the learned part of this and the neural network will output a distribution a categorical distribution for every single other node there is and what you do then is you choose the one the n you know that's the entry in the um ordering that you chose you choose the one that you want to decode and you simply augment amend the sample that you have by the one you want to decode this is written very complicated in a very complicated way so optimizing training these models isn't too hard either what you're going to do is you have a data point that i guess you sample from the data set you're going to sample one particular time step so notice here we go over all the time steps because we actually want to get a sample when we train that's much like uh transformer other regressive models actually there we can train all the time steps at once but the individual training sample is just we select one particular time step in one particular ordering right so we select an ordering and in that ordering we select the time step um and typically what you do is so you have a picture you have pixels what this amounts to is we say okay we're just gonna mask a bunch of these pixels right here we're just gonna black them out right that will correspond to some time step in some ordering so we're just going to assume we have predicted all of the ones that we haven't masked and now we're trying to predict all of the ones that we did mask right all of these ones we're going to predict at once and um yeah that will so you notice that there is no n right here the n specifies the one pixel you want to predict next but during training we simply mask out a bunch of pixels and then we predict all at once so again we have the m which is what we've predicted so far we input m times x into the neural network so the neural network will predict the distribution of every single thing that we haven't predicted so far and rather than selecting n from it we now select 1 minus m so everything that hasn't been predicted so far and then we average that and that will become our loss function okay now given that we know what the pixels are that we've masked during training we can actually compute this loss function and you know that's that's it that's how you train uh pretty simple as i said this should remind you of bird and yeah so they have several extensions to this which i just briefly wanna touch so they now they say well if we if we sort of allow a bunch of times these dependence independency mistakes so you know given that we have like i don't know a million pixels in an image right can't we just sort of assume that you know the pixel up here and maybe the pixel here they're kind of independent from each other so couldn't we just sort of sample um sample them at once so we can sample multiple pixels at once uh if they're kind of far away from each other we we're just kind of fine with that um and uh yeah so we trade off speed predicting multiple pixels at a time by we trade off speed and accuracy essentially because now the pixels that we predict at the same time they have no knowledge of the other pixels in the same time step that's the problem we've talked about before and then they go a step further and they say well rather than deciding you know we want to decode five pixels at a time instead of just one what we're going to do is we're going to give the algorithm a budget and they say look you have an entire image we have 20 steps so you need to decide this is the visualization right here you have 20 steps you need to decide do i want to go like um do i want to go so here is like one pixel then two pixels then three pixels then five pixels then the rest of the pixels right these are five time steps that's your budget you decide so they use a dynamic programming algorithm essentially they build up they go through their as far as i understand it they go through their training data set um they compute what they call loss components so here is your your budget and here is the number of nodes in the uh in the here is the number of nodes in your data points and so you can say okay for step number three if i were to decode five uh steps in step number three right how much would that cost and then you can try to find in classic dynamic programming fashion a path through this matrix and you know at the end this path is going to give you what how many pixels you should decode at what step so for example here in step one we decode two then we decode one i don't know what this actually means one no zero that makes no sense and then we decode uh the rest but you know how dynamic programming works and this isn't this is from a different paper actually but they just say you know we can use this given that we train for any order at all and predict all at the same time this is an option so you can technically trade this off what they also do is this depth upscaling and what they do in the depth of scaling is they say well you know if we're trying to predict a pixel value for a pixel right the pixel value is like 256 classes that's it's a big thing right let's not have the model so the model needs to sort of commit to one of them you know immediately like that's my pixel value what if what if we could do the following what if we could have the model just predict which half of the pixel values it's in right are you bright in the blue channel or are you not bright are you dark okay and then we do this for all the pixels so all the pixels in the image they simply first in the first iteration decide am i light or am i dark right am i light am i dark am i light am i dark and so on and then once everyone has decided on that we go over the image again and we say well okay now okay i should have filled all of them just imagine all of them filled in now they say okay now you pixel who previously decided you were light now that you see all the other pixel and their crude decision you know what sub part of the light do you fall in are you very light or just a bit light and then so we go through the image multiple times right it can even be in different orders and the advantage here is that you first let the other parts make crude decisions and then you don't have to decide out of the blue right so you you know sort of approximately what all the others are before you refine and then you refine refine refine until you get to the final choice so this is i think this is a neat idea they specify exactly you know how to do this however i can't help noticing that as you can see the ordering here by which you decode so you first predict the the crude part then the not so crude part then the not so not so crude part and finally you predict the the full part i can't help but notice that this is again a fixed order auto-aggressive model right this is this is again like this is exactly what they're trying to run away from uh so they they just introduce it again in a subpart of their model which i find to be funny right and on the on the other hand this this only works really uh this is my other problem with this this only works if this isn't really a categorical variable right pixel value pixel value is a continuous variable you can be anywhere we just discretize it right and that's why this works the you know decide on your crude and then go go um more less and less crude go more and more detailed if you have something like a true classification right um let's say into tokens of a vocabulary like a b c d e it it makes no sense to ask them well in which half of the alphabet are you the model can't do a crude decision it already needs to know to answer this question for you so unless you have a way to really split the vocabulary in meaningful fashion it this doesn't make sense this is really this is really a a workaround around the artifact that they need categorical variables for their model and therefore they discretize the the um the brightness here of the pixels and you know that's a result of that so in any case i don't want to dive too much into the results you've already seen them they do don't do large scale um as far as i can tell they do c410 generation they also do lossless compression what they can do is with their model they have a pretty good handle at the trade-off so this gives you the apple so the the user of the model a good way of trading off uh performance for speed and you can do this on the fly right you can do you can say i want less performance i want more performance i have less of a budget to infer the sample or more and you can change from from time to time and yeah these these models as i said they're young therefore they have a way to go uh we've put so much work into gans and whatnot and and other aggressive text models that the fail like the fact that these here are not state of the art yet they might it might just be an artifact of that or they might just suck who knows all right thank you so much for listening as i said join our discord to get uh in on the paper discussions they're usually very very entertaining and i'll see you next time bye
Info
Channel: Yannic Kilcher
Views: 26,394
Rating: undefined out of 5
Keywords: deep learning, machine learning, arxiv, explained, neural networks, ai, artificial intelligence, paper, diffusion models, autoregressive models, generative models, nlp, natural language processing, gpt, image-gpt, gpt-3, gpt-2, order agnostic, order agnostic diffusion, generative diffusion models, bert, autoregressive bert, bert text generation, character level language model, upscaling, dynamic programming, pixelwise sampling
Id: 2h4tRsQzipQ
Channel Id: undefined
Length: 34min 23sec (2063 seconds)
Published: Wed Nov 10 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.