DrJAX
DrJAX is a Python library built on JAX that provides scalable and differentiable MapReduce primitives. It enables users to express complex computations over distributed data in a functional, JAX-compatible manner, allowing for automatic differentiation and XLA compilation across various hardware accelerators. The current version is 0.1.4, and it sees regular, minor updates focusing on stability and JAX compatibility.
Common errors
-
ImportError: cannot import name 'map_reduce' from 'drjax'
cause The `drjax` library is either not installed, or the import path is incorrect. `map_reduce` is a top-level symbol.fixFirst, ensure `drjax` is installed: `pip install drjax`. Then, use the correct import statement: `from drjax import map_reduce`. -
ValueError: Inputs to `map_reduce` must have a batch dimension.
cause The `inputs` argument to `map_reduce` is expected to be an array with at least one dimension that can be mapped over. A scalar or an array that is not shaped appropriately will raise this error.fixEnsure your input data is a JAX array (or numpy array) with at least one dimension. For example, pass `jnp.array([1, 2, 3])` instead of `1` for a single item, or reshape if necessary. -
TypeError: 'numpy.ndarray' object is not callable (or similar for 'jax.Array')
cause This error often occurs when you accidentally pass an array (like `jnp.array([1, 2, 3])`) instead of a function as `mapper_fn` or `reducer_fn` to `map_reduce`.fixVerify that `mapper_fn` and `reducer_fn` are actual Python functions or JAX-compatible callables, not data structures. E.g., `mapper_fn=my_function_name`, not `mapper_fn=my_array`.
Warnings
- gotcha DrJAX heavily relies on a correctly configured JAX environment. Issues with JAX installation, especially for specific hardware (GPU/TPU), will manifest as errors in DrJAX.
- gotcha As of v0.1.4, `drjax.broadcast` can now accept a `mesh` argument for explicit sharding control. While not strictly a breaking change for existing code, ignoring or misconfiguring sharding in distributed setups can lead to unexpected data placement or performance bottlenecks.
- gotcha Dependency pinning for `absl-py` and `chex` changed from 'compatible release' (`~=`) to 'minimum version' (`>=`) in v0.1.2. Users with strict, older dependency pins might encounter conflicts during installation.
Install
-
pip install drjax
Imports
- map_reduce
import drjax.map_reduce
from drjax import map_reduce
- map
from drjax import map
- broadcast
from drjax import broadcast
Quickstart
import jax
import jax.numpy as jnp
from drjax import map_reduce
# Define the 'map' function: square an item
def mapper_fn(item):
return item ** 2
# Define the 'reduce' function: sum two items
def reducer_fn(a, b):
return a + b
# Generate some input data across 'shards'
# For a single device, this is just a batch dimension
num_shards = 4
num_items_per_shard = 10
data = jnp.arange(num_shards * num_items_per_shard).reshape(num_shards, num_items_per_shard)
# Use map_reduce to apply mapper_fn to each item, then reducer_fn across results
result = map_reduce(
mapper_fn=mapper_fn,
reducer_fn=reducer_fn,
inputs=data
)
print(f"Input data shape: {data.shape}")
print(f"Input data (first shard): {data[0]}")
print(f"Result: {result}")
# Expected result: jnp.sum(jnp.arange(40)**2) == 21090