Jraph

0.0.6.dev0 · active · verified Thu Apr 16

Jraph (pronounced "giraffe") is a lightweight library for building Graph Neural Networks (GNNs) in JAX. It provides a fundamental data structure for graphs (`GraphsTuple`), a set of utilities for working with graphs (e.g., batching, padding, message passing), and a 'zoo' of forkable GNN models. Leveraging JAX's automatic differentiation and XLA compilation, Jraph aims for high performance and flexibility in GNN research. It is currently in a pre-release development state (v0.0.6.dev0) with frequent updates.

Common errors

Warnings

Install

Imports

Quickstart

This quickstart demonstrates how to construct a basic `GraphsTuple` object, which is the core data structure for representing graphs in Jraph. It defines nodes, edges, and optional global features using `jax.numpy` arrays, along with the necessary `senders`, `receivers`, `n_node`, and `n_edge` arrays to describe the graph structure.

import jraph
import jax.numpy as jnp

# Define node features, 3 nodes, each with a scalar feature
node_features = jnp.array([[0.], [1.], [2.]])

# Define edges: 0 -> 1, 1 -> 2
senders = jnp.array([0, 1])
receivers = jnp.array([1, 2])

# Edge features (optional), 2 edges, each with a scalar feature
edge_features = jnp.array([[10.], [20.]])

# Global features (optional), 1 graph, with a scalar feature
global_features = jnp.array([[100.]])

# Number of nodes and edges per graph (for a single graph)
n_node = jnp.array([len(node_features)])
n_edge = jnp.array([len(senders)])

# Create a GraphsTuple
graph = jraph.GraphsTuple(
    nodes=node_features,
    edges=edge_features,
    receivers=receivers,
    senders=senders,
    globals=global_features,
    n_node=n_node,
    n_edge=n_edge
)

print(graph)
print(f"Nodes: {graph.nodes.shape}, Edges: {graph.edges.shape}")

view raw JSON →