2025-04-18 16:57:38 +00:00

886 lines
30 KiB
Python

# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Registry for custom pytree node types."""
# pylint: disable=too-many-lines
from __future__ import annotations
import contextlib
import dataclasses
import inspect
import sys
from collections import OrderedDict, defaultdict, deque, namedtuple
from operator import itemgetter, methodcaller
from threading import Lock
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Generator,
Generic,
Iterable,
NamedTuple,
TypeVar,
overload,
)
import optree._C as _C
from optree.accessor import (
AutoEntry,
MappingEntry,
NamedTupleEntry,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
from optree.typing import (
CustomTreeNode,
PyTreeKind,
T,
is_namedtuple_class,
is_structseq_class,
structseq,
)
from optree.utils import safe_zip, total_order_sorted, unzip2
if TYPE_CHECKING:
import builtins
from optree.typing import KT, VT, FlattenFunc, UnflattenFunc
__all__ = [
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]
SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+
@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS)
class PyTreeNodeRegistryEntry(Generic[T]):
"""A dataclass that stores the information of a pytree node type."""
type: builtins.type[Collection[T]]
flatten_func: FlattenFunc[T]
unflatten_func: UnflattenFunc[T]
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
_: dataclasses.KW_ONLY # Python 3.10+
path_entry_type: builtins.type[PyTreeEntry] = AutoEntry
kind: PyTreeKind = PyTreeKind.CUSTOM
namespace: str = ''
del SLOTS
# pylint: disable-next=missing-class-docstring,too-few-public-methods
class GlobalNamespace: # pragma: no cover
__slots__: ClassVar[tuple[()]] = ()
def __repr__(self, /) -> str:
return '<GLOBAL NAMESPACE>'
__GLOBAL_NAMESPACE: str = GlobalNamespace() # type: ignore[assignment]
__REGISTRY_LOCK: Lock = Lock()
del GlobalNamespace
if TYPE_CHECKING:
from typing_extensions import ParamSpec # Python 3.10+
_P = ParamSpec('_P')
_T = TypeVar('_T')
_GetP = ParamSpec('_GetP')
_GetT = TypeVar('_GetT')
class _CallableWithGet(Generic[_P, _T, _GetP, _GetT]):
def __call__(self, /, *args: _P.args, **kwargs: _P.kwargs) -> _T:
raise NotImplementedError
# pylint: disable-next=missing-function-docstring
def get(self, /, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT:
raise NotImplementedError
def _add_get(
get: Callable[_GetP, _GetT],
/,
) -> Callable[
[Callable[_P, _T]],
_CallableWithGet[_P, _T, _GetP, _GetT],
]:
def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]:
func.get = get # type: ignore[attr-defined]
return func # type: ignore[return-value]
return decorator
@overload
def pytree_node_registry_get(
cls: type,
/,
*,
namespace: str = '',
) -> PyTreeNodeRegistryEntry | None: ...
@overload
def pytree_node_registry_get(
cls: None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry]: ...
# pylint: disable-next=too-many-return-statements,too-many-branches
def pytree_node_registry_get( # noqa: C901
cls: type | None = None,
/,
*,
namespace: str = '',
) -> dict[type, PyTreeNodeRegistryEntry] | PyTreeNodeRegistryEntry | None:
"""Lookup the pytree node registry.
>>> register_pytree_node.get() # doctest: +IGNORE_WHITESPACE,ELLIPSIS
{
<class 'NoneType'>: PyTreeNodeRegistryEntry(
type=<class 'NoneType'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.accessor.PyTreeEntry'>,
kind=<PyTreeKind.NONE: 2>,
namespace=''
),
<class 'tuple'>: PyTreeNodeRegistryEntry(
type=<class 'tuple'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.accessor.SequenceEntry'>,
kind=<PyTreeKind.TUPLE: 3>,
namespace=''
),
<class 'list'>: PyTreeNodeRegistryEntry(
type=<class 'list'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.accessor.SequenceEntry'>,
kind=<PyTreeKind.LIST: 4>,
namespace=''
),
...
}
>>> register_pytree_node.get(defaultdict) # doctest: +IGNORE_WHITESPACE,ELLIPSIS
PyTreeNodeRegistryEntry(
type=<class 'collections.defaultdict'>,
flatten_func=<function ...>,
unflatten_func=<function ...>,
path_entry_type=<class 'optree.accessor.MappingEntry'>,
kind=<PyTreeKind.DEFAULTDICT: 8>,
namespace=''
)
>>> register_pytree_node.get(frozenset) # frozenset is considered as a leaf node
None
Args:
cls (type or None, optional): The class of the pytree node to retrieve. If not provided, all
the registered pytree nodes in the namespace are returned.
namespace (str, optional): The namespace of the registry to retrieve. If not provided, the
global namespace is used.
Returns:
If the ``cls`` is not provided, a dictionary of all the registered pytree nodes in the
namespace is returned. If the ``cls`` is provided, the corresponding registry entry is
returned if the ``cls`` is registered as a pytree node. Otherwise, :data:`None` is returned,
i.e., the ``cls`` is represented as a leaf node.
"""
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
if (
cls is not None
and cls is not namedtuple # noqa: PYI024
and not inspect.isclass(cls)
):
raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover
if not isinstance(namespace, str):
raise TypeError( # pragma: !=3.9 cover
f'The namespace must be a string, got {namespace!r}.',
)
if cls is None:
namespaces = frozenset({namespace, ''})
with __REGISTRY_LOCK:
registry = {
handler.type: handler
for handler in _NODETYPE_REGISTRY.values()
if handler.namespace in namespaces
}
if _C.is_dict_insertion_ordered(namespace):
registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
return registry
if namespace != '':
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler
if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY
if cls is defaultdict:
return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY
handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
if is_structseq_class(cls):
return _NODETYPE_REGISTRY.get(structseq)
if is_namedtuple_class(cls):
return _NODETYPE_REGISTRY.get(namedtuple) # type: ignore[call-overload] # noqa: PYI024
return None
@_add_get(pytree_node_registry_get)
def register_pytree_node(
cls: type[Collection[T]],
/,
flatten_func: FlattenFunc[T],
unflatten_func: UnflattenFunc[T],
*,
path_entry_type: type[PyTreeEntry] = AutoEntry,
namespace: str,
) -> type[Collection[T]]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node_class` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type): A Python type to treat as an internal pytree node.
flatten_func (callable): A function to be used during flattening, taking an instance of ``cls``
and returning a triple or optionally a pair, with (1) an iterable for the children to be
flattened recursively, and (2) some hashable metadata to be stored in the treespec and
to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path
entries to the corresponding children. If the entries are not provided or given by
:data:`None`, then `range(len(children))` will be used.
unflatten_func (callable): A function taking two arguments: the metadata that was returned
by ``flatten_func`` and stored in the treespec, and the unflattened children. The
function should return an instance of ``cls``.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str): A non-empty string that uniquely identifies the namespace of the type registry.
This is used to isolate the registry from other modules that might register a different
custom behavior for the same type.
Returns:
The same type as the input ``cls``.
Raises:
TypeError: If the input type is not a class.
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
Examples:
>>> # Registry a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='set',
... )
<class 'set'>
>>> # Register a Python type into a namespace
>>> import torch
>>> register_pytree_node(
... torch.Tensor,
... flatten_func=lambda tensor: (
... (tensor.cpu().detach().numpy(),),
... {'dtype': tensor.dtype, 'device': tensor.device, 'requires_grad': tensor.requires_grad},
... ),
... unflatten_func=lambda metadata, children: torch.tensor(children[0], **metadata),
... namespace='torch2numpy',
... )
<class 'torch.Tensor'>
>>> # doctest: +SKIP
>>> tree = {'weight': torch.ones(size=(1, 2)).cuda(), 'bias': torch.zeros(size=(2,))}
>>> tree
{'weight': tensor([[1., 1.]], device='cuda:0'), 'bias': tensor([0., 0.])}
>>> # Flatten without specifying the namespace
>>> tree_flatten(tree) # `torch.Tensor`s are leaf nodes
([tensor([0., 0.]), tensor([[1., 1.]], device='cuda:0')], PyTreeSpec({'bias': *, 'weight': *}))
>>> # Flatten with the namespace
>>> tree_flatten(tree, namespace='torch2numpy')
(
[array([0., 0.], dtype=float32), array([[1., 1.]], dtype=float32)],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cpu'), 'requires_grad': False}], [*]),
'weight': CustomTreeNode(Tensor[{'dtype': torch.float32, 'device': device(type='cuda', index=0), 'requires_grad': False}], [*])
},
namespace='torch2numpy'
)
)
>>> # Register the same type with a different namespace for different behaviors
>>> def tensor2flatparam(tensor):
... return [torch.nn.Parameter(tensor.reshape(-1))], tensor.shape, None
...
... def flatparam2tensor(metadata, children):
... return children[0].reshape(metadata)
...
... register_pytree_node(
... torch.Tensor,
... flatten_func=tensor2flatparam,
... unflatten_func=flatparam2tensor,
... namespace='tensor2flatparam',
... )
<class 'torch.Tensor'>
>>> # Flatten with the new namespace
>>> tree_flatten(tree, namespace='tensor2flatparam')
(
[
Parameter containing: tensor([0., 0.], requires_grad=True),
Parameter containing: tensor([1., 1.], device='cuda:0', requires_grad=True)
],
PyTreeSpec(
{
'bias': CustomTreeNode(Tensor[torch.Size([2])], [*]),
'weight': CustomTreeNode(Tensor[torch.Size([1, 2])], [*])
},
namespace='tensor2flatparam'
)
)
""" # pylint: disable=line-too-long
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.register_node(
cls,
flatten_func,
unflatten_func,
path_entry_type,
namespace,
)
_NODETYPE_REGISTRY[registration_key] = PyTreeNodeRegistryEntry(
cls,
flatten_func,
unflatten_func,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
del pytree_node_registry_get, _add_get
if TYPE_CHECKING:
# pylint: disable-next=invalid-name
CustomTreeNodeType = TypeVar('CustomTreeNodeType', bound=type[CustomTreeNode])
@overload
def register_pytree_node_class(
cls: str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> Callable[[CustomTreeNodeType], CustomTreeNodeType]: ...
@overload
def register_pytree_node_class(
cls: CustomTreeNodeType,
/,
*,
path_entry_type: type[PyTreeEntry] | None,
namespace: str,
) -> CustomTreeNodeType: ...
def register_pytree_node_class( # noqa: C901
cls: CustomTreeNodeType | str | None = None,
/,
*,
path_entry_type: type[PyTreeEntry] | None = None,
namespace: str | None = None,
) -> CustomTreeNodeType | Callable[[CustomTreeNodeType], CustomTreeNodeType]:
"""Extend the set of types that are considered internal nodes in pytrees.
See also :func:`register_pytree_node` and :func:`unregister_pytree_node`.
The ``namespace`` argument is used to avoid collisions that occur when different libraries
register the same Python type with different behaviors. It is recommended to add a unique prefix
to the namespace to avoid conflicts with other libraries. Namespaces can also be used to specify
the same class in different namespaces for different use cases.
.. warning::
For safety reasons, a ``namespace`` must be specified while registering a custom type. It is
used to isolate the behavior of flattening and unflattening a pytree node type. This is to
prevent accidental collisions between different libraries that may register the same type.
Args:
cls (type, optional): A Python type to treat as an internal pytree node.
path_entry_type (type, optional): The type of the path entry to be used in the treespec.
(default: :class:`AutoEntry`)
namespace (str, optional): A non-empty string that uniquely identifies the namespace of the
type registry. This is used to isolate the registry from other modules that might
register a different custom behavior for the same type.
Returns:
The same type as the input ``cls`` if the argument presents. Otherwise, return a decorator
function that registers the class as a pytree node.
Raises:
TypeError: If the path entry class is not a subclass of :class:`PyTreeEntry`.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is already registered in the registry.
This function is a thin wrapper around :func:`register_pytree_node`, and provides a
class-oriented interface::
@register_pytree_node_class(namespace='foo')
class Special:
TREE_PATH_ENTRY_TYPE = GetAttrEntry
def __init__(self, x, y):
self.x = x
self.y = y
def tree_flatten(self):
return ((self.x, self.y), None, ('x', 'y'))
@classmethod
def tree_unflatten(cls, metadata, children):
return cls(*children)
@register_pytree_node_class('mylist')
class MyList(UserList):
TREE_PATH_ENTRY_TYPE = SequenceEntry
def tree_flatten(self):
return self.data, None, None
@classmethod
def tree_unflatten(cls, metadata, children):
return cls(*children)
"""
if cls is __GLOBAL_NAMESPACE or isinstance(cls, str):
if namespace is not None:
raise ValueError('Cannot specify `namespace` when the first argument is a string.')
if cls == '':
raise ValueError('The namespace cannot be an empty string.')
cls, namespace = None, cls
if namespace is None:
raise ValueError('Must specify `namespace` when the first argument is a class.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if cls is None:
def decorator(cls: CustomTreeNodeType, /) -> CustomTreeNodeType:
return register_pytree_node_class(
cls,
path_entry_type=path_entry_type,
namespace=namespace,
)
return decorator
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if path_entry_type is None:
path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry)
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
register_pytree_node(
cls,
methodcaller('tree_flatten'),
cls.tree_unflatten,
path_entry_type=path_entry_type,
namespace=namespace,
)
return cls
def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry:
"""Remove a type from the pytree node registry.
See also :func:`register_pytree_node` and :func:`register_pytree_node_class`.
This function is the inverse operation of function :func:`register_pytree_node`.
Args:
cls (type): A Python type to remove from the pytree node registry.
namespace (str): The namespace of the pytree node registry to remove the type from.
Returns:
The removed registry entry.
Raises:
TypeError: If the input type is not a class.
TypeError: If the namespace is not a string.
ValueError: If the namespace is an empty string.
ValueError: If the type is a built-in type that cannot be unregistered.
ValueError: If the type is not found in the registry.
Examples:
>>> # Register a Python type with lambda functions
>>> register_pytree_node(
... set,
... lambda s: (sorted(s), None, None),
... lambda _, children: set(children),
... namespace='temp',
... )
<class 'set'>
>>> # Unregister the Python type
>>> unregister_pytree_node(set, namespace='temp')
"""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
registration_key: type | tuple[str, type]
if namespace is __GLOBAL_NAMESPACE:
registration_key = cls
namespace = ''
else:
registration_key = (namespace, cls)
with __REGISTRY_LOCK:
_C.unregister_node(cls, namespace)
return _NODETYPE_REGISTRY.pop(registration_key)
@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]:
"""Context manager to temporarily set the dictionary sorting mode.
This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)
.. warning::
The dictionary sorting mode is a global setting and is **not thread-safe**. It is
recommended to use this context manager in a single-threaded environment.
Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''
with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
_C.set_dict_insertion_ordered(bool(mode), namespace)
try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)
def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]:
return total_order_sorted(items, key=itemgetter(0))
def _none_flatten(_: None, /) -> tuple[tuple[()], None]:
return (), None
def _none_unflatten(_: None, children: Iterable[Any], /) -> None:
sentinel = object()
if next(iter(children), sentinel) is not sentinel:
raise ValueError('Expected no children.')
def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]:
return tup, None
def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]:
return tuple(children)
def _list_flatten(lst: list[T], /) -> tuple[list[T], None]:
return lst, None
def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]:
return list(children)
def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(_sorted_items(dct.items()))
return values, list(keys), keys
def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]:
return dict(safe_zip(keys, values))
def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
list[KT],
tuple[KT, ...],
]:
keys, values = unzip2(dct.items())
return values, list(keys), keys
def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]:
return OrderedDict(safe_zip(keys, values))
def _defaultdict_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, keys, entries = _dict_flatten(dct)
return values, (dct.default_factory, keys), entries
def _defaultdict_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_unflatten(keys, values))
def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
/,
) -> tuple[
tuple[VT, ...],
tuple[Callable[[], VT] | None, list[KT]],
tuple[KT, ...],
]:
values, keys, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, keys), entries
def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
/,
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))
def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]:
return deq, deq.maxlen
def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]:
return deque(children, maxlen=maxlen)
def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg]
return tup, type(tup)
# pylint: disable-next=line-too-long
def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg]
return cls(*children) # type: ignore[call-overload]
def _structseq_flatten(seq: structseq[T], /) -> tuple[tuple[T, ...], type[structseq[T]]]:
return seq, type(seq)
def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T], /) -> structseq[T]:
return cls(children)
_NODETYPE_REGISTRY: dict[type | tuple[str, type], PyTreeNodeRegistryEntry] = {
type(None): PyTreeNodeRegistryEntry(
type(None), # type: ignore[arg-type]
_none_flatten,
_none_unflatten,
path_entry_type=PyTreeEntry,
kind=PyTreeKind.NONE,
),
tuple: PyTreeNodeRegistryEntry(
tuple,
_tuple_flatten,
_tuple_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.TUPLE,
),
list: PyTreeNodeRegistryEntry(
list,
_list_flatten,
_list_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.LIST,
),
dict: PyTreeNodeRegistryEntry(
dict,
_dict_flatten,
_dict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
),
namedtuple: PyTreeNodeRegistryEntry( # type: ignore[dict-item] # noqa: PYI024
namedtuple, # type: ignore[arg-type] # noqa: PYI024
_namedtuple_flatten,
_namedtuple_unflatten,
path_entry_type=NamedTupleEntry,
kind=PyTreeKind.NAMEDTUPLE,
),
OrderedDict: PyTreeNodeRegistryEntry(
OrderedDict,
_ordereddict_flatten,
_ordereddict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.ORDEREDDICT,
),
defaultdict: PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_flatten,
_defaultdict_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
),
deque: PyTreeNodeRegistryEntry(
deque,
_deque_flatten,
_deque_unflatten,
path_entry_type=SequenceEntry,
kind=PyTreeKind.DEQUE,
),
structseq: PyTreeNodeRegistryEntry(
structseq,
_structseq_flatten,
_structseq_unflatten,
path_entry_type=StructSequenceEntry,
kind=PyTreeKind.STRUCTSEQUENCE,
),
}
_DICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten,
_dict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DICT,
)
_DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten,
_defaultdict_insertion_ordered_unflatten,
path_entry_type=MappingEntry,
kind=PyTreeKind.DEFAULTDICT,
)