Mode

Key idea

Every node updates itself by listening to its neighbours. Like a social network of nodes passing notes — each round, every node reads what its neighbours just said, blends it with its own state, and updates. Do this K times and even far-away nodes have rippled in.

Task Each node updates itself by aggregating messages from its neighbours; repeat for K steps and every node carries information from its K-hop neighbourhood. Pick a source, then watch how a node's representation evolves as messages propagate.
Click any node to make it the source · then "Step" or "Auto" to propagate
Round 0

Why graph structure can't just be flattened

A graph has no canonical ordering — relabel the nodes and it's the same graph. Flatten it into a vector and you've baked in a fake order; train an MLP on that and the model will treat node 3 differently from node 7 even when they play identical structural roles. On top of that, nodes have variable numbers of neighbours — atom 1 might bond to two others, atom 2 to four — so you can't even use a fixed-size input vector. GNNs solve both: the aggregation step is permutation-invariant (sum / mean / max don't care about order), and it works on any neighbourhood size.

How message passing works intuitively

Picture each node holding a small vector — its current "state". One round of message passing is:

  1. Every node sends its state along each outgoing edge.
  2. Every node collects everything it received and combines it (sum, mean, max, or learned attention).
  3. Every node mixes that aggregate with its old state to produce a new state.

Repeat K times. After one round, a node knows about its immediate neighbours. After two, its neighbours' neighbours. After K, its entire K-hop neighbourhood. Click a node on the edge of the figure and count how many rounds the signal needs to cross the graph.

Why depth matters — but also hurts

More layers means longer-range information, which sounds good. The catch is over-smoothing: keep averaging neighbours into yourself and after enough rounds, every node converges to the same blurry mean of the whole graph. You've lost the very distinctions you were trying to learn. That's why most GNNs use only 2–4 layers, with skip connections or jumping-knowledge tricks if you need more.

Node, edge, or graph-level tasks

The same backbone serves three task families:

  • Node-level — one prediction per node (which user will churn? what role does this protein play?).
  • Edge-level — predict whether an edge should exist or what kind it is (link prediction, knowledge-graph completion).
  • Graph-level — one prediction for the whole graph, formed by pooling all node embeddings into a single vector (will this molecule bind to the target?).

Reach for it when

  • The edges carry meaning — bonds in a molecule, friendships, citations, road segments
  • Molecular property prediction or drug discovery
  • Recommendation on bipartite user-item graphs
  • Knowledge-graph completion or reasoning over relations

Skip it when

  • Your data isn't naturally a graph (don't force it)
  • The graph is huge but the structure is incidental — try a baseline that ignores it first
  • You only have node features and no informative edges — an MLP will do
  • Long-range dependencies dominate — a graph transformer may beat a plain GNN

import torch.nn as nn
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_dim, hidden, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)
Want the message-passing math?

Message passing

$$ \mathbf{h}_v^{(\ell+1)} \;=\; \phi\!\left(\mathbf{h}_v^{(\ell)},\; \bigoplus_{u \in \mathcal{N}(v)} \psi(\mathbf{h}_u^{(\ell)}, \mathbf{h}_v^{(\ell)}, \mathbf{e}_{uv})\right) $$

  • hv(ℓ)feature vector of node v at layer
  • 𝒩(v)neighbours of node v
  • ψ, φlearned message and update functions (usually small MLPs)
  • permutation-invariant aggregation — sum, mean, max, or attention
  • euv(optional) feature vector on the edge from u to v

$$ \text{new features of } v \;=\; \text{update}\!\left(\text{old features of } v,\; \text{aggregate}\big[\,\text{message}(u, v, \text{edge})\,\big]_{u \in \text{neighbours}(v)}\right) $$

