Publish AI, ML & data-science insights to a global community of data professionals.

NodeFormer: Scalable Graph Transformers for Million Nodes

All-Pair Message Passing with O(N)

Image: Unsplash.
Image: Unsplash.

Recently, building Transformer models for handling graph-structured data has aroused wide interests in the machine learning research community. One critical challenge stems from the quadratic complexity of global attention that hinders Transformers for scaling to large graphs. This blog will briefly introduce a recent work on NeurIPS22:

NodeFormer: A Scalable Graph Structure Learning Transformer for Node Classification with its public implementation available.

This work proposes a scalable graph Transformers for large node classification graphs where the node numbers could vary from thousands to millions (or even more). The key module is a kernelized Gumbel-Softmax-based message passing that achieves all-pair feature propagation within O(N) complexity (N for #nodes).

The following content will summarize the main idea and results of this work.

Graph Transformers v.s. Graph Neural Networks

Unlike graph neural networks (GNNs) that resort to message passing over a fixed input graph topology, graph Transformers can more flexibly aggregate global information from all the nodes through adaptive topology in each propagation layer. Specifically, graph Transformers have several advantages:

  • Handling Imperfect Structures. For graph data with heterophily, long-range dependence and spurious edges, GNNs often shows insufficient power due to its local feature aggregation designs. However, Transformers adopt global attention that aggregates information from all the nodes in one layer, which can overcome the limitations of input structures.
  • Avoiding Over-Squashing. GNNs may exponentially lose the information when aggregating information into a fixed-length vector, while graph Transformers leverage global attentive aggregation that can adaptively attend on dominant nodes that are informative for the target node’s predictive tasks.
  • Flexibility for No-Graph Tasks. Beyond graph problems, there also exist a wide variety of tasks where there is no graph structure. For example, image and text classification (each image can be seen as a node but there is no graph connecting them), and physics simulation (each particle is a node but no explicitly observed graph). While a common practice is to use k-NN over input features to construct a similarity graph for message passing, such an artificially created graph is often independent of downstream predictve tasks and would lead to sub-optimal performance. Transformers resolve this issue by enabling automatically learning adaptive graph structures for message passing.
For node classification, Transformers can aggregate information from all other nodes in one layer. The layer-wise updating rule given by Transformers can be seen as a composition of one-step node embedding updating and graph structure estimation (we can treat the attention matrix as a graph adjacency matrix)
For node classification, Transformers can aggregate information from all other nodes in one layer. The layer-wise updating rule given by Transformers can be seen as a composition of one-step node embedding updating and graph structure estimation (we can treat the attention matrix as a graph adjacency matrix)

Challenges of Building Transformers on Large Graphs

Several challenges make it an non-trivial problem for building Transformers on large graphs, in particular the ones with more than thousands nodes.

  • Quadratic Complexity for Global Attention: The attention computation for all-pair feature aggregation requires O(N²) complexity which is prohibitive for large graphs where N can be arbitrarily large, e.g., from thousands to millions. Concretely speaking, a common GPU with 16GB memory would fail to run such global attention over all nodes if N is more than 10K.
  • Accommendation of Graph Sparsity: Real-world graphs are often sparse in comparison with the attentive graph (we can treat the attention matrix as a weighted graph adjacency matrix) that densely connect all node pairs. The problem is that when N goes large, the feature propagation over such a dense graph may cause what we call over-normalizing issue which means that the information from different nodes is dilluted by others. A plausible remedy to sparsify the learnable structures before the propagation.

Kernelized Gumbel-Softmax-based Message Passing

Our work NodeFormer combines random feature map and Gumbel-Softmax as a unified model for addressing the above-mentioned problems. Specifically, the Gumbel-Softmax is first used to replace the original Softmax-based attentive feature aggregation:

The updating for next-layer node representation using all the node representations at the current layer. The Gumbel-Sofmtax can be seen as the continuous relaxation of sampling one neighboring node from all the nodes for the target node u. In practice, one can sample K times which gives rise to a set of sampled neighboring nodes. The q, k, v are transformed features from node representations
The updating for next-layer node representation using all the node representations at the current layer. The Gumbel-Sofmtax can be seen as the continuous relaxation of sampling one neighboring node from all the nodes for the target node u. In practice, one can sample K times which gives rise to a set of sampled neighboring nodes. The q, k, v are transformed features from node representations

The above equation defines the computation for node u which requires O(N), and to compute the representations for N nodes requires O(N²) complexity since one has to independently compute the all-pair attention scores. To resolve this difficulty, we resort to the main idea in Performer and adopt the random feature map (RFM) to approximate the Gumbel-Softmax (the original adoption of RFM in Performer aims to approximate the deterministic Softmax attention and here we extend such a technique to Gumbel-Softmax with stochastic noise).

The new updating rule using the proposed kernelized Gumbel-Softmax. The derivation from LHS to RHS is according to the basic association rule of matrix product
The new updating rule using the proposed kernelized Gumbel-Softmax. The derivation from LHS to RHS is according to the basic association rule of matrix product

Critically, in the new computation, i.e., the RHS of the above equation, the two summation terms (over N nodes) are shared by all the nodes and can be computed only once in one layer. Therefore, this gives rise to O(N) complexity for updating N nodes in one layer.

How to Leverage of Input Graphs

Another important question is how to make use of input structures (if available) since the above all-pair message passing ignores the input graph. We additionally propose two simple strategies:

  • Adding Relational Bias: we additionally assume a learnable scalar term that is added to the attention score between node u and v if there is an edge between them in the input graph.
  • Edge Regularization Loss: use the attention score for edge (u, v) as an estimated probability and define a maximum likelihood estimation (MLE) loss for all the observed edges. Intuitively, this design maximizes the attention weight for observed edges.

But the importance (or say informativeness) of input graph varies among different datasets. So in practice, one needs to tune the weight (as a hyper-parameter) that determines how much emphasis on input graphs. The following figure shows the overall data flow of NodeFormer.

Data flow of NodeFormer whose inputs contain node features and graph adjacency similar as common GNNs. The red part is the all-pair message passing by kernelized Gumbel-Softmax, the green part is the relational bias and the blue part is for edge regularization loss. The later two components can be omitted if the input graph is unimportant or unavailable.
Data flow of NodeFormer whose inputs contain node features and graph adjacency similar as common GNNs. The red part is the all-pair message passing by kernelized Gumbel-Softmax, the green part is the relational bias and the blue part is for edge regularization loss. The later two components can be omitted if the input graph is unimportant or unavailable.

Experiment Results

We apply NodeFormer to node classification tasks and achieve very competitive results on eight datasets compared to common GNNs and state-of-the-art graph structure learning models LDS and IDGL.

Comparative experiment results for NodeFormer and common GNN models
Comparative experiment results for NodeFormer and common GNN models

Beyond node classification, we also consider image and text classification tasks where input graphs are missing. We use k-NN with different k’s (5, 10, 15, 20) to construct an graph and also consider not using the input graph for NodeFormer. Intriguingly, the later case doen not lead to obvious performance drop and could sometimes bring up better performance.

Visualization of node embeddings and estimated attention scores (filter out the ones with low weights). We mark the nodes of the same label class with one color. The global attention tends to connect nodes within the same class and also increase the global connectivity of the graph
Visualization of node embeddings and estimated attention scores (filter out the ones with low weights). We mark the nodes of the same label class with one color. The global attention tends to connect nodes within the same class and also increase the global connectivity of the graph

As a new model class, we highlight some advantages of NodeFormer:

  • Capacity: NodeFormer adaptively learns graph structure through sparse attentions in each layer and potentially aggregate information from all nodes.
  • Scalability: NodeFormer enables O(N) complexity and mini-batch partition training. In practice, it successfully scales to large graphs with million nodes using only 4GB memory.
  • Efficiency: The training of NodeFormer can be done in an efficient end-to-end manner with gradient-based optimization. For example, training and evaluation on Cora in 1000 epochs only takes 1–2 minutes.
  • Flexibility: NodeFormer is flexible for inductive learning and handling no-graph cases.

We also briefly discuss the potential applications of NodeFormer. In general sense, NodeFormer can be used as a general-purpose encoder for graph-structured data or handling inter-dependence among instances in standard predictive tasks. Specifically, it can be readily applied to two classes of problems:

  • Predictive tasks on (large) graphs, i.e., the goal is to predict each node’s (or edge’s) label based on the node features and graph structures. E.g., predicting user activities in large social networks.
  • Standard predictive tasks (classification or regression) without input graphs and NodeFormer could be used for exploiting the potential inter-dependence among instances in a dataset. E.g., image classification.

This blog introduces a new graph Transformers that successfully scales to large graphs and shows promising performance over common GNNs. Transformer-style models possess some inherent advantages that can overcome the limitations of GNNs regarding handling long-range dependence, heterophily, over-squashing, graph noise and the absence of graphs altogether. Despite this, building powerful graph Transformers that can serve as the next-generation graph representation model is still an open and under-explored problem, and we hope this work could shed some insights on this direction.

If interested in this work, one could read the paper for more details.

paper: https://openreview.net/pdf?id=sMezXGG5So

code: https://github.com/qitianwu/NodeFormer

All images unless otherwise noted are by the author.


Towards Data Science is a community publication. Submit your insights to reach our global audience and earn through the TDS Author Payment Program.

Write for TDS

Related Articles