Explainable AI explained! | #6 Layerwise Relevance Propagation with MRI data

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
[Music] hi guys welcome back to the last part of this explainable ai series today i want to talk about layer wise relevance propagation a method which is designed to explain the relevance of an input for a deep neural network later in that video i show you how to use the method on an example data set for brain cancer classification on medical mri scans mri stands for magnetic resonance imaging and looks like the image shown here you've probably seen it before we not only want the neural network to tell us if brain cancer can be found on that image but also why the model thinks so and for this we will use levi's relevance propagation let's start with an intuitive explanation of what the method does leovise relevance propagation or short lrp was presented in this paper and has the overall goal to explain the relevance of inputs for a certain prediction typically this is applied on images so that we see which part of the image or more precisely which pixels made our model arrive at a certain prediction lrp is a model specific method as it is mainly designed for neural networks generally the method assumes that the classifier can be decomposed into several layers of computation it can also be applied on support vector machines let's have a look at this example from the paper in the forward pass the image goes through a convolutional neural network that means several layers of convolution pooling and other operations in order to extract the features from the image then we pass these extracted features through a classifier which is typically a fully connected neural network this gives us the output so the prediction here we have two classes one for cats and one for no cats at this point we are interested in why the model predicts cats so the red square now lrp goes in reverse order over the layers we have visited in the forward pass and calculates the relevant scores for each of the neurons in each of the layers when we arrive at the input again we can calculate the relevance for the individual pixels positive relevant scores would indicate that the pixels were relevant for the prediction and negative values that these pixels would speak against it this leads to the heat map we can see on the right side this process is in the paper also called pixel device decomposition the visualization we see can be extremely helpful for many areas such as in medical imaging because often the humans require an explanation why the algorithm spits out a certain prediction we will see more of this in a few minutes now that we are familiar with the basic concept let's have a look at how we can calculate these relevant scores let's assume we have this cat image that we pass through this cnn to get a vector of extracted features this vector is further passed through a multi-layer perceptron to get the prediction for our input so in this case cat that's how the basic architecture of our model could look like to better understand lrp let's go over an example the intuition is that we want to calculate the relevance layer by layer until we get the values for each individual pixel we begin with the output neuron of our model and assign it the function value as relevant score in a classification problem we would for instance have 80 cat so we choose the corresponding neuron and 0.8 as value we don't care about the probabilities for the other classes and therefore set them to zero let's focus on the classifier part of our network here we have four layers and we can denote the relevance for the output neuron with r1 because it's the first neuron to the power of 4 because we are in layer 4 of the classifier part now we can move to the next layer the relevance of a neuron is calculated according to this formula it gives us the relevance r for a neuron i in layer l so our current layer l is the third layer and the output layer becomes l plus one so the fourth layer the calculation for neuron i now works as follows for each neuron j in the layer l plus one in this case we have only one neuron we calculate the activation based on the neuron eye compared to all neurons in the layer the activation is calculated according to this set ij on the right side it's simply multiplying the input for the neuron i in our current layer with the weight that goes into the neuron j in the next layer this input x comes from passing the pixel values through the previous layers this tells us how strong the activation between these neurons is intuitively if there is a high value it means that the neuron was very important for the output so we interpret this fraction as a relative activation of a specific neuron compared to all activations in that layer finally we multiply the relevant score of the neuron in the next layer with this relative number to propagate the relevance one layer backwards doing so we get the new relevant scores for the neurons like visualized here the intensity of the colors indicates the relevance of the neurons when we move to the next layer we do the same procedure here we have several neurons in layer l plus 1 now which is layer 3. the sum over j means we calculate the relevance activation of a neuron for each of the neurons in the next layer and multiply it with their corresponding relevance score and again we get the relevances for each of the neurons in layer 2. note that the sum of the relevance values stays the same for each layer in this case the relevances will always sum up to 0.8 of course this approach is not limited to dense layers we can also apply it on convolutional layers that's why we can further propagate the relevances through the network until we get a specific value for each pixel in the input image this image also nicely summarizes the idea of levi's relevance propagation neurons that lead to higher activations have stronger relevance flows than neurons that were not important for the prediction before we jump into the code i want to point out that the formula i just showed is the basic variant of lrp which is also called lrp0 depending on the layers there are small extensions of this formula that make the flow of the relevances more robust but we will see this in a second okay so now we can have a look at how the method works in practice i downloaded a brain mri dataset from kaggle on which we will apply this method all the code is pushed to github you find the link in the video description the goal for the dataset is to not only get predictions if there's potential cancer tissue on the mri scan but we also want to know which neurons were important for that specific prediction such as shown here all right here we are in vs code i'll quickly show you the data set i've put it into data brain mri the link to it is stored in that readme here you need to download the data by yourself if you want to use it because i won't push it to github as it's pretty big so there is a testing folder in the training folder and in the testing folder again also in the training folder we have four additional folders those contain the images so the no tumor folder for instance has the jpeg images that look like this for example also from the side and in the other folders we have patients that have a tumor as you can see here you can clearly see it um some of the images are duplicates i don't know why but i didn't further look into this so these are the the images we will work with and now let's have a look at how we use them so first of all i built a model using pytorch and on my local machine here i have a gpu setup and this will print true if the gpu is available in pytorch to speed up the training so then we can load the data set from this brain mri testing and training and we can have a look at some of the images but actually i just already showed them to you so we we can skip that and i will go quickly over these sections to focus on the lrp stuff itself then we build a model i use a pre-trained model here in this case the vgg g16 so we can quickly have a look at the architecture if i print the model itself it will print the summary all i do here is adjust the final output layer so the final linear layer of the model to output four classes as we have no cancer and the three tumor types and we don't want the 1000 classes from imagenet and yeah this model is automatically downloaded so it was downloaded right now i think and the architecture looks like this so we have several convolutional layers as you can see here and then we have this pulling layer and then a couple of linear output layers just like i've shown on the on the slides previously it's basically the same architecture one thing i want to point out here is that leo wise relevance propagation can be quite tedious if you have more complex architectures here as you can see it's pretty simple it's just a sequential set of layers if you have more complex things like resnets where you have splits like different branches of layers propagating that will not be so easy as we do it in this example because here we just can iterate over the layers and don't have to care about branches or forks in our network all right and then i put this cnn model to the gpu and then i prepare the train and test data set again but do some additional transformations like resizing it and then i create two data loaders one for the test data one for the train data and give them a batch size of 32 that means they will load 32 images for each batch we want to train on ah by the way the data so the data is automatically loaded using that image folder so as you can see in this description this image folder expects the different class types so the different labels to be in different folders and then we can simply build a data set using that function so that's pretty convenient for the folder structure we have all right and then we can already train so as always for classification problems we use the cross entropy loss basic atom optimizer we only need to optimize the final part of the network actually because it's a pre-trained network so most of the things should be fine already that's why i only need to train 10 epochs and i run that and meanwhile yeah i think i'll just cut the video all right now the training has finished we have a pretty small loss i didn't check for overfitting probably it overfitted a little bit to better investigate let's have a look at the test data so we use our test loader get the first batch so the first 32 images and all i do here is pass them through the model so here we pass it and compare the labels to the outputs so basically calculate the accuracy for that specific batch so for the first batch we have um an accuracy of 65. if we run it again for another batch with 81 and so on again this is not super precise but here we now want to focus on this lrp implementation and by the way the classes we see here are the numeric labels of our classes uh so we had four classes and that would be this third one so zero one two um yeah okay now let's have a look at this implementation so i got this implementation or the inspiration i did some slight adjustments from those two so there's one good lab from a university in berlin and then there is this paper and those describe quite nicely how you can implement it in pytorch so let's maybe first jump to the function call so we call the function here this apply lrp on vgg16 and we pass our model so this function only works for the vgg16 and we select a specific input image so this is a local explainability technique because we can only explain individual images because we cannot explain the full model but we can say for this specific image that section of the pixels was most important so now let's have a look at what happens in that function so this is the function where we apply lrp on the model that's why we get the model and the image as input the first thing we need to do is add the batch dimension again because remember our inputs had a shape like that and um yeah i've printed this already here so the image has now a shape like this but our model expects four dimensions that's why we need to add a one here essentially so we have one batch with one image then in the first step we need to extract the layers from our model and remember our model looked like this ah it's model sorry and we have three sections here we have the the features section which is basically a cnn then we have the average pooling which is the connection between the cnn and the classifier and then we have the classifier so those are the three sections and these are the sections we extract here first of all we need to treat each of the layers individually because for example a convolutional layer cannot be so the propagation we applied there is different than on a linear layer or on a relu and so on and that's why i get here the indices for the layers that are linear in our model so if we run this okay it's not defined one second so if we run this on our model we get those three layers that tell us okay at that position so index of our layers we have linear layers now the next step is we create a placeholder for the activations so we have all the layers now stored in that list here we know the list has a length of let's say 40. we can also quickly print this okay all of the functions are not defined right now i think i run this maybe then let's do it again okay and now we can say length of l okay so we have 39 layers in our vgg16 and we create a placeholder now for the activations in each of the layers that's why we basically create a list where the first entry is the image and all of the other entries are none so that's what's happening in that line and then we iterate over the layers and so as it says here propagate image through layers and store activations so that's what we want to do here we want to go over each of the layers and store the activation of that specific layer in that list so this is happening here and maybe one thing so you might have seen this function here dense to conf all of the linear layers are converted to convolutional layers so that we can treat all of the layers the same way and how do you convert a linear layer to convolutional layer let's have a look at this function basically all we need to do is reshape the weights so that they have a like convolutional style shape like kernel shape and so basically it's just realigning the weights and that's why i also had to distinguish between the different layers here and i use this here and if we pass this specific layer i need to reshape the the data we pass in okay now we say if the the um thing is instance of average pooling we flatten because then we have this transition from the feature extraction part to to the classifier part all right and then we have all activations inside of that list then the next thing is we want to specify the relevance of our last layer and to do so we need to basically get the maximum activation in the last layer so this is the last layer and then we say okay we select the value if the value equals the max otherwise we put a zero there so this will give us some sort of one hot encoding on the outputs because we said we only are interested in the explanations for a specific class and zero out all of the other activations of our output layer okay now what we do is we back propagates the relevance scores so now we have the relevances in our final layer already which is coming from this float tensor here so we have this one hot output um you can think of it like something like that so we have zero zero five point five zero so we have zeros only the class that we predict is non-zero and that will be the relevance value for our output layer for the other layers we do the same same thing like for the activations previously we set none for each of the relevances but now in reverse order so all of them are none except of the last one here we had the first one was the image and the others were none in the forward pass now we back propagate the relevances and now what we do is we iterate over the layers in reverse order so this thing here make sure that we iterate in reverse so for example if we have a range of 10 and print that looks like this so that's for example like this that gives us the reverse order of our content okay and what we do here now is we say if we have a max pooling layer we replace it by the average pooling layer that's also mentioned in the paper you can do that to propagate the relevances and then we check if we have a convolutional layer or average pooling layer or linear we calculate the relevances for that specific layer otherwise we just pass the relevances one layer further for example if we go through layers that don't really have weights so no real activations like relu or dropouts and so on so and now that section in the middle is calculating the relevances it might look a bit complex at first but i'll quickly show you that paper so in this paper which is called like this level as relevance propagation and overview there is a section called implementing lrp efficiently and they even state how to do that in pi torch and basically they say this is the basic formula and they also mentioned the other formulas i previously talked about so the basic rule is lrp0 then there's lrp epsilon and lrp gamma and each of them are usually used in different layers so for example the epsilon rule adds in the denominator a term epsilon which is not present here and the gamma rule adds to the weights an additional transformation so basically it's increasing the weights for example they also say another enhancement which we introduce here is favoring the effect of positive contributions over negative so all the positive values will be increased to get better relevances for these values and regarding the layers there's also a section called which lrp rule for which layer and here they state upper layers lrp 0 middle layers lrp epsilon lower layers lrp gamma and that's exactly what we do in this code so we say if the layer is smaller than 16 we apply the lower layer rule so the gamma rule then here in the middle layers the epsilon rule and in the upper layers the zero rule and that so for that i i've copied this from the the codes which i linked here and basically we define a couple of lambda functions so that means when we call that function for example here it is depending on so that function is passed to new layer which is a function defined here and that function now uses this row as g and has as argument the all of the weights of that layer and we said for example let's go to the paper one more time for the weights we apply in the gamma rule a transformation like here and now the specific value once we want to apply here is depending on which rule and here you can see we say that lambda function simply returns the same value so for the non-gamma layers we essentially apply no transformation and for the gamma layers we apply that transformation which says we increase the weights by 0.25 all of the positive weights and the same happens for the epsilon rule here we add this epsilon term like it's done here and then let's go to the paper one last time here it says the computation of this propagation formula can be decomposed in four steps and if you look at those four steps so that section here is exactly in the denominator so we first calculate the denominator and that's simply by doing a forward pass on the data on the activations then we calculate this fraction and if you look at this this term is this term and that term is that what we calculated in the previous step so essentially is a fraction between that one and that one and that can be seen as flipping this numerator with that uh term here so it's basically swapping those two and then we calculate that fraction this is happening here and then regarding that part this is happening in that calculation and they also states that the third step so that section here can be also expressed as the gradients and that's exactly what we will do in the codes so here are the four steps in the first one we simply calculate this forward pass then we calculate that fraction then we do the gradient calculation and then in the final step we multiply it to get the relevances of that specific layer and that's how the function works so again if you want more explanations about this check out those two links but that basically summarizes the basic idea all right so now we are set to use the function so we call the function we pause pass the model and the specific input image so add a specific index here it's zero and then what we also need to do so let's have a look at input shape the shape is like that so the channel is at the second position and typically when using libraries like matplotlib you need to put it at a last position and that's what's happening with this permute function and then we in the final step normalize the relevances to be see between 0 and 1. if this image is classified correctly we visualize why that prediction was made by the model so let's have a look at the first data point so here we can see we have a prediction for this specific tumor type and we see that area clearly spits out so we have high relevances for that area and all the other areas are less relevance so let's have a look at another data points here we see it's the same for that okay these images are pretty clear i would say let's have a look at another one so here it says no tumor and in contrast to those two we see there's no clear section that is highlighted and so that's actually exactly what we want um same here it might be that those points uh were confusing for the model um that's why they're highlighted here so here it's the same again you see it kind of works as we expect it to work we visualize the sections that are relevant [Music] alright that was my little explainable ai series i had a lot of fun going over the different topics and i hope that you found it helpful or interesting the code is uploaded on github and i appreciate any kind of feedback thank you for watching
Info
Channel: DeepFindr
Views: 2,633
Rating: undefined out of 5
Keywords: Explainable AI, Layerwise Relevance Propagation, LRP, MRI Data, Brain Cancer, XAI, Pytorch
Id: PDRewtcqmaI
Channel Id: undefined
Length: 28min 27sec (1707 seconds)
Published: Sat Apr 03 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.