{"id":23911,"library":"jax-dataclasses","title":"jax-dataclasses","description":"A library that provides a dataclass-like decorator for use with JAX, enabling mutable-style syntax with functional transformations, static fields, and support for pytree nodes. Current version is 1.6.3, actively maintained, with releases every few months.","status":"active","version":"1.6.3","language":"python","source_language":"en","source_url":"https://github.com/brentyi/jax_dataclasses","tags":["jax","dataclasses","pytree","functional"],"install":[{"cmd":"pip install jax-dataclasses","lang":"bash","label":"Install from PyPI"}],"dependencies":[{"reason":"Required for pytree registration and JIT compilation.","package":"jax","optional":false},{"reason":"Used for testing; optional for runtime.","package":"chex","optional":true}],"imports":[{"note":"jdc is the module alias, not a submodule. Using 'from ... import jdc' will raise ImportError.","wrong":"from jax_dataclasses import jdc","symbol":"jdc","correct":"import jax_dataclasses as jdc"},{"note":"jit is an attribute of the module, not importable directly. Use jdc.jit.","wrong":"from jax_dataclasses import jit","symbol":"jdc.jit","correct":"import jax_dataclasses as jdc; @jdc.jit"},{"note":"","wrong":"","symbol":"Static","correct":"from jax_dataclasses import Static"}],"quickstart":{"code":"import jax\nimport jax_dataclasses as jdc\n\n@jdc.pytree_dataclass\nclass MyModel:\n    a: jax.Array\n    b: jax.Array\n\nmodel = MyModel(a=jax.numpy.array(1.0), b=jax.numpy.array(2.0))\n\n# Functional update\nnew_model = jdc.replace(model, a=jax.numpy.array(3.0))\nprint(new_model.a, new_model.b)  # 3.0, 2.0","lang":"python","description":"Creates a simple pytree dataclass and demonstrates functional mutation via jdc.replace."},"warnings":[{"fix":"Replace shape_dtype annotations with jdc.Static[] for static fields.","message":"The shape / datatype annotation API (e.g., @jdc.pytree_dataclass(shape_dtype=...)) is deprecated since v1.6.0. Use Static[] annotations instead.","severity":"deprecated","affected_versions":">=1.6.0"},{"fix":"Use @jdc.pytree_dataclass or @jdc.pytree_dataclass(frozen=True).","message":"Do not use standard Python dataclass decorator (from dataclasses import dataclass) on a class with JAX arrays; it will break pytree registration. Always use @jdc.pytree_dataclass.","severity":"gotcha","affected_versions":"all"},{"fix":"Ensure Python >= 3.9.","message":"In v1.6.2, Python 3.8 support was dropped. Requires Python >=3.9.","severity":"breaking","affected_versions":">=1.6.2"},{"fix":"Use jdc.Static[type] for static fields.","message":"Static field annotations must use jdc.Static[] (e.g., a: jdc.Static[int]) to be properly handled; using typing.ClassVar may not work correctly.","severity":"gotcha","affected_versions":">=1.6.0"}],"env_vars":null,"last_verified":"2026-05-01T00:00:00.000Z","next_check":"2026-07-30T00:00:00.000Z","problems":[{"fix":"Use import jax_dataclasses as jdc","cause":"Attempting to import jdc as a submodule: from jax_dataclasses import jdc","error":"AttributeError: module 'jax_dataclasses' has no attribute 'jdc'"},{"fix":"Ensure the class is decorated with @jdc.pytree_dataclass.","cause":"Using jdc.replace on a class not decorated with @jdc.pytree_dataclass (maybe used standard dataclass).","error":"TypeError: replace() got an unexpected keyword argument 'a'"}],"ecosystem":"pypi","meta_description":null,"install_score":null,"install_tag":null,"quickstart_score":null,"quickstart_tag":null}