Building our first simple GAN

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
what is going on guys hope you're doing awesome in this video we're gonna build our very first gan uh but before we get started let's roll that intro all right so uh these are the imports that we're gonna use i think all of them i should be pretty basic so i'm not gonna go through them um so what i'm gonna start with is uh creating the discriminator and from the example that i showed in the previous video this is going to be sort of the detective um and judging and so the discriminator is going to judge the image and say if it's a real or a fake image but first we gotta inherit from nn module we're gonna do our init as usual we're gonna send in some in features this is going to be from the mnist data set so in features is going to be 784 uh what we got to do first though is we got to do our color parent class super and then we got to do self.discriminator we're going to initialize this as an n sequential so this is going to be just a very very simple model we're going to do an n linear and we're going to take some in features to 128 nodes we're going to do a leaky relative in between and let's say with the slope of 0.1 um you can play around with this this is sort of a hyper parameter and you can also try playing around with just using relu but in in gans like relu is oftentimes a better choice or a better default at least then we're going to do another linear layer and we're going to do 128 to 1 and that's because we're just going to output a single value if it's fake or real so fake is 0 real is 1 and to ensure that it's between zero and one we're just going to call sigmoid on that last layer or last output rather and then we're gonna do our forward so we're just gonna do self we're going to take some input x and then we're going to return self.discriminator of that x all right so pretty simple that was all for the discriminator then we're going to create our generator so generator we're going to inherit from n module again we're gonna do uh init and then we're gonna send in some z dimension that's going to be um sort of the the dimension of the hidden noise that the generator is gonna take as input then we're going to take some image dimension and perhaps actually i should change this in features to image dimension as well so image dimension and then we're going to do super uh init and we're going to do self.generator it's also going to be just a very simple model we're going to do an sequential uh and then linear of z dimension what's wrong here oh uh set dimension and then 256 then we're going to do nn leaky relu and let's do a 0.1 again another linear layer and we're going to do 256 to image dimension so image dimension is as a as i said previously it's going to be 784 because the mnist is going to be 28 by 28 by 1 which if we just flatten that that's going to be um 784 so yeah let me remove that i'm not sure why i wrote that anyways then we're going to do uh nn.10h and we're using 10h here to make sure that the output of the of the pixel values are between -1 and 1 and the reason why we're doing that is because well you're going to see but we're going to normalize the input from from the mnist data set to ensure that it's between -1 and 1. so if that's our input if the input is normalized between -1 and 1 uh it only makes sense that the output should only should also be between -1 and 1. then we're going to do a forward self x and we're just going to return set the generator of x so hopefully you've followed this far so this is just the discriminator and the generator um with some very simple models just i guess one linear layer or i guess two linear layers and then activation functions at the end so let's let's do our hyper parameters etc so we can decide our device here let's set cuda if torch dot cuda is available um otherwise let's do the cpu so this should be will this should run if you're in cuda or cpu we can set a learning rate this is actually incredibly important and i encourage you to play around with the learning rate to see what happens if you change it but this is one learning rate that i found after trying just some different ones that seem to work pretty good and then the set dimension is going to be 64. again this is also a parameter you can try 128 256 and and smaller as well one thing is that gans are incredibly sensitive to to hyperparameters and especially in this phase that we're doing now with this very simple gans we're sort of replicating the original gan paper in a way in newer papers they've come to better methods and and ways of stabilizing gans and we're going to see that for later on in this series of videos but for now you know the the there's going to be a lot of hyper parameters that's going to be very sensitive to to the actual training so if you if you're coding along and you're trying different hyper parameters just know that that might be a reason why things aren't working if you're not exactly replicating on what i'm doing so image dimension is going to be 28 pi 28 pi times one so this is just 784 and we're going to set the batch size let's set it to 32 just as standard and then epochs let's run it for 50. that seems pretty good so initializing the discriminator we're going to do discriminator of image dimension and then we're going to take that to device same thing for the generator we're going to do generator of set dimension and then also image dimension and then to device we're also going to set set up some fixed noise so we're going to do tours at random of batch size and set dimension so the reason why we're doing some fixed noise is because we can then uh see how it has changed across the the epochs and so we're going to do that in tensorboard and i'm going to going to give us a pretty nice visualization then we're going to do transforms is transform stat compose um we're going to do transform stat to tensor and then transforms dot normalize we're going to subtract mean 0.5 and then 0.5 and i guess this isn't you know the actual mean standard deviation for the mnist data set um but it's so you could i guess you could try out to have um the exact mean standard deviation and see how that works actually maybe we should do that let me google what that is all right so this is going to be pretty fun so let's see what what it was uh this is the actual mean standard deviation of the mns data set so i'm just going to copy those in so as i said the sort of training is very sensitive type parameters so i might have screwed everything up by taking these but hopefully it should work so then we're going to create our data set we're going to do data sets dot mnist root we're going to set it just to data set then transform equals transforms and download equals true all right so creating our data loader we're going to do data loader of data set batch size we're just going to set to batch size and we're going to shuffle the data and then create our optimizer so optimizer for the discriminator and we're going to use atom for both so disk parameters and then learn rate equals learning rate uh opt generator we're going to set optim.adam and then gen dot parameters hollering rate equals learn rate uh and let's initialize our our loss which is going to be the bce loss so i'm going to talk a bit more about the bc loss but it follows exactly the form or pretty much exactly the form that we saw in the previous video that introduction to gans so um but i'm going to explain that in more detail when we actually set up the training loop so we have our bc loss we're going to have our writer fake this is for the tensorboard so we're going to have a summary writer and the summary writer here we're going to set it to a runs directory and we're going to do again mnist and we're going to have it in a fake so this writer is only going to output fake images what the generator has i guess generated and then we're going to have um another one that's going to be rider reel and all we're going to do is we're going to set it in a different folder and then we need to have a step um for for the tensorboard so if you if you're confused over this and tensorboard and stuff then i have a separate video uh for that but it's not really relevant for the the sort of the gan but anyways we're gonna for epoch in range of num epochs then we're gonna do for batch index comma and then here we're gonna have the images and the labels i'm just gonna call the images real and i'm not gonna take so we're not gonna use the actual labels so that's pretty cool that gans are unsupervised in in that way uh and the first thing we're going to do is we're going to reshape so we're going to do real.view we're going to do -1 784 so here we want to keep the number of examples in our batch so minus one and then we're going to flatten everything else to make sure that it's 784 so we're just going to take that to device and we're going to check real quick what our batch size is so that's real shape of index zero that's the first dimension and then what we're gonna do is we're gonna set up training for the discriminator so the training for discriminator was to maximize if you remember log of d of x or x d of real i guess and then uh plus log one minus d of g of z and so z here is just some random noise that's going to be inputted to the generator let's generate that noise we're going to do towards that rand random and this is from a a gaussian distribution with mean 0 and standard deviation 1. then we're going to do let's see batch size and then z dimension to device uh we're gonna generate some fake images so we're gonna do fake equals uh gen of noise then we're gonna do disk uh reel sort of what the discriminator outputs on those real ones that's this part of the loss function so uh that's gonna be uh disk of real and then view and we're going to just flatten everything and so a loss d lost d of real that's this part that's going to be criterion our loss function on disk reel and then towards a one like of disk reel so uh just remember here that what we're looking for now for the lost d real is this expression right here here's the bc loss documentation uh the only difference here is that so we're we're sending in uh yn to be torts of ones so what we're doing here is we're sending in once and that means that this part right here is going to be uh this is going to be just a zero because these are all going to be ones so this term just completely disappears and then this part is the only thing that's left so it's log of xn and then you have to remember that we're not sending in xn we're sending in the discriminator of real so if we go back we're sending in the discriminator of real right here and that's this part of the expression log d of real then you can see here we have a minus sign and then we actually have a wn right here and this wn is just going to be one so you can ignore that but the important part here is that we have a minus sign in front and so the minus sign means that if so if we want to maximize log of d of of x that's the same thing as minimizing uh the negative of that expression and so if you're used to you know just training neural networks as usual you you normally want to minimize the loss function so that's why the i guess they have this negative sinus as default all right so now that we understand that uh we have this expression now except we have the minimum of the negative of this which is the exact same thing it's just an alternate way of of formulating it all right so we're going to do the same thing but for the other expression right here so now we're going to do a discriminator of fake so we're going to do disk of fake dot view of -1 so we're going to flatten everything so what we've done now is we've calculated this d of g of z and so that's a disk fake and then we're going to do a loss d fake is going to be the criterion of this fake and then torch zeros like and then disk fake so if we go back to the loss function we are now sending in zeros so this expression right here is going to completely disappear and all we're going to have left is this term and so this is just going to be 1 we're going to have this log of 1 minus d of g of z which is exactly what we want and then we just have this negative sign in front indicating that we want to minimize the negative of that expression rather than maximizing all right so hopefully you follow that that's sort of one of the tricky parts i guess and then we're going to do loss d is going to be loss the real loss d fake then we're going to divide this by 2. then we're going to do disk 0 grad we're going to do loss d dot backward and we're going to do optimizer disk dot step all right so now we've trained the discriminator for one sort of um step uh one thing here is that that's also tricky is that we're gonna use we're gonna reuse on this thing uh this actually let me write that up first so we want to now train the generator and remember we want to minimize log of one minus uh d of g of z but remember that let's see is this correct uh one more right so remember that in the last video um this expression so remember in the last video we talked that this expression leads to saturating gradients which means sort of weak gradients and means leads to you know slower training and actually no training sometimes so what is better is to maximize log of d of g of z so this is the expression we're going for one thing here though is that remember here that we're doing g of z and that's exactly what we did over here right when we calculated this fake so if we don't so one thing here is that we want to re-utilize this will be calculated over here since there's no point of really doing it again sort of wasting computation but one problem here is then when we're doing it this is i guess a little bit more advanced so when we did lost e dot backwards everything that was used in the forward pass to calculate those gradients have no have now been cleared from from sort of the cache because they want to i guess save on memory and stuff so uh but and this is you know normally how you want to do it but we want to actually reutilize this fake right here um at a later point um so one thing you can do here that that works is you can do let's see where to do it um here so you can do disk.fake.detach so now that we're when we're detaching it uh when we then run the backward pass then we're not um we're not going to clear uh those intermediate computation so this works another thing that you can do is you can also do lost d dot backwards and then retain a graph equals true so that's another way of doing it and both of them are are equivalent and both works so now that we're training the generator let's try and do first of all we're going to do output disk of fake and then dot view and perhaps here it's also clear that to compute the actual gradients we're going to need the intermediate results from from this fake and that's only still there in the computational graph because we did retain graph equals true but anyways so then we're going to do loss g is going to be criterion of output and then torch uh ones like output so you know here we're doing the exact same thing as before we're maximizing instead of maximizing log d of real we're now maximizing log d of z uh d of g of z so that's you know this thing right here um so that's why we're doing torch ones right here pretty similar to what we did over here then we're just going to do generator and then zero grad we're going to do loss g dot backward so we're only modifying the uh the generator's weights and then opt gen dot step all right so that's that's the training setup i'm just going to add some additional code here for [Music] tensorboard [Music] just another sunny day in southern california it's where [Music] missing [Music] is uh try and run this and hopefully there are no errors so i'm actually going to stop it there uh because i think this is not going to converge um and i think the reason is that i i use this new normalization to make it so this is the exact mean but previously when i came up with the hyper parameters i use 0.5 0.5 so if we use 0.5.5 i think it's going to work but maybe for fun let's try and set this to 1e minus 4 instead and try that out maybe that would work better and as we can see again this is you know just totally diverging and this is sort of the unfortunate truth when training gans that they are just incredibly sensitive to higher parameters so you know maybe you can so i think that for changing learning rate changing the um this dimension right here changing perhaps the atom the uh the beta values of this of the atom optimizer could make a difference um just just sort of know that they are incredibly sensitive to hyperparameters and so when i tried this code before this is the normalization values that i used and i used it together with learner rate three e minus four so um i'm i don't want to do that entire hyper parameter search again so i'm just going to keep it as is that's something that you can definitely play uh play around with and see if you can get it to work [Music] images look not that good but they look alright and if you train it for longer it's going to become even better uh but so we can definitely at least see that it's starting to get some resemblance to to numbers what i want to talk about a little bit is what you can try to get even better performance so things to try i guess is first of all you can do what happens if you use a larger network for example so you know we did a very very small network with just like two linear layers so that's probably the first thing i would try uh and then better uh normalization um so with batch norm so that's one thing that you could try and then also um different learning rates um so is there a better one than what i'm using and then on the fourth thing on the last thing i would try is that changing the architecture so change architecture to uh cnn and this is what i will do in the next video i'll sort of continue in this series um and i think we're going to get a lot better performance with a cnn so that's it for this video hopefully you will able to follow all of the steps and get an understanding of how to implement a simple can and hope to see in the next video when we build an even more powerful gan using convolutional neural networks [Music] you
Info
Channel: Aladdin Persson
Views: 24,594
Rating: 4.9381442 out of 5
Keywords: pytorch simple gan, pytorch first gan tutorial, pytorch original gan, pytorch gan mnist, simple gan tutorial, simple gan deep learning
Id: OljTVUVzPpM
Channel Id: undefined
Length: 24min 23sec (1463 seconds)
Published: Sun Nov 01 2020
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.