- So what I'm going to be
talking about in this session is graph embeddings. Some of you might have
joined my earlier talk about machine learning with graphs. This is kind of a much
more technical deep dive into one aspect of machine
learning with graphs, which is graph embeddings
which is basically a way of representing a graph
so you can leverage that in a machine learning
model, deep learning, calculate similarity. I'm Alicia Frame, I'm
the lead data scientist at Neo4j. I work on the product
team and my role is really to help build out our
graph algorithms library and our future data
science kind of features and roadmap. I also work with our early
adopters to help them get up and running in
production with graph algorithms and actually doing data
science with Neo4j. I'm super excited to share kind of a topic that's really close to my heart
which is graph embeddings. So in this talk what I'm
going to be talking about it really starting off
with what is an embedding? I think often this word
is used without people actually knowing what is
this, what does it represent, and what can I use it for? I think the easiest way
to understand embedding is to really start from
motivating examples. What I'm gonna do is I'm
going to start actually by explaining the word
embedding because I think we all speak, read, use words. It's easier to kind of wrap
your head around an embedding when you talk about it
with something familiar. Then once we kind of
have that understanding of what is a word embedding,
build on that to talk through a simple graph embedding. Once we kind of have that foundation, give a quick overview of graph embeddings. Like what are they, what
do you use them for? Talk about different
techniques to calculate them. And then kind of finally
end with graph embeddings in Neo4j. So this talk is much more
kind of a state of the science forward-looking talk. Where are we going next instead
of let me sell you something that's already built. I liken my talks to kind
of have this one slide where if you're not gonna
listen to anything else this slide is the one that matters. What on earth is an embedding? This is something that I
honestly struggled with when I first started in this space. What is an embedding? If you Google it you get
some very kind of confusing and opaque answers, right? Google says it's a low-dimensional space where you can translate
high-dimensional vectors. That doesn't help me. Wikipedia is worse, right? It's an instance of some
mathematical structure in another instance, such as a group. Huh? The way I think of it is
an embedding is just a way of mapping something
complicated, a document, an image, a graph, into something simple, a fixed-length vector
so a bunch of numbers or a matrix that captures the key features of that complicated thing
while making it lower dimensionality or fewer things. So if you think about a word embedding, a word comes out of a book. I have a big thick book,
but my word embedding is actually just you know, 20 digits. I'm taking something
complicated and making it into something simple. Graph embeddings are a
specific type of embedding where you're trying to translate a graph or part of a graph into
a fixed length vector. This is from a DeepWalk paper. You have your graph,
you could represent it as an adjacency matrix, right? You've one hot encoded
which node is attached to which other node? It's a large matrix and
it's gonna be very sparse. So your embedding just translates that into a much smaller matrix, lower dimensional that you can then use for some kind of task. What we're doing is we're taking, we're learning the important
features of your graph and distilling that giant adjacency matrix into something easier to work with. And you may be thinking well why bother? You just said okay we're
gonna do a lot of math to take something complicated and make it into something simpler. A question that came up in
my other talk is why would I take tabular data, put it into a graph, and then try to make it into a vector? It's this idea of you want to
translate something complex into something simple that you can use. The embedding captures
the important features of your input object for the
task you want to use it for in a compact way. And that representation
then can basically represent your graph for some
kind of machine learning or deep learning approach. Generally speaking, not always the case, but they learn what's
important in an unsupervised, generalizable way. So when we talk about how do
you train a word embedding or how do you train
DeepWalk, you can apply that technique to any context
window for the word or any graph for DeepWalk. And you end up with a
custom representation for your use case that's generalizable. Motivating examples, I
feel like this is still maybe a little hand-wavy
so I like to really get down into the weeds. I think word embeddings
are the easiest way to wrap your head around what this is. What I want to do here
is I want to represent a word in a way that I can
use it mathematically, right? So how similar are two words? Maybe I want to know can I use
representation of this word in a model? So I want a way of representing my word so I could predict what
word should come next. Or what words are around this word? And if you ask someone like
what's the simplest way to do this, you'd say
well maybe I could just one hot encode all the letters
of the alphabet, right? So cat is 1-0-1 and then
some more zeros and a one. The problem here is this
doesn't tell me anything about what the word means and anagrams would be the same. Maybe that's probably
not what I want to do. Or maybe I want to do
a hand-engineered rule. So I have every word and I know something about the word, right? I speak English, I know
what these words mean so I could make rules
like categories where I could do in and assign weights. Where I say well you know,
I want to encode king, queen, woman, and princess. Well you know, are they royalty? I can assign a number there. They're men, I can put a number there. Are they women, I can put a number there. But this doesn't work at scale, right, because it requires that
I know something about it. Then I hand-encode that for every word. So what could I do instead? Well I could start by
staying well my words exist in documents, right? So I want to encode cat, right? Maybe I have a bunch of
documents that contain a bunch of words. I can say well I bet the
documents my words are in tell me something about those words. So I could say well you
know, I have each term that I'm interested in and how many times does it show up in a document? And then my word vector
could just be how many times was this word in each of the
documents I've looked at? The problem there is that maybe
that's informative, right? It might help me tell the difference between different documents
but I'm gonna end up with certain words like is, and, and the. Those are going to be way overrepresented and maybe words that are
actually really important just don't show up that often. So what you can do instead
is you can talk about something like TF-IDF or
weighted term frequency. Where instead of just saying
this is the raw count, I normalize by how often a word shows up. But these still aren't really telling me anything about what the
words themselves mean. The first one just tells me I don't know which words are in which
documents that might help me tell the difference
between documents. The second one maybe is
telling me which words help distinguish documents. But maybe what I really
want to do is I want to say you know, what is the relationship between woman and princess? And if I see the word
princess what word goes next? So what I can do instead
is I can start thinking about context windows, right? So what this is, is if you think about it words exist in sentences, right? And the context around a
word helps you understand what it means. This is just like when
you're learning to speak another language or if
you run across a word you've never seen before
you use the words around it to guess what it means. A great example is if you say well Tylenol is a pain reliever. Then you have another
sentence that says Paracetamol is a pain reliever. Even if you don't know what Paracetamol is you see that they're in the same context and they probably mean the same thing. So what you can do from here
is you can start looking at co-occurrence or how
often do they show up in the same context window and specifying a context window. So how far away in my sentence or document do I want to go that I
think still has meaning for knowing what this word means? A simple example on the right-hand side is I have three sentences. He is not lazy, he is
intelligent, he is smart. Maybe I want to look
at is and I want to say well I have a context window of one behind and one forward. My focal word is, is. And then what's the context for is? So I can make a matrix where
I represent for every word in all of this corpus of
three sentences I've created, what word comes next? And I can say well
look, hey smart and lazy never show up together. I bet they don't mean the same thing. And that's kind of useful. Maybe I could just do
this for a lot of words and I'd have this context window. Maybe I could use a context
window to predict something. Why not just stop here, right? A context window seems useful. It tells me what comes
before and what's after. The problem is you need
lots and lots of documents to understand context. But the more documents you
have the bigger your matrix is. So if we go back, every document I add that was three very short sentences. Every document I add is
going to add to this matrix. And you end up with a very sparse matrix. A lot of words don't show up together. How do I handle this? What you really want to
do is you want to take that giant matrix you've
created of context and squish it down into
something you can work with. So if you're listening to
this and you've ever taken a linear algebra class
you're probably like okay, when are we going to get to
singular value decomposition? Linear algebra knows how to fix this. You know how to take a
giant matrix and reduce the dimensionality. So you can do something like SVD. It's great, preserves your relationships. It's accurate, we know how to do it. The problem here is that it
requires a lot of memory. And it's not for a specific task, right? Singular value decomposition
is how do I decompose a matrix? Maybe I want something
specific for what word is next? Maybe I can do some
optimization so I don't have to pull that entire
adjacency matrix into memory. That's when you get to
these predictive methods. So you say I actually
want to say for every word that I am interested in I want to pull out its context window and I
want to predict something about the relationship between that word and its context window. So those two kind of
classic models in this space that we talk about
there's CBOW and SkipGram. CBOW is given a context
window, what is the word? So if I go back if I see the quick blank, or blank brown fox, what word is missing? SkipGram is I have a word,
I want to know what's next or what was behind it, right? So they're pretty similar. Generally speaking a lot of our embeddings are based off of SkipGram just
because it performs better. So I'm going to talk
about the SkipGram model which is what's behind word2vec. That is the most common word embedding. Everyone here has probably heard of it. Maybe you've used a
pre-calculated word2vec embedding. This is where it comes from. So SkipGram you learn
the vector representation for each word that
maximizes the probability that that word, what is the next word? Your input vector is your
one hot encoded vector for that word. How often, what is my word? How often does it show up? Then I have a hidden layer which is where the weights are assigned. Then what I want to be
able to do is predict the probability for
every word in my corpus. So my set of documents,
that it's the next word. In this example my focal word is ants, and I have every word I have
ever seen and I want to predict what is the next most likely word. And you're probably like
this sounds very specific and maybe not terribly useful. The cool thing here is
that we don't actually care about this output layer. What we care about is the hidden layer. The hidden layer is your
weight matrix where you take each of those values in your input vector and you're assigning a weight to calculate the output layer. What you're doing is you're using forward and back propagation and
gradient descent to learn what the right weights are. So the hidden layer is a weight matrix with one row per word,
one column per neuron. This is what your embedding actually is. You train SkipGram to learn given a word what's the context? And what you take away from
that is the hidden layer and that hidden layer
is your word embedding. And once you have your word embedding the exciting thing is this is
your condensed representation of your word that preserves context. So instead of having
king, queen, man, woman and I've done all of this
work to hand engineer things where I have this giant matrix, I have a very concise
low-dimensional representation of each word as a bunch of numbers. What's cool about that is
that it still preserves the context. So the relationship between the embeddings of king and queen, and
man and woman is the same as the relationship between
them in the full context. You can do cool things
like look at the distance between different word
embeddings and understand how those words are related to each other. So you can look at things like verb tense. The distance between walking
and walked is the same as the distance between swimming and swam. Or gender, so king and
queen, king is to queen as man is to woman. You've probably heard
these examples before. They're super powerful. And you're probably like cool, now I know what a word embedding is. But what does this have to do with graphs? So the reason I start
off with word embeddings is that they're really intuitive, right? I know what a word is,
I want to represent it. With a graph embedding
you can kind of think of all of the nodes in
your graph like words. And your graph is like
this corpus of text. And you want to learn
what does this node mean? Instead of having a giant adjacency matrix for my graph what I actually want to have is a short embedding just
like I had for my word. So what you can do is
one of the first and kind of most widespread embeddings, very simple for graphs is DeepWalk. The idea of DeepWalk
is just kind of saying how do you represent a node
in a graph mathematically? And so what this is is it's
basically a simple adaptation of word2vec. So we just walked through
the derivation of that. Where does it come from? DeepWalk basically says
every node in my graph is like a word, and the
neighborhood around every node in my graph is like its context window. And I want to use that
SkipGram model just like before to predict given this node
and its context window what should be after it? So stepping through
this, what you want to do is you start by taking
every node in your graph and you want to make sentences for it. So you don't have those
natively so what you do is you take every node
and you do a random walk across your graph and you
do a fixed number of these. So in this example I've got four. So I take every node in my
graph, I do four random walks of a fixed length. Then I've generated by
sentences from which I get my context window. Once I have those sentences I can extract the context windows and then
I use the same SkipGram model and I learn the weights. So just like before the objective
is given my input vector I want to predict the neighboring nodes. But all I'm taking out of
that is the hidden layer. So the embeddings from
DeepWalk or the hidden layer weights from your SkipGram model. And I basically have
computed embeddings that now represent every node in my graph. This is just one example so I like to walk through two concrete
examples to make this kind of intuitive and hands-on. So you now know how to
calculate two embeddings. You could go off and
implement these yourselves. There are a lot of other
methodologies out there. There are matrix factorization approaches, hand engineered approaches
for graphs just like there's a lot of different
ways to do word embeddings. So that's what this next section is about. So we understand one simple example, what are they in general? When I talk about graph
embeddings I think it's really helpful to break it down
into different types. I think you can break it
out by what type of graph are you trying to create embedding for? The easiest way to say
this is are you using a monopartite graph? So all the nodes in your
graph are the same type. DeepWalk is for a monopartite graph. What this is if DeepWalk is
for a social network, right? People know each other
so Alicia knows Jake, Jake knows Phillip. We're nodes in a social graph. But you could also have
a multipartite graph where you have a bunch of
different types of nodes. You have Alicia listens to this song. This song is on this record,
this record was produced by this company. You need to treat each class of node, each node label, differently. You need different mathematics. This is really applicable
to knowledge graphs. Generally speaking you
want to use the embedding that is derived for the
graph you are trying to run it on. The other piece that you want to consider is what aspect of the graph
are you trying to represent as an embedding? So you could have a vertex
embedding which is basically I want an embedding for every node. If we think back to
walking through DeepWalk, DeepWalk was a vertex embedding, right? We're figuring out how do
we represent every node in the graph? Path embeddings instead
are basically traversals across the graph. If you were at my previous
talk I talked about eBay's kind of recommendation path embeddings. One thing I've worked quite a bit with is looking at journey embedding. So patient journeys,
someone you have a graph of patients and their contacts
with the hospital system. Every patient's contact with
doctors and prescriptions and physical therapists is
a path across that graph. Every patient has a series of encounters. I actually want to pull
out that path and I want to represent that as an
embedding to use it to say how similar two patients'
journeys across my graph. Then kind of the final
category is the graph embedding where I want to take my
entire graph and encode it into a single vector. Maybe you want to do this
with something like molecules where every molecule can be represented as an individual graph. Or maybe you have time
serious data and you want to say here is my graph at time one. Here is my graph at time
two and embed each one. So the figure on the
bottom basically shows you can take a single graph and input and you can find all different kinds of, oh I don't want to see that. So I can find all different
kind of embeddings from it. If I start off with my node embedding I create an embedding for
every node in that input graph. My edge embedding is an
embedding for every edge. You can actually combine
things like graph algorithms in your embeddings. Maybe you want to do a
substructure embedding. Maybe I go in and I label
propagation to break my graph into subsets. Then I create an embedding
for each of those subsets. Then you have your whole graph
embedding where you end up with only one data point. Now generally speaking, when
we talk about node embeddings and most embeddings you
have kind of four things that you need. You need a similarity
function that measures how similar are any two
nodes in your graph? That helps you tell is it okay
to just use the neighborhood? I have an encoder function that generates the node embedding and then
I have my decoder function that reconstructs the pairwise similarity. And I have a loss function
that measures how good my reconstruction is. So if we think about that
idea of we're training our SkipGram model to learn
a vector that represents every word so I can
predict what word is next, I need a loss function
to know if my prediction of what word is next
is good or not, right? So every node embedding
kind of generally speaking will have these steps. You need some way to embed your node. You need some way to decode your node. Then you need some way of
measuring is this embedding any good or not? Because an embedding that
doesn't represent anything about my graph is not useful. Generally speaking you
can talk about shallow or deep graph embedding techniques. Shallow embedding techniques are basically where your encoder function is a look up. So what we've talked about so
far with matrix factorizations and random walks, these
are just look up methods. These are techniques that rely
on an adjacency matrix input or kind of neighborhoods. Matrix factorization you're
looking at adjacency matrix. I want to apply some
kind of transformation directly to that adjacency matrix or some transformation of it. Or the other category in
here is these random walks which is what we talked about before. So the random walk is I have every node. I want to take a random
walk from that node and then I want to learn
some weight to optimize the similarity measures. These are useful, why not stop here? Just like when we talked about
with the word embeddings, the problem with matrix factorization is that you have a
massive memory footprint and it's computationally intense. So you're basically pulling
your entire graph into memory. You're doing very intensive
operations and it's not always the optimal way to
go for a real large graph. So cool so maybe we just use random walks. The limitations of random
walks is you're getting a local only perspective, right? So we talked about how do I
generate my context window? Well I'm looking at a
random walk from each node and I'm looking at six hops out, right? When I do this I'm
assuming that similar nodes are closer together. What do I do instead? Let's say these aren't good enough. I want to get more complicated. Why would I ever want to move on? Shallow embeddings generally
speaking are inefficient. You don't have parameters
shared between nodes. This is a question someone
asked me during my last talk of can you use node attributes? And in this case no, you can't. And you're only looking
at generated embeddings for nodes that are present
when you're embedding was trained. If you're constantly
adding data to your graph you're going to have to
constantly have to retrain your embedding. So newer methodologies
actually switch over to compressing your
information using things like neighborhood autoencoders,
neighborhood aggregation and convolutional autoencoders. So this is basically
instead of saying okay either I have my full adjacency matrix or I have these specified random walks, I want to have a high
dimensional neighborhood vector. So remember what we were coming up with for every node before,
you could think of that as a one dimensional vector. I have a series of numbers. You have a high dimensional
vector that represents your full neighborhood around your node, your proximity to all the other nodes. And then you're compressing
that high dimensional neighborhood vector into
something lower dimensional. This is something like if you've looked at any of the VAs. If you've ever used DeepChem. Graph Sage uses neighborhood stuff. These are kind of the
more complicated methods. This is basically a
preview of where the state of the science is. But a lot of the current
work and kind of often when you see in production
is based on simpler methods like DeepWalk. So cool, we've gone to all
this work to understand what these are, understand
how to get them. What do I use them for? When we're talking about
a graph you know one of the easiest things
to do is you can talk about visualization and pattern discovery. When you have a giant
graph that's really useful, it's really cool to explore locally but how do I visualize
my billion node graph all at once? You can leverage lots of
existing methodologies and you can do things
like a t-SNE plot or a PCA where you basically
project those embeddings. And instead of trying to look at a billion nodes at once you can say
hey, in embedded space all of these nodes are
closer together, right? Or you can use those
embeddings to capture kind of high dimensional information
about relationships with something like a
standard clustering method like k-means that lets
you look at both kind of the functional and structural
data from your graph. Or you can do one of my
favorite things is once you have an embedding
you can take all of those embeddings and calculate
pairwise similarity and take your graph or
take your graphing data and come up with a KNUS neighbors graph. So I want to say I had this social graph or I had my knowledge
graph of Alicia listened to these songs. I want to come up with
a graph of who is most similar to me based on
what songs they listen to? And I can use an embedding to represent that similarity and quickly come up with the nearest neighbors graph. You can also use these
as inputs to your machine learning models. Typically speaking you
see node classification. So I have these embeddings
to represent my nodes. I want to predict is this
person male or female? Or I want to say is this node going to buy my product or not? You can use embeddings to
predict missing node attributes. Maybe I have a graph, some data's missing. I want to infer it. I can use the most similar
embeddings to my node of interest to infer
that node's attribute. You can also use embeddings,
and this is actually what I've spent a lot
of time on historically, is for link predictions. So you can have your embeddings
and you want to predict edges that are currently
not present in your graph. You can either use similarity measures or heuristics like we talk about in the graph algorithm space. Or you can use the
machine learning pipeline to actually build a model to say yes, there should be a link here
based on the input feature of my embedding. You can think of embeddings
as a way of making the graph algorithm
library or graph algorithms in general even more powerful. They're kind of like a special algorithm that is trained for your specific graph for your use case and for
the thing you want to do. So this is cool, when
can I get them, right? This is almost always the next question. Graph embeddings in Neo4j,
everyone wants them. So what we have right
now are there are two implementations that were
created by Neo4j labs. A prototype implementation
of DeepWalk and DeepGL. I didn't really talk about DeepGL today. DeepGL is kind of more
similar to those handcrafted embeddings we talked
about in the beginning. It uses graph algorithms
to generate the features. And then you kind of
have a diffusion process of values across edges and some simple dimensionality reduction. These are really proof of concept. So it's can we create an embedding and what does it look like? Neither one is really
ready for production use but we've learned a lot. The first thing we have
learned is from my first day on the job I was getting
emails of help make this scalable, I want to use this. These prototype implementations
were really put together to say can we do this and
what might it look like? So they're not tuned for performance. They're quite memory intensive. And also the limitation
here is that Deep Learning isn't easy off the shelf in java. So an alternate approach and
one thing that we've seen, I know we have a graph
hack entrant with PYMBO. I know there's a talk from another company who's using embedding
also going on at nodes. You can use Python to pull
your data out from Neo4j and use an off the shelf
library like PyTorch or Keras or GenSem to train in embedding. We found from our experimentation there that it's really good to get started. It's an easy way of pulling my data out, using something off the shelf. But often it doesn't perform at scale due to IO limitations
where you're pulling data in and out of your graph. Or these aren't really
optimized to take use of all of the stuff that
we have inside Neo4j. Probably the question you're
having is what's next? What I can say is that my
team is actively looking at the best ways to
implement graph embeddings at scale and please stay tuned. So this is a forward looking talk. I would say they are
clearly on our road map, we're developing them. And we're excited to
hear what people want. So with this I'm going to stop and pause for the Hunger Games
questions and let everyone take some time to answer
those while I pull up the Q and A. I think for this one I don't
have Jennifer on the line so I will have to manage. - [Jennifer] I am actually here. - Oh you're here so maybe I
can leave the Hunger Games up while you do Q and A? - Yes.
- Awesome. - [Jennifer] So someone
asked are the embeddings of any use in climate analysis? - I mean generically
speaking climate analysis would use a lot of high
dimensional data, right? You have spatial data, image
data, weather patterns, time series. So an embedding is a
way of representing that in a lower dimensional space. If you're talking about
specifically something like DeepWalk it would
depend on the question you were trying to answer and kind of how you formulated that as a graph. So generically speaking
embeddings are useful for almost anything where you
have something complicated. Whether or not it's useful in the context of a graph embedding depends on kind of your data architecture. - [Jennifer] Okay, someone
asked about dask to paralyze? I don't know. - I don't know what that
is, but I would be happy to connect afterwards. - [Jennifer] Okay,
another one asked any ETA on production ready release? - I knew I would get that question. I would say look at six months
to one year time horizon. - [Jennifer] Okay then when you
mentioned Python performance is it Python with Neo4j or
general Python ML tools? - It depends, you can use
libraries that are let's say, GPU compute supported, right? So a lot of the Python
libraries are wrappers for C-libraries. They are highly optimized,
it's really the reshaping your data in and out where
they kind of fall over. And no matter where
your data is coming from you're still having to do that IO step. Python itself is incredibly powerful. I've used libraries like
PyTorch and Keras in the past. Not saying never use
Python, I'm saying Python with Neo4j is great to get started. But once you get into the
millions, billions of nodes it doesn't scale. - [Jennifer] And then someone
asked with the embedding performance issue is it
in extracting the data? Would be interesting to talk further. - I would be happy to talk further. It depends on the approach you're taking. Whether it's extracting the data, whether it's with the
example of DeepWalk creating the walks in your graph. You can think about it like
there's a lot of different points where you can optimize. - [Jennifer] Okay and I think
that's all the questions. If there were any that were not answered or were missed, please do
remember you can sign up on the community site
at community.neo4j.com. Feel free to ask your questions there and any one of us will be
happy to take a look at those. I think if that's it's
Alicia, I think we're good. - Awesome, thank you everyone for joining.