886 lines
30 KiB
Python
886 lines
30 KiB
Python
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Registry for custom pytree node types."""
|
|
|
|
# pylint: disable=too-many-lines
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import inspect
|
|
import sys
|
|
from collections import OrderedDict, defaultdict, deque, namedtuple
|
|
from operator import itemgetter, methodcaller
|
|
from threading import Lock
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
Collection,
|
|
Generator,
|
|
Generic,
|
|
Iterable,
|
|
NamedTuple,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
import optree._C as _C
|
|
from optree.accessor import (
|
|
AutoEntry,
|
|
MappingEntry,
|
|
NamedTupleEntry,
|
|
PyTreeEntry,
|
|
SequenceEntry,
|
|
StructSequenceEntry,
|
|
)
|
|
from optree.typing import (
|
|
CustomTreeNode,
|
|
PyTreeKind,
|
|
T,
|
|
is_namedtuple_class,
|
|
is_structseq_class,
|
|
structseq,
|
|
)
|
|
from optree.utils import safe_zip, total_order_sorted, unzip2
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
import builtins
|
|
|
|
from optree.typing import KT, VT, FlattenFunc, UnflattenFunc
|
|
|
|
|
|
__all__ = [
|
|
'register_pytree_node',
|
|
'register_pytree_node_class',
|
|
'unregister_pytree_node',
|
|
'dict_insertion_ordered',
|
|
]
|
|
|
|
|
|
SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+
|
|
|
|
|
|
@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS)
|
|
class PyTreeNodeRegistryEntry(Generic[T]):
|
|
"""A dataclass that stores the information of a pytree node type."""
|
|
|
|
type: builtins.type[Collection[T]]
|
|
flatten_func: FlattenFunc[T]
|
|
unflatten_func: UnflattenFunc[T]
|
|
|
|
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
|
|
_: dataclasses.KW_ONLY # Python 3.10+
|
|
|
|
path_entry_type: builtins.type[PyTreeEntry] = AutoEntry
|
|
kind: PyTreeKind = PyTreeKind.CUSTOM
|
|
namespace: str = ''
|
|
|
|
|
|
del SLOTS
|
|
|
|
|
|
# pylint: disable-next=missing-class-docstring,too-few-public-methods
|
|
class GlobalNamespace: # pragma: no cover
|
|
__slots__: ClassVar[tuple[()]] = ()
|
|
|
|
def __repr__(self, /) -> str:
|
|
return '<GLOBAL NAMESPACE>'
|
|
|
|
|
|
__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment]
|
|
__REGISTRY_LOCK: Lock = Lock()
|
|
del GlobalNamespace
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import ParamSpec # Python 3.10+
|
|
|
|
_P = ParamSpec('_P')
|
|
_T = TypeVar('_T')
|
|
_GetP = ParamSpec('_GetP')
|
|
_GetT = TypeVar('_GetT')
|
|
|
|
class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]):
|
|
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
|
|
raise NotImplementedError
|
|
|
|
# pylint: disable-next=missing-function-docstring
|
|
def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT:
|
|
raise NotImplementedError
|
|
|
|
|
|
def _add_get(
|
|
get: Callable[_GetP, _GetT],
|
|
/,
|
|
) -> Callable[
|
|
[Callable[_P, _T]],
|
|
_CallableWithGet[_P, _T, _GetP, _GetT],
|
|
]:
|
|
def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]:
|
|
func.get = get # type: ignore[attr-defined]
|
|
return func # type: ignore[return-value]
|
|
|
|
return decorator
|
|
|
|
|
|
@overload
|
|
def pytree_node_registry_get(
|
|
cls: type,
|
|
/,
|
|
*,
|
|
namespace: str = '',
|
|
) -> PyTreeNodeRegistryEntry | None: ...
|
|
|
|
|
|
@overload
|
|
def pytree_node_registry_get(
|
|
cls: None = None,
|
|
/,
|
|
*,
|
|
namespace: str = '',
|
|
) -> dict[type, PyTreeNodeRegistryEntry]: ...
|
|
|
|
|
|
# pylint: disable-next=too-many-return-statements,too-many-branches
|
|
def pytree_node_registry_get( # noqa: C901
|
|
cls: type | None = None,
|
|
/,
|
|
*,
|
|
namespace: str = '',
|
|
) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None:
|
|
"""Lookup the pytree node registry.
|
|
|
|
>>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS
|
|
{
|
|
<class 'NoneType'>: PyTreeNodeRegistryEntry(
|
|
type=<class 'NoneType'>,
|
|
flatten_func=<function ...>,
|
|
unflatten_func=<function ...>,
|
|
path_entry_type=<class 'optree.accessor.PyTreeEntry'>,
|
|
kind=<PyTreeKind.NONE: 2>,
|
|
namespace=''
|
|
),
|
|
<class 'tuple'>: PyTreeNodeRegistryEntry(
|
|
type=<class 'tuple'>,
|
|
flatten_func=<function ...>,
|
|
unflatten_func=<function ...>,
|
|
path_entry_type=<class 'optree.accessor.SequenceEntry'>,
|
|
kind=<PyTreeKind.TUPLE: 3>,
|
|
namespace=''
|
|
),
|
|
<class 'list'>: PyTreeNodeRegistryEntry(
|
|
type=<class 'list'>,
|
|
flatten_func=<function ...>,
|
|
unflatten_func=<function ...>,
|
|
path_entry_type=<class 'optree.accessor.SequenceEntry'>,
|
|
kind=<PyTreeKind.LIST: 4>,
|
|
namespace=''
|
|
),
|
|
...
|
|
}
|
|
>>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
|
|
PyTreeNodeRegistryEntry(
|
|
type=<class 'collections.defaultdict'>,
|
|
flatten_func=<function ...>,
|
|
unflatten_func=<function ...>,
|
|
path_entry_type=<class 'optree.accessor.MappingEntry'>,
|
|
kind=<PyTreeKind.DEFAULTDICT: 8>,
|
|
namespace=''
|
|
)
|
|
>>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node
|
|
None
|
|
|
|
Args:
|
|
cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all
|
|
the registered pytree nodes in the namespace are returned.
|
|
namespace (str, optional): The namespace of the registry to retrieve. If not provided, the
|
|
global namespace is used.
|
|
|
|
Returns:
|
|
If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the
|
|
namespace is returned. If the ``cls`` is provided, the corresponding registry entry is
|
|
returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned,
|
|
i.e., the ``cls`` is represented as a leaf node.
|
|
"""
|
|
if namespace is __GLOBAL_NAMESPACE:
|
|
namespace = ''
|
|
if (
|
|
cls is not None
|
|
and cls is not namedtuple # noqa: PYI024
|
|
and not inspect.isclass(cls)
|
|
):
|
|
raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover
|
|
if not isinstance(namespace, str):
|
|
raise TypeError( # pragma: !=3.9 cover
|
|
f'The namespace must be a string, got {namespace!r}.',
|
|
)
|
|
|
|
if cls is None:
|
|
namespaces = frozenset({namespace, ''})
|
|
with __REGISTRY_LOCK:
|
|
registry = {
|
|
handler.type: handler
|
|
for handler in _NODETYPE_REGISTRY.values()
|
|
if handler.namespace in namespaces
|
|
}
|
|
if _C.is_dict_insertion_ordered(namespace):
|
|
registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
|
|
registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
|
|
return registry
|
|
|
|
if namespace != '':
|
|
handler = _NODETYPE_REGISTRY.get((namespace, cls))
|
|
if handler is not None:
|
|
return handler
|
|
|
|
if _C.is_dict_insertion_ordered(namespace):
|
|
if cls is dict:
|
|
return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
|
|
if cls is defaultdict:
|
|
return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
|
|
|
|
handler = _NODETYPE_REGISTRY.get(cls)
|
|
if handler is not None:
|
|
return handler
|
|
if is_structseq_class(cls):
|
|
return _NODETYPE_REGISTRY.get(structseq)
|
|
if is_namedtuple_class(cls):
|
|
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
|
|
return None
|
|
|
|
|
|
@_add_get(pytree_node_registry_get)
|
|
def register_pytree_node(
|
|
cls: type[Collection[T]],
|
|
/,
|
|
flatten_func: FlattenFunc[T],
|
|
unflatten_func: UnflattenFunc[T],
|
|
*,
|
|
path_entry_type: type[PyTreeEntry] = AutoEntry,
|
|
namespace: str,
|
|
) -> type[Collection[T]]:
|
|
"""Extend the set of types that are considered internal nodes in pytrees.
|
|
|
|
See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.
|
|
|
|
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
|
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
|
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
|
|
the same class in different namespaces for different use cases.
|
|
|
|
.. warning::
|
|
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
|
|
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
|
|
prevent accidental collisions between different libraries that may register the same type.
|
|
|
|
Args:
|
|
cls (type): A Python type to treat as an internal pytree node.
|
|
flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
|
|
and returning a triple or optionally a pair, with (1) an iterable for the children to be
|
|
flattened recursively, and (2) some hashable metadata to be stored in the treespec and
|
|
to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
|
|
entries to the corresponding children. If the entries are not provided or given by
|
|
:data:`None`, then `range(len(children))` will be used.
|
|
unflatten_func (callable): A function taking two arguments: the metadata that was returned
|
|
by ``flatten_func`` and stored in the treespec, and the unflattened children. The
|
|
function should return an instance of ``cls``.
|
|
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
|
|
(default: :class:`AutoEntry`)
|
|
namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
|
|
This is used to isolate the registry from other modules that might register a different
|
|
custom behavior for the same type.
|
|
|
|
Returns:
|
|
The same type as the input ``cls``.
|
|
|
|
Raises:
|
|
TypeError: If the input type is not a class.
|
|
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
|
|
TypeError: If the namespace is not a string.
|
|
ValueError: If the namespace is an empty string.
|
|
ValueError: If the type is already registered in the registry.
|
|
|
|
Examples:
|
|
>>> # Registry a Python type with lambda functions
|
|
>>> register_pytree_node(
|
|
... set,
|
|
... lambda s: (sorted(s), None, None),
|
|
... lambda _, children: set(children),
|
|
... namespace='set',
|
|
... )
|
|
<class 'set'>
|
|
|
|
>>> # Register a Python type into a namespace
|
|
>>> import torch
|
|
>>> register_pytree_node(
|
|
... torch.Tensor,
|
|
... flatten_func=lambda tensor: (
|
|
... (tensor.cpu().detach().numpy(),),
|
|
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
|
|
... ),
|
|
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
|
|
... namespace='torch2numpy',
|
|
... )
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> # doctest: +SKIP
|
|
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
|
|
>>> tree
|
|
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
|
|
|
|
>>> # Flatten without specifying the namespace
|
|
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes
|
|
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
|
|
|
|
>>> # Flatten with the namespace
|
|
>>> tree_flatten(tree, namespace='torch2numpy')
|
|
(
|
|
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
|
|
PyTreeSpec(
|
|
{
|
|
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
|
|
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
|
|
},
|
|
namespace='torch2numpy'
|
|
)
|
|
)
|
|
|
|
>>> # Register the same type with a different namespace for different behaviors
|
|
>>> def tensor2flatparam(tensor):
|
|
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
|
|
...
|
|
... def flatparam2tensor(metadata, children):
|
|
... return children[0].reshape(metadata)
|
|
...
|
|
... register_pytree_node(
|
|
... torch.Tensor,
|
|
... flatten_func=tensor2flatparam,
|
|
... unflatten_func=flatparam2tensor,
|
|
... namespace='tensor2flatparam',
|
|
... )
|
|
<class 'torch.Tensor'>
|
|
|
|
>>> # Flatten with the new namespace
|
|
>>> tree_flatten(tree, namespace='tensor2flatparam')
|
|
(
|
|
[
|
|
Parameter containing: tensor([0., 0.], requires_grad=True),
|
|
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
|
|
],
|
|
PyTreeSpec(
|
|
{
|
|
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
|
|
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
|
|
},
|
|
namespace='tensor2flatparam'
|
|
)
|
|
)
|
|
""" # pylint: disable=line-too-long
|
|
if not inspect.isclass(cls):
|
|
raise TypeError(f'Expected a class, got {cls!r}.')
|
|
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
|
|
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
|
|
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
|
|
registration_key: type | tuple[str, type]
|
|
if namespace is __GLOBAL_NAMESPACE:
|
|
registration_key = cls
|
|
namespace = ''
|
|
else:
|
|
registration_key = (namespace, cls)
|
|
|
|
with __REGISTRY_LOCK:
|
|
_C.register_node(
|
|
cls,
|
|
flatten_func,
|
|
unflatten_func,
|
|
path_entry_type,
|
|
namespace,
|
|
)
|
|
_NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry(
|
|
cls,
|
|
flatten_func,
|
|
unflatten_func,
|
|
path_entry_type=path_entry_type,
|
|
namespace=namespace,
|
|
)
|
|
return cls
|
|
|
|
|
|
del pytree_node_registry_get, _add_get
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
# pylint: disable-next=invalid-name
|
|
CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode])
|
|
|
|
|
|
@overload
|
|
def register_pytree_node_class(
|
|
cls: str | None = None,
|
|
/,
|
|
*,
|
|
path_entry_type: type[PyTreeEntry] | None = None,
|
|
namespace: str | None = None,
|
|
) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ...
|
|
|
|
|
|
@overload
|
|
def register_pytree_node_class(
|
|
cls: CustomTreeNodeType,
|
|
/,
|
|
*,
|
|
path_entry_type: type[PyTreeEntry] | None,
|
|
namespace: str,
|
|
) -> CustomTreeNodeType: ...
|
|
|
|
|
|
def register_pytree_node_class( # noqa: C901
|
|
cls: CustomTreeNodeType | str | None = None,
|
|
/,
|
|
*,
|
|
path_entry_type: type[PyTreeEntry] | None = None,
|
|
namespace: str | None = None,
|
|
) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]:
|
|
"""Extend the set of types that are considered internal nodes in pytrees.
|
|
|
|
See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.
|
|
|
|
The ``namespace`` argument is used to avoid collisions that occur when different libraries
|
|
register the same Python type with different behaviors. It is recommended to add a unique prefix
|
|
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
|
|
the same class in different namespaces for different use cases.
|
|
|
|
.. warning::
|
|
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
|
|
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
|
|
prevent accidental collisions between different libraries that may register the same type.
|
|
|
|
Args:
|
|
cls (type, optional): A Python type to treat as an internal pytree node.
|
|
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
|
|
(default: :class:`AutoEntry`)
|
|
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
|
|
type registry. This is used to isolate the registry from other modules that might
|
|
register a different custom behavior for the same type.
|
|
|
|
Returns:
|
|
The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
|
|
function that registers the class as a pytree node.
|
|
|
|
Raises:
|
|
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
|
|
TypeError: If the namespace is not a string.
|
|
ValueError: If the namespace is an empty string.
|
|
ValueError: If the type is already registered in the registry.
|
|
|
|
This function is a thin wrapper around :func:`register_pytree_node`, and provides a
|
|
class-oriented interface::
|
|
|
|
@register_pytree_node_class(namespace='foo')
|
|
class Special:
|
|
TREE_PATH_ENTRY_TYPE = GetAttrEntry
|
|
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
def tree_flatten(self):
|
|
return ((self.x, self.y), None, ('x', 'y'))
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, children):
|
|
return cls(*children)
|
|
|
|
@register_pytree_node_class('mylist')
|
|
class MyList(UserList):
|
|
TREE_PATH_ENTRY_TYPE = SequenceEntry
|
|
|
|
def tree_flatten(self):
|
|
return self.data, None, None
|
|
|
|
@classmethod
|
|
def tree_unflatten(cls, metadata, children):
|
|
return cls(*children)
|
|
"""
|
|
if cls is __GLOBAL_NAMESPACE or isinstance(cls, str):
|
|
if namespace is not None:
|
|
raise ValueError('Cannot specify `namespace` when the first argument is a string.')
|
|
if cls == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
cls, namespace = None, cls
|
|
|
|
if namespace is None:
|
|
raise ValueError('Must specify `namespace` when the first argument is a class.')
|
|
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
|
|
if cls is None:
|
|
|
|
def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType:
|
|
return register_pytree_node_class(
|
|
cls,
|
|
path_entry_type=path_entry_type,
|
|
namespace=namespace,
|
|
)
|
|
|
|
return decorator
|
|
|
|
if not inspect.isclass(cls):
|
|
raise TypeError(f'Expected a class, got {cls!r}.')
|
|
if path_entry_type is None:
|
|
path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry)
|
|
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
|
|
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
|
|
|
|
register_pytree_node(
|
|
cls,
|
|
methodcaller('tree_flatten'),
|
|
cls.tree_unflatten,
|
|
path_entry_type=path_entry_type,
|
|
namespace=namespace,
|
|
)
|
|
return cls
|
|
|
|
|
|
def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry:
|
|
"""Remove a type from the pytree node registry.
|
|
|
|
See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.
|
|
|
|
This function is the inverse operation of function :func:`register_pytree_node`.
|
|
|
|
Args:
|
|
cls (type): A Python type to remove from the pytree node registry.
|
|
namespace (str): The namespace of the pytree node registry to remove the type from.
|
|
|
|
Returns:
|
|
The removed registry entry.
|
|
|
|
Raises:
|
|
TypeError: If the input type is not a class.
|
|
TypeError: If the namespace is not a string.
|
|
ValueError: If the namespace is an empty string.
|
|
ValueError: If the type is a built-in type that cannot be unregistered.
|
|
ValueError: If the type is not found in the registry.
|
|
|
|
Examples:
|
|
>>> # Register a Python type with lambda functions
|
|
>>> register_pytree_node(
|
|
... set,
|
|
... lambda s: (sorted(s), None, None),
|
|
... lambda _, children: set(children),
|
|
... namespace='temp',
|
|
... )
|
|
<class 'set'>
|
|
|
|
>>> # Unregister the Python type
|
|
>>> unregister_pytree_node(set, namespace='temp')
|
|
"""
|
|
if not inspect.isclass(cls):
|
|
raise TypeError(f'Expected a class, got {cls!r}.')
|
|
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
|
|
registration_key: type | tuple[str, type]
|
|
if namespace is __GLOBAL_NAMESPACE:
|
|
registration_key = cls
|
|
namespace = ''
|
|
else:
|
|
registration_key = (namespace, cls)
|
|
|
|
with __REGISTRY_LOCK:
|
|
_C.unregister_node(cls, namespace)
|
|
return _NODETYPE_REGISTRY.pop(registration_key)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]:
|
|
"""Context manager to temporarily set the dictionary sorting mode.
|
|
|
|
This context manager is used to temporarily set the dictionary sorting mode for a specific
|
|
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
|
|
should be sorted or keeping the insertion order when flattening a pytree.
|
|
|
|
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
|
|
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
|
|
(
|
|
[1, 2, 3, 4, 5],
|
|
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
|
|
)
|
|
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
|
|
... tree_flatten(tree, namespace='some-namespace')
|
|
(
|
|
[2, 3, 4, 1, 5],
|
|
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
|
|
)
|
|
|
|
.. warning::
|
|
The dictionary sorting mode is a global setting and is **not thread-safe**. It is
|
|
recommended to use this context manager in a single-threaded environment.
|
|
|
|
Args:
|
|
mode (bool): The dictionary sorting mode to set.
|
|
namespace (str): The namespace to set the dictionary sorting mode for.
|
|
"""
|
|
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
if namespace is __GLOBAL_NAMESPACE:
|
|
namespace = ''
|
|
|
|
with __REGISTRY_LOCK:
|
|
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
|
|
_C.set_dict_insertion_ordered(bool(mode), namespace)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
with __REGISTRY_LOCK:
|
|
_C.set_dict_insertion_ordered(prev, namespace)
|
|
|
|
|
|
def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]:
|
|
return total_order_sorted(items, key=itemgetter(0))
|
|
|
|
|
|
def _none_flatten(_: None, /) -> tuple[tuple[()], None]:
|
|
return (), None
|
|
|
|
|
|
def _none_unflatten(_: None, children: Iterable[Any], /) -> None:
|
|
sentinel = object()
|
|
if next(iter(children), sentinel) is not sentinel:
|
|
raise ValueError('Expected no children.')
|
|
|
|
|
|
def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]:
|
|
return tup, None
|
|
|
|
|
|
def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]:
|
|
return tuple(children)
|
|
|
|
|
|
def _list_flatten(lst: list[T], /) -> tuple[list[T], None]:
|
|
return lst, None
|
|
|
|
|
|
def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]:
|
|
return list(children)
|
|
|
|
|
|
def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
|
|
keys, values = unzip2(_sorted_items(dct.items()))
|
|
return values, list(keys), keys
|
|
|
|
|
|
def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
|
|
return dict(safe_zip(keys, values))
|
|
|
|
|
|
def _dict_insertion_ordered_flatten(
|
|
dct: dict[KT, VT],
|
|
/,
|
|
) -> tuple[
|
|
tuple[VT, ...],
|
|
list[KT],
|
|
tuple[KT, ...],
|
|
]:
|
|
keys, values = unzip2(dct.items())
|
|
return values, list(keys), keys
|
|
|
|
|
|
def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
|
|
return dict(safe_zip(keys, values))
|
|
|
|
|
|
def _ordereddict_flatten(
|
|
dct: OrderedDict[KT, VT],
|
|
/,
|
|
) -> tuple[
|
|
tuple[VT, ...],
|
|
list[KT],
|
|
tuple[KT, ...],
|
|
]:
|
|
keys, values = unzip2(dct.items())
|
|
return values, list(keys), keys
|
|
|
|
|
|
def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]:
|
|
return OrderedDict(safe_zip(keys, values))
|
|
|
|
|
|
def _defaultdict_flatten(
|
|
dct: defaultdict[KT, VT],
|
|
/,
|
|
) -> tuple[
|
|
tuple[VT, ...],
|
|
tuple[Callable[[], VT] | None, list[KT]],
|
|
tuple[KT, ...],
|
|
]:
|
|
values, keys, entries = _dict_flatten(dct)
|
|
return values, (dct.default_factory, keys), entries
|
|
|
|
|
|
def _defaultdict_unflatten(
|
|
metadata: tuple[Callable[[], VT], list[KT]],
|
|
values: Iterable[VT],
|
|
/,
|
|
) -> defaultdict[KT, VT]:
|
|
default_factory, keys = metadata
|
|
return defaultdict(default_factory, _dict_unflatten(keys, values))
|
|
|
|
|
|
def _defaultdict_insertion_ordered_flatten(
|
|
dct: defaultdict[KT, VT],
|
|
/,
|
|
) -> tuple[
|
|
tuple[VT, ...],
|
|
tuple[Callable[[], VT] | None, list[KT]],
|
|
tuple[KT, ...],
|
|
]:
|
|
values, keys, entries = _dict_insertion_ordered_flatten(dct)
|
|
return values, (dct.default_factory, keys), entries
|
|
|
|
|
|
def _defaultdict_insertion_ordered_unflatten(
|
|
metadata: tuple[Callable[[], VT], list[KT]],
|
|
values: Iterable[VT],
|
|
/,
|
|
) -> defaultdict[KT, VT]:
|
|
default_factory, keys = metadata
|
|
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))
|
|
|
|
|
|
def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]:
|
|
return deq, deq.maxlen
|
|
|
|
|
|
def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]:
|
|
return deque(children, maxlen=maxlen)
|
|
|
|
|
|
def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg]
|
|
return tup, type(tup)
|
|
|
|
|
|
# pylint: disable-next=line-too-long
|
|
def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg]
|
|
return cls(*children) # type: ignore[call-overload]
|
|
|
|
|
|
def _structseq_flatten(seq: structseq[T], /) -> tuple[tuple[T, ...], type[structseq[T]]]:
|
|
return seq, type(seq)
|
|
|
|
|
|
def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T], /) -> structseq[T]:
|
|
return cls(children)
|
|
|
|
|
|
_NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = {
|
|
type(None): PyTreeNodeRegistryEntry(
|
|
type(None), # type: ignore[arg-type]
|
|
_none_flatten,
|
|
_none_unflatten,
|
|
path_entry_type=PyTreeEntry,
|
|
kind=PyTreeKind.NONE,
|
|
),
|
|
tuple: PyTreeNodeRegistryEntry(
|
|
tuple,
|
|
_tuple_flatten,
|
|
_tuple_unflatten,
|
|
path_entry_type=SequenceEntry,
|
|
kind=PyTreeKind.TUPLE,
|
|
),
|
|
list: PyTreeNodeRegistryEntry(
|
|
list,
|
|
_list_flatten,
|
|
_list_unflatten,
|
|
path_entry_type=SequenceEntry,
|
|
kind=PyTreeKind.LIST,
|
|
),
|
|
dict: PyTreeNodeRegistryEntry(
|
|
dict,
|
|
_dict_flatten,
|
|
_dict_unflatten,
|
|
path_entry_type=MappingEntry,
|
|
kind=PyTreeKind.DICT,
|
|
),
|
|
namedtuple: PyTreeNodeRegistryEntry( # type: ignore[dict-item] # noqa: PYI024
|
|
namedtuple, # type: ignore[arg-type] # noqa: PYI024
|
|
_namedtuple_flatten,
|
|
_namedtuple_unflatten,
|
|
path_entry_type=NamedTupleEntry,
|
|
kind=PyTreeKind.NAMEDTUPLE,
|
|
),
|
|
OrderedDict: PyTreeNodeRegistryEntry(
|
|
OrderedDict,
|
|
_ordereddict_flatten,
|
|
_ordereddict_unflatten,
|
|
path_entry_type=MappingEntry,
|
|
kind=PyTreeKind.ORDEREDDICT,
|
|
),
|
|
defaultdict: PyTreeNodeRegistryEntry(
|
|
defaultdict,
|
|
_defaultdict_flatten,
|
|
_defaultdict_unflatten,
|
|
path_entry_type=MappingEntry,
|
|
kind=PyTreeKind.DEFAULTDICT,
|
|
),
|
|
deque: PyTreeNodeRegistryEntry(
|
|
deque,
|
|
_deque_flatten,
|
|
_deque_unflatten,
|
|
path_entry_type=SequenceEntry,
|
|
kind=PyTreeKind.DEQUE,
|
|
),
|
|
structseq: PyTreeNodeRegistryEntry(
|
|
structseq,
|
|
_structseq_flatten,
|
|
_structseq_unflatten,
|
|
path_entry_type=StructSequenceEntry,
|
|
kind=PyTreeKind.STRUCTSEQUENCE,
|
|
),
|
|
}
|
|
|
|
|
|
_DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
|
|
dict,
|
|
_dict_insertion_ordered_flatten,
|
|
_dict_insertion_ordered_unflatten,
|
|
path_entry_type=MappingEntry,
|
|
kind=PyTreeKind.DICT,
|
|
)
|
|
_DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
|
|
defaultdict,
|
|
_defaultdict_insertion_ordered_flatten,
|
|
_defaultdict_insertion_ordered_unflatten,
|
|
path_entry_type=MappingEntry,
|
|
kind=PyTreeKind.DEFAULTDICT,
|
|
)
|