How to explain Graph Neural Networks (with XAI)

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hello everyone in this video i want to talk about how explainable ai can be applied on graph neural networks before we have a look at graphs let's see how typical explanations for machine learning models look like in other domains generally there are many methods available and the examples i show here are just a small subset i've actually also uploaded a video series about explainable ai in case you're interested in further details let's say we have this tabular data set for which we built a model that predicts if a person gets a stroke or not for a new instance we now want to explain why our model predicts stroke and a typical explanation type is to assign importance weights to each of the features of this new data point so in this example the fact that this person is a smoker has the highest impact for images we can simply explain which area of an image led to a certain prediction just like in the tabular example we look for the features with the highest impact which are pixels for image data for this brain mri data set we can for example explain what made a model predict cancer using such approaches we can shed some light on the black box machine learning models because we understand the reasoning behind the predictions now what about graph data when it comes to explaining graph neural networks the situation is a bit different because the data structure is more complex compared to the previous ones our feature space consists of a graph in form of an adjacency matrix a dynamic number of nodes and potentially node and etch features let's say we have a gnn that creates node embeddings for each of the nodes in this graph using these final embeddings we could perform all sorts of prediction tasks such as graph level predictions note level predictions or link prediction how could useful explanations look like in all of these cases so that we better understand why the neural network outputs specific values useful explanations on graph data are for example which nodes and edges were relevant like for example the blue ones here how relevant were they we could assign important scores and also how relevant were the node or edge features in the graph this gives us a clear idea of what information was used to arrive at a certain prediction the next question is how do we get these predictions there is a survey about explainability in gnns which was published in 2021 it contains this tree which provides a good overview of most of the currently available methods many of these were developed in the xai research field and simply redesigned for graph structured data so not all of them are actual gnn literature of course i cannot talk about all of them so i've picked out one method i will explain in detail later and decided to quickly go over each of the sections here our first decision is if we want to explain the whole machine learning model or just individual predictions it's much easier to explain individual predictions because the decision boundary of the whole model might be very complex for these instance level predictions one research branch is using the gradients to back propagate a model output to the input space this back propagation allows to determine how important the input features were for the model to give you an example this image shows the result of grad cam which stands for gradient class activation mappings applied on molecule data it tells us which atoms had the highest contribution to the classification result another branch applies perturbations on the input graphs in order to analyze the reaction of the model doing this we find out which nodes and edges are most important for the predictions whenever these elements are present in the input graph the model should stick with the current prediction and once they are removed it might lead to a different outcome now how do you perturb a graph this can be done by for example masking elements out and that's the basis for all of the methods in this branch they simply remove parts of the graph and then check how the predictions of the model change later we will have a closer look at the gnn explainer this example shows sub graph x which samples many different sub graphs analyzes the changes in the model predictions and combines them to assign a contribution to each node and edge this approach is based on sharply values decomposition methods decompose the prediction into the input space this means layer by layer the output is transferred back until the input layer is reached the values then indicate which of the inputs had the highest impact on the output a popular method in this field is called layer wise relevance propagation or short lrp which is based on taylor decomposition the graph variant of it gnn lrp uses so called relevant walks on the graph to identify the areas with a high contribution the last branch for instance level predictions are surrogates these are simply interpretable approximations for the g n so the idea of all of these approaches is to fit another machine learning model in a local area of the complex model an adjustment of the popular xai paper lime is called graph lime and explains the importance of the features of each node in the graph finally to explain the whole graph neural network there is currently only one method in the survey which is called xgnn the idea is to generate graphs that lead to a highest activation in the model for a certain class xgnn uses reinforcement learning to achieve this the generated graph then provides an explanation by showing what the model thinks belongs to a specific class so i hope that this gave you a quick overview on some of the available methods of course this is not an exhaustive summary as new methods are published all the time worth to mention are also attention based models where the attention weights can be utilized for explanations let's now dive into one specific method which is gnn explainer it was published in 2019 by these researchers before we talk about the mathematical details i want to give you a high level idea of how the method works let's assume this is our input graph for a g n and we want to perform node classification for this greenish node we have the labels for the other nodes in this case simply yes or no we also have node features which i will leave out here for simplicity let's say our trend gnn predicts no for this node how do we find out which nodes and edges contributed to this prediction a basic observation of gnn explainer is that all information that went into this model output stems from the computation graph of the green node that's especially relevant for larger graphs with millions of nodes because it's sufficient to look only at the computation graph of a node in case of 2g n layers it looks like this for the green node so we first collect the information of the pink and purple nodes and then the information of all of the nodes that they are connected to this computation graph is denoted with gc in the paper you can see that the gray node is actually out of scope because it's too many hops away and therefore it has no influence on the green node what we can do now is to remove some of the edges and nodes from this graph and then check how the prediction for the green node is changing we could remove the blue nodes and the dark purple nodes just like illustrated here this subset of the computation graph is denoted with gs in the paper we see that the model still predicts no so that means these three nodes were not really relevant for the result if we however remove the pink nodes the gnn outputs a different class and this tells us that these nodes and edges were quite relevant for the outcome that's in fact a counterfactual explanation because it says if we remove this node then the prediction will change the same applies for node features we simply remove some of the values and see how the prediction changes this way we get a feeling for what information the model uses when we limit the computation graph to a subset i want to emphasize that we can not only use the prediction output as binary value but also use the output probabilities so for the green node and the full computation graph the model might have this output distribution so it predicts with around eighty percent no and twenty percent yes when we remove nodes and edges the output probabilities change and tell us something about the certainty of the model prediction the ultimate goal is to find the notes and edges that maintain the output probabilities but use only a subset of the original graph more formally this can be formulated as an optimization problem using mutual information mutual information comes from information theory and measures the information that two random variables share a high mi value indicates that there is a strong dependence between the two variables here we want to maximize this term by finding a subgraph gs that maintains as much information as possible compared to using the full graph on the right side we can see the entropy of the original model prediction and the entropy when limiting the computation graph to a subset more specifically gs is the subgraph of the computation graph and xs is a subset of the node features so overall mi quantifies the change of the predicted probabilities just like we've seen in the previous example now how do you optimize this formula for a given graph and the trend neural network the first term is constant so the only interesting part is the second term to maximize the overall formula it is therefore necessary to minimize this conditional entropy the entropy h can be expressed like this it's the negative expectation over the logarithm of the model probabilities entropy measures the amount of information and here it would measure how much information the subgraph contains when the model phi uses it for prediction in the ideal case the uncertainty would be zero which means no information is lost when limiting to the sub graph now we are one step further because we only need to minimize this term to maximize the overall formula on the top there is one more obstacle that we need to get out of the way and that is the fact that there are exponentially many subgraphs gs instead of trying out all of them the authors propose to use a continuous mask on the computation graph this leads us to this funnel formula we can see that the subgraph gs is now replaced by this term and this is simply an element-wise multiplication of the adjacency matrix of the computation graph with a learnable mask the overall optimization problem is now reduced to constructing a mask that finds the subgraph that maximizes the mutual information we will come back to this formula in a second but first i want to give you a better understanding about this masking part so this is the graph of the previous example and this is how the adjacency matrix of this graph looks like i've put the note colors on the diagonal this is the masking part of the previous formula ac is the adjacency matrix of the computation graph for a specific node it can also be denoted like this in the paper there's not much more detail about it but the only logical explanation for me is that it is the adjacency matrix defined on the local neighborhood of the node we are interested in so for our two layer gnn and a classification of the green nodes the gray node for example would not be considered because it's not part of the computation graph therefore all connections are zero finally there is this mask m which has the same dimension as the adjacency matrices it is a continuous matrix that defines if there exists an edge between two nodes for the computation graph the sigmoid function is applied to convert this continuous representation into a discrete one when ac and the sigmoid of m are multiplied now there will be a few edges left that define the subgraph gs so as mentioned before this objective function is minimized using gradient descent there are a couple of additional considerations necessary to arrive at this formula but i will not go into more detail about it in this video the important part is that this mask m is adjusted in such a way that the entropy is minimized besides the mask for the nodes there is another mask applied for the node features this however is also considering the interactions between the features and therefore a monte carlo sampling is used to estimate the importance now that was a lot of theoretical information but how do the generated explanations look like in practice for the previous example this might look like this a sub graph highlighted in green here that contains all the nodes that were relevant furthermore for the node features of these nodes we get a subset as well this provides a pretty clear picture of why the model made a certain prediction besides the things that i've mentioned so far there exist a couple more extensions in the paper for example some regularization terms are applied to the objective function one ensures that the generated masks are discrete another one allows to include specific constraints for example to limit the size of the explanation furthermore there are some details about how the gnn explainer can be extended to link and graph level prediction tasks another option is also to generate explanations for multiple instances so basically a set of predictions finally it's worth to mention that the method is gnn layer agnostic which means it can be applied to pretty much any gnn model so i hope this video gave you some insights about explainability for gnns i will soon upload another video where i demonstrate the use of the gnn explainer on a real data set finally i can recommend to have a look at this library which implements most of the explainability methods i've mentioned in the overview previously so that's all for this video thanks for watching and see you soon in the next part
Info
Channel: DeepFindr
Views: 10,638
Rating: undefined out of 5
Keywords: XAI, GNN, GNNExplainer, Graph Neural Network, Explainable AI
Id: NvDM2j8Jgvk
Channel Id: undefined
Length: 15min 7sec (907 seconds)
Published: Thu Oct 21 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.