464 lines
18 KiB
Python
464 lines
18 KiB
Python
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""PyTree integration with :mod:`dataclasses`.
|
|
|
|
This module implements PyTree integration with :mod:`dataclasses` by redefining the :func:`field`,
|
|
:func:`dataclass`, and :func:`make_dataclass` functions. Other APIs are re-exported from the
|
|
original :mod:`dataclasses` module.
|
|
|
|
The PyTree integration allows dataclasses to be flattened and unflattened recursively. The fields
|
|
are stored in a special attribute named ``__optree_dataclass_fields__`` in the dataclass.
|
|
|
|
>>> import math
|
|
... import optree
|
|
...
|
|
>>> @optree.dataclasses.dataclass(namespace='my_module')
|
|
... class Point:
|
|
... x: float
|
|
... y: float
|
|
... z: float = 0.0
|
|
... norm: float = optree.dataclasses.field(init=False, pytree_node=False)
|
|
...
|
|
... def __post_init__(self) -> None:
|
|
... self.norm = math.hypot(self.x, self.y, self.z)
|
|
...
|
|
>>> point = Point(2.0, 6.0, 3.0)
|
|
>>> point
|
|
Point(x=2.0, y=6.0, z=3.0, norm=7.0)
|
|
>>> # Flatten without specifying the namespace
|
|
>>> optree.tree_flatten(point) # `Point`s are leaf nodes
|
|
([Point(x=2.0, y=6.0, z=3.0, norm=7.0)], PyTreeSpec(*))
|
|
>>> # Flatten with the namespace
|
|
>>> accessors, leaves, treespec = optree.tree_flatten_with_accessor(point, namespace='my_module')
|
|
>>> accessors, leaves, treespec # doctest: +IGNORE_WHITESPACE,ELLIPSIS
|
|
(
|
|
[
|
|
PyTreeAccessor(*.x, (DataclassEntry(field='x', type=<class '...Point'>),)),
|
|
PyTreeAccessor(*.y, (DataclassEntry(field='y', type=<class '...Point'>),)),
|
|
PyTreeAccessor(*.z, (DataclassEntry(field='z', type=<class '...Point'>),))
|
|
],
|
|
[2.0, 6.0, 3.0],
|
|
PyTreeSpec(CustomTreeNode(Point[()], [*, *, *]), namespace='my_module')
|
|
)
|
|
>>> point == optree.tree_unflatten(treespec, leaves)
|
|
True
|
|
"""
|
|
|
|
# pylint: disable=too-many-arguments
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import dataclasses
|
|
import inspect
|
|
import sys
|
|
import types
|
|
from dataclasses import * # noqa: F401,F403,RUF100 # pylint: disable=wildcard-import,unused-wildcard-import
|
|
from typing import Any, Callable, Iterable, Literal, TypeVar, overload
|
|
from typing_extensions import dataclass_transform # Python 3.11+
|
|
|
|
|
|
# Redefine `field`, `dataclasses`, and `make_dataclasses`.
|
|
# The remaining APIs are re-exported from the original package.
|
|
__all__ = [*dataclasses.__all__]
|
|
|
|
|
|
_FIELDS = '__optree_dataclass_fields__'
|
|
_PYTREE_NODE_DEFAULT: bool = True
|
|
|
|
|
|
_T = TypeVar('_T')
|
|
_U = TypeVar('_U')
|
|
_TypeT = TypeVar('_TypeT', bound=type)
|
|
|
|
|
|
@overload # type: ignore[no-redef]
|
|
def field(
|
|
*,
|
|
default: _T,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
hash: bool | None = None, # pylint: disable=redefined-builtin
|
|
compare: bool = True,
|
|
metadata: dict[Any, Any] | None = None,
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
|
|
pytree_node: bool | None = None,
|
|
) -> _T: ...
|
|
|
|
|
|
@overload
|
|
def field(
|
|
*,
|
|
default_factory: Callable[[], _T],
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
hash: bool | None = None, # pylint: disable=redefined-builtin
|
|
compare: bool = True,
|
|
metadata: dict[Any, Any] | None = None,
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
|
|
pytree_node: bool | None = None,
|
|
) -> _T: ...
|
|
|
|
|
|
@overload
|
|
def field(
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
hash: bool | None = None, # pylint: disable=redefined-builtin
|
|
compare: bool = True,
|
|
metadata: dict[Any, Any] | None = None,
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
|
|
pytree_node: bool | None = None,
|
|
) -> Any: ...
|
|
|
|
|
|
def field( # noqa: D417 # pylint: disable=function-redefined
|
|
*,
|
|
default: Any = dataclasses.MISSING,
|
|
default_factory: Any = dataclasses.MISSING,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
hash: bool | None = None, # pylint: disable=redefined-builtin
|
|
compare: bool = True,
|
|
metadata: dict[Any, Any] | None = None,
|
|
kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+
|
|
pytree_node: bool | None = None,
|
|
) -> Any:
|
|
"""Field factory for :func:`dataclass`.
|
|
|
|
This factory function is used to define the fields in a dataclass. It is similar to the field
|
|
factory :func:`dataclasses.field`, but with an additional ``pytree_node`` parameter. If
|
|
``pytree_node`` is :data:`True` (default), the field will be considered a child node in the
|
|
PyTree structure which can be recursively flattened and unflattened. Otherwise, the field will
|
|
be considered as PyTree metadata.
|
|
|
|
Setting ``pytree_node`` in the field factory is equivalent to setting a key ``'pytree_node'`` in
|
|
``metadata`` in the original field factory. The ``pytree_node`` value can be accessed using
|
|
``field.metadata['pytree_node']``. If ``pytree_node`` is :data:`None`, the value
|
|
``metadata.get('pytree_node', True)`` will be used.
|
|
|
|
.. note::
|
|
If a field is considered a child node, it must be included in the argument list of the
|
|
:meth:`__init__` method, i.e., passes ``init=True`` in the field factory.
|
|
|
|
Args:
|
|
pytree_node (bool or None, optional): Whether the field is a PyTree node.
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.field`.
|
|
|
|
Returns:
|
|
dataclasses.Field: The field defined using the provided arguments with
|
|
``field.metadata['pytree_node']`` set.
|
|
"""
|
|
metadata = (metadata or {}).copy()
|
|
if pytree_node is None:
|
|
pytree_node = metadata.get('pytree_node', _PYTREE_NODE_DEFAULT)
|
|
metadata['pytree_node'] = pytree_node
|
|
|
|
kwargs = {
|
|
'default': default,
|
|
'default_factory': default_factory,
|
|
'init': init,
|
|
'repr': repr,
|
|
'hash': hash,
|
|
'compare': compare,
|
|
'metadata': metadata,
|
|
}
|
|
|
|
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
|
|
kwargs['kw_only'] = kw_only
|
|
elif kw_only is not dataclasses.MISSING: # pragma: <3.10 cover
|
|
raise TypeError("field() got an unexpected keyword argument 'kw_only'")
|
|
|
|
if not init and pytree_node:
|
|
raise TypeError(
|
|
'`pytree_node=True` is not allowed for non-init fields. '
|
|
f'Please explicitly set `{__name__}.field(init=False, pytree_node=False)`.',
|
|
)
|
|
|
|
return dataclasses.field(**kwargs) # pylint: disable=invalid-field-call
|
|
|
|
|
|
@overload # type: ignore[no-redef]
|
|
def dataclass(
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
match_args: bool = True, # Python 3.10+
|
|
kw_only: bool = False, # Python 3.10+
|
|
slots: bool = False, # Python 3.10+
|
|
weakref_slot: bool = False, # Python 3.11+
|
|
namespace: str,
|
|
) -> Callable[[_TypeT], _TypeT]: ...
|
|
|
|
|
|
@overload
|
|
def dataclass(
|
|
cls: _TypeT,
|
|
/,
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
match_args: bool = True, # Python 3.10+
|
|
kw_only: bool = False, # Python 3.10+
|
|
slots: bool = False, # Python 3.10+
|
|
weakref_slot: bool = False, # Python 3.11+
|
|
namespace: str,
|
|
) -> _TypeT: ...
|
|
|
|
|
|
@dataclass_transform(field_specifiers=(field,))
|
|
def dataclass( # noqa: C901,D417 # pylint: disable=function-redefined,too-many-locals,too-many-branches
|
|
cls: _TypeT | None = None,
|
|
/,
|
|
*,
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
match_args: bool = True, # Python 3.10+
|
|
kw_only: bool = False, # Python 3.10+
|
|
slots: bool = False, # Python 3.10+
|
|
weakref_slot: bool = False, # Python 3.11+
|
|
namespace: str,
|
|
) -> _TypeT | Callable[[_TypeT], _TypeT]:
|
|
"""Dataclass decorator with PyTree integration.
|
|
|
|
Args:
|
|
cls (type or None, optional): The class to decorate. If :data:`None`, return a decorator.
|
|
namespace (str): The registry namespace used for the PyTree registration.
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.dataclass`.
|
|
|
|
Returns:
|
|
type or callable: The decorated class with PyTree integration or decorator function.
|
|
"""
|
|
# pylint: disable-next=import-outside-toplevel
|
|
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
|
|
|
|
kwargs = {
|
|
'init': init,
|
|
'repr': repr,
|
|
'eq': eq,
|
|
'order': order,
|
|
'unsafe_hash': unsafe_hash,
|
|
'frozen': frozen,
|
|
}
|
|
|
|
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
|
|
kwargs['match_args'] = match_args
|
|
kwargs['kw_only'] = kw_only
|
|
kwargs['slots'] = slots
|
|
elif match_args is not True: # pragma: <3.10 cover
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'match_args'")
|
|
elif kw_only is not False: # pragma: <3.10 cover
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'kw_only'")
|
|
elif slots is not False: # pragma: <3.10 cover
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'slots'")
|
|
|
|
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
|
|
kwargs['weakref_slot'] = weakref_slot
|
|
elif weakref_slot is not False: # pragma: <3.11 cover
|
|
raise TypeError("dataclass() got an unexpected keyword argument 'weakref_slot'")
|
|
|
|
if cls is None:
|
|
|
|
def decorator(cls: _TypeT) -> _TypeT:
|
|
return dataclass(cls, namespace=namespace, **kwargs) # type: ignore[call-overload]
|
|
|
|
return decorator
|
|
|
|
if not inspect.isclass(cls):
|
|
raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.')
|
|
if _FIELDS in cls.__dict__:
|
|
raise TypeError(
|
|
f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.',
|
|
)
|
|
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
|
|
cls = dataclasses.dataclass(cls, **kwargs) # type: ignore[assignment]
|
|
|
|
children_fields = {}
|
|
metadata_fields = {}
|
|
for f in dataclasses.fields(cls):
|
|
if f.metadata.get('pytree_node', _PYTREE_NODE_DEFAULT):
|
|
if not f.init:
|
|
raise TypeError(
|
|
f'PyTree node field {f.name!r} must be included in `__init__()`. '
|
|
f'Or you can explicitly set `{__name__}.field(init=False, pytree_node=False)`.',
|
|
)
|
|
children_fields[f.name] = f
|
|
elif f.init:
|
|
metadata_fields[f.name] = f
|
|
|
|
children_field_names = tuple(children_fields)
|
|
children_fields = types.MappingProxyType(children_fields)
|
|
metadata_fields = types.MappingProxyType(metadata_fields)
|
|
setattr(cls, _FIELDS, (children_fields, metadata_fields))
|
|
|
|
def flatten_func(
|
|
obj: _T,
|
|
/,
|
|
) -> tuple[
|
|
tuple[_U, ...],
|
|
tuple[tuple[str, Any], ...],
|
|
tuple[str, ...],
|
|
]:
|
|
children = tuple(getattr(obj, name) for name in children_field_names)
|
|
metadata = tuple((name, getattr(obj, name)) for name in metadata_fields)
|
|
return children, metadata, children_field_names
|
|
|
|
# pylint: disable-next=line-too-long
|
|
def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[_U, ...], /) -> _T: # type: ignore[type-var]
|
|
kwargs = dict(zip(children_field_names, children))
|
|
kwargs.update(metadata)
|
|
return cls(**kwargs)
|
|
|
|
from optree.accessor import DataclassEntry # pylint: disable=import-outside-toplevel
|
|
from optree.registry import register_pytree_node # pylint: disable=import-outside-toplevel
|
|
|
|
return register_pytree_node( # type: ignore[return-value]
|
|
cls,
|
|
flatten_func,
|
|
unflatten_func, # type: ignore[arg-type]
|
|
path_entry_type=DataclassEntry,
|
|
namespace=namespace,
|
|
)
|
|
|
|
|
|
# pylint: disable-next=function-redefined,too-many-locals,too-many-branches
|
|
def make_dataclass( # type: ignore[no-redef] # noqa: C901,D417
|
|
cls_name: str,
|
|
# pylint: disable-next=redefined-outer-name
|
|
fields: Iterable[str | tuple[str, Any] | tuple[str, Any, Any]],
|
|
*,
|
|
bases: tuple[type, ...] = (),
|
|
ns: dict[str, Any] | None = None, # redirect to `namespace` to `dataclasses.make_dataclass()`
|
|
init: bool = True,
|
|
repr: bool = True, # pylint: disable=redefined-builtin
|
|
eq: bool = True,
|
|
order: bool = False,
|
|
unsafe_hash: bool = False,
|
|
frozen: bool = False,
|
|
match_args: bool = True, # Python 3.10+
|
|
kw_only: bool = False, # Python 3.10+
|
|
slots: bool = False, # Python 3.10+
|
|
weakref_slot: bool = False, # Python 3.11+
|
|
module: str | None = None, # Python 3.12+
|
|
namespace: str, # the PyTree registration namespace
|
|
) -> type:
|
|
"""Make a new dynamically created dataclass with PyTree integration.
|
|
|
|
The dataclass name will be ``cls_name``. ``fields`` is an iterable of either (name), (name, type),
|
|
or (name, type, Field) objects. If type is omitted, use the string :data:`typing.Any`. Field
|
|
objects are created by the equivalent of calling :func:`field` (name, type [, Field-info]).
|
|
|
|
The ``namespace`` parameter is the PyTree registration namespace which should be a string. The
|
|
``namespace`` in the original :func:`dataclasses.make_dataclass` function is renamed to ``ns``
|
|
to avoid conflicts.
|
|
|
|
The remaining parameters are passed to :func:`dataclasses.make_dataclass`.
|
|
See :func:`dataclasses.make_dataclass` for more information.
|
|
|
|
Args:
|
|
cls_name: The name of the dataclass.
|
|
fields (Iterable[str | tuple[str, Any] | tuple[str, Any, Any]]): An iterable of either
|
|
(name), (name, type), or (name, type, Field) objects.
|
|
namespace (str): The registry namespace used for the PyTree registration.
|
|
ns (dict or None, optional): The namespace used in dynamic type creation.
|
|
See :func:`dataclasses.make_dataclass` and the builtin :func:`type` function for more
|
|
information.
|
|
**kwargs (optional): Optional keyword arguments passed to :func:`dataclasses.make_dataclass`.
|
|
|
|
Returns:
|
|
type: The dynamically created dataclass with PyTree integration.
|
|
"""
|
|
# pylint: disable-next=import-outside-toplevel
|
|
from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE
|
|
|
|
if isinstance(namespace, dict) or namespace is None: # type: ignore[unreachable]
|
|
if ns is GLOBAL_NAMESPACE or isinstance(ns, str): # type: ignore[unreachable]
|
|
ns, namespace = namespace, ns
|
|
elif ns is None:
|
|
raise TypeError("make_dataclass() missing 1 required keyword-only argument: 'ns'")
|
|
if namespace is not GLOBAL_NAMESPACE and not isinstance(namespace, str):
|
|
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
|
|
if namespace == '':
|
|
raise ValueError('The namespace cannot be an empty string.')
|
|
|
|
dataclass_kwargs = {
|
|
'init': init,
|
|
'repr': repr,
|
|
'eq': eq,
|
|
'order': order,
|
|
'unsafe_hash': unsafe_hash,
|
|
'frozen': frozen,
|
|
}
|
|
make_dataclass_kwargs = {
|
|
'bases': bases,
|
|
'namespace': ns,
|
|
}
|
|
|
|
if sys.version_info >= (3, 10): # pragma: >=3.10 cover
|
|
dataclass_kwargs['match_args'] = match_args
|
|
dataclass_kwargs['kw_only'] = kw_only
|
|
dataclass_kwargs['slots'] = slots
|
|
elif match_args is not True: # pragma: <3.10 cover
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'match_args'")
|
|
elif kw_only is not False: # pragma: <3.10 cover
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'kw_only'")
|
|
elif slots is not False: # pragma: <3.10 cover
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'slots'")
|
|
|
|
if sys.version_info >= (3, 11): # pragma: >=3.11 cover
|
|
dataclass_kwargs['weakref_slot'] = weakref_slot
|
|
elif weakref_slot is not False: # pragma: <3.11 cover
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'weakref_slot'")
|
|
|
|
if sys.version_info >= (3, 12): # pragma: >=3.12 cover
|
|
if module is None:
|
|
try:
|
|
# pylint: disable-next=protected-access
|
|
module = sys._getframemodulename(1) or '__main__' # type: ignore[attr-defined]
|
|
except AttributeError: # pragma: no cover
|
|
with contextlib.suppress(AttributeError, ValueError):
|
|
# pylint: disable-next=protected-access
|
|
module = sys._getframe(1).f_globals.get('__name__', '__main__')
|
|
make_dataclass_kwargs['module'] = module
|
|
elif module is not None: # pragma: <3.12 cover
|
|
raise TypeError("make_dataclass() got an unexpected keyword argument 'module'")
|
|
|
|
cls = dataclasses.make_dataclass(
|
|
cls_name,
|
|
fields=fields,
|
|
**dataclass_kwargs, # type: ignore[arg-type]
|
|
**make_dataclass_kwargs, # type: ignore[arg-type]
|
|
)
|
|
dataclass_kwargs.pop('slots', None) # already defined in `make_dataclass()`
|
|
dataclass_kwargs.pop('weakref_slot', None) # already used in `make_dataclass()`
|
|
return dataclass(cls, **dataclass_kwargs, namespace=namespace) # type: ignore[call-overload]
|