GNN Project #3.1 - Graph-level predictions

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
hi everyone i'm back with an update on the gnm project i'm currently working on the next point on the agenda was to build a simple graph neural network that is able to classify whether a molecule is an hiv inhibitor or not so this is a classical graph classification task i haven't uploaded for a while as i didn't have too much time recently and also getting a working model took me some time so here's an update of the current progress before we jump into the code changes let's quickly talk about how the architecture of the graph neural network looks like the input to our gnn is a molecule with the node and edge features that i calculated in the last video in the gnn we apply a couple of layers such as graph attention layers and finally the output are the updated node embeddings that hold the information about their neighborhood or even the whole graph these embeddings are in our case task specific that means this automatic feature extraction by the gnn works best on hiv classification tasks it would also be possible to get task independent embeddings that can be used for any molecule prediction task gnns are perfectly suited for node level predictions or link prediction as we can directly work with these per node outputs of the model in our case however we want to learn a representation for the whole molecule which means one single embedding vector instead of one per node the question is how do we get this embedding there are different approaches and i quickly want to talk about three possibilities in the following approach number one would be a naive pulling on the set of node embeddings this set can have a different number of elements for each molecule and therefore we cannot use it as input for a fully connected neural network for instance instead we need to use naive set aggregate functions such as mean max or sum and apply them on the individual node embeddings the result of this gives us a representation for the whole graph we can also combine different aggregation functions for instance we calculate the mean max and sum and then simply concatenate them such as shown here this is a very straightforward approach but first of all we lose a lot of information by using these simple aggregation functions and secondly this polling approach only takes place in the final layer of the network it would be much nicer if we could aggregate the information like it is done in the pooling layers of a convolutional neural network this brings us to the next option similarly as pulling on images we can also apply pulling on graphs this means we iteratively reduce the graph by removing notes the idea is that the note feature information of a specific node is already distributed among other nodes after the message passing and therefore we don't need all of them anymore by doing this we can reduce the graph until we end up with one single node embedding which at the same time is also our graph embedding furthermore and this is something i also did in a code we can use the intermediate results of the node embeddings and include them into our final representation so in this example before the first pooling operation we have five nodes and we aggregate the information of these node embeddings somehow and then after the next pooling operation we only have two nodes and we aggregate that information again and then we simply concatenate everything into one big representation now you might ask but how do we remove the notes or how do we reduce the size of the graph there are different possibilities in the literature which are also implemented in pytorch geometric the div pool layer for instance learns how to cluster the nodes of a graph which makes it a learnable pulling operator these clusters are then simply aggregated so that the output of that layer is basically a new graph another example is top k pooling which simply drops specific nodes to reduce the size of the graph the joists which nodes to drop depends on a learnable vector that squeezes each of the full embeddings into one dimension based on the values of this vector the top k notes are selected and used to build a new graph to keep this video short i won't go further into detail for these approaches but i can talk about them in another video the last approach i want to quickly mention here is to add a super node this means we simply extend the graph with an additional node that is connected to all other nodes in the graph when we perform message passing now we say every node is allowed to share the information with the super nodes but the super node doesn't pass information back that means we have directed edges after several layers of message passing we can now simply use this node as the graph representation as it collected all information from the other nodes okay so these were a couple of approaches how to obtain a graph representation with graph neural networks now let's have a look at the code changes i decided to separate this section into two parts first of all i want to talk about the data set and what changed there and then about the gnn and the training so let's start with the data set changes i discovered a library called deep cam which provides several featurizers that can be used to easily build a data set for pytorch geometric i used a featurizer called mallgraphconf which uses rdkit in the background to do exactly what i did in the last video it provides several node and edge features for a molecule with only a few lines of code to simplify the data set i've created this new file dataset featurizer which now uses this featurizer using the function to pytorch geometric graph we can easily generate a data object so if you also work with molecule data i can highly recommend to take a look at these deep cam featurizers which make it much easier to get started another thing i improved is that the data is now cached and reloaded whenever this dataset class is instantiated this way i don't have to recreate the data set over and over again in the code this takes place in this function so whenever the python geometric dataset class looks for this processed file names i return this list of my processed file names and if they are found on my local machine so in this folder then the data is automatically loaded from this folder if one of them is missing the whole data set is automatically recreated next i manually divided the data set into a test and train set so roughly 20 000 samples each so 50 50 with the same class distribution i called these files hiv train and hiv test and i did the split manually and i will also push them to github as you might know we have a quite imbalanced data set with 40 000 samples of the negative class and only 1 000 of the positive class therefore i additionally applied oversampling on the train data which simply means that i replicated the minority samples to have an even class distribution this oversampled train data is then stored in this file hiv train oversampled and you can see we have much more of the positive class the oversampling itself is done in this oversample data script which is not too important because i push the data anyways all i do here is just replicate the positive samples in the data set and then shuffle everything so that's all regarding the data set and now let's move on with the gnn model and more precisely let's have a look at the model architecture the model is defined in this model file and i use three graph attention convolution layers and three top k pooling layers as the molecule graphs are not too big i thought three layers should be sufficient to start with additionally i have three attention heads for the attention mechanism and i use a fully connected network to transform them back to the initial note feature size because three hats generate three times the node embedding and i need to convert that back to the original embedding size to pass it in the next layer so in pytorch geometric three attention heads generate three different output vectors so this is the forward function and here the data which is x so the note embeddings are passed through the first convolutional layer with three attention heads then i transformed it back to the initial shape then apply this pooling and get essentially a new graph like i've shown before and for this graph i i apply global max pooling and global average pooling and store that intermediate results and then apply the next block and the third block and eventually i concatenate all of those three pulled vectors of the intermediate graphs and then i pass everything so this final graph representation through two linear layers until we have an output linear layer which has two output classes as you see it's a pretty basic architecture and i was also playing around with other layer types such as graph isomorphism layers or transformer layers but they seem to be more difficult to train as you might have already seen one thing that's still missing at the moment is to include edge features the current implementation of graph attention layers for pytorch geometric doesn't support edge attributes to be included that's why i will most likely further work on this architecture in the next weeks so that i can also include these edge attributes next let's have a look at the training and optimization of the network the training can be started by running this train file here i simply load to two data sets so the train data set and the test data set and i also load the model from the other file so let's quickly run this and as you can see i print the architecture of the model here and i also count the number of parameters so the architecture looks like this just like the one i've just shown you and we see we have 17 million parameters which is quite big as loss function i use the weighted cross entropy loss because i want to put more emphasis on the positive class if i don't use this parameter the model mainly predicts zero so no hiv inhibitor on the other side if the weight is too high for the positive class it will only predict class 1. my optimizer is a classical still has the gradient descent optimizer with a learning rate of 0.1 this learning rate is exponentially decayed over the epochs to smoothen the training i also tried out adam optimizer and rms prop but those trained much worse than the current optimizer then i perform a classical batch training with 256 graph samples for each batch i use the pytorch geometric data loaders to iterate over the train data set and the test data set so i have one separate loader for each of the data sets and i also shuffle the data set i have a train and test function that simply pass the data through the model and then calculates the loss and back propagate it and that's pretty much it another important points are the metrics at the moment i mainly look at the area under the curve metric as well as precision and recall that's because we have an unbalanced data set and we want to model to get all the hiv inhibitors correct if our model would only predict no inhibitor on the test data set we would already get 95 accuracy but a very low precision and recall another thing i want to mention here is ml flow as you might have already seen i use this library to lock all the metrics and the model this not only allows me to track the training but i can also use it later for deploying the model it provides an api rest endpoint that i can use to get predictions from other applications logging works by simply wrapping everything in this start run function and then calling log metric or log model and everything will be locked in a local database by running ml flow ui in the console a server is launched and i can use the server to have a look at the logs so for example here i can also see the previous training runs i did i aborted most of them as this was mainly for testing and if i click on this for instance i've locked the roc the area under the curve and the test and train loss and here on the artifacts i will also lock the model if i run this last cell where it says log model we can see we can also smooth it a little bit we can see the loss is more or less decreasing i can also have a look at other metrics so we see the roc is increasing and generally it's an easy tool a simple to use tool to lock everything so now i would say we run everything so let's go down here run above so again i load the data set i load the model and define optimizers and then in this loop here i simply run over the tests and train sets so every fifth epoch i um test the current test performance so now this training will take some time and i cut the video here all right so now let's have a look at the current results of this model so here i'm back in ml flow where i tracked the training and i ran one training for 70 epochs as you can see here the train loss which is the orange curve decreases and the test loss decreases as well which is good however at that specific point the model starts to slightly overfit i will run a training for a couple of hundred epochs to see if this is really true because this already took a couple of hours and if we take a look now at the metrics we can see that the model is certainly improving so we see that at the moment we get for the train data a area under the curve of 0.86 and for the test data we get 0.74 but of course i'm not satisfied with this result yet but i also have to say that i'm already happy that the training works and it took really some time to get everything set up so if we take a look now at the confusion matrix so i plot all of these metrics after each epoch for train and test data and this is the confusion matrix for the test data so we have approximately 20 000 molecules which are no hiv inhibitor and around 600 or 700 which are hiv inhibitors and we would have a perfect model if all values were on the diagonal of this matrix and you can see that some of the non-inhibitors are classified as inhibitors around 300 and we also see that some of the inhibitors are actually misclassified as no inhibitor so we would like to see more of these values on the diagonal and i would say misclassifying a couple of them is not too bad because i assume we will never get a perfect model here but the goal would be to get a higher precision at the moment we are at 0.4 approximately and also to get a higher recall at the moment also 0.4 all right and finally the last point would be what are my next steps so as i said i want to model to get an area under the curve of at least 0.8 i also want to mention that i already tried out different loss functions such as focal loss which are better suited for imbalanced data sets my biggest hope is that incorporating the edge features will improve the model further so let's see also of course if you have more ideas i'm happy to incorporate them and try them out in the current model so just leave a comment if you have any idea how i could improve the current results with that i look forward to the next video where i hopefully can finish this part of the series and then continue with building this generative model thanks for watching and have a nice day
Info
Channel: DeepFindr
Views: 1,327
Rating: 5 out of 5
Keywords: GNN, Graph-level predictions, MLFlow, Pytorch Geometric
Id: CJ-KdKCeiYg
Channel Id: undefined
Length: 18min 34sec (1114 seconds)
Published: Tue Jun 08 2021
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.