Graph Embeddings

Video Statistics and Information

Video
Captions Word Cloud
Reddit Comments
Captions
- So what I'm going to be talking about in this session is graph embeddings. Some of you might have joined my earlier talk about machine learning with graphs. This is kind of a much more technical deep dive into one aspect of machine learning with graphs, which is graph embeddings which is basically a way of representing a graph so you can leverage that in a machine learning model, deep learning, calculate similarity. I'm Alicia Frame, I'm the lead data scientist at Neo4j. I work on the product team and my role is really to help build out our graph algorithms library and our future data science kind of features and roadmap. I also work with our early adopters to help them get up and running in production with graph algorithms and actually doing data science with Neo4j. I'm super excited to share kind of a topic that's really close to my heart which is graph embeddings. So in this talk what I'm going to be talking about it really starting off with what is an embedding? I think often this word is used without people actually knowing what is this, what does it represent, and what can I use it for? I think the easiest way to understand embedding is to really start from motivating examples. What I'm gonna do is I'm going to start actually by explaining the word embedding because I think we all speak, read, use words. It's easier to kind of wrap your head around an embedding when you talk about it with something familiar. Then once we kind of have that understanding of what is a word embedding, build on that to talk through a simple graph embedding. Once we kind of have that foundation, give a quick overview of graph embeddings. Like what are they, what do you use them for? Talk about different techniques to calculate them. And then kind of finally end with graph embeddings in Neo4j. So this talk is much more kind of a state of the science forward-looking talk. Where are we going next instead of let me sell you something that's already built. I liken my talks to kind of have this one slide where if you're not gonna listen to anything else this slide is the one that matters. What on earth is an embedding? This is something that I honestly struggled with when I first started in this space. What is an embedding? If you Google it you get some very kind of confusing and opaque answers, right? Google says it's a low-dimensional space where you can translate high-dimensional vectors. That doesn't help me. Wikipedia is worse, right? It's an instance of some mathematical structure in another instance, such as a group. Huh? The way I think of it is an embedding is just a way of mapping something complicated, a document, an image, a graph, into something simple, a fixed-length vector so a bunch of numbers or a matrix that captures the key features of that complicated thing while making it lower dimensionality or fewer things. So if you think about a word embedding, a word comes out of a book. I have a big thick book, but my word embedding is actually just you know, 20 digits. I'm taking something complicated and making it into something simple. Graph embeddings are a specific type of embedding where you're trying to translate a graph or part of a graph into a fixed length vector. This is from a DeepWalk paper. You have your graph, you could represent it as an adjacency matrix, right? You've one hot encoded which node is attached to which other node? It's a large matrix and it's gonna be very sparse. So your embedding just translates that into a much smaller matrix, lower dimensional that you can then use for some kind of task. What we're doing is we're taking, we're learning the important features of your graph and distilling that giant adjacency matrix into something easier to work with. And you may be thinking well why bother? You just said okay we're gonna do a lot of math to take something complicated and make it into something simpler. A question that came up in my other talk is why would I take tabular data, put it into a graph, and then try to make it into a vector? It's this idea of you want to translate something complex into something simple that you can use. The embedding captures the important features of your input object for the task you want to use it for in a compact way. And that representation then can basically represent your graph for some kind of machine learning or deep learning approach. Generally speaking, not always the case, but they learn what's important in an unsupervised, generalizable way. So when we talk about how do you train a word embedding or how do you train DeepWalk, you can apply that technique to any context window for the word or any graph for DeepWalk. And you end up with a custom representation for your use case that's generalizable. Motivating examples, I feel like this is still maybe a little hand-wavy so I like to really get down into the weeds. I think word embeddings are the easiest way to wrap your head around what this is. What I want to do here is I want to represent a word in a way that I can use it mathematically, right? So how similar are two words? Maybe I want to know can I use representation of this word in a model? So I want a way of representing my word so I could predict what word should come next. Or what words are around this word? And if you ask someone like what's the simplest way to do this, you'd say well maybe I could just one hot encode all the letters of the alphabet, right? So cat is 1-0-1 and then some more zeros and a one. The problem here is this doesn't tell me anything about what the word means and anagrams would be the same. Maybe that's probably not what I want to do. Or maybe I want to do a hand-engineered rule. So I have every word and I know something about the word, right? I speak English, I know what these words mean so I could make rules like categories where I could do in and assign weights. Where I say well you know, I want to encode king, queen, woman, and princess. Well you know, are they royalty? I can assign a number there. They're men, I can put a number there. Are they women, I can put a number there. But this doesn't work at scale, right, because it requires that I know something about it. Then I hand-encode that for every word. So what could I do instead? Well I could start by staying well my words exist in documents, right? So I want to encode cat, right? Maybe I have a bunch of documents that contain a bunch of words. I can say well I bet the documents my words are in tell me something about those words. So I could say well you know, I have each term that I'm interested in and how many times does it show up in a document? And then my word vector could just be how many times was this word in each of the documents I've looked at? The problem there is that maybe that's informative, right? It might help me tell the difference between different documents but I'm gonna end up with certain words like is, and, and the. Those are going to be way overrepresented and maybe words that are actually really important just don't show up that often. So what you can do instead is you can talk about something like TF-IDF or weighted term frequency. Where instead of just saying this is the raw count, I normalize by how often a word shows up. But these still aren't really telling me anything about what the words themselves mean. The first one just tells me I don't know which words are in which documents that might help me tell the difference between documents. The second one maybe is telling me which words help distinguish documents. But maybe what I really want to do is I want to say you know, what is the relationship between woman and princess? And if I see the word princess what word goes next? So what I can do instead is I can start thinking about context windows, right? So what this is, is if you think about it words exist in sentences, right? And the context around a word helps you understand what it means. This is just like when you're learning to speak another language or if you run across a word you've never seen before you use the words around it to guess what it means. A great example is if you say well Tylenol is a pain reliever. Then you have another sentence that says Paracetamol is a pain reliever. Even if you don't know what Paracetamol is you see that they're in the same context and they probably mean the same thing. So what you can do from here is you can start looking at co-occurrence or how often do they show up in the same context window and specifying a context window. So how far away in my sentence or document do I want to go that I think still has meaning for knowing what this word means? A simple example on the right-hand side is I have three sentences. He is not lazy, he is intelligent, he is smart. Maybe I want to look at is and I want to say well I have a context window of one behind and one forward. My focal word is, is. And then what's the context for is? So I can make a matrix where I represent for every word in all of this corpus of three sentences I've created, what word comes next? And I can say well look, hey smart and lazy never show up together. I bet they don't mean the same thing. And that's kind of useful. Maybe I could just do this for a lot of words and I'd have this context window. Maybe I could use a context window to predict something. Why not just stop here, right? A context window seems useful. It tells me what comes before and what's after. The problem is you need lots and lots of documents to understand context. But the more documents you have the bigger your matrix is. So if we go back, every document I add that was three very short sentences. Every document I add is going to add to this matrix. And you end up with a very sparse matrix. A lot of words don't show up together. How do I handle this? What you really want to do is you want to take that giant matrix you've created of context and squish it down into something you can work with. So if you're listening to this and you've ever taken a linear algebra class you're probably like okay, when are we going to get to singular value decomposition? Linear algebra knows how to fix this. You know how to take a giant matrix and reduce the dimensionality. So you can do something like SVD. It's great, preserves your relationships. It's accurate, we know how to do it. The problem here is that it requires a lot of memory. And it's not for a specific task, right? Singular value decomposition is how do I decompose a matrix? Maybe I want something specific for what word is next? Maybe I can do some optimization so I don't have to pull that entire adjacency matrix into memory. That's when you get to these predictive methods. So you say I actually want to say for every word that I am interested in I want to pull out its context window and I want to predict something about the relationship between that word and its context window. So those two kind of classic models in this space that we talk about there's CBOW and SkipGram. CBOW is given a context window, what is the word? So if I go back if I see the quick blank, or blank brown fox, what word is missing? SkipGram is I have a word, I want to know what's next or what was behind it, right? So they're pretty similar. Generally speaking a lot of our embeddings are based off of SkipGram just because it performs better. So I'm going to talk about the SkipGram model which is what's behind word2vec. That is the most common word embedding. Everyone here has probably heard of it. Maybe you've used a pre-calculated word2vec embedding. This is where it comes from. So SkipGram you learn the vector representation for each word that maximizes the probability that that word, what is the next word? Your input vector is your one hot encoded vector for that word. How often, what is my word? How often does it show up? Then I have a hidden layer which is where the weights are assigned. Then what I want to be able to do is predict the probability for every word in my corpus. So my set of documents, that it's the next word. In this example my focal word is ants, and I have every word I have ever seen and I want to predict what is the next most likely word. And you're probably like this sounds very specific and maybe not terribly useful. The cool thing here is that we don't actually care about this output layer. What we care about is the hidden layer. The hidden layer is your weight matrix where you take each of those values in your input vector and you're assigning a weight to calculate the output layer. What you're doing is you're using forward and back propagation and gradient descent to learn what the right weights are. So the hidden layer is a weight matrix with one row per word, one column per neuron. This is what your embedding actually is. You train SkipGram to learn given a word what's the context? And what you take away from that is the hidden layer and that hidden layer is your word embedding. And once you have your word embedding the exciting thing is this is your condensed representation of your word that preserves context. So instead of having king, queen, man, woman and I've done all of this work to hand engineer things where I have this giant matrix, I have a very concise low-dimensional representation of each word as a bunch of numbers. What's cool about that is that it still preserves the context. So the relationship between the embeddings of king and queen, and man and woman is the same as the relationship between them in the full context. You can do cool things like look at the distance between different word embeddings and understand how those words are related to each other. So you can look at things like verb tense. The distance between walking and walked is the same as the distance between swimming and swam. Or gender, so king and queen, king is to queen as man is to woman. You've probably heard these examples before. They're super powerful. And you're probably like cool, now I know what a word embedding is. But what does this have to do with graphs? So the reason I start off with word embeddings is that they're really intuitive, right? I know what a word is, I want to represent it. With a graph embedding you can kind of think of all of the nodes in your graph like words. And your graph is like this corpus of text. And you want to learn what does this node mean? Instead of having a giant adjacency matrix for my graph what I actually want to have is a short embedding just like I had for my word. So what you can do is one of the first and kind of most widespread embeddings, very simple for graphs is DeepWalk. The idea of DeepWalk is just kind of saying how do you represent a node in a graph mathematically? And so what this is is it's basically a simple adaptation of word2vec. So we just walked through the derivation of that. Where does it come from? DeepWalk basically says every node in my graph is like a word, and the neighborhood around every node in my graph is like its context window. And I want to use that SkipGram model just like before to predict given this node and its context window what should be after it? So stepping through this, what you want to do is you start by taking every node in your graph and you want to make sentences for it. So you don't have those natively so what you do is you take every node and you do a random walk across your graph and you do a fixed number of these. So in this example I've got four. So I take every node in my graph, I do four random walks of a fixed length. Then I've generated by sentences from which I get my context window. Once I have those sentences I can extract the context windows and then I use the same SkipGram model and I learn the weights. So just like before the objective is given my input vector I want to predict the neighboring nodes. But all I'm taking out of that is the hidden layer. So the embeddings from DeepWalk or the hidden layer weights from your SkipGram model. And I basically have computed embeddings that now represent every node in my graph. This is just one example so I like to walk through two concrete examples to make this kind of intuitive and hands-on. So you now know how to calculate two embeddings. You could go off and implement these yourselves. There are a lot of other methodologies out there. There are matrix factorization approaches, hand engineered approaches for graphs just like there's a lot of different ways to do word embeddings. So that's what this next section is about. So we understand one simple example, what are they in general? When I talk about graph embeddings I think it's really helpful to break it down into different types. I think you can break it out by what type of graph are you trying to create embedding for? The easiest way to say this is are you using a monopartite graph? So all the nodes in your graph are the same type. DeepWalk is for a monopartite graph. What this is if DeepWalk is for a social network, right? People know each other so Alicia knows Jake, Jake knows Phillip. We're nodes in a social graph. But you could also have a multipartite graph where you have a bunch of different types of nodes. You have Alicia listens to this song. This song is on this record, this record was produced by this company. You need to treat each class of node, each node label, differently. You need different mathematics. This is really applicable to knowledge graphs. Generally speaking you want to use the embedding that is derived for the graph you are trying to run it on. The other piece that you want to consider is what aspect of the graph are you trying to represent as an embedding? So you could have a vertex embedding which is basically I want an embedding for every node. If we think back to walking through DeepWalk, DeepWalk was a vertex embedding, right? We're figuring out how do we represent every node in the graph? Path embeddings instead are basically traversals across the graph. If you were at my previous talk I talked about eBay's kind of recommendation path embeddings. One thing I've worked quite a bit with is looking at journey embedding. So patient journeys, someone you have a graph of patients and their contacts with the hospital system. Every patient's contact with doctors and prescriptions and physical therapists is a path across that graph. Every patient has a series of encounters. I actually want to pull out that path and I want to represent that as an embedding to use it to say how similar two patients' journeys across my graph. Then kind of the final category is the graph embedding where I want to take my entire graph and encode it into a single vector. Maybe you want to do this with something like molecules where every molecule can be represented as an individual graph. Or maybe you have time serious data and you want to say here is my graph at time one. Here is my graph at time two and embed each one. So the figure on the bottom basically shows you can take a single graph and input and you can find all different kinds of, oh I don't want to see that. So I can find all different kind of embeddings from it. If I start off with my node embedding I create an embedding for every node in that input graph. My edge embedding is an embedding for every edge. You can actually combine things like graph algorithms in your embeddings. Maybe you want to do a substructure embedding. Maybe I go in and I label propagation to break my graph into subsets. Then I create an embedding for each of those subsets. Then you have your whole graph embedding where you end up with only one data point. Now generally speaking, when we talk about node embeddings and most embeddings you have kind of four things that you need. You need a similarity function that measures how similar are any two nodes in your graph? That helps you tell is it okay to just use the neighborhood? I have an encoder function that generates the node embedding and then I have my decoder function that reconstructs the pairwise similarity. And I have a loss function that measures how good my reconstruction is. So if we think about that idea of we're training our SkipGram model to learn a vector that represents every word so I can predict what word is next, I need a loss function to know if my prediction of what word is next is good or not, right? So every node embedding kind of generally speaking will have these steps. You need some way to embed your node. You need some way to decode your node. Then you need some way of measuring is this embedding any good or not? Because an embedding that doesn't represent anything about my graph is not useful. Generally speaking you can talk about shallow or deep graph embedding techniques. Shallow embedding techniques are basically where your encoder function is a look up. So what we've talked about so far with matrix factorizations and random walks, these are just look up methods. These are techniques that rely on an adjacency matrix input or kind of neighborhoods. Matrix factorization you're looking at adjacency matrix. I want to apply some kind of transformation directly to that adjacency matrix or some transformation of it. Or the other category in here is these random walks which is what we talked about before. So the random walk is I have every node. I want to take a random walk from that node and then I want to learn some weight to optimize the similarity measures. These are useful, why not stop here? Just like when we talked about with the word embeddings, the problem with matrix factorization is that you have a massive memory footprint and it's computationally intense. So you're basically pulling your entire graph into memory. You're doing very intensive operations and it's not always the optimal way to go for a real large graph. So cool so maybe we just use random walks. The limitations of random walks is you're getting a local only perspective, right? So we talked about how do I generate my context window? Well I'm looking at a random walk from each node and I'm looking at six hops out, right? When I do this I'm assuming that similar nodes are closer together. What do I do instead? Let's say these aren't good enough. I want to get more complicated. Why would I ever want to move on? Shallow embeddings generally speaking are inefficient. You don't have parameters shared between nodes. This is a question someone asked me during my last talk of can you use node attributes? And in this case no, you can't. And you're only looking at generated embeddings for nodes that are present when you're embedding was trained. If you're constantly adding data to your graph you're going to have to constantly have to retrain your embedding. So newer methodologies actually switch over to compressing your information using things like neighborhood autoencoders, neighborhood aggregation and convolutional autoencoders. So this is basically instead of saying okay either I have my full adjacency matrix or I have these specified random walks, I want to have a high dimensional neighborhood vector. So remember what we were coming up with for every node before, you could think of that as a one dimensional vector. I have a series of numbers. You have a high dimensional vector that represents your full neighborhood around your node, your proximity to all the other nodes. And then you're compressing that high dimensional neighborhood vector into something lower dimensional. This is something like if you've looked at any of the VAs. If you've ever used DeepChem. Graph Sage uses neighborhood stuff. These are kind of the more complicated methods. This is basically a preview of where the state of the science is. But a lot of the current work and kind of often when you see in production is based on simpler methods like DeepWalk. So cool, we've gone to all this work to understand what these are, understand how to get them. What do I use them for? When we're talking about a graph you know one of the easiest things to do is you can talk about visualization and pattern discovery. When you have a giant graph that's really useful, it's really cool to explore locally but how do I visualize my billion node graph all at once? You can leverage lots of existing methodologies and you can do things like a t-SNE plot or a PCA where you basically project those embeddings. And instead of trying to look at a billion nodes at once you can say hey, in embedded space all of these nodes are closer together, right? Or you can use those embeddings to capture kind of high dimensional information about relationships with something like a standard clustering method like k-means that lets you look at both kind of the functional and structural data from your graph. Or you can do one of my favorite things is once you have an embedding you can take all of those embeddings and calculate pairwise similarity and take your graph or take your graphing data and come up with a KNUS neighbors graph. So I want to say I had this social graph or I had my knowledge graph of Alicia listened to these songs. I want to come up with a graph of who is most similar to me based on what songs they listen to? And I can use an embedding to represent that similarity and quickly come up with the nearest neighbors graph. You can also use these as inputs to your machine learning models. Typically speaking you see node classification. So I have these embeddings to represent my nodes. I want to predict is this person male or female? Or I want to say is this node going to buy my product or not? You can use embeddings to predict missing node attributes. Maybe I have a graph, some data's missing. I want to infer it. I can use the most similar embeddings to my node of interest to infer that node's attribute. You can also use embeddings, and this is actually what I've spent a lot of time on historically, is for link predictions. So you can have your embeddings and you want to predict edges that are currently not present in your graph. You can either use similarity measures or heuristics like we talk about in the graph algorithm space. Or you can use the machine learning pipeline to actually build a model to say yes, there should be a link here based on the input feature of my embedding. You can think of embeddings as a way of making the graph algorithm library or graph algorithms in general even more powerful. They're kind of like a special algorithm that is trained for your specific graph for your use case and for the thing you want to do. So this is cool, when can I get them, right? This is almost always the next question. Graph embeddings in Neo4j, everyone wants them. So what we have right now are there are two implementations that were created by Neo4j labs. A prototype implementation of DeepWalk and DeepGL. I didn't really talk about DeepGL today. DeepGL is kind of more similar to those handcrafted embeddings we talked about in the beginning. It uses graph algorithms to generate the features. And then you kind of have a diffusion process of values across edges and some simple dimensionality reduction. These are really proof of concept. So it's can we create an embedding and what does it look like? Neither one is really ready for production use but we've learned a lot. The first thing we have learned is from my first day on the job I was getting emails of help make this scalable, I want to use this. These prototype implementations were really put together to say can we do this and what might it look like? So they're not tuned for performance. They're quite memory intensive. And also the limitation here is that Deep Learning isn't easy off the shelf in java. So an alternate approach and one thing that we've seen, I know we have a graph hack entrant with PYMBO. I know there's a talk from another company who's using embedding also going on at nodes. You can use Python to pull your data out from Neo4j and use an off the shelf library like PyTorch or Keras or GenSem to train in embedding. We found from our experimentation there that it's really good to get started. It's an easy way of pulling my data out, using something off the shelf. But often it doesn't perform at scale due to IO limitations where you're pulling data in and out of your graph. Or these aren't really optimized to take use of all of the stuff that we have inside Neo4j. Probably the question you're having is what's next? What I can say is that my team is actively looking at the best ways to implement graph embeddings at scale and please stay tuned. So this is a forward looking talk. I would say they are clearly on our road map, we're developing them. And we're excited to hear what people want. So with this I'm going to stop and pause for the Hunger Games questions and let everyone take some time to answer those while I pull up the Q and A. I think for this one I don't have Jennifer on the line so I will have to manage. - [Jennifer] I am actually here. - Oh you're here so maybe I can leave the Hunger Games up while you do Q and A? - Yes. - Awesome. - [Jennifer] So someone asked are the embeddings of any use in climate analysis? - I mean generically speaking climate analysis would use a lot of high dimensional data, right? You have spatial data, image data, weather patterns, time series. So an embedding is a way of representing that in a lower dimensional space. If you're talking about specifically something like DeepWalk it would depend on the question you were trying to answer and kind of how you formulated that as a graph. So generically speaking embeddings are useful for almost anything where you have something complicated. Whether or not it's useful in the context of a graph embedding depends on kind of your data architecture. - [Jennifer] Okay, someone asked about dask to paralyze? I don't know. - I don't know what that is, but I would be happy to connect afterwards. - [Jennifer] Okay, another one asked any ETA on production ready release? - I knew I would get that question. I would say look at six months to one year time horizon. - [Jennifer] Okay then when you mentioned Python performance is it Python with Neo4j or general Python ML tools? - It depends, you can use libraries that are let's say, GPU compute supported, right? So a lot of the Python libraries are wrappers for C-libraries. They are highly optimized, it's really the reshaping your data in and out where they kind of fall over. And no matter where your data is coming from you're still having to do that IO step. Python itself is incredibly powerful. I've used libraries like PyTorch and Keras in the past. Not saying never use Python, I'm saying Python with Neo4j is great to get started. But once you get into the millions, billions of nodes it doesn't scale. - [Jennifer] And then someone asked with the embedding performance issue is it in extracting the data? Would be interesting to talk further. - I would be happy to talk further. It depends on the approach you're taking. Whether it's extracting the data, whether it's with the example of DeepWalk creating the walks in your graph. You can think about it like there's a lot of different points where you can optimize. - [Jennifer] Okay and I think that's all the questions. If there were any that were not answered or were missed, please do remember you can sign up on the community site at community.neo4j.com. Feel free to ask your questions there and any one of us will be happy to take a look at those. I think if that's it's Alicia, I think we're good. - Awesome, thank you everyone for joining.
Info
Channel: Neo4j
Views: 16,735
Rating: 4.9266057 out of 5
Keywords:
Id: oQPCxwmBiWo
Channel Id: undefined
Length: 31min 39sec (1899 seconds)
Published: Thu Nov 07 2019
Related Videos
Note
Please note that this website is currently a work in progress! Lots of interesting data and statistics to come.