2025-04-18 16:57:38 +00:00

112 lines
4.0 KiB
Python

# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains _sequence_like and helpers for sequence data structures."""
import collections
from collections import abc as collections_abc
import types
from tree import _tree
# pylint: disable=g-import-not-at-top
try:
import wrapt
ObjectProxy = wrapt.ObjectProxy
except ImportError:
class ObjectProxy(object):
"""Stub-class for `wrapt.ObjectProxy``."""
def _sorted(dictionary):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
return sorted(dictionary)
except TypeError:
raise TypeError("tree only supports dicts with sortable keys.")
def _is_attrs(instance):
return _tree.is_attrs(instance)
def _is_namedtuple(instance, strict=False):
"""Returns True iff `instance` is a `namedtuple`.
Args:
instance: An instance of a Python object.
strict: If True, `instance` is considered to be a `namedtuple` only if
it is a "plain" namedtuple. For instance, a class inheriting
from a `namedtuple` will be considered to be a `namedtuple`
iff `strict=False`.
Returns:
True if `instance` is a `namedtuple`.
"""
return _tree.is_namedtuple(instance, strict)
def _sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
Args:
instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
`collections.OrderedDict`.
args: elements to be converted to the `instance` type.
Returns:
`args` with the type of `instance`.
"""
if isinstance(instance, (dict, collections_abc.Mapping)):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
# ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
keys_and_values = ((key, result[key]) for key in instance)
if isinstance(instance, collections.defaultdict):
# `defaultdict` requires a default factory as the first argument.
return type(instance)(instance.default_factory, keys_and_values)
elif isinstance(instance, types.MappingProxyType):
# MappingProxyType requires a dict to proxy to.
return type(instance)(dict(keys_and_values))
else:
return type(instance)(keys_and_values)
elif isinstance(instance, collections_abc.MappingView):
# We can't directly construct mapping views, so we create a list instead
return list(args)
elif _is_namedtuple(instance) or _is_attrs(instance):
if isinstance(instance, ObjectProxy):
instance_type = type(instance.__wrapped__)
else:
instance_type = type(instance)
try:
if _is_attrs(instance):
return instance_type(
**{
attr.name: arg
for attr, arg in zip(instance_type.__attrs_attrs__, args)
})
else:
return instance_type(*args)
except Exception as e:
raise TypeError(
f"Couldn't traverse {instance!r} with arguments {args}") from e
elif isinstance(instance, ObjectProxy):
# For object proxies, first create the underlying type and then re-wrap it
# in the proxy type.
return type(instance)(_sequence_like(instance.__wrapped__, args))
else:
# Not a namedtuple
return type(instance)(args)