{"id":9708,"library":"drjax","title":"DrJAX","description":"DrJAX is a Python library built on JAX that provides scalable and differentiable MapReduce primitives. It enables users to express complex computations over distributed data in a functional, JAX-compatible manner, allowing for automatic differentiation and XLA compilation across various hardware accelerators. The current version is 0.1.4, and it sees regular, minor updates focusing on stability and JAX compatibility.","status":"active","version":"0.1.4","language":"en","source_language":"en","source_url":"https://github.com/google/drjax","tags":["jax","mapreduce","distributed-computing","machine-learning","differentiation","google"],"install":[{"cmd":"pip install drjax","lang":"bash","label":"Install stable release"}],"dependencies":[{"reason":"Core dependency for numerical computation and automatic differentiation. Specific JAX versions are required for compatibility.","package":"jax","optional":false},{"reason":"Google's Python Common Libraries, used for various utilities.","package":"absl-py","optional":false},{"reason":"Collection of utilities for JAX, often used in JAX ecosystem libraries for testing and assertions.","package":"chex","optional":false}],"imports":[{"note":"The primary MapReduce primitive.","wrong":"import drjax.map_reduce","symbol":"map_reduce","correct":"from drjax import map_reduce"},{"note":"A simplified map primitive.","symbol":"map","correct":"from drjax import map"},{"note":"A primitive for broadcasting data across devices/shards.","symbol":"broadcast","correct":"from drjax import broadcast"}],"quickstart":{"code":"import jax\nimport jax.numpy as jnp\nfrom drjax import map_reduce\n\n# Define the 'map' function: square an item\ndef mapper_fn(item):\n    return item ** 2\n\n# Define the 'reduce' function: sum two items\ndef reducer_fn(a, b):\n    return a + b\n\n# Generate some input data across 'shards'\n# For a single device, this is just a batch dimension\nnum_shards = 4\nnum_items_per_shard = 10\ndata = jnp.arange(num_shards * num_items_per_shard).reshape(num_shards, num_items_per_shard)\n\n# Use map_reduce to apply mapper_fn to each item, then reducer_fn across results\nresult = map_reduce(\n    mapper_fn=mapper_fn,\n    reducer_fn=reducer_fn,\n    inputs=data\n)\n\nprint(f\"Input data shape: {data.shape}\")\nprint(f\"Input data (first shard): {data[0]}\")\nprint(f\"Result: {result}\")\n# Expected result: jnp.sum(jnp.arange(40)**2) == 21090","lang":"python","description":"This example demonstrates how to use the core `drjax.map_reduce` primitive to apply a `mapper_fn` to each element in a batched input and then combine the results using a `reducer_fn`. This mimics a distributed MapReduce pattern, even when run on a single device, showcasing its functional interface."},"warnings":[{"fix":"Ensure JAX is installed correctly for your target hardware by following the official JAX installation guide. For CUDA, this often involves `pip install jax[cuda11_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`.","message":"DrJAX heavily relies on a correctly configured JAX environment. Issues with JAX installation, especially for specific hardware (GPU/TPU), will manifest as errors in DrJAX.","severity":"gotcha","affected_versions":"All versions"},{"fix":"When using `drjax.broadcast` in a distributed JAX setup (e.g., with `jax.experimental.pjit`), explicitly define and pass a `jax.sharding.Mesh` object if you need fine-grained control over data placement on your devices. Review the documentation for `drjax.broadcast` and JAX sharding.","message":"As of v0.1.4, `drjax.broadcast` can now accept a `mesh` argument for explicit sharding control. While not strictly a breaking change for existing code, ignoring or misconfiguring sharding in distributed setups can lead to unexpected data placement or performance bottlenecks.","severity":"gotcha","affected_versions":">=0.1.4"},{"fix":"If experiencing dependency conflicts, ensure your `absl-py` version is `>=1.2.0` and `chex` is `>=0.1.5`. Consider updating your virtual environment and all related JAX ecosystem libraries to their latest compatible versions.","message":"Dependency pinning for `absl-py` and `chex` changed from 'compatible release' (`~=`) to 'minimum version' (`>=`) in v0.1.2. Users with strict, older dependency pins might encounter conflicts during installation.","severity":"gotcha","affected_versions":">=0.1.2"}],"env_vars":null,"last_verified":"2026-04-17T00:00:00.000Z","next_check":"2026-07-16T00:00:00.000Z","problems":[{"fix":"First, ensure `drjax` is installed: `pip install drjax`. Then, use the correct import statement: `from drjax import map_reduce`.","cause":"The `drjax` library is either not installed, or the import path is incorrect. `map_reduce` is a top-level symbol.","error":"ImportError: cannot import name 'map_reduce' from 'drjax'"},{"fix":"Ensure your input data is a JAX array (or numpy array) with at least one dimension. For example, pass `jnp.array([1, 2, 3])` instead of `1` for a single item, or reshape if necessary.","cause":"The `inputs` argument to `map_reduce` is expected to be an array with at least one dimension that can be mapped over. A scalar or an array that is not shaped appropriately will raise this error.","error":"ValueError: Inputs to `map_reduce` must have a batch dimension."},{"fix":"Verify that `mapper_fn` and `reducer_fn` are actual Python functions or JAX-compatible callables, not data structures. E.g., `mapper_fn=my_function_name`, not `mapper_fn=my_array`.","cause":"This error often occurs when you accidentally pass an array (like `jnp.array([1, 2, 3])`) instead of a function as `mapper_fn` or `reducer_fn` to `map_reduce`.","error":"TypeError: 'numpy.ndarray' object is not callable (or similar for 'jax.Array')"}]}