PyTorch Image Segmentation Tutorial with U-NET: everything from scratch baby

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
what is going on guys hope you're doing awesome in this video we will go through how to do image segmentation in pytorch and this means that we will you know set up everything from scratch uh so we're gonna build a model and that the model is going to be very similar to unit and then we're going to set up our data loading pipeline which will include doing data augmentation for segmentation tasks and for that we will be using the albumentations library so if you're new to unit or if you don't know what that is or the albumentations library then i have separate videos for both of those which you could check out and then lastly we will train this this model on the carvana dataset which is a a kaggle competition from a few years back all right let's roll that intro and then let's get started with the model [Music] all right so let's just create the a model.python file and maybe before i start let me just talk a little bit about the differences that are going to be in our implementation versus the paper so here's an image that i went through on the paper review video but so uh the difference the major difference is going to be that we're going to use padded convolutions because it's going to simplify a lot maybe not for the model the actual model implementation you know it's pretty trivial to just change the padding to zero and then you'll get a valid convolution but the major difference is that when we're doing the data loading part that's going to be much more difficult if we use these valid convolutions uh so you would need to implement this uh uh this uh mirroring padding which might not be too too challenging but when you have larger images you you're gonna have to need you're gonna have to do sort of a multi-crop of your image and then sort of stitch that up in some way afterwards just makes things a lot more complicated and when you think about the fact that the winners for example for the carvana kaggle competition they used same same convolutions in their implementation and i think the prize was thirty thousand dollars or something like that and you can only imagine that if their accuracy would have improved using valid convolutions then that's exactly what they would have done so you know we're not losing too much performance here all right that's enough rambling about the architecture let's just get started so what we're going to do is we're just going to start with import torch import torch.nn as nnn and we're also going to use import torch vision dot transforms dot functional as tf all right so looking at the architecture right here you can see that the first thing that they do is that they or actually what they do in sort of all of these steps is that they use two convolutions in between so they have some input they run it through two three by three columns and then they do similar right here similar right here similar right here and similar on the opposite side so one idea is that we can create a class which we'll call double convolution we'll inherit from nn module we'll create our init function that's going to take in some in channels and some out channels and we'll just call super first of all double com self dot init all right so all we'll do here is just self.com equals a sequential model which will use a com2d with some in channels some out channels we'll set the kernel size to 3 the stride to 1 and we'll set the padding to one which will make it a same convolution um so you know i'm assuming you know what same convolution is but that means just that the the input height and width is going to be the same after the convolution i'm going to set bias equals false here also and the reason for that is because i'm gonna going to use batch norm and so in the original paper i think the paper was from 2015 and i think batch norm the bachelor paper came in 2016 so of course they didn't use batch norm but i figured that there's i mean it would only be beneficial so let's use it and um so we're using by bias equals false just because we're using a batch norm so uh it's just an unnecessary parameter to have a bias because it's going to be cancelled by the batch norm that's just the detail but yeah then we'll use a relu we'll set in place to true and we'll just copy paste this right here and do our channels to our channels for the second comp and that is it so we'll just do forward forward is going to take some input x and we'll just return self.com of of x all right so what did we do well we created this step that are going to be sort of intermediate between these and so what we want to do here is just a self pool right we just want the max pool stride 2 kernel size 2 and then again use these these double columns in between so we'll just create our class class unit we'll inherit from nn module again create our init function and let's see so we're going to have self we're going to specify some in channels in our case our in channels um is going to be 3 so let's just set default to 3. the out channels is going to be one so in the paper they had our channels to two and um well i guess i mean this is a parameter you can choose but we're going to do sort of a binary image segmentation of course we're gonna you're going to be able to extend this to to several but the difference here is that if we're doing binary we can just output a single uh channel so let's just set output channels to one as default then we're going to specify some features and these features right here is going to be sort of the features right here so 64 128 256 512. um and then they're going to be exactly the same on the upward so we'll do just set 64 128 to 56 and 512. like that all right so then what we'll do next is call super can't forget that part so super in it all right and then what we're going to create is um we'll create two uh lists so uh we'll call them self.downs and we'll create a list but you can't just use a list like this uh because we're going to be storing the um sort of the convolutional layers and all that stuff and then what we got to do is use a module list and i guess why this is important is because we want to be able to do um what is it called like model.eval and stuff like that and then we um so for the batch normal layers and so on and that's why it's important to have this module list so we'll just have one for ups and we'll have one for downs and then we'll have a pooling layer so let's just do max pulling 2d the kernel size will be 2 stride will be stride will be 2. and so let's just do the down part of unit first so for the down part what we'll do is we'll go through the features so for feature in features um and we'll do first self.downs dot append so we'll add a layer to this module list it's going to be a double comp and it's going to be in channels right because this double comma that we just created take some in channels and some out channels so we'll just send in in channels as initial and then feature right because it's going to be if you look at the architecture it's just going to map some some input right in this case one that they had to 64. and that's exactly what we have right so in in the beginning or in the beginning we have three here but it's going to map three to 64. for the next feature we need to make sure that we set in channels to feature right so this is going to go through the loop depending on all that we have in those features and it's going to do all those com layers all right and we're going to use this pooling layer in between in our forward method but you'll see that uh pretty soon so uh what's next so we gotta create the the up part and what we'll do here is we'll just go through the so if we look right here what we just did is create all of these steps right here and this at the bottom here will create sort of outside um of the for loop so we'll do we'll do that pretty soon so for the up sampling part we're going to be using transpose convolutions what you could also do here is just use a sort of a bilinear and then a comp layer afterwards uh and you know that's sort of what they um what they've started or they've done it for a long time but if you look at gans for example in progan they realize that transpose convolutions create these artifacts and they start to use that so i would imagine that would be a better option and a cheaper option probably but you know we're not going to do that we're going to use similar approach to the paper but so what we'll do is we'll go through the features what's important here is that obviously we're going to go from the bottom and up so we'll do for feature in reversed of features in our com transpose 2d we'll set the in channels to be featured times two and the reason for this is because we're going to be adding the skip connection so it's going to be you know 512 times 2 and that's going to be the case for for all of the transpose we're going to concatenate along the channel dimension sort of this this gray skip connection right here for all of those um sort of up layers and then the output is just going to be feature so as you can see it's right here is 512 and then here it's 256 and then it's 128. so we'll just do it to feature the kernel size is going to be two stride is going to be two so this will sort of um double the the height and the width of the image all right what we'll do next is we'll just append also a double comp so we'll do feature times two to feature and uh so why we're doing that is because we're going to do up and then we're going to do two columns up two columns up two columns so that's what where we're doing uh appending this right here and let's see so what's missing right now sort of this at the bottom right here which i'll call a bottleneck layer so we'll just do self.bottleneck is equal to double comp of features -1 and so what we're doing features minus one is because we want this 512 right that's the last in our features list and we're going to map that to features uh times five features minus one times two right so it's going to be a thousand twenty four and then uh for the last thing we're gonna do right here at the top we're going to this sort of the absolute last thing we're gonna do is a final column which is going to be a one by one comp and it doesn't change the um the height and the width of the image and really what it what it does is uh just change the number of channels so we'll do final comp that's just going to be a common 2d layer of features of zero to out channels and the kernel size of that is going to be one all right hopefully i didn't go through this too fast so that's just the inner part now we've created all of the modules that we're going to need for the forward step and so let's do it define forward self comma x and what i'm going to do is i'm going to create a list i'll just call that skip connections and you guessed it we're going to store all of the skip connections in that one so let's see what we're gonna do is we're gonna do four down in uh self.downs so you know that's the the downward part right that's what we do first and for the downward part we just do x equals down of x then we add the skip connection right before the down sampling so we we do these two columns add the skip connection to that which we're going to use later on and we do that for every every part so skip connections dot append x and i guess the ordering is important to know here the first is going to be the one with the highest resolution and then the last will be with the smallest resolution and uh what we'll do next is self pool of x and that's it for the downward part then we'll go through the the bottleneck layer right so what we've done in these in this for loop is we've done all of these steps until we're right here so um we've done the pooling and we just want to do this double com right here and this double comp will be a the self.bottleneck all right and all right so one thing we're going to do just to make the skip connections a little bit easier is because obviously we want to sort of go backwards in that order when we are doing the concatenation as i said the first element is going to be the one with the highest resolution so i'll just reverse that list so skip connections um we'll go and do skip connections and then just reverse that list and then what we'll do is uh we'll go through the the self.ups so the self.ups that we created we added a comp transpose to d and we also added the double comp so what i'll do here is i'll do for index in range of zero length of self.ups and then we'll do a step of two all right and the reason for that is because i want to do the um the up and then double comp so up and double comp and that's going to be sort of a single step that's why we're using a step of two here again there are probably better ways of doing this um let me know if you think of a better way to do it but so what we'll do is we'll do self.ups of that particular index so zero in the beginning right we'll do of x of that and what so what we're doing here is is uh we're doing the com transpose 2d right what we'll do then is we'll add the skip connection so we'll do skip connection equals skip connections of index integer division by two and why we're doing that is just because we're doing this step by two right there so obviously we want to still take escape connections in you know a linear sort of a step of one ordering um so that might be a little confusing but nothing too too difficult there then we'll concatenate them so we'll do concatenate skip sort of concatenate this hip connection we'll do tour start concatenate skip connection and the x and we'll add them along the channel dimension right so we have batch channel height width so we'll add them along dimension one lastly we'll just do self ups of index plus one of concatenate skip right that's it so what we're doing here is we're doing the up sampling all right so following this we're doing the up sampling we're concatenating uh no sorry okay so we're getting the skip connection and then we're concatenating then we're just running it through this double comp and then the last thing we're going to do is just return self dot final conf of x right so i i think sort of uh we're done uh but i guess i guess one improvement we could make here is uh let's say you're not in inputting something that is perfectly divisible by two at every step right that's that's what is going to happen for this um self.pulling layer so for just to give a concrete example let's say we input 161 by 161. um what's going to happen here is that the max pool is going to take this into an 80 by 80 so it's going to floor the division by two but so we have 80 by 80 and and obviously when we then up sample that's going to create an output so the output will be 160 by 160 meaning we won't be able to actually sort of concatenate the two right because they need to have the same height and width and that's a problem so you could solve that by always choosing an input that is perfectly divisible or i guess divisible by 16 because it's going to be four steps right so one two three four all of those are going to be divided by two so that's one solution but of course we want to make our implementation general in that way so what we can do is when we're concatenating them we can just check right here if they don't match so if x dot shape is not equal to skip connection dot shape then you can solve this by different methods um in the paper which is a little bit different but they used cropping right here um what you could also do is you could add padding and what i what i did is i'll just do a resizing so i'll use this tf the torch vision functional and we'll do tf.resize of x and then we'll just set the size so if you think about it the the x which is coming from the um sort of the upward part here it's always going to be smaller um than what was you know down sampled just because of that reason that that max pool always um floors the uh the shapes so so i guess it's just important to resize the x and i don't think this is going to um sort of impact the accuracy too much i hope because you know it's just going to be you know sort of a one pixel difference and so i don't think it matters too much i haven't tried this extensively though but so skip connection dot shape so what we'll do here is we'll just take out the height and the width so we're skipping the the batch size and the number of channels but that's all we have to need for the resizing um all right so that was you know unit in yeah i guess 67 lines of code not too bad and hopefully it's understandable and readable still so we'll just create a test right here the test torch.random say we have a batch size of three input let's see one channel and then we'll do i don't know 160 160 which is perfectly divisible then we'll do model unit in channels we'll set to 1 our channels will set to 1. we'll do predictions is model of x and what we want to do or make sure is that the input is the exact same shape as the output so we'll just do pre print press dot shape print x dot shape and maybe also assert threads dot shape equals x dot shape and we'll do if name equals main then we'll run that test right there all right pretty good let's see if we can run this now and we get the exact same shape meaning no errors were done during this that was that's amazing i always make mistakes all right pretty nice so what we're gonna do next i think is uh we'll look at the data and we'll set up the data loading part and then we'll actually do our um then we'll do the training uh set up the training so let's just create one called uh data set and what i got to do first here is actually copy in so this is the code that i've written beforehand what i'm going to do is i'm going to copy in that data into our folder right here all right so what i've done is just put in the data right here and this you can download this from kaggle all i've done here is i've just taken the training images which are which look like this and let's see if i can pull this up so that's how they look like and uh what they want is uh the segmentation for the for the car and so you know the the uh the training masks looks like this which is a binary just black and white and then and then what i did is i just took sort of a couple of examples 48 examples to be specific to a very small validation set and so that's how the data set looks like you can download it from kaggle there's going to be a link in the description so what we want to do now is set up the data loading part all right so for that we're going to use os we're going to import pill i guess from pill import image then we're going to do from torch utils.data import data set import numpy as np and i just thought of one thing we could do actually on the model so in the model right here um i chose this 160 which is perfectly divisible let's just do 161 and make sure that it still works for the model yeah and there's no error right so the assertion uh was was okay and that means that our resizing also worked anyways moving on so we're going to create a class we're going to call carvana dataset and inherit from dataset let's create the init method so what we'll take in here is the image directory and a mask directory and we'll also set transform to none by default then we'll just do self.image directory it's image directory and we'll just copy paste that to times i'm going to do the same thing for the mass directory all right so now we have the transform all that what we will also do is we'll do self.images is os list directory of image directory so that'll list all the files that are in that folder so for the length of that data set is going to be pretty simple we'll just return the length of self.images for our get item get item like that uh we're going to take as input some index and what we'll do is uh we'll first get the image path which is going to be os path join of the image directory so right so where we stored the images and then also the file of that particular image so that'll be self.images of index right hopefully you're following following so far for the mask path pretty much the same thing so os path join but now we'll do self.mask directory and we'll also do self.images of index but the only difference here between so i didn't take a look at show you exactly but for the train masks it has the same name but at the end it just has an underscore mask and instead of being a i think jpeg or png it is actually a dot gif so we gotta adapt to that we'll just do self images dot replace and all we'll do here is dot dot jpeg and we'll change that to underscore mask dot gif all right so now that we have the mask path and the image path we'll just load those two so image equals numpy array of image dot open of image path and yeah we'll also do dot convert to rgb here i think it's by default rgb so we might not have to do this but just make sure uh and also why we're doing numpy over here is because we're going to be using the albumutations library which if you're using pill needs to be converted to a numpy array so we'll just do mask equals numpy array of image dot open mask mask path and then we're going to do dot convert l and the reason why we're doing l here is because the mask path is isn't going to be a grayscale so that's how you do that for pill and one more thing also is we'll specify the d type to be np float 32 right so this is going to be loaded as images between so the i guess it's going some the most are going to be 0 and then some elements are going to be 255 for the white uh parts so um what we'll do is we'll just create some sort of a pre-process for the mask and we'll we'll look for the mask where it's equal to 250 to 55 and then we'll change that to one so the reason for that is because we're going to use a sigmoid on our as our last activation and so um indicating the probability that it's a sort of a white pixel and then to make sure that it's actually the correct for the labels we'll we'll just convert those two to one all right we'll just do if self.transform is not none then we'll perform the data augmentation this is using the albumentations library so we'll just do augmentations is self transform of image equals image and then mask equals mask and so this you should be pretty familiar with this if you watch that video otherwise you know we're just sending in the image and the mask and then to obtain the image we'll get the augment and we'll get the augmented image from the dictionary with the key image and similarly with the mask all we'll do here is just change that to mask pretty simple to be honest and then return image and mask all right pretty cool that's for the data set and so what what we've done so far is create the model and the data set i guess what we'll move on to now is create the training part alright so let's just create a new file called train and i'm gonna copy paste in let's see i'm going to copy paste in the imports here so we're going to use torch albumentations as a we're going to need to tensor v2 tqdm for the loading for the progress bar and then we're gonna import our model as well and um yeah we're also gonna create a utils file where we're gonna import um i guess load checkpoint save checkpoint let me copy pasting that as well and we're going to create those uh i guess later on so we're going to be using low checkpoint save checkpoint get loaders check accuracy and then save predictions as images so we'll create those later on just i'll comment that for now and for hyper parameters i'm going to copy paste in those as well so i'm just setting the learn rate 1e minus 4 device cuda if it's available otherwise we'll set it to the cpu batch size 32 might have to decrease that now when i'm recording number of epochs maybe let's set it to three numbers two uh image height i'm just going to set a very small image height they are originally image height is originally 1280 and the image width is 1918 originally so we're just going to be using you know a very very smart small part uh but you would just change this if if that's um you know if you were doing the competition and you wanted to obtain better score um and yeah so pin memory is true we're going to set load model set it to false originally and then just specifying the paths uh where we have the data set that's the training images and then validation images all right so i'll just show you the general structure that we're going to be creating so the general structure will just have a training function where we will send in some loader some model some optimizer loss function and also scalar i will go through that so that's the train function we'll have a main function and then we'll do check if name equals main and why this is important right here is uh it's on windows you need to do this so you don't get any issues when you're running num workers all right so i mean we could perhaps start with the train function so the train function is going to do one epoch of training and we'll use tqdm here for a progress bar so we'll just do loop is tqdm of loader and if you're not familiar with this you don't you don't have to do this but i have a separate video on tqdm for progress bar as well so we'll just go through for batch index and then data and targets uh in enumerate of that loop and what we'll do is we'll just get it to cuda so data.2 device equals device same for the target also one thing we need to do for the targets is uh convert it to float it should be float right um yeah it should be float for some reason it might not i don't remember but i think it's important for the um i think it's important for the binary cross entropy loss that we're going to be using um yeah so i'm just converting to flow i think i got an error for that or something all right we'll just do unsqueeze one and sort of adding a channel dimension and then we'll do two device equals device all right so for the forward um we'll use um float 16 training uh which i have a separate video on um how to do but so float 16 um just sort of reducing our our vram and speeding up training so it's not too difficult to do either we'll just do with towards cuda amp auto cast we'll get the predictions by doing model of data the loss will send in loss function and the predictions comma targets for the backward i will first do zero grad zero all the gradients from previous we'll do scalar.scale of loss.backward and then scalar dot step of optimizer and then scalar dot update all right and then we'll do sort of we can update uh tqdm loop showing the loss function so far so we'll do loop.set postfix loss equals loss.item all right so that's for training one epoch in our main function i'm going to let's see i'll copy in those two just because i've covered it in the albumentations library when i did that video but so all i'm doing here is uh we're doing a resize to the height of image height image width which we wrote at the top and then we'll do a rotation horizontal flip vertical flip normalize here we're actually just dividing by 255 so we're getting a value between zero and one and then to tensor and similar for the validation although we'll just resize and normalize right so that's all we're doing here um very similar to pytorch nothing really difficult here all right so next we'll just create our model i'm is going to come from unit which we created so in channels is going to be three our channel is going to be one and we'll get that to uh device the loss function is going to be a bc with logit loss and so so binary cross entropy and why are we using with largest is because we're not so on our output we're not doing sigmoid um you could remove this with logic loss uh if you did right here if you did torch dot sigmoid on the model or the output of the model all right and then so what would you do if you were extending this if you wanted multi-class segmentation all you would do you would change the out channels let's say you had three different colors three different classes you would change this to three you would change the loss function to a cross entropy loss and that's all right so you know this is very easily extendable to multiple classes let's just do that for now and then we'll do our optimizer so our optimizer we're just going to use atom with mondo parameters learning rate equals learning rate all right so to get our loaders obviously we could do you know carvana data set and then data loader of that i think it looked a little bit ugly it took a little bit too much space on on this main function so we're just going to do train loader comma val loader and we'll run uh get loaders and i guess here we're going to send in the train image directory train mass directory validation and the batch size the transforms that we're going to perform and then the valve transforms so we'll go to that create that utils file and do that but i guess we could finish this as well so let's we're going to create that pretty soon first we want to do scalar is torch dot cuda scalar we'll go through for epoch in range of num epochs i guess also we need to send in here we're going to send in num workers and also pin memory all right so for epoch in range of num epochs we're going to call our train function which will just send in the train loader the model the optimizer the loss function and the scalar right nothing too difficult here and let's see what we need to do else so i guess one thing what we'll do later on is we'll do save model we'll do check accuracy and then we'll do print some examples to to a folder and then we'll just see that it actually looks uh good but so first let's create this utils files now so utils and i'll just copy paste in the import so to import torch vision torch carvana dataset right and then the data loader i'll paste in two files right here which is two functions rather which is save checkpoint and load checkpoint again i also have separate videos on that if you want to check that out and i'll also copy paste in the get loaders and i'll go through it so the reason for this is i think this is honestly pretty simple uh all we're doing here is we're creating the corona data set we're specifying the transforms and all of that the image directory match directory for the train loader i was specifying the training data set batch size num workers pin memory and then shuffle similar thing for the validation so the difficult part here was actually creating that karana data set which we did before and all we got to do is just create similar for the validation loader and we'll we won't shuffle that one and then just return train and loader i guess we could do we could do the check accuracy that might be a little bit more difficult so we'll send in some loader some model some device default cuda and we'll check first so the number of correct it's going to be zero number of pixels will set to zero so you know remember for segmentation we're outputting a prediction for each individual pixel the class for each individual pixel so what we'll do is we'll just or we'll get to that so and then we'll also do model.eval we'll wrap everything with a torch no grad and then we'll go through that loader so for x and y in loader we'll do x x dot to device y to device uh we'll run through the model so we'll get some predictions which is torch sigmoid of model of x right then what we'll do is we'll convert all those that are higher than 0.5 to 1. all those less than that will convert to zero obviously you would need to adapt this check accuracy to to if you would have more classes but i don't think that would be too difficult perhaps but just know that this is for binary so we'll do predictions uh as predictions greater than 0.5 and then dot float right we'll do number of correct is plus equals the predictions that are equal to y and then sum so that will sum all sum all of the pixels then the number of pixels that was in that image is going to be plus equals torch dot number of elements numeral of predictions and yeah so then we'll just print um you know an f string got number correct divided by number of pixels with accuracy and then let's just do a num correct divided by num pixels and we'll multiply that with 100 and then with two floating decimals uh yeah so that's it for the accuracy one thing i actually want to do here as well which if you don't follow this it doesn't matter too much but one thing that's kind of flawed when doing um the the accuracy is that if we just output sort of just black pixels it will have an accuracy of greater than 80 so they found that there are better metrics for measuring this you know similar to object detection using intersection over union and so on it's a better metric they also found that there's a better metric for this which is using a dice score and this is just one line so i'm just gonna add it but what it is is um it's gonna be two times sort of the um the pixels where they um where they're where the boat uh outputting a sort of a white pixel so we're gonna just element wise multiply them so where they are the same they're going to output one we're gonna sum those two so we're we're summing the number of pixels where they are both the same or both outputting a white pixel then we'll divide that by prediction by predictions plus y dot sum so we're dividing by the number of pixels that they're outputting uh one for both of them and we'll just add a epsilon of one e minus eight yeah so um yeah actually i don't know too much about diet score this this is just for binary so it's pretty simple here uh you can probably google and find out how you do it for multi classes um and so we'll just use this for now just knowing that it's a little bit of a better measurement to evaluate how good it is but if you don't follow this don't don't bother too much about that all right and then we'll just print dice score uh we'll do die score and then divide up by the length of the loader all right and then we'll set model.train again all right so that's the check accuracy for the same predictions as images that i also imported i'll copy paste that in and so what we're doing here is we're just going through and then outputting the prediction and the corresponding um sort of correct one associated with that with that one it doesn't matter too much here this is just so that we can get a visualization of what it's doing all right that was probably too fast but yeah what we'll do now is we'll just remove from utils we'll just import all of that stuff and yeah so now we can in our right here we can save the model and we can do all of those things that we wanted to do so for saving the model we'll create a checkpoint uh we'll just do a state dictionary model state dictionary we'll also store optimizer optimizer state dictionary then save checkpoint of checkpoint from the utils file and then check accuracy we'll just do check accuracy of validation loader send in the model send in the device and then we'll print some examples to a folder so let's do save predictions as images we'll send in the val loader model the folder will do saved images and then the device is just a device so what we'll do here is we'll just create a folder right there directory called saved images and yeah that should be it for this one and are probably going to be some errors but you know that's how that's how it is so let's try and run this now wow it actually worked on the first try i mean i haven't made a single mistake this entire video that's insane yeah so we're not using the low checkpoint um i guess what we can do in the meanwhile is also let's see right here maybe we'll do if load model will do load checkpoint of torch.load of my checkpoint that's just what i call the file mycheckpoint.pth.tar and then model yeah i'll just let it run and we'll see how good it is at the end of training all right so i also noticed now after one epoch the accuracy stuff is actually wrong um and i wonder what we did wrong there exactly i think uh we missed to do one thing which is that we gotta not unsqueeze one right there and the reason why we have to do that is because uh the label doesn't have a channel um because it's grayscale so we just have to do that unsqueeze and then the accuracy should work so now we're just outputting rubbish but the die score should be okay so what we'll do is uh we'll just let this train until it's uh done and then we can rerun it and evaluate and see what we get for the accuracy and the die score all right so now that it's trained three epochs what we'll do is we'll just copy this check accuracy i will put it just after we load the model and then i guess we'll go to load model and we'll just change that to true and rerun it and see what we get all right so we see that we get an accuracy of 99.52 um and and we've just trained three epochs we get a dice score of 98.7 uh and you can just train this for longer and it would definitely improve uh then you would change the image height image width make it larger um and that's probably what i would do and then maybe perhaps consider adding more data augmentation but yeah so i mean this is for sure working let's see if we can look at some saved images so here are some predictions let's put them side to side yeah all right so here it is so at the right we have the correct target segmentations on the left we have the predictions for the corresponding targets as you can see it's obviously very close to being correct which is why it also has that good of a score um i think that what the winners had of that kaggle competition was 99.7 in dice score so obviously we're pretty far off but yeah and also we're looking at smaller images um what you would do what you would do is you would want it for these original images and then you would just do a resizing and yeah so that would and i guess the resizing has to be done with a um with a nearest interpolation but yeah so what we did was a lot in this video it's probably going to be quite long but if you consider all that we did i think it was pretty compact we built the entire structure from scratch the model the data set loading part all of that stuff and we trained it as well all right that's it thank you so much for watching the video like the video if you thought it was good and uh yeah see you in the next you
Info
Channel: Aladdin Persson
Views: 46,493
Rating: undefined out of 5
Keywords: Pytorch image Segmentation, Pytorch unet from scratch, Pytorch semantic segmentation
Id: IHq1t7NxS8k
Channel Id: undefined
Length: 51min 53sec (3113 seconds)
Published: Tue Feb 02 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.