sequence.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains _sequence_like and helpers for sequence data structures."""
  16. import collections
  17. from collections import abc as collections_abc
  18. import types
  19. from tree import _tree
  20. # pylint: disable=g-import-not-at-top
  21. try:
  22. import wrapt
  23. ObjectProxy = wrapt.ObjectProxy
  24. except ImportError:
  25. class ObjectProxy(object):
  26. """Stub-class for `wrapt.ObjectProxy``."""
  27. def _sorted(dictionary):
  28. """Returns a sorted list of the dict keys, with error if keys not sortable."""
  29. try:
  30. return sorted(dictionary)
  31. except TypeError:
  32. raise TypeError("tree only supports dicts with sortable keys.")
  33. def _is_attrs(instance):
  34. return _tree.is_attrs(instance)
  35. def _is_namedtuple(instance, strict=False):
  36. """Returns True iff `instance` is a `namedtuple`.
  37. Args:
  38. instance: An instance of a Python object.
  39. strict: If True, `instance` is considered to be a `namedtuple` only if
  40. it is a "plain" namedtuple. For instance, a class inheriting
  41. from a `namedtuple` will be considered to be a `namedtuple`
  42. iff `strict=False`.
  43. Returns:
  44. True if `instance` is a `namedtuple`.
  45. """
  46. return _tree.is_namedtuple(instance, strict)
  47. def _sequence_like(instance, args):
  48. """Converts the sequence `args` to the same type as `instance`.
  49. Args:
  50. instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
  51. `collections.OrderedDict`.
  52. args: elements to be converted to the `instance` type.
  53. Returns:
  54. `args` with the type of `instance`.
  55. """
  56. if isinstance(instance, (dict, collections_abc.Mapping)):
  57. # Pack dictionaries in a deterministic order by sorting the keys.
  58. # Notice this means that we ignore the original order of `OrderedDict`
  59. # instances. This is intentional, to avoid potential bugs caused by mixing
  60. # ordered and plain dicts (e.g., flattening a dict but using a
  61. # corresponding `OrderedDict` to pack it back).
  62. result = dict(zip(_sorted(instance), args))
  63. keys_and_values = ((key, result[key]) for key in instance)
  64. if isinstance(instance, collections.defaultdict):
  65. # `defaultdict` requires a default factory as the first argument.
  66. return type(instance)(instance.default_factory, keys_and_values)
  67. elif isinstance(instance, types.MappingProxyType):
  68. # MappingProxyType requires a dict to proxy to.
  69. return type(instance)(dict(keys_and_values))
  70. else:
  71. return type(instance)(keys_and_values)
  72. elif isinstance(instance, collections_abc.MappingView):
  73. # We can't directly construct mapping views, so we create a list instead
  74. return list(args)
  75. elif _is_namedtuple(instance) or _is_attrs(instance):
  76. if isinstance(instance, ObjectProxy):
  77. instance_type = type(instance.__wrapped__)
  78. else:
  79. instance_type = type(instance)
  80. try:
  81. if _is_attrs(instance):
  82. return instance_type(
  83. **{
  84. attr.name: arg
  85. for attr, arg in zip(instance_type.__attrs_attrs__, args)
  86. })
  87. else:
  88. return instance_type(*args)
  89. except Exception as e:
  90. raise TypeError(
  91. f"Couldn't traverse {instance!r} with arguments {args}") from e
  92. elif isinstance(instance, ObjectProxy):
  93. # For object proxies, first create the underlying type and then re-wrap it
  94. # in the proxy type.
  95. return type(instance)(_sequence_like(instance.__wrapped__, args))
  96. else:
  97. # Not a namedtuple
  98. return type(instance)(args)