Explainable AI explained! | #2 By-design interpretable models with Microsofts InterpretML

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi everyone today let's have a look at a couple of by design interpretable machine learning models the easiest way to ensure interpretability is just to use interpretable algorithms i mean you don't always work on the most complex problems and sometimes for instance a simple decision tree can be sufficient so there's no need to over engineer the situation just like occam's razer says the simplest solution is almost always the best furthermore in this video you will also get familiar with another algorithm that achieves high accuracy while maintaining explainability if you recall this graph we previously had we would now be on the bottom right of course we kind of need to ignore the y-axis here as the accuracy can also be very high or even the highest if the complexity of our data science problem matches the capabilities of our algorithm more precisely we will have a look at how to interpret these three algorithms decision tree and logistic regression should usually be familiar to you if not i will also quickly summarize how the models work additionally we have the explainable boosting machine or ebm which is a model developed by microsoft research all of these models are sometimes also called glass box models as we can directly look into the model and understand what is going on so before we get started let me introduce the data set we will use for this video series i decided to choose a health data set for which the goal is to predict if a person is likely to get a stroke so it's a classical binary classification problem i pulled the data from kaggle and you can find the link in the video description in the data set we have information about several properties such as age gender work type where the people live and also their body mass index the target variable is stroke yes or no according to the world health organization stroke is the second leading cause of death globally responsible for approximately 11 of total deaths let me introduce john for whom we will predict if he's likely to get a stroke or not if he would receive our data-driven prediction he would be able to decrease the probability of getting a stroke in the future but also his immediate response would be why after telling him the prediction and that's exactly where explainability is required let's quickly import the data and the code and we can have a look at all the columns we have available so this time i'm working in vs code and not a jupyter collab notebook but i will also push this code to a repository on github as you can see on the left we have a data directory which holds in a csv file the data we have we have different properties like gender age and so on and in the final column stroke and the rows are simply our data points so for all of those we have a stroke yes i think it's ordered and then we have stroke no and next let me quickly introduce the utils file what i put inside here is a class data loader that simply has a function load data set and it loads the data set from this path which links to the csv file we just saw and then it calls pandas read csv function and stores the data in this member variable in our constructor so our init function self.data furthermore this utils file has a couple of additional functions such as pre-processed data get data split and so on but we will look at this in a second let's do a quick data exploration which i will do in this first python file and what i use here is vs code cell magic which simply under the hood uses a jupyter notebook but we don't have to worry about the notebook it will all be handled by vs code so we can run the cells here and we can create cells using this prefix so what i do here now is i import this data loader from our utils file and i initialize it load the data set and then assign the data set of this data loader to a local variable called data so let's run those two cells okay now we can have a look at the data so let's print the shape we see we have 5 000 data points and if we print a hat we see we have so the first five data points we have those columns here hypotension heart disease so it's binary yes or no married work type different kinds of things we can also run this data.info which gives us basic information we see we have a couple of object columns which basically refers to categorical or non-continuous columns and furthermore we have also things like id which we probably won't need also the body mass index which is a float value next we can iterate over all columns and then simply print a histogram for each of those columns so let's run this and now what we can see is okay id is not important but for gender for instance we have more females than males we have a h distribution which looks like this hypertension only a couple of people have hypotension um let's say 500 of our 5 000 so 10 maybe same for heart disease most of the people are married we also have different work types like private self-employed government job job with children i guess i don't know and never worked and other things so it's not too important to be honest for our for the things we want to do and that's the most important thing our target variable and as you can see it's quite imbalanced we have only let's say 300 patients with a stroke in our data set and the rest without next we can run the preprocess function which let's quickly have a look at it does the following it gets the one hot encodings which we can do with this pandas dummies function and uh for all of those categorical columns further it appends them to the original data and drops the old categorical columns so those we can quickly have a look at this if we run this function we see we get things like gender female gender male gender other so instead of this column gender we have now three columns which are zero or one for each row so this row would be a male this row would be a female and so on then another thing we saw that in that info plots here we have missing values for body mass index and i simply impute them with zero so that would be more intelligent strategies but we will do it like this for now and finally i said id is not important for our prediction so i will drop id and usually you should always standardize the data to to enable better learning for the algorithms but here as we want to interpret a couple of values we would need to convert them back basically and to avoid this i will simply ignore the standardization okay so now that we are familiar with the data set let's quickly have a look at the first algorithm logistic regression logistic regression is an extension of linear regression to classification problems using the sigmoid function displayed here on the right the output of the linear regression is squeezed between 0 and 1. that means we get probabilities using these probabilities we can for a binary classification assign the class 0 or 1 depending on the output value so under the hood we use the same concept here as for linear regression and that means we get the output as a weighted sum of input variables and these betas here are learned during the optimization of our model which simply means we adjust them to minimize the loss function now back to explainability it is straightforward to interpret the coefficients here that led to our prediction for the classical linear regression if x1 for instance is a continuous feature such as h increasing that x1 by 1 unit will make the output y increased by beta1 this way we immediately understand the impact to our output y of a change in the input features in other words we have full transparency on how the prediction is calculated and also know the contribution of each input to the outputs for logistic regression we need to consider the sigmoid transformation we apply after a couple of equation reshapes not displayed in detail here we get this and that simply means if x1 increases by one unit our outputs will increase by e to the power of beta1 units okay so now let's have a look at our data set and see how this works in practice we will use a microsoft library called interpret which contains most of the algorithms we will capture in this xai series okay so let's have a look at this first file interpretable models for this first video here we simply again import our data loader we also used in the data exploration and additionally we import the modules from this interpret library so all of those are glass box algorithms which simply means transparent by design interpretable algorithms furthermore we also import this show function which generates an interactive window and additionally a couple of metrics from psychic learn such as the one f1 score and the accuracy score so just like before we now load the data set with this data loader and pre-process it and then i call this getdatasplit function which simply returns this testtrain split for 20 test data and additionally to mitigate the the impact of our glass imbalance i apply this oversampling function and there i use a random oversampler what that means is for our minority class we simply sample more data points to get a larger distribution which means our learning algorithm sees more of those data points in the trained data so that means as input we get the original train data do this over sampling and then return the oversampled population that's why our data set also increases but again i only do this for the train data okay so back in our interpretable models file let's run those first two cells so the imports and this data pre-processing and you see previously our train data was around 4 000 data points now we have almost doubled it so next we will use this logistic regression model and that comes from this interpret.classbox package and under the hood this simply uses the scikit-learn model but additionally we can call additional functions such as explain local and so on we will see this in a second so now let's first train it calling the fit function we simply pass it the train data and the trained labels okay training finished and now we can perform predictions and evaluate them using the f1 score which is important because we have an imbalanced data set and just looking at accuracy would not be enough information for instance we could always predict no stroke and would already get a pretty high accuracy and the f1 score is the weighted average of precision and recall and tells us more about how also the minority class is predicted and we see we get an accuracy of 73 percent and an f1 score of 51. so now what we can do is we call this logistic regression explain local function and we do this for the first 100 data points and simply what this does is explain individual predictions and that means for the first 100 data points and here in this interactive summary let me enlarge this we see different data points and if we click on this we get explanations which refer to the coefficients in the linear or logistic regression problem and here for instance the age has a strongly positive impact and things like if the person was married which is one in this case or the fact that this person never smoked reduces the stroke risk basically but what we can also do is we can interpret global predictions or not global predictions but rather global features and here this way we can see the overall importance coefficients so that's basically for all predictions so as we've seen the interpretation of the coefficients is straightforward and therefore our test patient john approves this by design into a portable model for completeness i want to say that in case of correlations between the inputs the interpretation of the weights might not be reasonable sometimes but if you're interested in more details read the book i've linked in the video description next on our list we have the decision tree which is another algorithm that is commonly used important that doesn't include ensemble methods like random first but just a plain original decision tree the big advantage is that these trees are able to model non-linear relationships the way hard works is that they grow a tree such as the one displayed here by applying feature splits that are able to separate the data on the leaves of the tree we obtain the predictions such as here stroke or no stroke there are different ways how to grow a tree and i just quickly explained the id3 tree generation algorithm here the way how these feature splits are selected is based on a concept from information theory called entropy a feature that is able to split all of our data into two classes stroke and no stroke has an entropy of zero and therefore the highest information gain on the other hand if we are not able to correctly separate all of our data into the two buckets we have some impurities for that split which are expressed by the entropy so we can calculate the information gain by splitting each feature and simply select the best splits interpreting decision tree works by starting at the root note and then simply following the edges to the leaves of the tree in this example on the right for john we would check his age he's older than 50 and then his body mass index which is 23 and we end up with a prediction for no stroke that means a decision tree can be interpreted as a set of if-else rules which are human-friendly explanations furthermore we are also able to read the feature importance from the grown tree as we know which features led to better splits than others finally a decision tree allows us to transparently follow the reasoning process and we can for instance say if the body mass index of john was higher he would be more likely to get a stroke now let's quickly implement this for our data set so here back in vs code we use this classification tree which comes again from this interpret glass box so it's that second import here and we fit it using the fit function and again we predict on our test data and that was pretty quick we see we have an accuracy of 73 and an f1 score of 0.52 so if we call this explain local function again from interpret we can even get the graph here and the graph shows us with the red line which path our prediction followed and here for instance we have h smaller than 56.5 and here it was a subject with h higher than that so that's why we go the red path another important thing for decision trees is the depth so here we have one two three decisions we have to make until we end up with a prediction if that tree has for instance 500 decisions we need to make it's getting pretty complex so decision trees are only interpretable if the depth is reasonable so after our prediction for jon he said okay those rule sets totally makes sense for me that's why he approves it again so now let's talk about the last by design interpretable model on our list explainable boosting machines were developed by microsoft research and their goal was to mitigate the trade-off between interpretability and accuracy so if we think about the graph we've seen before ebms would be on the top right like this they build several decision tree where each of them is only allowed to use one specific feature they do this in many iterations such as 10 000 and then they summarize all the results of the fitted trees for each feature what they end up with is a graph like the one shown here that shows the impact of a feature on the target variable this is some sort of dependence plot such visual explanations are easy to understand for humans and we can directly see why certain predictions are made if you are interested in further details there's also a video on ebms uploaded by microsoft on youtube i will link it in the video description okay so now let's apply ebm on our data set so again we initialize this ebm model and now we train it just like the other models using our train data and what we can see oh it takes a little bit longer because as i said it needs to fit several thousand decision trees and then the prediction we get is this f1 score of 0.55 so we got better here and an accuracy of 87 so that's much better than the previous ones and we see okay we got a better performance but as we will see now we maintain the interpretability so if i run this explain local cell again we can have a look at the explainability of this model so again for individual predictions let's say for this one we see the impact of specific numbers for example here we had a h of 8 and that's why the prediction was rather zero than one okay here we see that the h is relatively high which would mean we have a stroke but there are other things like the relatively normal body mass index that shift the prediction to the left so we get that kind of explanation for each individual prediction and again we can also do this globally let's run this cell and we see we get this feature importance plot and it says the most important feature by far is the age and then it's the body mass index and the glucose level and after that it's combinations of specific features so now one last thing we can do here is also investigate individual features let's quickly search for the most important one which is h here we can see that the impact of h is for here minus five so the second term in this tuple and the higher the h gets so here we have 80. the more likely it is that we have a positive impact which means we shift more towards stroke instead of no stroke for example here we have around three so you see we have a positive trend here and that's the the plot i've also shown previously so we can investigate the impact so basically the dependence of features to the target variable so for ebms we could prove on our data set that the accuracy and the f1 score get better while the explainability remains that's pretty cool and therefore john of course also approves this method that's it for today as you've seen in some cases we can also get decent results with by design interpretable models however sometimes neural networks or random forests work even better on data and that's why we focus on those black box algorithms in the next videos so i hope you enjoy the series so far and i see you in the next video
Info
Channel: DeepFindr
Views: 2,526
Rating: 4.951807 out of 5
Keywords: Explainable AI, XAI, Microsoft Interpret, Glassbox, Explainable Boosting Machine, Explainable Artificial Intelligence
Id: qPn9m30ojfc
Channel Id: undefined
Length: 20min 54sec (1254 seconds)
Published: Tue Feb 16 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.