{"id":8246,"library":"jraph","title":"Jraph","description":"Jraph (pronounced \"giraffe\") is a lightweight library for building Graph Neural Networks (GNNs) in JAX. It provides a fundamental data structure for graphs (`GraphsTuple`), a set of utilities for working with graphs (e.g., batching, padding, message passing), and a 'zoo' of forkable GNN models. Leveraging JAX's automatic differentiation and XLA compilation, Jraph aims for high performance and flexibility in GNN research. It is currently in a pre-release development state (v0.0.6.dev0) with frequent updates.","status":"active","version":"0.0.6.dev0","language":"en","source_language":"en","source_url":"https://github.com/deepmind/jraph","tags":["Graph Neural Networks","GNN","JAX","Deep Learning","Graphs"],"install":[{"cmd":"pip install jraph","lang":"bash","label":"Install Jraph from PyPI"},{"cmd":"pip install git+git://github.com/deepmind/jraph.git","lang":"bash","label":"Install latest from GitHub"}],"dependencies":[{"reason":"Jraph is built on JAX for high-performance numerical computation and automatic differentiation.","package":"jax","optional":false},{"reason":"JAX's compiled operations require jaxlib for the backend.","package":"jaxlib","optional":false},{"reason":"Required for graph data structures, added as a dependency in v0.0.3.dev0.","package":"frozendict","optional":false}],"imports":[{"symbol":"GraphsTuple","correct":"from jraph import GraphsTuple"},{"note":"Models are typically imported from the `jraph.models` submodule, or often constructed directly from custom update functions.","symbol":"GraphNetwork","correct":"from jraph.models import GraphNetwork"},{"note":"Utilities like `batch`, `unbatch`, `pad_with_graphs` are directly available under the `jraph` namespace.","symbol":"batch","correct":"from jraph import batch"}],"quickstart":{"code":"import jraph\nimport jax.numpy as jnp\n\n# Define node features, 3 nodes, each with a scalar feature\nnode_features = jnp.array([[0.], [1.], [2.]])\n\n# Define edges: 0 -> 1, 1 -> 2\nsenders = jnp.array([0, 1])\nreceivers = jnp.array([1, 2])\n\n# Edge features (optional), 2 edges, each with a scalar feature\nedge_features = jnp.array([[10.], [20.]])\n\n# Global features (optional), 1 graph, with a scalar feature\nglobal_features = jnp.array([[100.]])\n\n# Number of nodes and edges per graph (for a single graph)\nn_node = jnp.array([len(node_features)])\nn_edge = jnp.array([len(senders)])\n\n# Create a GraphsTuple\ngraph = jraph.GraphsTuple(\n    nodes=node_features,\n    edges=edge_features,\n    receivers=receivers,\n    senders=senders,\n    globals=global_features,\n    n_node=n_node,\n    n_edge=n_edge\n)\n\nprint(graph)\nprint(f\"Nodes: {graph.nodes.shape}, Edges: {graph.edges.shape}\")","lang":"python","description":"This quickstart demonstrates how to construct a basic `GraphsTuple` object, which is the core data structure for representing graphs in Jraph. It defines nodes, edges, and optional global features using `jax.numpy` arrays, along with the necessary `senders`, `receivers`, `n_node`, and `n_edge` arrays to describe the graph structure."},"warnings":[{"fix":"Pin Jraph to a specific version (`pip install jraph==0.0.6.dev0`) and regularly review the GitHub changelog for updates when upgrading.","message":"Jraph is currently in early development (0.0.x.dev0 versions), meaning API changes can occur frequently and without adherence to semantic versioning for breaking changes. Code developed with one minor version may not be compatible with the next.","severity":"breaking","affected_versions":"All 0.0.x.dev0 versions"},{"fix":"Familiarize yourself with Haiku or Flax for defining and managing model parameters within your Jraph-based GNNs. See examples for integration patterns.","message":"Jraph focuses on graph data structures and message passing, but it does not manage parameters for graph neural networks. Users need to integrate with external JAX-native neural network libraries like Haiku or Flax for parameter management and model construction.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Avoid using `jraph.unbatch` inside `jax.jit` decorated functions. Unbatch graphs outside of jitted functions, or consider using padding and masking techniques for variable-sized graphs within jitted contexts (e.g., `pad_with_graphs`).","message":"The `jraph.unbatch` utility does not support `jax.jit` compilation because the shapes of the unbatched output graphs are data-dependent, preventing JAX from tracing a static computation graph.","severity":"gotcha","affected_versions":"All versions"},{"fix":"Always account for the `n_node` and `n_edge` properties of the `GraphsTuple` to calculate the correct absolute indices when accessing or manipulating features for a specific graph within a batch. For example, the nodes of the i-th graph start at `sum(graph.n_node[:i])`.","message":"Indexing for nodes and edges within a batched `GraphsTuple` is absolute (cumulative) across all graphs in the batch, rather than relative to each individual graph. This can lead to off-by-one errors or incorrect feature access if not handled carefully.","severity":"gotcha","affected_versions":"All versions"}],"env_vars":null,"last_verified":"2026-04-16T00:00:00.000Z","next_check":"2026-07-15T00:00:00.000Z","problems":[{"fix":"Ensure `jraph` is correctly installed and up-to-date (`pip install --upgrade jraph`). Verify the import statement is `from jraph import GraphsTuple`.","cause":"This usually indicates an outdated `jraph` installation or a confusion with older API patterns. In some environments, `jraph` might be shadowed or incorrectly installed.","error":"AttributeError: module 'jraph' has no attribute 'GraphsTuple'"},{"fix":"Review the parameters used for `jraph.pad_with_graphs` or similar batching utilities. Adjust the padding limits or ensure that the input graphs do not individually exceed the capacity of the padded batch. This is often related to `n_node_total` or `n_edge_total` arguments.","cause":"This error typically occurs during batching or padding operations when trying to create a batch of graphs where a single graph exceeds the maximum allowed size configured for the batch (e.g., during `pad_with_graphs`).","error":"ValueError: Found graph bigger than batch size. Valid Batch Size: {...}, Graph Size: {...}"},{"fix":"Ensure that if you are using dictionary features, the functions processing these features are designed to handle `ArrayTrees` (e.g., using `jax.tree_util` functions). For simpler cases, convert dictionary features to concatenated arrays if only a single feature vector is needed per node/edge/graph for a given operation.","cause":"While `GraphsTuple` supports `ArrayTrees` (including nested dictionaries) for `nodes`, `edges`, and `globals`, some internal Jraph functions or external JAX operations might expect flat arrays or specific structures. Passing complex dictionary structures where a flat array is expected can cause errors.","error":"TypeError: 'dict' object is not callable (or similar error when passing dicts as features)"}]}