{"id":28410,"library":"tree-math","title":"tree-math","description":"Provides mathematical operations (add, multiply, etc.) on JAX pytrees, treating them as vectors. Version 0.2.1, infrequent releases, currently active.","status":"active","version":"0.2.1","language":"python","source_language":"en","source_url":"https://github.com/google/tree-math","tags":["jax","pytree","math","vector"],"install":[{"cmd":"pip install tree-math","lang":"bash","label":"PyPI"}],"dependencies":[{"reason":"core dependency for pytree support","package":"jax","optional":false}],"imports":[{"note":"Star imports can cause name collisions; prefer explicit import of required classes.","wrong":"from tree_math import *","symbol":"tree_math","correct":"import tree_math"},{"note":"Vector is a class, not a submodule.","wrong":"import tree_math.Vector","symbol":"Vector","correct":"from tree_math import Vector"}],"quickstart":{"code":"import jax.numpy as jnp\nfrom tree_math import Vector\n\ntree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])}\nv = Vector(tree)\nprint(v + v)\nprint(2 * v)\nprint(v @ v)","lang":"python","description":"Creates a Vector from a pytree of arrays and performs basic operations."},"warnings":[{"fix":"Ensure all arrays in the pytree have consistent dtypes before constructing a Vector.","message":"All leaf arrays must have the same dtype and be convertible to a single type. Mixing float32 and float64 can cause unexpected type promotion errors.","severity":"gotcha","affected_versions":"all"},{"fix":"Use jax.tree_util.tree_map(lambda x: x.copy(), tree) if you need independent copies.","message":"Vector.__init__ does not perform a deep copy of the pytree; mutations to the original tree may affect the Vector.","severity":"gotcha","affected_versions":"all"}],"env_vars":null,"last_verified":"2026-05-09T00:00:00.000Z","next_check":"2026-08-07T00:00:00.000Z","problems":[{"fix":"Run 'pip install tree-math' in your virtual environment.","cause":"tree-math package not installed.","error":"ModuleNotFoundError: No module named 'tree_math'"},{"fix":"Convert all arrays to the same dtype, e.g., tree = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), tree).","cause":"The pytree contains arrays with different dtypes (e.g., float32 and float64).","error":"ValueError: All leaves must have the same dtype."}],"ecosystem":"pypi","meta_description":null,"install_score":null,"install_tag":null,"quickstart_score":null,"quickstart_tag":null}