checkpoint.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Serialize and deserialize trees."""
  15. import dataclasses
  16. import io
  17. import types
  18. from typing import Any, BinaryIO, Optional, TypeVar
  19. import numpy as np
  20. _T = TypeVar("_T")
  21. def dump(dest: BinaryIO, value: Any) -> None:
  22. """Dump a tree of dicts/dataclasses to a file object.
  23. Args:
  24. dest: a file object to write to.
  25. value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
  26. other basic types. Unions are not supported, other than Optional/None
  27. which is only supported in dataclasses, not in dicts, lists or tuples.
  28. All leaves must be coercible to a numpy array, and recoverable as a single
  29. arg to a type.
  30. """
  31. buffer = io.BytesIO() # In case the destination doesn't support seeking.
  32. np.savez(buffer, **_flatten(value))
  33. dest.write(buffer.getvalue())
  34. def load(source: BinaryIO, typ: type[_T]) -> _T:
  35. """Load from a file object and convert it to the specified type.
  36. Args:
  37. source: a file object to read from.
  38. typ: a type object that acts as a schema for deserialization. It must match
  39. what was serialized. If a type is Any, it will be returned however numpy
  40. serialized it, which is what you want for a tree of numpy arrays.
  41. Returns:
  42. the deserialized value as the specified type.
  43. """
  44. return _convert_types(typ, _unflatten(np.load(source)))
  45. _SEP = ":"
  46. def _flatten(tree: Any) -> dict[str, Any]:
  47. """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
  48. if dataclasses.is_dataclass(tree):
  49. # Don't use dataclasses.asdict as it is recursive so skips dropping None.
  50. tree = {f.name: v for f in dataclasses.fields(tree)
  51. if (v := getattr(tree, f.name)) is not None}
  52. elif isinstance(tree, (list, tuple)):
  53. tree = dict(enumerate(tree))
  54. assert isinstance(tree, dict)
  55. flat = {}
  56. for k, v in tree.items():
  57. k = str(k)
  58. assert _SEP not in k
  59. if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
  60. for a, b in _flatten(v).items():
  61. flat[f"{k}{_SEP}{a}"] = b
  62. else:
  63. assert v is not None
  64. flat[k] = v
  65. return flat
  66. def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
  67. """Unflatten a dict to a tree of dicts."""
  68. tree = {}
  69. for flat_key, v in flat.items():
  70. node = tree
  71. keys = flat_key.split(_SEP)
  72. for k in keys[:-1]:
  73. if k not in node:
  74. node[k] = {}
  75. node = node[k]
  76. node[keys[-1]] = v
  77. return tree
  78. def _convert_types(typ: type[_T], value: Any) -> _T:
  79. """Convert some structure into the given type. The structures must match."""
  80. if typ in (Any, ...):
  81. return value
  82. if typ in (int, float, str, bool):
  83. return typ(value)
  84. if typ is np.ndarray:
  85. assert isinstance(value, np.ndarray)
  86. return value
  87. if dataclasses.is_dataclass(typ):
  88. kwargs = {}
  89. for f in dataclasses.fields(typ):
  90. # Only support Optional for dataclasses, as numpy can't serialize it
  91. # directly (without pickle), and dataclasses are the only case where we
  92. # can know the full set of values and types and therefore know the
  93. # non-existence must mean None.
  94. if isinstance(f.type, (types.UnionType, type(Optional[int]))):
  95. constructors = [t for t in f.type.__args__ if t is not types.NoneType]
  96. if len(constructors) != 1:
  97. raise TypeError(
  98. "Optional works, Union with anything except None doesn't")
  99. if f.name not in value:
  100. kwargs[f.name] = None
  101. continue
  102. constructor = constructors[0]
  103. else:
  104. constructor = f.type
  105. if f.name in value:
  106. kwargs[f.name] = _convert_types(constructor, value[f.name])
  107. else:
  108. raise ValueError(f"Missing value: {f.name}")
  109. return typ(**kwargs)
  110. base_type = getattr(typ, "__origin__", None)
  111. if base_type is dict:
  112. assert len(typ.__args__) == 2
  113. key_type, value_type = typ.__args__
  114. return {_convert_types(key_type, k): _convert_types(value_type, v)
  115. for k, v in value.items()}
  116. if base_type is list:
  117. assert len(typ.__args__) == 1
  118. value_type = typ.__args__[0]
  119. return [_convert_types(value_type, v)
  120. for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
  121. if base_type is tuple:
  122. if len(typ.__args__) == 2 and typ.__args__[1] == ...:
  123. # An arbitrary length tuple of a single type, eg: tuple[int, ...]
  124. value_type = typ.__args__[0]
  125. return tuple(_convert_types(value_type, v)
  126. for _, v in sorted(value.items(), key=lambda x: int(x[0])))
  127. else:
  128. # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
  129. assert len(typ.__args__) == len(value)
  130. return tuple(
  131. _convert_types(t, v)
  132. for t, (_, v) in zip(
  133. typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
  134. # This is probably unreachable with reasonable serializable inputs.
  135. try:
  136. return typ(value)
  137. except TypeError as e:
  138. raise TypeError(
  139. "_convert_types expects the type argument to be a dataclass defined "
  140. "with types that are valid constructors (eg tuple is fine, Tuple "
  141. "isn't), and accept a numpy array as the sole argument.") from e