Optimized PyTree Utilities

0.19.0 · active · verified Mon Apr 06

OpTree is an optimized Python library for working with PyTrees, which are arbitrarily nested Python containers. It provides efficient utilities for flattening, unflattening, and mapping functions over tree structures. The current version is 0.19.0, and the library maintains an active development cycle with frequent releases.

Warnings

Install

Imports

Quickstart

This quickstart demonstrates the core `tree_map` function to apply a transformation to all leaves of a PyTree, and also shows how to flatten a PyTree into leaves and a structure (`treespec`) and then unflatten it back, potentially with modified leaves.

from optree import tree_map

def add_one(x):
    return x + 1

tree = {'a': 1, 'b': [2, 3], 'c': {'d': 4}}
mapped_tree = tree_map(add_one, tree)
print(mapped_tree)

from optree import tree_flatten, tree_unflatten, PyTreeSpec

leaves, treespec = tree_flatten(tree)
print(f"Leaves: {leaves}")
print(f"TreeSpec: {treespec}")

reconstructed_tree = tree_unflatten(treespec, [l * 10 for l in leaves])
print(f"Reconstructed tree with modified leaves: {reconstructed_tree}")

view raw JSON →