Stable Diffusion from Scratch in PyTorch | Unconditional Latent Diffusion Models

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi everyone in this video we'll be building stable diffusion we'll work our way towards stable diffusion model in two parts in first part which is this video we'll be understanding and building latent diffusion models specifically unconditional latent diffusion models we'll get into all the components present inside it and why each of them are needed understand how we are able to get higher quality outputs we'll then go through the implementation and how all of it looks in code and the results that we get on training it on mnist as well as a data set consisting of high quality images of faces in the next part we'll first transition our ldm to a conditional ldm where we'll condition on images text as well as classes get into classify free guidance and finish with seeing how all of this comes together to create stable diffusion so let's go I've already covered the theory and math behind diffusion models and implementation of ddpm in the previous two videos on diffusion models on this channel so I would suggest to check them out prior to this video in case you aren't already aware about them and also because I would be reusing a lot of implementation that I did for ddpm but here's a quick recap of the major details diffusion models have a forward process in which we take an image and add small amount of noise to it and repeating this process over large number of steps gives us a complete r random noise in order to do this we have a schedule which decides how much noise we are going to add at each time step and how much of the original information present in the image we are going to destroy we then train a model to do the reverse that is given a noisy image at time step T predict what will be a slightly denoised version of this image which is the image at time step T minus one and if we are successful in doing this then we can generate an image starting from a random noise by denoising it step by step for training these models we take an image a random noise sample from normal distribution and a Time step T we had seen that using the properties of the forward process we could straight away compute a noisy image at time step T from the original image using this equation where we need Alpha which was 1 minus noise variance at time step T and cumulative product of these Alphas then after simplifying the math we reach to the conclusion that learning to predict slightly denoised image at T minus1 from a noisier image is equivalent to training the model to predict the original noise so we built a model with unit architecture which given a noisy image will predict the original noise added and trained it using msse between predicted and actual ground truth noise then for inference we start with a random noise sample Denise it a bit to find find the image at tus1 using our prediction of Noise by the train model then again Den noise it to find the image at tus2 and go all the way till zero to get our generated image and in our ddpm implementation video we trained it on images of 28 cross 28 for mnist and images of size 56 CR 56 from a data set of texture images and we were able to ultimately get the diffusion model to generate images that indeed look like digits and textures so given that diffusion models work fine why do we need this latent diffusion models well if we are training it on smaller images like 28 CR 28 or even on 56 CR 56 when the data set does not have high diversity we are okay but as we keep increasing the resolution of images or the diversity of Target data our training costs increase exponentially for example authors of ddpm mentioned that training on CFR data set with effectively 8 v00 GPU takes about 11 hours to train in another paper the authors mentioned that training on a data set of bedroom images of 64 64 and 256 256 till convergence takes insane amount of time and this is still with multiple gpus which means we clearly need something different to be able to even think about training diffusion models on larger resolution images on a single GPU the authors of this paper which is the paper that introduced latent diffusion models propose the following to reduce this cost instead of training diffusion models in pixel space they train diffusion models on latent space of pre-trained Auto encoders and this enabled them to reduce the training cost while maintaining higher quality of generated samples so effectively we can think of latent diffusion models as first training an autoencoder where the encoder takes the image in pixel space and generates a latent space representation of this image at a lower resolution the decoder then takes this latent space image and generates the image of the original resolution and the entire autoencoder is trained using L1 or L2 reconstruction laws the authors experiment with both VA and vqv as the auto encoder then training diffusion models in this latent space allows us to take a noise sample and generate an image in the latent space which can then be fed to the train decoder to generate an image at a higher resolution unfortunately simply training an auto encoder on this reconstruction loss will not lead us to generate highquality images we'll soon see why but first let's quickly review both the auto encoders in vae the encoder given an input image image generates a mean and variance of the encoder distribution of this image in the latent space the decoder then receives a sample from this distribution and is responsible to generate the original image the entire Vie is trained using reconstruction laws and K Divergence laws for ensuring that the encoder distribution are as close to Prior normal distribution as possible which also allows smooth transitions in the latent space in vqv rather than having a continuous Laten space like VA we have a discrete Laten space so we have a code book of size K which means we have set of K discrete latent space vectors and we take the encoder output feature map and replace each grid cell with its nearest codebook Vector the decoder then takes this quantised encoder output and learns to reconstruct the original image from it we qva is trained using reconstruction loss again L1 or L2 and we also have a codebook and commitment loss which are meant to bring the codebook vectors and the encoder feature map vectors assigned to them closer to each other feel free to look at the linked videos of these autoencoders in case you need further details to better understand them for our autoencoder choice we'll pick vqv but it really doesn't matter because the only difference between them is in the implementation so everywhere that I say vqv you can assume it to be VA as well now coming back to why reconstruction loss isn't going to be enough when we train vqv using the L1 or L2 loss while the Reconstruction image matches the content of original image the quality of the image however is not the same specifically the reconstructions are blurry and high frequency information ends up missing from the reconstruction in fact this is exactly what we did in our di video and the reason is that even between images where a human would see perceptual difference say here because of the blurriness L2 loss would still be of a very low value so if we train diffusion models on autoencoders trained just using reconstruction laws our generated images will also be blurry hence while we would have reduced the computation requirement we would have also compromised on the image quality so we need to train our autoencoders with some other term in the loss as well which caters to these perceptual differences which L2 won't cater to the authors add perceptual loss for this so what exactly is this perceptual loss term the features of a deep convolutional Network though trained on a classification task say on imet end up being very useful as a representation for a wide variety of tasks for example take the case of a style transfer task where the goal is to take a Content image and a style image and generate an image which has content similar to one and style similar to the other to do this we use a pre-trained network say V and for now forget about style and only focus on matching the content in order to generate such an image we ensure that the distance between feature maps of these two images at different layers of a pre-trained frozen network is as minimum as possible so that essentially means that the features of a pre-trained network can be used to get some sense of image similarity or dissimilarity this measure of distance in the feature space of a deep network is the perceptual loss so that's all great but L2 reconstruction loss Will any way ensure that the content matches so what's the point of adding this well the authors of this paper found that when comparing against human Judgment of which two images are more similar to a reference image decisions made using features of a pre-trained supervised Network agree more with human judgments than traditional metrics specifically they are better for evaluating perceptual similarity than L1 or L2 metric like you can see for the last set of images between a slightly distorted but still very similar patch zero and a very blurry patch one L2 is lesser for the blurry image but using features of a pre-trained network we get to the right decision that pth zero is indeed most similar we want to penalize our autoencoder when it generates an image that has lower L2 but is still perceptually very dissimilar while the pre-trained vgg features are also decent as far as perceptual distance is concerned the authors further improve the network capability to measure this perceptual distance they collect a data set of triplet of reference image patch one and Patch two and train a model which given a pair outputs a measure of perceptual distance between two images the way they train this model is they take the feature maps of a pre-trained network for reference image and say patch one then for each of the feature Maps find square of differences between them post that they add a one cross one convolutional layer with output Channel as one so this will give us a single Channel feature map these values are then spatially averaged to get a distance metric for the lth layer and each layer metric is again aggregated to get a single value D1 doing the same between reference image and Patch two will give D2 they then train a very small Network on top of these distance measures to train this entire model on the collected data set in one setting which we will also use the pre-trained network is frozen and only the one cross one convolutional layers are trained so effectively we are collapsing the channels and we are learning that for each layer how important each Channel's difference is with respect to the goal of evaluating perceptual similarity the way we use this network is that we will feed the input image and the reconstructed image of the auto encoder as a pair of image to this network it will then give us a measure D of perceptual distance so for two images that are same D will be zero and higher the D is higher is the perceptual difference this metric is called L Pips and is exactly the perceptual loss component used in latent diffusion models and for training the auto encoder we add this output by the the network to the loss as well the autoencoder together with reducing L2 loss is supposed to generate reconstruction that lead to lower L Pips and hence perceptually more similar and obviously the L Pips network is going to be frozen here is an example of L Pips between original image and their blurry versions now we just talked using blurriness as an example but the network is trained with a data set with a lot of different distor s so adding the perceptual loss will ensure that the generated reconstructions are without all of these I would encourage you to read the paper if you want to know more details about it let's quickly look at the lpip model code as well so this is basically copied from the official repo provided by the authors of lpip paper the first thing we have is the vgt class and we load pre-trained vgg from torch Vision the entire vgg network is divided into five stages and these slices are just the features present in each stage we freeze the vgg parameters then in the forward method we simply return the feature maps at all of these five stages this L Pips is the main model class the scaling layer here is only meant to achieve imaginate normalization so L Pips requires images to be from minus1 to 1 so this layer just ensures that imaginate normalization is achieved now you can see that standard deviations are exactly half and the means are also related as new mean will be 2 * the imaginate mean minus one okay so after that we use our pre-trained vgg model and these are the channels in the final feature maps of those five [Music] stages then we add these linear layers which is exactly our one cross one convolution which will collapse the channels and hence the output channels are one we load the weights and this weight file will only have one cross one convolution layer weights in the forward method we first scale the images to be from minus1 to 1 if needed and then use scaling layer to normalize we get the vgg feature maps for both the images then we compute the square difference between them and this we do for all the finer feature map of those five slices post that like we saw one cross one convolution to collapse the channel and then spatially average we aggregate the measure for all the layers and here the aggregation is simply adding them because the forward method accepts two images as input we can simply feed the input image and the reconstructed image and the network will return the perceptual distance between them so remember L Pips is just vgg plus linear layers indicating how important each Channel's difference are as far as perceptual distance is concerned for ldms apart from perceptual loss we also have an adversarial loss we can think of the entire autoencoder as a generator which is generating these imit samples or reconstructions generated reconstructions can be looked as fake images images and input images can be looked as real samples and then we can have a discriminator which will perform the responsibility of identifying which is real and which is fake or which is the input image and which is the Reconstruction generator's goal will be to try and fool the discriminator by making generated samples or reconstructions as close to input samples as possible and the discriminator will try to become better at classifying input and reconstructions we have already seen the math behind this setup in our Gan videos and we saw how over time the generator keeps generating better and better samples or here reconstructions let's try to understand the benefit of this adversarial loss obviously it penalizes the autoencoder when it generates blurry reconstructions as the discriminator can use that aspect to differentiate between input samples and reconstruction Even in our video on DC Gan we saw that the generated digits were of very high quality and in comparison in our vqv video our reconstructions were a little blurry so that's the image Clarity benefit but there's also another benefit the decoder is going to have to generate larger images from a compressed version so it has to learn to generate these high frequency textures that is missing in the latent image and presence of discriminator ensures that it's learning to generate the right things in order to disable the discriminator from distinguishing which is which let's take a hypothetical scenario assume our data set is images of faces and our reconstruction gets everything right except the eyes now think about the Reconstruction and perceptual laws given that almost entire image is exactly the same these would be of a very low value but a discriminator could potentially latch on this distinguishing factor between input images and reconstruction and use that to classify reconstructions as fake and it's the presence of this adversarial loss which will lead our autoencoder to learn to generate better samples or reconstruction leaving no distinguishing factor to be picked up by the discriminator obviously this was a very hypothetical example but hopefully it gave the intuition as well so ultimately we train our Auto encoder with these three losses now let's look at the architecture of the auto encoder for this I've tried to mostly mimic the architecture used in hugging phas diffusers which is also very similar to the official stable diffusion implementation the building blocks for this will be exactly same as what we did for ddpm which were these reset and self attention blocks our reset block was couple of Norm activation convolution layers with residual connection and self attention block was normalization and self attention again with residual connection for diffusion we also had to add time step information which we added between the two coners of reset our ddpm unit had down blocks which were layers of reset block block followed by self attention and a down sampling at the end the output of last down block was then passed to Mid block this was one resonet block followed by layers of self attention and resonet block mid block output was then passed to up block up block first upsample the input followed by concatenation with down block output at similar resolution and then we again had layers of reset Block Plus self attention and we added the time step information after projecting the positional embeddings through FC layers to all the resonant blocks like we saw this was added between the two con blares we'll create our Auto encoder using the exact same blocks with minor changes first obviously we don't need any sort of time step information and we'll also not be doing any concatenation of down block output with up block looking at encoder first our encoder will be a series of down blocks which will down sample the input to a resolution at which we desire to train our diffusion model so if we have images of height and WID 256 and we want to run diffusion on say 32 cross 32 then we'll have three down blocks since the images that down blocks will deal with will be of higher resolution we will not have any attention in these down blocks after after down blocks the encoder will also have a mid block we will keep the attention at this block because down blocks would have already reduced the resolution our decoder will first have layers of mid block and then it will be followed by layers of up block again in up block we'll remove the attention in the encoder prior to the first down block we add a convolutional layer which takes the input from three channels to the number of channels that the first down block expects and after the last mid block we add normalization activation and convolutional layers to take it from the number of channels in the last mid block to the latent dimension on the decoder side as well we have a convolution layer taking input from latent Dimension channels to the channels that the first midblock expects and after last up block again normalization activation and convolution to have output channels as three or whatever our images have up till now we had components which were exactly the same no matter what Auto encoder we choose but VA and vqv will have certain differences in the implementation of encoder so let's see that for vae instead of our encoder generating an output with latent Dimension channels it will have twice the latent Dimension channels we'll split it into two chunks along the channel dimension and have first latent Dimension channels as the mean and second sort of latent Dimension channel will act as variance rather log variance then we'll use the reparameterization trick to generate a sample Z and this sample Z is what is fed to the decoder for vqv we'll quantise the encoder output we'll have latent Dimension as the number of channels but we'll also have a code book which is an embedding of of K vectors each having the same Dimension as the latent Dimension we replace each of H cross W latent Dimension vectors present in the encoder output feature map with the nearest codebook Vector this gives us the quantized encoder output this quantized incoder output is what is fed to the vqv decoder let's see how all of this looks in code so as I mention mention this will use the same down mid and up block which we created for ddpm but let's take a quick look at them anyway first down block which was layers of resonet and self attention and the resonet block had two convolution layers with time information in between and a residual connection from input of first to the output of second this is the first set of corn blares of that resonant block and this is the second one the time embedding layers are only added if needed these conditions are basically added to allow the same down block being used in diffusion where time step information will be passed as well as Auto encoder where time step won't be we then have normalization and self attention layers again behind a condition because for down blocks of autoencoder we don't want the attention this is the one cross one convolution for the residual connection while it's only a must if the input channel is different from the output Channel but I've used it for all the residual connections and then our down sampling the forward method simply goes through these layers of reset and self attention and down sampling in the end Med block has the same initialization so the only thing different is that we have first one reset block and then we have layers of self attention and reset up block is also same as down block same reset and same self attention only difference is instead of down sampling in the end we have upsampling in the beginning also for diffusion unit we concatenate ated current up block feature map with down block output of same resolution but for our Auto encoder we can't do this because decoder can't actually use the encoder output so this check is just for that purpose and after that reset and self attention with again a condition as for auto encoders up block we don't want the attention in order to reduce the computation because otherwise we would have to deal with self attention between something like 256 cross 256 spatial elements now let's see our vqv class we first specify the down channels and mid channels and the way the code is written is that the encoder uses these channels and the decoder uses the same channels but in Reverse this is the case for both down block as well as mid block the down sampling parameter in casee we need to avoid down sampling but by default we'll down sample in all the down blocks and again up block will do the reverse so if only the first two down blocks down sample then only the last two up blocks will actually up sample for each block we use two layers which means for example a down block we'll have two layers of reset and self attention we use group norm and this is just the channel parameter for that we don't want to do attention so this is just the specification for the down block and the same is used for a block number of heads in attention used is four so for vqv we need to mention the latent Dimension which we have used as three but also codebook size as in how many discrete latent vectors will be present in the code book so for the highquality celb faces data set we initialize that to be 819 to in the repo all of this is configurable so you can change all of this based on your data set and whatever compute that is available to you and as I had said upsampling order is exactly the reverse of downsampling order now let's get into the encoder remember our encoder was this first a con player to take the input to the number of channels that the first down block expects and then we have layers of down block and mid block and finally we have convolutions to have the output channels as our latent Dimension this conversion is done using two convolutions instead of one we have a 3 cross 3 followed by a 1 cross 1 this is our code book which is 8192 cross 3 here our decoder again has a convolution which takes the input from latent Dimension channels to what the mid block expects and then layers of mid block and up block and this final convolution here generates the Reconstruction image our forward method will first encode the input and then decode it and we'll finally return the Reconstruction we'll soon get to see what Z and these losses are but let's look at the encode method first which just calls the down block and mid block [Music] layers and at the end it calls this quantise method as I had mentioned for vqv we replace each encoder feature map output with its nearest codebook Vector so this is just what we are doing here finding the nearest codebook Vector for all the H cross W latent Dimension vectors and then we replace those with the nearest ones this section here is actually needed because we used argman which made the entire operation non-differentiable so for backrop we use straight through estimation where this line in the forward is just quore out equals quore out but in backward this becomes quore out equals x so the gradients from loss to the quantized output are directly passed to the non-quantized encoder feature map but because of this straight through estimation the codebook won't get any gradients so we have this codebook and commitment loss to bring the codebook vectors and the encoder outputs assigned to them closer to each other similar to how you would do clustering I've talked about this in much detail in my vqv video so please check it out if you're having problems in understanding what exactly is happening here and why we are doing all of this the decoder then simply goes through mid block and up block layers and output convolutions to generate the Reconstruction VA is exactly same as vqv in terms of architecture at least but like we saw earlier instead of the encoder output being latent Dimension it outputs twice the latent Dimension channels and then in incode we extract the mean and log variance and use reparameterization to generate a sample which is then fed to the decoder rest all is exactly the same now we get into the training code for a Auto encoder so for training this Auto encoder we need three models for the three losses we have already seen vqv and L Pips but for the adversarial loss we have a generator our Auto encoder but we also need a discriminator let's look at that this is what our discriminator looks like we'll use a patch Gan discriminator which is same as our DC Gan discriminator that we implemented in the last video but with a minor change instead of having one scalar output indicating the probability that the discriminator thinks this image is real now we output a score for patches which indicate whether the discriminator thinks that that patch is real or fake so rather than your ground truth being one or zero it will be of the same shape as the discriminator output but each cell will be one for real images or input images and zero for fake images or reconstructions and discriminator will say How likely it thinks the region of the image that that grid cell corresponds to is real and we use convolutional layers and take the parameters as argument and for our case we have four convolutional layers and the con layers are followed by batchnorm and activation for which we use leaky relu so exactly like the DC Gan setup that we had in terms of layers except now we can handle images of any size okay back to the training code we create our vqv model instantiate the L Pips model and it will obviously be Frozen and instantiate the discriminator here we are just creating the data set class so for the Reconstruction loss we'll use MSC and for the Gan part we have used BC with logits but you can even use MSE like least Square Gan that also works similar to the Gan setup we have different optimizers for generator and discriminator where our generator is the entire Auto encoder this disore start is to specify after how many steps the discriminator kicks in this is something that is present in the official repo as well and I also found empirically that the entire training is more more stable if the discriminator is not kicked off immediately rather we first reach a point where the autoencoder creates the best reconstructions it can but those will be blurred and then the discriminator kicks in and improves the quality of reconstructions but this is something that you'll have to experiment on your data set so this disore start is just to have the adversarial loss component added only after certain steps then we have the generator part here we call the auto encoder and get the Reconstruction compute our L2 reconstruction loss to this we add the codebook and commitment loss and if this was vae we would add the K Divergence loss the adversarial loss is only added if this condition matches and for optimizing the generator we want the discriminator to believe that the reconstructions are real so all patches will have one as the ground truth then we call L Pips get the perceptual distance and add that to loss as well here I've just added the default weights for these losses but in the repo you can configure the weight of each loss based on your experiments for optimizing discriminator we want to classify the reconstructions as fake so all patches will have ground truth as zero that is it for training our Auto encoder for latent diffusion models once we train this on mnist and cup data set with images of size 64 cross 64 we see that first the output reconstructions are definitely not high quality but once discriminator kicks in and the adversarial game Begins the reconstructions significantly improve and we ultimately get decent output I also trained it on on cell data set with images of 256 cross 256 so here we have a compression factor of 8 and I only trained it till the time I started getting decent results but here the generator has a greater responsibility because of the higher compression Factor the reconstructions still need to improve and the auto encoder needs to be trained further but even with a factor of eight we don't get blurry images here is what the quantized output of the encoder looks like feel free to pause and see the level of detail that is present in the encoder output and what the decoder receives from which it has to generate a reconstruction by the way as a side note all this that we implemented vqv plus L Pips plus adversarial loss is actually introduced in VQ Gan in VQ Gan we do image generation through a Transformer so once auto encoder is trained we train a Transformer using sequence of discrete latent tokens these are nothing but sequence of tokens obtained by passing images in our data set to the vqv encoder then during inference the Transformer generates a plausible sequence of MH tokens which are then passed to the vqv decoder but in our case the role of Transformer will be performed by the diffusion model so let's move to the diffusion part the diffusion part is almost exactly the same as our ddpm video in fact we use the same architecture with down blocks mid blocks and up blocks and have time embedding information added to the resonet blocks let's quickly take a peek at how our unit model for diffusion looked in code we instantiated the same down blocks mid blocks and up blocks just that this time we have self attention in all the blocks and in the forward method we kept saving the down block outputs of each layer then call the mid block layers and then while calling up block we pass the down block output to have them concatenated with the current uplock feature map and finally have a convolutional layer with output Channel same as the input image the only thing that will change now now is that our image channels will not really be one for grayscale or three for RGB rather it'll be the latent [Music] Dimension this is our ddpm training code we create a noise Schuler which had the responsibility of adding noise and generating a slightly denoised image of the previous time step using the noise prediction by our unit model we have the data set created then since our latent Dimension was three so we create unit with that as the input channels this is the new part that we need to add where we create the VA or vqv and load it with the trained weights and we freeze the vi parameters finally the aoch loop and as of now this is exactly same as ddpm we take images from data set sample some time stamp STS sample noise and have schedular create noisy images and we then train unit to predict original noise using msse the only difference for latent diffusion is that these two lines will be uncommented so instead of doing everything on input images we do everything on encoder output of va so you can see that from a diffusion perspective nothing changes other than the fact that instead of pixel space it's trained on latent space and once you train it you can use the same sampling code as we used in ddpm just that after getting the complete denoised latent image feed it to the decoder to generate the image in pixel space here is a sample output of training this diffusion model on mnist and cbhq data set here we actually decode at all diffusion time steps to see progression in both the latent space as well as pixel space these are obviously not amazing but the fact that using single GPU I was able to generate 256 cross 256 images this is Possible only because of latent diffusion models and because we are playing the game in the latent space so up till now we have seen the components of unconditional latent diffusion models how do we train the autoencoder to ensure that the compression does not compromise on quality and we saw how everything looks in code in the next video we'll add conditioning based on text images classes we'll talk about classify free guidance see how to do text to image image to image inpainting and actually Implement everything and transition to stable diffusion model I really hope you got some benefit from this and as always thank you so much for watching it till this point see you in the next video
Info
Channel: ExplainingAI
Views: 9,311
Rating: undefined out of 5
Keywords: stable diffusion, stable diffusion architecture, stable diffusion tutorial, stable diffusion pytorch, stable diffusion from scratch pytorch, latent diffusion, compvis latent diffusion tutorial, latent diffusion models explained, latent diffusion models, what is latent diffusion, build stable diffusion from scratch, train stable diffusion model from scratch, how to train stable diffusion, how latent diffusion works, coding stable diffusion from scratch, deep learning
Id: 1BkzNb3ejK4
Channel Id: undefined
Length: 42min 29sec (2549 seconds)
Published: Fri Feb 02 2024
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.