DrJAX

0.1.4 · active · verified Fri Apr 17

DrJAX is a Python library built on JAX that provides scalable and differentiable MapReduce primitives. It enables users to express complex computations over distributed data in a functional, JAX-compatible manner, allowing for automatic differentiation and XLA compilation across various hardware accelerators. The current version is 0.1.4, and it sees regular, minor updates focusing on stability and JAX compatibility.

Common errors

Warnings

Install

Imports

Quickstart

This example demonstrates how to use the core `drjax.map_reduce` primitive to apply a `mapper_fn` to each element in a batched input and then combine the results using a `reducer_fn`. This mimics a distributed MapReduce pattern, even when run on a single device, showcasing its functional interface.

import jax
import jax.numpy as jnp
from drjax import map_reduce

# Define the 'map' function: square an item
def mapper_fn(item):
    return item ** 2

# Define the 'reduce' function: sum two items
def reducer_fn(a, b):
    return a + b

# Generate some input data across 'shards'
# For a single device, this is just a batch dimension
num_shards = 4
num_items_per_shard = 10
data = jnp.arange(num_shards * num_items_per_shard).reshape(num_shards, num_items_per_shard)

# Use map_reduce to apply mapper_fn to each item, then reducer_fn across results
result = map_reduce(
    mapper_fn=mapper_fn,
    reducer_fn=reducer_fn,
    inputs=data
)

print(f"Input data shape: {data.shape}")
print(f"Input data (first shard): {data[0]}")
print(f"Result: {result}")
# Expected result: jnp.sum(jnp.arange(40)**2) == 21090

view raw JSON →