Pytorch Random | How To Deal With Imbalanced Datasets In Pytorch – Weighted Random Sampler Tutorial

Aladdin Persson

Subscribe Here





How To Deal With Imbalanced Datasets In Pytorch - Weighted Random Sampler Tutorial


What is going on, guys? Welcome back to another video in this video. We’re gonna be looking at how to deal with the very common problem of an imbalanced data set. So let’s roll that beautiful intro first. And then let’s get into the video. [MUSIC] Alright. So, as an example, I have images of two different dog breeds a golden retriever, so I’ll just show you one image of that and for that for the for that class, we have 50 images so for we have 50 golden retrievers and for, uh, Swedish L count. We only have one image, which is, you know, a very clear imbalance. So what we want to be able to do is still give equal sort of weight for the, um, the network and so from my understanding. There are two methods of dealing with imbalanced data sets, so i’ll just write that methods for dealing with imbalanced data sets and the first one is over sampling and it. It is kind of exactly what it sounds like in that. We will, uh, over sample that in this case, this single image and, you know, we’ll perform different data augmentation and so on, but we’ll see that example, uh, very M like more frequently than other examples and then the other method that I see is a class weighting and what that means is that when we’re computing the loss for each of those classes when there is when there’s the case of, um, of the Swedish outcome, for example, we might give that a higher priority for the network and what that means is that we might multiply that loss by by some number. I’ll just say that from what I’ve seen. Oversampling seems to be the preferred method, but I haven’t seen any any studies and or any papers rather on on comparing the two, but this one is the is the one i’ll kind of focus on, and the one that I see most in practice. Um, and I’ll actually show you the the second one first, just because it’s so much shorter and and easier, so for the class weighting. All you do is when you create your loss function. So in this case cross. Oh, yeah, we need to import as well. So, um, let’s see I’ll do all the imports. So import torch vision data sets as data sets import OS from Torch Util’s data import weighted random sampler and data loader. Uh, we’re also going to need import vision dot transforms as transforms import torchnn as NN. So, um, now we can hopefully use this so cross entropy loss and all that you do here is that you send in a weight. Um, so let’s do torch tensor of in this case. We have two classes right and lets. Just say that the first one the golden retriever is the zeroth class and then Swedish account is the first class. So, uh, what we do here is that we send in the weight for the golden retriever first, so let’s put that at, you know, one, and so if we want to balance those two since we have 50 more examples for golden retriever than we do for the L count, let’s put a weight of 50 uh, to the L count, and you know, this is, as I explained earlier in that, this will multiply that loss by 50 whenever whenever we see that image of the of the Swedish L count and, uh, yeah, so that’s all you need to do for for class waiting. Um, you know, of course. If you have more classes you would need to send in. You know, additional examples here and so on. But you know, this is how you do it. If you want to do class waiting, so let’s now move on to the sort of the preferred method that I more commonly see so for that, Let’s do a function get Loader. We’ll do we’ll send in a root directory and we’ll also send in a batch size and this will all make sense soon. I’m just writing some skeleton code right now, so we’re going to have a main function and then we’re also going to do if name equals main. Um, we’ll run the main function, all right, So then, um, in my get loader? I will just do first of all. My what happened there? Um, let’s bring that back. So what I want to do is my transforms equals First of all. Let’s just do some transforms so transforms compose and we’ll do transformsresize to 24 and then we’ll convert those to tensor. Okay, so how we’ll do it in now with loading the data is that we’ll use image folder, so we’ll just do Dataset is datasetsimagefolder and the root is just the root directory that we send in in this case. It’s in in that dataset folder. But, uh, you know, we’ll do that soon. And then the transform is going to be equal to my transforms, so that’s the one we just created. Um, okay, so now that we have in a data set, uh, we’re going to use this weighted random sampler and so what we want to do. First of all is that we want to sort of create some class weights, so what we can do for the class weights is that we can send in one and then 50 right. Those are similar to the weights we did before, um, although they don’t have to be on the exact number, so for example, we could do one divided by 50 and we could set one here that would equal the same thing. It’s just sort of, uh, the relative weight difference that that matters, but maybe for simplicity. Let’s just set one and 50 all right, so right now we’re just specifying those class weights ill. Show you a way to do that more in code. If you would have, I don’t know over 100 classes, and you wouldn’t want to go through each of them and show look exactly how many examples you have and so on, um, but all right, so then we are going to create sample weights and this is going to be just zeros to start with and then we’re going to times that by the length of our data set, so how this weighted random sampler works is that we need to specify exactly the weight for each example in our entire data set. So how we do that is that we first create sample weights to just be zero, which is and the length, so each each example in our data set starts with having a sample weight of zero, then we’re going to go through our data set, so we’ll do for index and then data comma label in in enumerate of our data set and the first thing we’re going to do is we’re going to take out what the class weight is for that particular class, so that’s why we created this class weights, so we do class weights of of some label, right, depending on what that label is, so we’ll take out the class weight for that, and we’ll just call that class weight. Then we’ll do sample weights of this particular index, right for this particular sample of our data set, well set that equal to class weight and that’s pretty much it so now we’ve created those sample weights and then we’ll create our sampler and this is going to be our weighted random sampler where we’ll send in the sample weights. We’ll send in the NUM samples, which is going to equal the length of our data set or, yeah, so length of sample weights and then we can also specify replacement equals true or false. All right, so when editing the video, I noticed that I didn’t really explain why we set replacement equals true, and that is because, uh, if we set it to false, then we’ll only see that example, once when we iterate through our entire data set, so obviously, that’s not what we want when we’re doing over sampling, so when we’re dealing with an imbalanced data set and we’re using oversampling. Then we always want to use replacement equals true, but now that we’ve created our sampler, we want to create our loader, so our loader is just going to be a data loader of that data set, right, were all used to this. This is just what’s? You know, normal for creating our for when we create our data set and data load and so on, so we’ll just do the batch sides and we’ll set that to batch size, which we send into this function right here. And then what’s different is that we specify a sampler and in this case, our sampler is just going to equal sampler. Which is this weighted random sampler, All right, so that might have been, You know, a little bit quick. I’m not sure, but let’s go through it, so we make sure what’s actually going on here. So first of all, we’re creating our transforms in this case we’re just using resize into tensor, you know, in reality you would. In practice. You would normally add some data augmentation and so on and then we’re loading our data set using dataset image folder and and here we’re sending in some root directory, which in our case is going to be that data set and that’s going to automatically handle the loading for us. Then we’re specifying class weights in this case 1 and 50 because we have much more, so we want to prioritize this class much more than the first one because we have fewer examples for this class and then we’re creating our sample weights that’s going to be the weight for each individual individual example in our data set to create those sample weights we’re first starting out with initializing them as zero, then we’re going through all of our examples in our data set and then specifying exactly that weight dependent on which class that example belongs to and then we’re creating our sampler where we send in this those sample weights, and then we specify you know how many examples we have and then replacement equals true. Um, and then we’re creating our data, our data loader as normal and the only difference is that we send in this sampler right there, so I’m going through this very step-by-step because in the beginning, I didn’t feel that this was very intuitive for me, but, um, when you get used to it. It sort of makes sense. So what I want to do is generalize this bit right here, because I don’t want to individually, or, you know, write all of the class weights all the time because that might take some time when you have over 100 different classes, so what we’ll do is we’ll just create an empty list and we’ll do, uh, for root and then subdirectory files in OS dot walk of root directory, all right, and if you’re not familiar with Oswalk, we’re simply walking through each of those subfolders in that root directory, so well, we’re going to check if the length of the files are greater than zero, then well just, um, add class weights, dot append, and then we’ll add the length of those files. Sorry, uh, we’re not actually going to add the length of those files because that then we would prioritize those who have more examples, so we’ll do one divided by the length of those files, and I also just added this if the length of those files. So if there are no files in that subfolder, then we would be simply, you know, dividing by by zero here, and I guess there are other ways, but this is just a simple way of dealing with that problem. Okay, so now we’ve created our git loader. Let’s create our main file here and make sure that this works. And did I do a mistake here? Oh, yeah, sorry, this should be two equal signs. So now let’s do. Loader is and we’ll run our get loader. Well, sin, send in that root directory to be data set and then we’ll create a batch size of eight and just to make sure that it works or actually first of all, we can just go through them so for X and Y or data comma labels in loader print labels and then let’s run that, and that should just be transforms transform, rather than transforms, so hopefully that works. Now, all right, and as you can see here. Uh, if we you know if we would have just not done those class weights, then we would not see this very balanced data set right here so here we’re seeing it might be difficult to count all of those, but this should be balanced and to make sure that it actually is, we can do something like for epoch in range of I don’t know 10. We can go through that data set and we can count. How many, you know, number of retrievers and then the number of L counts so we’ll just do number of Retrievers, plus equals torchsum of labels equals. Um, let’s see that was zero, and then we’ll copy that and we’ll do num lcounts plus equals labels equals one, and then in the end, Let’s just do print num retrievers and then print num l counts, so lets. See what that looks like, all right so here we can see. I think there’s just some randomness to how I sort of. Um, how exactly that number comes up and I think if we rerun it we’re probably going to see a different result, but as we can see, they’re at least relatively balanced, much more balanced than they were in the beginning. Um, so that was it for dealing with imbalanced data sets. Hopefully this is useful to you. Uh, if it is then please do subscribe to the channel because that helps a lot. Damn, I feel like a seller when I’m saying that, but anyways, Thank you for watching the video, and I hope to see you in the next one.

0.3.0 | Wor Build 0.3.0 Installation Guide

Transcript: [MUSIC] Okay, so in this video? I want to take a look at the new windows on Raspberry Pi build 0.3.0 and this is the latest version. It's just been released today and this version you have to build by yourself. You have to get your own whim, and then you...

read more