In words. Every node carries a feature vector. To update node v, do three things in order. First, for each neighbour u, compute a message — a small learned function of u's features, v's features, and the edge between them. Second, aggregate those messages with a permutation-invariant combiner (usually sum, mean, or max) so the result doesn't depend on which order the neighbours showed up in. Third, combine the aggregate with v's previous features using another learned function (the update). Stack a few layers and information ripples outward. The funny symbol is just a placeholder for whichever order-free aggregator you chose.

The aggregate-then-update recipe. Most GNN papers boil down to picking a message function, an aggregator, and an update function. Everything is differentiable end-to-end, so you train the whole stack with the loss of your downstream task (node classification, link prediction, graph regression).

GCN vs. GraphSAGE vs. GAT. Three classic choices, each making a different trade-off:

  • GCN (Kipf & Welling, 2017) — a single linear transform plus a degree-normalized sum over neighbours. Cheapest baseline; works surprisingly often.
  • GraphSAGE (Hamilton et al., 2017) — samples a fixed number of neighbours per node so you can train on huge graphs without loading the whole adjacency. Aggregator is configurable (mean, LSTM, pool).
  • GAT (Veličković et al., 2018) — replaces the fixed neighbour weights with attention computed on each edge, so the model learns which neighbours matter for which node. Multi-head attention stabilises training.

Node, edge, and graph features. Edges can carry their own learned features (bond type, friendship strength), which the message function takes as a third argument. Graph-level tasks need a readout / pooling step that collapses all node embeddings into one vector — mean and sum pooling are the workhorses; attention pooling and hierarchical pooling (DiffPool) help on harder tasks.

Batching graphs. Unlike images, graphs in a batch have different node counts and edge counts. Frameworks like PyTorch Geometric handle this by concatenating all graphs into one big disconnected supergraph and keeping a batch vector that records which node belongs to which graph. Pooling layers read that vector to keep predictions per-graph.

Reach for it when

  • Molecule or protein property prediction with explicit chemical structure
  • Recommendation systems with user-item bipartite graphs
  • Citation / co-authorship / social-network analysis
  • Traffic and road-network forecasting where topology matters

Skip it when

  • A graph-blind baseline already wins — never deploy a GNN you didn't sanity-check against an MLP
  • Very long-range dependencies — vanilla message passing over-smooths; reach for a graph transformer
  • The graph changes faster than you can train — use streaming or incremental methods
  • You need calibrated uncertainty — plain GNNs are over-confident, reach for a Bayesian or ensemble variant

import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, global_mean_pool

class MolPropertyGNN(torch.nn.Module):
    def __init__(self, in_dim, hidden, out_dim, heads=4):
        super().__init__()
        self.conv1 = GATv2Conv(in_dim, hidden, heads=heads)
        self.conv2 = GATv2Conv(hidden * heads, hidden, heads=1)
        self.head  = torch.nn.Linear(hidden, out_dim)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        # Pool per graph (batch holds graph index for each node)
        x = global_mean_pool(x, batch)
        return self.head(x)
Want the Weisfeiler-Leman ceiling, graph transformers, and how the field actually scales?

WL-1 expressiveness ceiling

$$ \text{standard message-passing GNNs} \;\preceq\; \text{1-WL test}\;<\; \text{graph isomorphism} $$

  • Any standard message-passing GNN can only distinguish graphs that the 1-dimensional Weisfeiler-Leman colour-refinement test can distinguish
  • Some non-isomorphic graphs share WL-1 colourings, so no standard GNN can tell them apart
  • GIN (Xu et al., 2019) — sum aggregation + an MLP — achieves the WL-1 bound; weaker aggregators (mean, max) fall below it

$$ \text{power of standard GNNs} \;\le\; \text{power of the 1-WL test} \;<\; \text{telling all non-identical graphs apart} $$

