GraphWise is a graph neural network (GNN) algorithm based on the popular GraphSAGE paper [1]. In this blog post, we illustrate the general ideas and functionality behind the algorithm. To motivate the post, let's consider some common use cases for graph convolutional networks.
Recommender systems are some of the most essential pieces of software behind online marketplaces; the quality of recommendations often translates very directly into sales metrics.
A performant and effective way of looking at a recommender data-set is to consider the user actions as edges in a graph consisting of users and items. When a user buys an item, you can represent that by adding an edge between the node for that user and the node for the item that the user bought.
At Pinterest [2], a GraphSAGE-based algorithm is used to make all recommendations on their huge, web-scale (18 billion edges!) graph. GraphSAGE is interesting both because it is fast enough for such a large-scale website, and because it can provide very high-quality recommendations by leveraging the information in the graph. This graph convolutional neural network, known as PinSage, consistently outperforms existing, non-graph machine learning models in A/B tests.
There exists thousands of applications in which you would like to detect anomalous behavior in a graph. We've seen examples from large transaction graphs (from the banking sector) all the way to cyber-security, where the graphs represent processes and other components in an operating system. In such cases, the graph structure is often invaluable to the task, and graph neural networks allow you to exploit this data effectively.
While we presented only two examples, these are far from the only ones; graphs are very often the most natural way to represent a wide spectrum of data, from knowledge graphs, to maps, to social graphs (like Facebook), to web networks (e.g. Google's PageRank), and even geometric (3D) objects!
As a running example of graph convolutional networks, we'll consider trying to identify fraudulent actors in a graph of bank transactions.
Let's assume that we're given the following graph, where the vertices are customers of the bank, and that they are connected to each other if a transfer between them has taken place. The following graph G is an example of what such a network could look like.
Identifying fraud in the graph depends heavily on the graph topology, i.e. the structure of the graph. Through GNN algorithms, we will be able to exploit this data to make more accurate predictions.
With a more standard machine learning algorithm, we cannot ingest graph data directly, since they can usually only work with matrices. Therefore, a common approach is to make predictions based on some vertex features, such as the size of the account. These would be concatenated with hand-crafted aggregate features based on the graph and its attributes, such as the average transaction time, or the vertex degree. Thus, the dataflow looks as follows:
The GNN approach replaces the manual feature creation step with graph convolutional layers that can learn features which are suited to the problem. By doing this, we can ingest the graph data directly:
To do this, GraphSAGE defines a forward pass that takes in graph data and outputs a vector, which can then be fed to more standard neural network layers. Interesting features of the algorithm are:
To perform this transformation, the approach proceeds in two stages: sampling and aggregation. To illustrate how it works, we will step through the forward pass for node v of the above example graph.
We will consider a two-layer model, where the first layer samples three neighbors, and the second layer samples two neighbors.
First, a set of neighbors is sampled. Since we have a two layer model, we will need to sample first and second hop neighbors. We will do this randomly here, but in practice, it is often helpful to sample neighbors based on auxiliary metrics like PPR (as in [2]). For the first hop, we sample the nodes {a, h, b}. For each of these nodes, we now need to sample 2 of their neighbors. For a, we sample {f, h}, for h we sample {a, d} and for b we sample {e, g}. Note that, in practice, sampling is often performed with replacement for performance reasons.
The next step is to aggregate information from these nodes. The general idea is that, for each set of sampled neighbors, we will aggregate the node features using a column-wise symmetric aggregation. In practice, commonly used aggregations are mean, max and even LSTM aggregations (even though these are not symmetric). In the following figure, we can see how the features flow through the graph.
As an example, consider that the node features are the monthly net balance and the account age in days. Then, if the nodes "f" and "h" have feature vectors [-$200, 40 days] and [+$300, 100 days], then a mean aggregator for "a" would output the mean vector, i.e. [$50, 70 days].
After each aggregation, we multiply the values from previous layers with a matrix. The weights of this matrix are trainable, and this learnable transformation allows the algorithm to learn suitable feature aggregations.
However, this does not yet fully describe the forward pass of GraphSAGE. A key step that is still missing is the concatenation operation. In the model we just described, the features of the node "v" itself are ignored, even though they are usually quite important for the task! To avoid this, we will concatenate the aggregated features with the node's own features every time the aggregator is used.
To make this easier to understand, it's helpful to formalize the notion of layer representations of nodes. The 0th layer representation of a node is it's input feature vector. Then, the 1st layer representation of a node will be defined as the concatenation of the node's own feature vector with the aggregated features of its 1 hop neighbors.
In general, the ith layer representation of node v (denoted f_i(v)), can be found using
where N(v) is the sampling function that returns a set of neighbors [3].
Finally, the output of the forward pass is the l2 normalized last layer representation of the node "v."
GraphWise is implemented in the PGX package available from version 19.4 onwards. To run a GraphWise model on an input graph, you can use the following java code:
We can also enable more specific parameters for the individual layers. To recreate the forward pass described above, we can use the following code:
The weights of the model can then be trained in an end-to-end manner, such as by feeding the result of the convolutional layers into some dense layers (this is what is done in GraphWise), and training based on a classification loss on the output.
However, this final layer does not have to be a dense layer, and the loss does not have to be a classification loss; since have a forward pass and trainable weights, we can add any loss or output we desire.
We can take the activations of our model after the convolutional layers and consider them as embeddings. These embeddings can then be used in downstream tasks like link prediction, recommender systems, or clustering.
We can also learn the weights of our forward pass in an unsupervised manner using a skip-gram-like loss.
We start by generating context pairs, i.e. nodes that should be close in the embedding space. Similar to algorithms such as DeepWalk [4], we can get these context pairs from random walks on the graph: we perform random walks starting from a node "v," and for every node "u" that we reach on these random walks, we add the context pair (v, u).
With these context pairs, we can apply the following loss
Where s is the sigmoid activation, f is our forward pass and P is a negative sampling distribution. The basic idea is that for a context pair (u, v), this loss enforces that u and v are close in embedding space, and that for some negative sample x, u is far away from x in embedding space.
As mentioned above, GraphSAGE is able to infer embeddings for nodes that were not seen during training. To do this, we can simply pass the nodes through the same forward pass, using the weights that were learned during training.
In this post, we have discussed how GraphWise, a graph neural network algorithm, and its use cases, including high-quality recommender systems and fraud detection. We also walked through how to build a graph neural network to identify fraudulent actors in a graph of bank transactions.
If you'd like to learn more about GraphWise and convolutional neural networks, see the References section for papers on this topic.
To learn more about Oracle AI and machine learning, visit the Oracle AI page.
[1] Inductive Representation Learning on Large Graphs. William L. Hamilton et. al, 2017.
[2] Graph Convolutional Neural Networks for Web-Scale Recommender Systems. Rex Ying et. al, 2018.
[3] We deliberately omit the sampling size for simplicity
[4] DeepWalk: Online Learning of Social Representations. Perozzi et. al, 2014.