dm-tree: Nested Data Structure Utilities
dm-tree (DeepMind Tree) is a lightweight Python library designed for working with nested data structures such as lists, tuples, and dictionaries. It provides functional tools like `map_structure`, `flatten`, and `unflatten` to apply operations across arbitrary tree-like data. The current stable version is 0.1.10, and it follows an infrequent release cadence focused on stability for its core functionalities.
Warnings
- gotcha The PyPI package `dm-tree` should be imported as `import tree`, not `import dm_tree`. This is a common source of 'ModuleNotFoundError'.
- gotcha dm-tree's core functions (e.g., `map_structure`) define 'nested structures' (nodes) as `dict`, `list`, `tuple`, `namedtuple`, and `collections.OrderedDict`. All other types are considered 'leaves'. This can be unexpected for custom objects that you might consider iterable or 'tree-like' but don't fall into these categories.
- breaking dm-tree requires Python 3.10 or newer. Attempting to install or run on older Python versions will fail or result in compatibility errors.
- gotcha `tree.map_structure` is strict about structure matching. All arguments to `map_structure` must have the same nested structure. If the structures differ (e.g., one has a list of 3 items, another a list of 2), it will raise a `ValueError`.
Install
-
pip install dm-tree
Imports
- tree
import tree
- map_structure
from tree import map_structure
- flatten
from tree import flatten
- unflatten
from tree import unflatten
Quickstart
import tree
# Define a nested data structure
data_tree = {
'a': [1, 2],
'b': {'c': 3, 'd': (4, 5)},
'e': 6
}
# 1. Map a function over all 'leaves' in the structure
def increment(x):
return x + 1
mapped_tree = tree.map_structure(increment, data_tree)
print(f"Mapped tree: {mapped_tree}")
# Expected: {'a': [2, 3], 'b': {'c': 4, 'd': (5, 6)}, 'e': 7}
# 2. Flatten the structure into a list of leaves and a 'structure' object
leaves, structure = tree.flatten(data_tree)
print(f"Flattened leaves: {leaves}")
print(f"Original structure (abstracted): {structure}")
# Expected: Flattened leaves: [1, 2, 3, 4, 5, 6]
# 3. Unflatten the leaves back into the original structure
new_leaves = [x * 10 for x in leaves]
unflattened_tree = tree.unflatten(structure, new_leaves)
print(f"Unflattened tree: {unflattened_tree}")
# Expected: {'a': [10, 20], 'b': {'c': 30, 'd': (40, 50)}, 'e': 60}