jax-dataclasses
raw JSON → 1.6.3 verified Fri May 01 auth: no python
A library that provides a dataclass-like decorator for use with JAX, enabling mutable-style syntax with functional transformations, static fields, and support for pytree nodes. Current version is 1.6.3, actively maintained, with releases every few months.
pip install jax-dataclasses Common errors
error AttributeError: module 'jax_dataclasses' has no attribute 'jdc' ↓
cause Attempting to import jdc as a submodule: from jax_dataclasses import jdc
fix
Use import jax_dataclasses as jdc
error TypeError: replace() got an unexpected keyword argument 'a' ↓
cause Using jdc.replace on a class not decorated with @jdc.pytree_dataclass (maybe used standard dataclass).
fix
Ensure the class is decorated with @jdc.pytree_dataclass.
Warnings
deprecated The shape / datatype annotation API (e.g., @jdc.pytree_dataclass(shape_dtype=...)) is deprecated since v1.6.0. Use Static[] annotations instead. ↓
fix Replace shape_dtype annotations with jdc.Static[] for static fields.
gotcha Do not use standard Python dataclass decorator (from dataclasses import dataclass) on a class with JAX arrays; it will break pytree registration. Always use @jdc.pytree_dataclass. ↓
fix Use @jdc.pytree_dataclass or @jdc.pytree_dataclass(frozen=True).
breaking In v1.6.2, Python 3.8 support was dropped. Requires Python >=3.9. ↓
fix Ensure Python >= 3.9.
gotcha Static field annotations must use jdc.Static[] (e.g., a: jdc.Static[int]) to be properly handled; using typing.ClassVar may not work correctly. ↓
fix Use jdc.Static[type] for static fields.
Imports
- jdc wrong
from jax_dataclasses import jdccorrectimport jax_dataclasses as jdc - jdc.jit wrong
from jax_dataclasses import jitcorrectimport jax_dataclasses as jdc; @jdc.jit - Static
from jax_dataclasses import Static
Quickstart
import jax
import jax_dataclasses as jdc
@jdc.pytree_dataclass
class MyModel:
a: jax.Array
b: jax.Array
model = MyModel(a=jax.numpy.array(1.0), b=jax.numpy.array(2.0))
# Functional update
new_model = jdc.replace(model, a=jax.numpy.array(3.0))
print(new_model.a, new_model.b) # 3.0, 2.0