Optimized PyTree Utilities
OpTree is an optimized Python library for working with PyTrees, which are arbitrarily nested Python containers. It provides efficient utilities for flattening, unflattening, and mapping functions over tree structures. The current version is 0.19.0, and the library maintains an active development cycle with frequent releases.
Warnings
- breaking Python 3.7 support was dropped in optree v0.14.0. Users on Python 3.7 must upgrade their Python version or use optree versions prior to 0.14.0.
- breaking Deprecated key path APIs and `optree.Partial` were removed in optree v0.15.0. Any code relying on these older APIs will break.
- gotcha When registering a custom PyTree node type using `optree.register_pytree_node` or `optree.register_pytree_node_class`, a `namespace` argument is explicitly required. This prevents accidental collisions between different libraries registering the same type with different behaviors in the same Python interpreter.
- gotcha By default, `None` is treated as a non-leaf node with zero children. This means it's part of the tree structure (`treespec`), not the list of leaves. To treat `None` as a leaf node, you must explicitly pass `none_is_leaf=True` to functions like `tree_flatten` or `tree_map`.
- gotcha Custom `flatten_func` implementations for `register_pytree_node` must include a proper termination condition to prevent infinite recursion, especially if the children can be of the same type as the current node. This can lead to a `RecursionError`.
- breaking Module naming conventions changed in v0.16.0, affecting direct imports. Specifically, `optree.accessor` became `optree.accessors`, `optree.integration` became `optree.integrations`, etc.
Install
-
pip install optree
Imports
- tree_flatten
from optree import tree_flatten
- tree_unflatten
from optree import tree_unflatten
- tree_map
from optree import tree_map
- register_pytree_node
from optree import register_pytree_node
- pytree
import optree.pytree as pt
Quickstart
from optree import tree_map
def add_one(x):
return x + 1
tree = {'a': 1, 'b': [2, 3], 'c': {'d': 4}}
mapped_tree = tree_map(add_one, tree)
print(mapped_tree)
from optree import tree_flatten, tree_unflatten, PyTreeSpec
leaves, treespec = tree_flatten(tree)
print(f"Leaves: {leaves}")
print(f"TreeSpec: {treespec}")
reconstructed_tree = tree_unflatten(treespec, [l * 10 for l in leaves])
print(f"Reconstructed tree with modified leaves: {reconstructed_tree}")