Jraph
Jraph (pronounced "giraffe") is a lightweight library for building Graph Neural Networks (GNNs) in JAX. It provides a fundamental data structure for graphs (`GraphsTuple`), a set of utilities for working with graphs (e.g., batching, padding, message passing), and a 'zoo' of forkable GNN models. Leveraging JAX's automatic differentiation and XLA compilation, Jraph aims for high performance and flexibility in GNN research. It is currently in a pre-release development state (v0.0.6.dev0) with frequent updates.
Common errors
-
AttributeError: module 'jraph' has no attribute 'GraphsTuple'
cause This usually indicates an outdated `jraph` installation or a confusion with older API patterns. In some environments, `jraph` might be shadowed or incorrectly installed.fixEnsure `jraph` is correctly installed and up-to-date (`pip install --upgrade jraph`). Verify the import statement is `from jraph import GraphsTuple`. -
ValueError: Found graph bigger than batch size. Valid Batch Size: {...}, Graph Size: {...}cause This error typically occurs during batching or padding operations when trying to create a batch of graphs where a single graph exceeds the maximum allowed size configured for the batch (e.g., during `pad_with_graphs`).fixReview the parameters used for `jraph.pad_with_graphs` or similar batching utilities. Adjust the padding limits or ensure that the input graphs do not individually exceed the capacity of the padded batch. This is often related to `n_node_total` or `n_edge_total` arguments. -
TypeError: 'dict' object is not callable (or similar error when passing dicts as features)
cause While `GraphsTuple` supports `ArrayTrees` (including nested dictionaries) for `nodes`, `edges`, and `globals`, some internal Jraph functions or external JAX operations might expect flat arrays or specific structures. Passing complex dictionary structures where a flat array is expected can cause errors.fixEnsure that if you are using dictionary features, the functions processing these features are designed to handle `ArrayTrees` (e.g., using `jax.tree_util` functions). For simpler cases, convert dictionary features to concatenated arrays if only a single feature vector is needed per node/edge/graph for a given operation.
Warnings
- breaking Jraph is currently in early development (0.0.x.dev0 versions), meaning API changes can occur frequently and without adherence to semantic versioning for breaking changes. Code developed with one minor version may not be compatible with the next.
- gotcha Jraph focuses on graph data structures and message passing, but it does not manage parameters for graph neural networks. Users need to integrate with external JAX-native neural network libraries like Haiku or Flax for parameter management and model construction.
- gotcha The `jraph.unbatch` utility does not support `jax.jit` compilation because the shapes of the unbatched output graphs are data-dependent, preventing JAX from tracing a static computation graph.
- gotcha Indexing for nodes and edges within a batched `GraphsTuple` is absolute (cumulative) across all graphs in the batch, rather than relative to each individual graph. This can lead to off-by-one errors or incorrect feature access if not handled carefully.
Install
-
pip install jraph -
pip install git+git://github.com/deepmind/jraph.git
Imports
- GraphsTuple
from jraph import GraphsTuple
- GraphNetwork
from jraph.models import GraphNetwork
- batch
from jraph import batch
Quickstart
import jraph
import jax.numpy as jnp
# Define node features, 3 nodes, each with a scalar feature
node_features = jnp.array([[0.], [1.], [2.]])
# Define edges: 0 -> 1, 1 -> 2
senders = jnp.array([0, 1])
receivers = jnp.array([1, 2])
# Edge features (optional), 2 edges, each with a scalar feature
edge_features = jnp.array([[10.], [20.]])
# Global features (optional), 1 graph, with a scalar feature
global_features = jnp.array([[100.]])
# Number of nodes and edges per graph (for a single graph)
n_node = jnp.array([len(node_features)])
n_edge = jnp.array([len(senders)])
# Create a GraphsTuple
graph = jraph.GraphsTuple(
nodes=node_features,
edges=edge_features,
receivers=receivers,
senders=senders,
globals=global_features,
n_node=n_node,
n_edge=n_edge
)
print(graph)
print(f"Nodes: {graph.nodes.shape}, Edges: {graph.edges.shape}")