CycleGAN implementation from scratch

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
in this video we will be implementing cyclogan from scratch and i'm pretty excited for this video uh this sort of overview is that in the first part i'm going to do a quick summary of cyclogan what it is uh what the how the training looks like what the loss looks like in the uh the architecture and this part is going to be quite compact because i've already done a paper walk through of this paper so if you want more details you can look at that video but and so the main part of this video is sort of going to be um for the implementation so uh without further ado let's get uh started so uh this is cyclogan uh basically what it does is that it can do these awesome things which is it can take an image of a zebra and it converted and it can convert it to an image of a horse and it can also do you know uh vice versa as we see down here so you know what is quite remarkable about this is that this is unsupervised right we don't have an uh a target image for the corresponding input and it cannot do other amazing things as well which is these uh summer to winter images and and stuff like that so um let's move on and uh basically ask the question you know how do we train this uh and i'll go into the architecture a little bit also later on so to train this we're gonna have two discriminators g and f and we're going to have two discriminators d uh dx and d y so to make things a little bit simpler just imagine that x is a horse and y is a zebra and uh basically g here takes this image of a horse and it tries to convert it to an image of a zebra the generator f basically does the opposite which is taking image of a zebra and tries to convert it to a horse so the idea is that basically we have this discriminator dy that tries to say if an image of a zebra is a real zebra or a generated zebra and basically the the opposite for dx which is tries to classify if the image is a real uh real horse or if it's a fake uh sort of generated horse so you know you can imagine that this is sort of a i guess a standard gan setup except that we have i guess double the amount of of discriminators and generators normally you just have one generator and one discriminator but they do make things a little bit more complicated which is that they try to um also add this additional loss term which is this uh what they call a cycle consistency and basically that is that if we take an image of of uh of a horse so this image is a sort of the uh real horse right and we try to uh generate an image of a of a zebra with g um and then if we take the other uh generator and take this image of a zebra just generate zebra and try to to generate a horse we should ideally get back the original image right so this is the reconstructed on the on the right this is reconstructed horse and if you look closely you can see sort of that there are some color differences for sure this does you know this is a pretty good job to be honest right this looks pretty um pretty close but uh we add this additional loss and we say that you know x should be equal to x hat or you know approximately equal to x hat and we enforce this with an l1 loss so we have this again we have this standard gain loss over here where we and i think you know if you're confused i think it's going to be a little bit more clear with implementation but we you know we train these two generators and these two discriminators but we also add this cycle consistency and here i only showed it for the um for the horse right but we're gonna you can imagine that we're gonna do the same thing uh for zebras all right so what they also do is that they add this sort of identity mapping loss so what it is is that basically if we take this image of a horse right and we send it into the generator that is supposed to generate a horse then it shouldn't do anything so the idea is that you know if we i guess before we used so this might be a little bit confusing with gy and gx here i think this was um f and then this was just g and uh so basically we do the same thing for the zebra right if we send in a zebra to the generator that is supposed to generate a zebra it shouldn't touch the image and again we enforce this with an l1 loss and the reason why we do this is because we don't want the the generators to change the coloring or the tint of the image too much so this is a way of enforcing that um we should preserve the coloring all right so sort of the the last summary here is that we have for the generators right we have two generators so we have one for sort of the standard gan loss right the adversarial loss and the only difference is here is that they're using a least squares loss instead of bce then we have this cycle consistency um meaning we go to horse to a zebra back to a horse and then we also have this identity loss and uh yeah similarly down here for for the um i guess the uh the inverse of that so one thing to mention here also is that this identity loss is actually not used in all of the data sets so i think it was actually just used for one of the data sets uh specifically the uh the paintings and or data set i'm not really sure what that was called but it wasn't used for the um for the um i i didn't use it at least uh and from my understanding the paper didn't use it either for uh for horses to zebras all right so what does the architect oh and by the way this right this is just for the generator but you can imagine that we're gonna train the discriminator in a pretty standard way um just you know identify if it's real or fake so what does the architecture look like uh well it's uh i don't want to go too in depth on this because i've already made a video on this um or sort of the paper walkthrough and you'll get the details when we implement it but basically we're gonna you know this is the generator on the left side here this is generator and we're gonna do basically two down samplings uh right there with just a stride of two on the convolutional layer then we're gonna do a bunch of residual blocks right so this is residual blocks then we're going to up sample here and here using column transpose or transpose convolutions and then we'll do a final cur a final comp layer that just maps it to rgb channels for the discriminator it's going to be really really simple we're just going to use four convolutional layers all with a stride of two and i guess the idea or what it is called is a patch can because we're not going to output a single scalar between 0 or 1 what we are going to output is a grid of values so let's see maybe we up with a 3x3 grid uh each of those values are going to be between 0 or 1 and why it's called patch can is because each of those values corresponds to seeing a patch in the original image all right so that's the architecture what does the result look like so for this is from the paper basically you can see that it looks pretty alright pretty good these are cherry picked but in general they actually do look quite quite okay so this is what the input looks like and then this is an output all of this row right here is from horse to zebra and then this is from zebra to horse they don't show the um they're reconstructed but from from my implementation to reconstruct it is uh oftentimes really really good it's kind of surprisingly good but so the results from this implementation that you know we're gonna do uh looks like this for horses to zebras and uh i think this looks pretty good um i would say that this is pretty equivalent to uh what they did um in the paper you know perhaps slightly different but then in the um yeah and these are zebras to horses so now we've done a summary of uh of of cycle gan and what we can do now is uh try to implement this from scratch all right so for the uh implementation uh what we are going to do is uh create one for the discriminator model uh we're gonna create another one for the generator model then we are going to create uh data set loading um and there's gonna be link where you can download uh let's see so you can download the uh the data set from kaggle um they're also probably the pre-trained weights if you want to load them and try it out um but so what we're going to create is the models first then we're going to create the data set loading and then i will do the um the train file so let's start with the discriminator perhaps this one is probably the easiest um so let's do this one let me make the font just a little bit bigger on this all right that should be pretty good uh we're gonna import torch import torch dot nn as n and uh let me just change the to pi torch so um what i usually do is just create a block first of all so that uh we'll use a basically a compact from relu except we are not using bathroom here we're gonna use instance norm so we're gonna do an n module uh in it and what we're gonna send into this is the in channels uh and you know this is yeah so and then in channels and then our channels and we're also gonna send in a stride so uh here we need to call super first of all and then we're gonna create our comp block so this is gonna be n and sequential and uh we're basically just gonna do a com2d of in channels to out channels the kernel size is going to be four this is always the case and it's uh detailed in the paper and then we're gonna set stride we're gonna set padding to one and then bias true and then padding mode to reflect so they mentioned in the paper that using padding mode reflect help to reduce artifacts so that's the the reasoning behind that and then we're gonna do instance norm to out channels and then we're going to use leaky relu we're going to use that all for the discriminator and then we're going to use relu for the generator so then for the forward part we're just basically going to return self.com of x all right so nothing difficult here we're just creating a comp block and uh so for the actual discriminator we're going to do class discriminator inherent from anon module not len init and then to this uh init we're going to send in in channels so this is going to be 3 sort of rgb as default and we're going to send in some features and this is basically you know we send in in channels it's going to go to to um to 64 then 128 then 256 and then 512. and then we're gonna call super and so basically you know we're going to use the com block for for all of these pretty much but there was an exception to the rule of using instance norm uh in every block which is the in the initial and then in the initial we're just going to do an n sequential um where we send in you know the in channels to features of zero again the kernel size is wait a second uh this is supposed to be n in com2d so and then com 2d here we are going to send in uh the in channels two features of zero then we're going to send in the kernel size which is four stride uh two so i guess we can do like this kernel size equals four uh stride equals two padding equals one and then padding mode again um reflect pretty much everywhere then we're just going to call you know a leaker relu on that so the only idea here in the initial part is to basically do similar as we did in the block except no instance norm so yeah i guess you could also just send in sort of a uh uh basically saying that you know should we use instance norm or not but i guess that's one thing you could perhaps improve make it a little bit cleaner so then afterwards we're gonna uh create sort of a bunch of layers and the layers are going to be with these blocks where the in channels first of all is going to be features of zero because we we run it through this initial block which changes it to channels 64. um and then we basically you know go through feature in features and we were skipping the first one because that was done in the initial uh so we're gonna then layers that append uh a block of in channels to uh let's see feature and then we're going to use a stride of one if feature is equal to features features minus one otherwise two so basically we're going to use shroud of two stride of two stride of two and then for the last one we're going to use a stride of one and uh yeah then we just change the in channels to be equal to feature so in the end we're also going to add additional comp layer which is uh going to map it to rgb channels so here we just do layers that append and then com2d of uh of in channels right because if we change that you know right here so in the end that's going to be 512 so we're going to do in channels to 1 and the reason why oh sorry so i said rgb not rgb that's for the generator in the discriminator we're just going to put a single value right between 0 or 1 which indicates if it's a real or a fake image so here we're not trying to generate an image that's what we're doing in the generator so here we're just going to do two one the kernel size is going to be four uh stride is going to be one our padding is going to be one and then padding mode reflect and in the end we're just going to do self.model is nn sequential and then just uh do asterisks of layers so we're just basically unwrapping that list to this nn sequential and uh that is pretty much it so in the forward uh we're gonna send in some input x we are going to um just run it through the initial so self initial of x and then we're going to return self.model of x and let's see if there's anything missing here so and i actually just noticed an error that i think i did in the original which is surprising that it still worked so well you know the results that we saw looked pretty close to the paper but i didn't uh it seems like i don't know i must have made a mistake perhaps when editing this afterwards but we should use a sigmoid on the end the output as well to make sure that it's between zero and one um so yeah all right and then we i'm just gonna copy paste this so let's see i'm gonna run a test here basically uh we're sending in 256 with three channels and then let's say five examples and then we're running the model and the idea is to just uh print the shape and the shape so this is a kind of test for you what should the well i guess this is kind of impossible to to tell so uh in this case uh the channels is going to be one but i guess this is kind of hard to tell but this should be a 30 by 30 output so and uh yeah so basically uh when they mention it in the paper they say that it's a 70 by 70 patch can and what they mean by that is that each value that we see in this 30 by 30. so here what we see in this 30 by 30 grid is that each value each value in this grid sees a 70 by 70 patch in the original image so let's move on to the generator i guess what we're going gonna do is uh import tour set nnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn snnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn we we're gonna do a call block which is nn module uh init and we're gonna send in some in channels to some out channels uh we're gonna specify if it's down or not so basically for down sampling or if we're up sampling then we're going to specify if we're going to use an activation and then we're going to send in some keyword arguments then we call super of init and uh we're gonna do self con conf is and then sequential of and then come to d in channels out channels and then uh padding mode to be reflect and then we're just gonna send in some keyword arguments so the keyword arguments here would basically be the kernel size stride and padding and we're using this and then let's see we're using this n com 2d if down otherwise we want to do a n and comp transpose 2d of in channels to our channels and then keyword arguments after we've done either a conf 2d or a transpose convolution for upsampling then we want to do an instance norm which is going to out channels and in the end we're going to do a relu and let's do in place equals true i think that just adds some additional performance benefits and then that is the case if you use activation otherwise we're going to do n and dot identity so and an identity means that it's just going to sort of pass it through not do anything to what's input so for the forward right we're just going to have x and we're going to return self.com of x and then for the other part we're going to use right this is for um downsampling and upsampling that's sort of the main block then we're going to have another one which is a residual block this is going to head from n module we're going to send do an init and to this init we're going to send in channels so um yeah i yes that's just in channels so what we're going to do here we're going to create a block which is going to be an sequential of com block of channels to channels kernel size 3 right this is using the keyword arguments and then padding of one and uh stride of of zero uh sorry stride of one in this case but that's default in n uncommon2d then we do another com block doing channels to channels and then let's see we're going to set use activation to false and the kernel size is going to be 3 padding is going to be 1. so that's just sort of what i found when going through the paper that the you know that's what they mentioned in the implementation details so they use an activation for the first com block but not for the second i'm not really entirely sure why that makes sense but that's just what i found them to do so then in the forward we're sending in x and basically we're just going to return x plus self.block of x and uh the reason why we can do this is because we're not changing the channels right and all of these are same convolutions because they are kernel three and padding one so really really not changing anything in the in input with these uh blocks right here so i guess it's a pretty interesting architecture that way so then we are gonna create our generator again inherit from nn module define our init and then we're gonna send in some image channels and then the number of residuals that we're going to use so this is going to be nine if it's 256 squared or larger and then number residual is six um sort of uh if it's 128 or uh or smaller so we're gonna i guess we can set that to default to nine then we we're gonna call the init method um and similarly as a discriminator we're basically going to create an initial block or initial one which doesn't use a instance norm so this is good this is going to be an uncomf 2d and uh yeah all right so for the initial and i notice another thing i i think i made um incorrectly in my implementation here i just did an uncom2d but i think this should be you know an uncome to the plus relu and just sort of skipping the instance norm like we did in the discriminator um and actually when going through their implementation so this is their source code for the for the paper they here use um this is the initial one in the resonant generator they use reflection pad com2d and then they use an n relu there so i'm pretty sure i made it incorrectly my implementation but you know obviously it still works pretty well but let's just change this so here we're going to do an uncommon 2d and here we're going to do image channels to 64. and then the kernel size is going to be 7 stride 1 padding padding 3 and padding mode is going to be reflect then we're going to do an n relu of in place equals true and then let's see so the down blocks so down blocks is going to be an n modulus where we're going to send in basically uh com block which is going to take 64 to 128 kernel i guess maybe it's perhaps better to also do you know num features to uh to guest 64 and then we could just do num features and then here we could do you know num features to num features times two and then uh kernel size is going to be three stride two uh and then padding one and uh yeah i guess one thing you could specify here also is that this is going to be down equals true right but we set that as default in the com block so you know we could just remove this part because uh we're already doing it by default so and uh yeah here we're also gonna do times two two noun features times four and then uh we're gonna do self.residual blocks is going to be an unsequential of basically doing asterisks for unwrapping all of the blocks that we're going to do and here we're going to do residual block of uh let's say features times 4 for underscore in range of num residual on residuals like that so here it should be num features right so num features and uh yeah this is going to be basically nine of those residual blocks those residual blocks uses uh these com blocks and then yeah so here also we need to call the super i just noticed that's going to give us an error later on if we don't fix it so so yeah okay so just a quick recap we've done the initial this is basically a com block without the instance norm uh the down block is just a two pom blocks with stride of two so that we downsample and then we use a bunch of residual blocks that doesn't change the input or the number of channels and then in the end we're going to do up blocks and we're going to do an n module module list um pretty similar to what we did over here and then we're going to do com block of num features times four two num feature num features times two down is going to be false here kernel size is going to be 3 again stride 2 padding 1 and then i also found that they used output padding 1. so the output padding i guess adds an additional padding after the convolutional block so i'm not really sure if this is essential but that's what they did anyways so let's just copy that i wonder why it formats it like that so let's put that over there and then we're going to change these stuff so this is going to be let's see it's going to be 2 to 1 and it's going to be the same thing then here we should change the um the channels right so now it's 64. we should do an additional one that converts it to rgb so the last let's just call it self.last is an uncommon2d of uh num features times one i guess two image channels kernel size seven stride one padding three and then padding mode reflect now okay so now we've all done all of that in the forward we're just sending in x uh we're running it through the initial um and then we're gonna go through four layer in self.down we're gonna do x's layer of x we're then running it through all the residual blocks so self.residual blocks of x and then for layer in self.uplocks x is layer of x and then we're going to return torch.10h of self.last of x so really you know here there's nothing you know running into initial down blocks residual blocks uploads and then the last one to convert it to rgb and also the towards the 10h which converts it to this minus one and one range as always um i'm going to run in just copy paste in some test quick test case so this is just uh image channel 3 image size 256 and then we're running it through the generator and then we're checking this shape all right so a quick question for you what should the shape be after running it through the generator and the answer here is that it should be exactly identical as the input because we're only sort of converting um the pixels in the original to this hopefully you know converted zebra or a converted um yeah you know yeah a fake horse or a fake zebra so let's try to run this and it works magically so that is really awesome now we've done i guess one of the most tricky parts we're going to now do the data set loading and for that i'm actually going to need the data set first all right so let's see i've now copied over the data set so there's gonna be a link where you can download it but basically we're gonna have data and we're gonna have train and we should also have a valve it doesn't show up here oh there we go so we're going to have let's see you know val and train we're going to have horses and zebras inside this it's just going to be images of of uh horses and then yeah similarly for the other one for the zebras we're just going to have images of zebras and yeah so this is gonna i guess not gonna be too difficult right so i'm gonna just copy paste in the imports basically torch uh pill to load the images torch vision for transforms and uh actually we're not gonna need that i think so we're gonna use os and this is not we're not gonna need that either we're gonna import config which we haven't created yet and we're going to import numpy as np all right so for the data set we're going to use this config file and basically i've copied over those files because i don't think it's too important basically we have a in the utils file all right so i copied over the config and the utils file in the youtube file we're just going to have save and load checkpoint and yeah that's pretty much it we're not going to need this part so um in the config file however we're going to use um you know we have the device train directory valve directory back size learning rate all of that stuff and yeah these are just taken from the paper and then we have sort of where to save the checkpoint we have transforms that we're going to use which is uh for the um for the images right we're using albumentations here pretty similar to to as you would see with torch vision basically we're resizing here we're doing a horizontal flip and you could also do add a color jitter if you want i didn't find this necessary but you could add that and that is pretty much it in those files so i didn't want to show it because it's it's not really it's just going to take a bunch of time and it's not going to really give anything so for the data set we're going to create class horse zebra dataset which is going to inherit from dataset which i just removed so from torch utils data import dataset then in the init in it we are going to send in our root zebra root horse and then transform we're gonna set that default to none then um we're just gonna do self root zebra is root c root zebra and then self.root horse is root horse um self.transform equal transform and um what we're going to do then is i'll list all the file the image files inside this uh root tb and root horse directory so we're going to do self. zebra images is os.list directory of root zebra and then pretty much pretty much the same thing here so this is going to be horse image and then this is going to be root horse and we're going to check now also what is the uh so we're going to set that length of data set so the tricky part here is i guess that you know we have two data sets and we need to load one image from each but the um the length of both of those data sets are not are not the same right usually you have xy pairs and then you would have an exact match so you would just load the exact same index sort of for what you're loading from one of the sort of for x as you would for y in this case you know we don't have that because there are no pairs and the length of the data sets are not exactly equal so what we do is that we check what is the maximum length so we do length of self.zebra image and then self dot course image i guess maybe we should do horse images perhaps that would be more accurate and then change this zebra images and horse images so here we're taking what is the maximum length right so maybe one is 1000 and maybe one is 1 500. then we're going to do what is the zebra length which is a length of self.zebra images and then we're going to use a horse length which is the length of the course images so then all right so that's the initialization and what is wrong here right we should have length here and length here like that i think and then we're going to do the length of the data set and that is just going to be you know self.length of data set so we're returning the the maximum length um so the 1500 in the true example and then for the get item uh we're gonna get an index right between i guess zero and one thousand four hundred and ninety nine so what we do then uh is that we we get the zebra image which is self.zebra images of index and then we're doing a modulus self dot zebra length so the reason why we're doing this is because we get an index right and this index could be greater than the data set that we have um because we're taking the maximum of the two so a way to solve this is by taking modulus uh the length of the of the total length of that data set so this ensures that it's going to be you know the correct sort of range so that we don't get sort of an index error when i think about it might be flaws in this way of thinking in that some examples in the data set might be uh sort of shown more often than others um yeah i haven't really thought about that but you know if you have some thoughts about that let me know so for the horse image we're going to do horse images uh index of let's see horse horse length so this is going to be a similar thing as for the zebra image then to get right this is just sort of the jpeg file that is in that root directory to be able to load the image we need to have a path so we're going to do zebra path is os path.join of image dot let's see um self dot zebra directory root zebra and then the zebra image and then we're gonna do this similar thing for the horse path which is root horse and then horse image then we need to convert both of those to um to um let's see we're gonna convert both of those to to uh pill images or numpy arrays in this case because we're using albumentation so here we're gonna do numpy array of image.open of zebra path and we're gonna also do convert to rgb here not inside entirely sure if that's necessary but if some of them are grayscale this is going to solve that and then the horse image is going to be similar but for the horse path and then if self.transform we're going to do augmentations equals self.transform of image is zebra image and then image 0 is horse image so uh why this is guess is necessary or we don't have to do this but this way we're performing the same transformations for uh for both of them uh the both both of the images so the zebra and the horse then after performing the augmentation we're gonna get out this zebra image which is augmentation of image and then force image is augmentations of image of image zero and the reason why we set image zero here is because in the config file here i set image zero okay so now we just return zebra image and horse image and normally i would do a test case here uh just save some of the images hopefully you know that there's nothing wrong um and i just noticed that we don't need the config file because we're going to send in the transform and it doesn't look like we need torch either so we can remove that and that is what the data set looks like all right so yeah we've done pretty much all of the you know difficult parts um hopefully you don't feel like it was too quickly um let me know if you think i should make it a little bit slower in the future or if this is a good uh if this is a good sort of tempo i guess you can always rewatch it right if you want to so uh yeah for the train um i'm going to just copy paste in the um all of the imports so what we're going to do is torch we're going to do the uh import the data set purchase save checkpoint load checkpoint we're not going to use torch vision transforms here so we can remove that we're going to use data loader we're not going to use image folder i think and then we're going to do torch.nnsnn optim config gqdm for progress bar save image to look at them and see if they look okay and discriminator and the generator so the setup as usual is just going to be a train function and then there's also going to be a main function and then we're going to do if name equals main if that is the case we're going to run the main file all right so perhaps you know or i guess you know this could be one of the most tricky parts because here we're going to do all the loss functions and stuff yeah i'm gonna do this okay so this is gonna make things a little bit more quickly because this is really nothing difficult here so i'll go through it with you so that we you know we understand all this stuff here i'm initializing the discriminator that we created uh here so we're initializing the discriminator h discriminator h is for uh for classifying images of horses right so that that's the h then discriminator z is to discriminate if an image is a real zebra or a fake zebra then we're initializing gen z which is if the this is to generate a zebra so it takes in an image and it tries to generate a zebra this gen h is that it takes in an image you know of a zebra and it tries to generate a horse then we're initializing the discriminator the discriminator is gonna here we're sending in uh both of the discriminator parameters so how we do that is we do list of disk h parameters and then we add it uh sort of concatenate it to the list for the disk said parameters we use config.learningrate which we set here and we use 2e minus 4. then for the beta terms in the paper they specified 0.5 for the momentum term and 0.999 for beta2 we do a similar thing for the generator where we just add the parameters here similar as we did for the discriminator in this way we don't have to use four discriminator for optimizers and then we do the uh the config learning rate here uh and then the same beta values for the loss we're going to use an l1 loss for the cycle consistency loss and also the identity loss for the sort of the stat the adversary loss we're going to use a mean squared error here i'm just checking if we should load a checkpoint then i'm running that we're going to load both of the generators and both of the critics or i'm not sure why i call it critic here um it's yeah so that's just from vegan i guess kind of confusing discriminating critic in this case it's a discriminator but so then we're creating the data set uh here so the data set we're just going to do sending root horse which i guess we can do config.train directory plus of course like that and then for the root zebra we can do uh config.train directory plus uh course and uh yeah and then you could do the same thing for a val data set which is um you know important to to evaluate the model and then for the transform here we're using config transform which is auburn mutation for the loader we're doing data loader we're sending in the data set batch size shuffle non-workers pin memory standard stuff here we're defining a g scalar and descaler and this is sort of for float16 training so you know you could remove these to run it in flow 32 but it's always nice to run in float16 then we're running the train function which we haven't created and this is the i guess difficult part to get all the losses right so we're gonna do that one step it uh sort of step by step and then each epoch i'm saving a checkpoint for all of the four different ones all right so hopefully you recognize that there was nothing really difficult here uh and we're gonna do the train function so in the train function we're gonna send in the discriminator h discriminator zed gen z and gen h we're going to send in the loader op discriminator up generator l1 mse descaler and g-scaler then first we're going to do is basically create a a loop so loop is tqdm of loader leave equals true the reason why we're doing this is just to get a nice progress bar so that we see what's going on then in the actual training loop right we're going to do for index zebra comma horse in enumerate of loop so that you know that's the way we returned it in the data set zebra horse so that's what we're doing here then uh we're going to do zebra is zebra.2 config.device and similar thing for horse and yeah so we're first going to train discriminators h and z so the first thing we do is that we run everything with torch.cuda.amp.autocast and this is you know necessary for float16 and what we do first of all is that we generate a fake horse so how do we generate a fake horse well we're going to generate a horse so we do gen h of zebra then we do a discriminator h reel is going to be a discriminator h of um of horse right that's the real one and for the fake one the h fake we're gonna send in disk h of fake wars and i guess here you can do detach because you we're going to use this fake course later on when we train the generator and doing detach here we don't have to sort of repeat this line uh later on when we uh train the generator so what we do now is uh we do dh real loss that is going to be mean squared error of dh real and the um the sort of what we should say here is that each of those should be real right and real is one fake is zero so we're going to do torch ones like d h real and pretty much the opposite for d h fake loss here we're sending in d h fake and we're doing two torch zeros like d z i know t h wait a second d h fake so that is it for the sort of the first discriminator right so here we're doing dh loss is dh real loss plus dh fake loss so we're just adding those together and then i guess we can copy paste all of this and we can do uh the opposite so we're gonna do fake zebra is gen z of horse we're gonna send in the real uh a real zebra and it's going to be dz reel dz fake and this is going to be disc zed this said and here we're sending in the fake zebra then we're doing uh the zed real loss and the dz fake loss that's going to be dz real towards ones like tz reel and then the opposite here then we're going to add those two laws together which is the z fake loss uh dz real loss and then dz loss so to sort of put it together uh the d loss right the discriminator loss is going to be the h loss plus the z loss and i guess they divided it by two i think they mentioned in the paper i have no idea it doesn't make any difference i think but i guess for just to be sure we can do it and then um we're going to do op discriminator.0grad uh the scalar dot scale d loss and then dot backward uh these descaler dot step of up discriminator and then d scalar dot update right this is just sort of what we uh normally do all right we're halfway there so we've created the training for the discriminators now we're gonna set up all the training for uh the generator so train generators uh h and z all right so here we're going to do pretty much the same thing with torch.amp.autocast first we're going to do dh fake which is this h of fake horse and then uh we're going to do dz fake is disc zed of uh fake zebra right so we're doing the same thing here the only difference is here we detach the one to the generator and here we're not going to do that so what the generator wants to do is obviously to fool the discriminator so how we do that is by doing let's see loss g zed is mse of dz fake and then we want to trick the discriminator to believe it's a real one so real is one i'm gonna do towards ones like dz fake and i guess we could do this in different ordering so this should be first and this is d g h d h fake and then dh fake so all right that's that's the the loss for the adversarial loss right for both of the generators um so maybe we can do this to be clear adversarial loss for both generators then we're going to do the cycle loss and then we're going to do the identity loss and then we're going to add all together so the cycle loss is going to be basically if we cycle this zebra right so if we take this let's see if we take this fake horse and we try to generate a a a zebra of that one so if we do gen z of this fake horse then this should be give us back the the original zebra hopefully and similar thing if we do the opposite so cycle horse is gen h of fake zebra then for the identity loss uh so identity zebra loss it's just going to be l1 of zebra and then the cycle zebra and similar thing for the horse here so i think the horse loss is l1 of horse and then the cycle zebra sorry l1 of course and cycle horse and yeah so that is the cycle loss for the identity loss we're just going to do identity zebra is gen z of zebra so this is if we let's see this is if we send in a zebra to the one that should already generate a zebra then i think you know it shouldn't do anything right because it's already sort of a perfect image for the identity horse we're going to do gen h of horse and the identity zebra loss the identity zebra loss is going to be l1 of zebra and identity zebra and similar thing for sort of the opposite of this identity horse loss of horse and identity horse and yeah so let's put this together right the g loss is going to be loss g z um the adversarial loss for the discriminate f4 generator that generates a zebra and then loss gh then cycle zebra and we're going to times this with lambda cycle and cycle horse times config dot lambda cycle then we're going to do identity horse loss times config.lambda identity and then identity zebra loss times config lambda identity and uh some of you might be observant and say that well um we don't you wait a second sorry this is not correct so this should be cycle zebra loss this should be cycle horse loss and we need to add those here as well so cycle zebra loss and then cycle horse loss yeah it's easy to get lost in this so some of you might be very observant um when we looked at the paper we basically saw that they didn't use this identity loss when training the generators for the particular data set that we're looking at so you know really this would be unnecessary so we could remove those and the reason why they are necessary is because um i've sort of set the identity loss to zero so we're not using that part so it's really you know unnecessary computation here um and so you know it would be best to remove it of course it will run faster but the reason why i include it here is because i wanted this to be a general implementation so that you can just change that constant and this is something you can do um you know sort of for for that efficient uh sort of efficiency reasons but yeah that is pretty much all we need to do there and then and we're going to do upgen that zero grad g scalar dot scale of g loss dot backward g scalar dot step of opt gen and then g scalar dot update and that's it so what we can do also is we can check uh if index modulus 200 maybe we can do save image of fake horse and then i'm going to multiply this with 0.5 and add 0.5 to sort of do the inverse of this normalization so that we get sort of the correct coloring and i'm just going to do save images folder to pors index.png and then similar thing here for fake zebra and this is going to be zebra like that so we're just saving the images sometimes and uh yeah i guess that is it hopefully and uh you know really i'm hoping there are no errors here i'm gonna change this low model to false and i'm gonna change save model to false i guess i'm gonna just check yeah we can save the model but i'm not gonna load anything now all right so let's try to run this and uh let's see what happens i'm probably not gonna train this for very long just so we know that something is going on and then we'll load the pre-chain weights all right so in the data loading stuff in the main file this should be horses and this should be zebras and then we also need to create create that directory so we're gonna do uh directory saved images and hopefully it should work now and i guess just for fun sake we can see that this is about 5.5 iterations per second 5.45 maybe somewhere around there i just want to see if there's any sort of big difference if we were to remove all of this identity stuff so if we just comment out those yeah so now we're up at about almost seven seven iterations per second all right so uh yeah i waited for about four minutes and then i got bored uh but i guess you know we can maybe see some stuff like if we look at this one we can see that it's starting to do this a little bit brown sometimes and then here uh we can see that it's really doing some weird stuff um to the image and this is similar behavior that i got when i trained the uh the model for a bit longer i think what you need to do is uh train this one for um maybe 150 epochs you'll probably get to see some results after maybe 20 bucks or something like that but um i just realized that's going to take an hour i don't want to wait an hour so what i'm going to do is i'm just going to copy paste some trained weight and uh we are going to load those and uh we'll see what those look like right and i just noticed that you know in my implementation that i did i didn't use this relu here and so that is i think this should be here as we looked at in their implementation but perhaps we can just remove this for now and see if we can run the uh pre-trained weight uh model and then i think i named these res blocks so i need to change those as well all right so i'm just going to stop here here we can see the horses we can see that it's doing uh some stuff that makes sense and then for the zebra it's not so here i guess it's changing the the road instead of the actual uh horse which is pretty funny does that horse have sunglasses those eyes look kind of weird anyways um so yeah i might actually you know retrain this perhaps we're just going to change it back to the way we had it so we should drill blocks and then and then sequential of an unrelated like that so um yeah i'm pretty sure that um that this should be a relay here and then that would you know maybe speed up the the training a little bit so yeah i'm starting to think if there was anything else to add to this but i think that was it uh that is how you implement cycle gan from scratch um yeah hopefully this video was useful to you if you find the video useful please like the video and that way i'll make more of these from scratch videos with that said thank you so much for watching and i hope to see you in the next video [Music] you
Info
Channel: Aladdin Persson
Views: 11,768
Rating: undefined out of 5
Keywords:
Id: 4LktBHGCNfw
Channel Id: undefined
Length: 58min 25sec (3505 seconds)
Published: Tue Mar 16 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.