tree-math

raw JSON →
0.2.1 verified Sat May 09 auth: no python

Provides mathematical operations (add, multiply, etc.) on JAX pytrees, treating them as vectors. Version 0.2.1, infrequent releases, currently active.

pip install tree-math
error ModuleNotFoundError: No module named 'tree_math'
cause tree-math package not installed.
fix
Run 'pip install tree-math' in your virtual environment.
error ValueError: All leaves must have the same dtype.
cause The pytree contains arrays with different dtypes (e.g., float32 and float64).
fix
Convert all arrays to the same dtype, e.g., tree = jax.tree_util.tree_map(lambda x: x.astype(jnp.float32), tree).
gotcha All leaf arrays must have the same dtype and be convertible to a single type. Mixing float32 and float64 can cause unexpected type promotion errors.
fix Ensure all arrays in the pytree have consistent dtypes before constructing a Vector.
gotcha Vector.__init__ does not perform a deep copy of the pytree; mutations to the original tree may affect the Vector.
fix Use jax.tree_util.tree_map(lambda x: x.copy(), tree) if you need independent copies.

Creates a Vector from a pytree of arrays and performs basic operations.

import jax.numpy as jnp
from tree_math import Vector

tree = {'a': jnp.array([1.0, 2.0]), 'b': jnp.array([3.0])}
v = Vector(tree)
print(v + v)
print(2 * v)
print(v @ v)