How to deal with Imbalanced Datasets in PyTorch - Weighted Random Sampler Tutorial

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
what is going on guys welcome back to another video in this video we're gonna be looking at how to deal with the very common problem of an imbalanced data set so let's roll that beautiful intro first and then let's get into the video [Music] alright so as an example i have images of two different dog breeds a golden retriever so i'll just show you one image of that and for that for the for that class we have 50 images so for we have 50 golden retrievers and for uh swedish l count we only have one image which is you know a very clear imbalance so what we want to be able to do is still give equal sort of weight for the um the network and so from my understanding there are two methods of dealing with imbalanced data sets so i'll just write that methods for dealing with imbalanced data sets and the first one is over sampling and it it is kind of exactly what it sounds like in that we will uh over sample that in this case this single image and you know we'll perform different data augmentation and so on but we'll see that example uh very m like more frequently than other examples and then the other method that i see is a class weighting and what that means is that when we're computing the loss for each of those classes when there is when there's the case of um of the swedish outcome for example we might give that a higher priority for the network and what that means is that we might multiply that loss by by some number i'll just say that from what i've seen oversampling seems to be the preferred method but i haven't seen any any studies and or any papers rather on on comparing the two but this one is the is the one i'll kind of focus on and the one that i see most in practice um and i'll actually show you the the second one first just because it's so much shorter and and easier so for the class weighting all you do is when you create your loss function so in this case cross oh yeah we need to import as well so um let's see i'll do all the imports so import torch vision data sets as data sets import os from torch utils data import weighted random sampler and data loader uh we're also going to need import vision dot transforms as transforms import torch.nn as nn so um now we can hopefully use this so cross entropy loss and all that you do here is that you send in a weight um so let's do torch tensor of in this case we have two classes right and let's just say that the first one the golden retriever is the zeroth class and then swedish account is the first class so uh what we do here is that we send in the weight for the golden retriever first so let's put that at you know one and so if we want to balance those two since we have 50 more examples for golden retriever than we do for the l count let's put a weight of 50 uh to the l count and you know this is as i explained earlier in that this will multiply that loss by 50 whenever whenever we see that image of the of the swedish l count and uh yeah so that's all you need to do for for class waiting um you know of course if you have more classes you would need to send in you know additional examples here and so on but you know this is how you do it if you want to do class waiting so let's now move on to the sort of the preferred method that i more commonly see so for that let's do a function get loader we'll do we'll send in a root directory and we'll also send in a batch size and this will all make sense soon i'm just writing some skeleton code right now so we're going to have a main function and then we're also going to do if name equals main um we'll run the main function all right so then um in my get loader i will just do first of all my what happened there um let's bring that back so what i want to do is my transforms equals first of all let's just do some transforms so transforms compose and we'll do transforms.resize to 24 and then we'll convert those to tensor okay so how we'll do it in now with loading the data is that we'll use image folder so we'll just do dataset is datasets.imagefolder and the root is just the root directory that we send in in this case it's in in that dataset folder but uh you know we'll do that soon and then the transform is going to be equal to my transforms so that's the one we just created um okay so now that we have in a data set uh we're going to use this weighted random sampler and so what we want to do first of all is that we want to sort of create some class weights so what we can do for the class weights is that we can send in one and then 50 right those are similar to the weights we did before um although they don't have to be on the exact number so for example we could do one divided by 50 and we could set one here that would equal the same thing it's just sort of uh the relative weight difference that that matters but maybe for simplicity let's just set one and 50. all right so right now we're just specifying those class weights i'll show you a way to do that more in code if you would have i don't know over 100 classes and you wouldn't want to go through each of them and show look exactly how many examples you have and so on um but all right so then we are going to create sample weights and this is going to be just zeros to start with and then we're going to times that by the length of our data set so how this weighted random sampler works is that we need to specify exactly the weight for each example in our entire data set so how we do that is that we first create sample weights to just be zero which is and the length so each each example in our data set starts with having a sample weight of zero then we're going to go through our data set so we'll do for index and then data comma label in in enumerate of our data set and the first thing we're going to do is we're going to take out what the class weight is for that particular class so that's why we created this class weights so we do class weights of of some label right depending on what that label is so we'll take out the class weight for that and we'll just call that class weight then we'll do sample weights of this particular index right for this particular sample of our data set we'll set that equal to class weight and that's pretty much it so now we've created those sample weights and then we'll create our sampler and this is going to be our weighted random sampler where we'll send in the sample weights we'll send in the num samples which is going to equal the length of our data set or yeah so length of sample weights and then we can also specify replacement equals true or false all right so when editing the video i noticed that i didn't really explain why we set replacement equals true and that is because uh if we set it to false then we'll only see that example once when we iterate through our entire data set so obviously that's not what we want when we're doing over sampling so when we're dealing with an imbalanced data set and we're using oversampling then we always want to use replacement equals true but now that we've created our sampler we want to create our loader so our loader is just going to be a data loader of that data set right we're all used to this this is just what's you know normal for creating our for when we create our data set and data load and so on so we'll just do the batch sides and we'll set that to batch size which we send into this function right here and then what's different is that we specify a sampler and in this case our sampler is just going to equal sampler which is this weighted random sampler all right so that might have been you know a little bit quick i'm not sure but let's go through it so we make sure what's actually going on here so first of all we're creating our transforms in this case we're just using resize into tensor you know in reality you would in practice you would normally add some data augmentation and so on and then we're loading our data set using dataset image folder and and here we're sending in some root directory which in our case is going to be that data set and that's going to automatically handle the loading for us then we're specifying class weights in this case 1 and 50 because we have much more so we want to prioritize this class much more than the first one because we have fewer examples for this class and then we're creating our sample weights that's going to be the weight for each individual individual example in our data set to create those sample weights we're first starting out with initializing them as zero then we're going through all of our examples in our data set and then specifying exactly that weight dependent on which class that example belongs to and then we're creating our sampler where we send in this those sample weights and then we specify you know how many examples we have and then replacement equals true um and then we're creating our data our data loader as normal and the only difference is that we send in this sampler right there so i'm going through this very step-by-step because in the beginning i didn't feel that this was very intuitive for me but um when you get used to it it sort of makes sense so what i want to do is generalize this bit right here because i don't want to individually or you know write all of the class weights all the time because that might take some time when you have over 100 different classes so what we'll do is we'll just create an empty list and we'll do uh for root and then subdirectory files in os dot walk of root directory all right and if you're not familiar with os.walk we're simply walking through each of those subfolders in that root directory so we'll we're going to check if the length of the files are greater than zero then we'll just um add class weights dot append and then we'll add the length of those files sorry uh we're not actually going to add the length of those files because that then we would prioritize those who have more examples so we'll do one divided by the length of those files and i also just added this if the length of those files so if there are no files in that subfolder then we would be simply you know dividing by by zero here and i guess there are other ways but this is just a simple way of dealing with that problem okay so now we've created our git loader let's create our main file here and make sure that this works and did i do a mistake here oh yeah sorry this should be two equal signs so now let's do loader is and we'll run our get loader we'll sin send in that root directory to be data set and then we'll create a batch size of eight and just to make sure that it works or actually first of all we can just go through them so for x and y or data comma labels in loader print labels and then let's run that and that should just be transforms transform rather than transforms so hopefully that works now all right and as you can see here uh if we you know if we would have just not done those class weights then we would not see this very balanced data set right here so here we're seeing it might be difficult to count all of those but this should be balanced and to make sure that it actually is we can do something like for epoch in range of i don't know 10. we can go through that data set and we can count how many you know number of retrievers and then the number of l counts so we'll just do number of retrievers plus equals torch.sum of labels equals um let's see that was zero and then we'll copy that and we'll do num lcounts plus equals labels equals one and then in the end let's just do print num retrievers and then print num l counts so let's see what that looks like all right so here we can see i think there's just some randomness to how i sort of um how exactly that number comes up and i think if we rerun it we're probably going to see a different result but as we can see they're at least relatively balanced much more balanced than they were in the beginning um so that was it for dealing with imbalanced data sets hopefully this is useful to you uh if it is then please do subscribe to the channel because that helps a lot damn i feel like a seller when i'm saying that but anyways thank you for watching the video and i hope to see you in the next one
Info
Channel: Aladdin Persson
Views: 7,951
Rating: undefined out of 5
Keywords: imbalanced dataset pytorch, class weighting dataset, oversampling dataset, pytorch weighted random sampler
Id: 4JFVhJyTZ44
Channel Id: undefined
Length: 15min 25sec (925 seconds)
Published: Fri Jan 08 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.