Optimized PyTree Utilities
OpTree is an optimized Python library for working with PyTrees, which are arbitrarily nested Python containers. It provides efficient utilities for flattening, unflattening, and mapping functions over tree structures. The current version is 0.19.0, and the library maintains an active development cycle with frequent releases.
Common errors
-
ImportError: cannot import name 'NamedTuple' from 'typing_extensions'
cause This error typically occurs when the `typing_extensions` package installed in your environment is an outdated version that does not provide the `NamedTuple` type, which `optree` or its dependencies (like Keras) require.fixUpgrade `typing_extensions` to a recent version: `pip install --upgrade typing-extensions` -
ModuleNotFoundError: No module named 'optree'
cause The `optree` library is not installed in your current Python environment.fixInstall `optree` using pip: `pip install optree` -
AttributeError: module 'optree' has no attribute 'dict_insertion_ordered'
cause This error indicates that the installed `optree` version is too old and lacks features like `dict_insertion_ordered`, which newer versions of dependent libraries (e.g., PyTorch) might expect.fixUpgrade `optree` to the latest version: `pip install --upgrade optree` -
RecursionError: Maximum recursion depth exceeded during flattening the tree.
cause This usually happens when a custom PyTree node's `__tree_flatten__` (or `tree_flatten`) method, or the `is_leaf` predicate, is implemented in a way that leads to infinite recursion during tree traversal, without a proper termination condition for subtrees.fixReview the `__tree_flatten__` (or `tree_flatten`) implementation for custom PyTree nodes or the `is_leaf` function to ensure that children are correctly identified and that the recursion terminates for leaf nodes. -
TypeError: '<' not supported between instances of 'int' and 'str'
cause This error often occurs when `optree` (which sorts dictionary keys by default for deterministic flattening) encounters a dictionary with heterogeneous keys (e.g., a mix of integers and strings) that cannot be directly compared.fixEnsure dictionary keys are of a comparable type or use `collections.OrderedDict` if insertion order is important and keys are heterogeneous. Alternatively, provide a custom `is_leaf` function if specific keys should be treated as leaves without sorting.
Warnings
- breaking Python 3.7 support was dropped in optree v0.14.0. Users on Python 3.7 must upgrade their Python version or use optree versions prior to 0.14.0.
- breaking Deprecated key path APIs and `optree.Partial` were removed in optree v0.15.0. Any code relying on these older APIs will break.
- gotcha When registering a custom PyTree node type using `optree.register_pytree_node` or `optree.register_pytree_node_class`, a `namespace` argument is explicitly required. This prevents accidental collisions between different libraries registering the same type with different behaviors in the same Python interpreter.
- gotcha By default, `None` is treated as a non-leaf node with zero children. This means it's part of the tree structure (`treespec`), not the list of leaves. To treat `None` as a leaf node, you must explicitly pass `none_is_leaf=True` to functions like `tree_flatten` or `tree_map`.
- gotcha Custom `flatten_func` implementations for `register_pytree_node` must include a proper termination condition to prevent infinite recursion, especially if the children can be of the same type as the current node. This can lead to a `RecursionError`.
- breaking Module naming conventions changed in v0.16.0, affecting direct imports. Specifically, `optree.accessor` became `optree.accessors`, `optree.integration` became `optree.integrations`, etc.
- gotcha Building `optree` from source requires C++ build tools (like `g++` or `clang++`) and `cmake` in the environment. This issue commonly arises when installing on minimal distributions (e.g., Alpine Linux) or when pre-built wheels are not available for the specific Python version/platform, forcing a source build.
Install
-
pip install optree
Imports
- tree_flatten
from optree import tree_flatten
- tree_unflatten
from optree import tree_unflatten
- tree_map
from optree import tree_map
- register_pytree_node
from optree import register_pytree_node
- pytree
from optree import pytree
import optree.pytree as pt
Quickstart
from optree import tree_map
def add_one(x):
return x + 1
tree = {'a': 1, 'b': [2, 3], 'c': {'d': 4}}
mapped_tree = tree_map(add_one, tree)
print(mapped_tree)
from optree import tree_flatten, tree_unflatten, PyTreeSpec
leaves, treespec = tree_flatten(tree)
print(f"Leaves: {leaves}")
print(f"TreeSpec: {treespec}")
reconstructed_tree = tree_unflatten(treespec, [l * 10 for l in leaves])
print(f"Reconstructed tree with modified leaves: {reconstructed_tree}")