In words. There's a classical algorithm called the 1-WL test (1-dimensional Weisfeiler-Leman) that iteratively recolours each node by hashing the multiset of its neighbours' colours. It's a fast heuristic for telling whether two graphs are the same up to relabelling ("isomorphic"). Xu et al. (2019) proved that any standard message-passing GNN is at most as discriminative as this test — and the test itself fails on some genuinely different graph pairs (e.g. two regular graphs with the same degree sequence). So there are graphs no plain GNN can ever tell apart, no matter how deep or wide. The symbol means "no more powerful than".

Over-smoothing & over-squashing. Two failure modes that bite as you stack more layers. Over-smoothing: each round mixes neighbours into self, so after enough rounds every node converges to the same vector — the graph mean. Over-squashing: information from a distant node has to pass through narrow bottlenecks (low-degree cut edges) and gets crushed into a fixed-size message. Modern fixes: skip / jumping-knowledge connections, normalization, graph rewiring (add long-range shortcuts), or switch to attention.

Expressive power & the Weisfeiler-Leman test. Standard message-passing GNNs are capped at WL-1 (above). To break the ceiling you need: higher-order GNNs (operate on k-tuples of nodes — gain expressiveness at O(Nk) cost), subgraph methods (SUN, GNN-AK — extract a local subgraph around each node and process it separately), or equivariant GNNs like EGNN that use 3D coordinates and stay invariant under rotation / translation — essential for molecules and physics.

Graph transformers vs. GNNs. Treat the graph as fully connected and let attention figure out the structure, but inject the graph as an inductive bias — distance encodings, edge encodings, or Laplacian-eigenvector positional encodings (GraphGPS, Graphormer). On small / medium graphs with long-range interactions they tend to beat plain GNNs; on huge sparse graphs the quadratic attention cost is the bottleneck and classical GNNs still win.

Scaling to web-scale graphs. Billions of edges don't fit in memory. The three dominant strategies all sample subgraphs:

  • Neighbour sampling (GraphSAGE) — pick a fixed number of neighbours per node per layer; gradient is unbiased but variance grows with depth.
  • Cluster sampling (Cluster-GCN) — partition the graph into dense clusters, train one cluster per batch.
  • Layer / subgraph sampling (GraphSAINT) — sample a connected induced subgraph and normalize importances so the expected loss matches full-batch.

Real-world wins. AlphaFold 2's evoformer is essentially a graph transformer over residue pairs — message passing on a protein's contact graph is what makes structure prediction work. Pinterest's PinSage uses GraphSAGE-style sampling over 3 billion nodes for recommendation. Uber Eats and Google Maps both use GNNs for ETA and routing. In drug discovery, GNN-based property predictors are now standard for screening molecule libraries before any lab work.

Reach for it when

  • EGNN / equivariant — 3D structure-aware tasks (molecules, proteins, particle physics)
  • Graph transformers — long-range dependencies on small or medium graphs
  • Cluster-GCN / GraphSAGE / SAINT — web-scale graphs that don't fit in memory
  • Higher-order or subgraph methods — when WL-1 isn't enough and you can pay the compute

Skip it when

  • You're chasing the WL hierarchy but the dataset doesn't actually need it
  • Tabular or time-series data dressed up as a graph — use the native model
  • You need bounded inference latency on a constantly growing graph
  • A handful of hand-crafted structural features plus an MLP already solves the task

import torch
import torch.nn as nn
from torch_geometric.utils import scatter

# A bare-bones expressive GNN layer (GIN-style) — implemented from scratch.
# Sum-aggregate neighbours, then MLP — provably matches the WL-1 ceiling.
class GINLayer(nn.Module):
    def __init__(self, d_in, d_out, eps=0.0, learn_eps=True):
        super().__init__()
        self.eps = nn.Parameter(torch.tensor(eps)) if learn_eps else eps
        self.mlp = nn.Sequential(
            nn.Linear(d_in, d_out), nn.ReLU(),
            nn.Linear(d_out, d_out),
        )

    def forward(self, x, edge_index):
        src, dst = edge_index
        agg = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce="sum")
        return self.mlp((1 + self.eps) * x + agg)
Want the picture instead?