123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Serialize and deserialize trees."""
- import dataclasses
- import io
- import types
- from typing import Any, BinaryIO, Optional, TypeVar
- import numpy as np
- _T = TypeVar("_T")
- def dump(dest: BinaryIO, value: Any) -> None:
- """Dump a tree of dicts/dataclasses to a file object.
- Args:
- dest: a file object to write to.
- value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
- other basic types. Unions are not supported, other than Optional/None
- which is only supported in dataclasses, not in dicts, lists or tuples.
- All leaves must be coercible to a numpy array, and recoverable as a single
- arg to a type.
- """
- buffer = io.BytesIO() # In case the destination doesn't support seeking.
- np.savez(buffer, **_flatten(value))
- dest.write(buffer.getvalue())
- def load(source: BinaryIO, typ: type[_T]) -> _T:
- """Load from a file object and convert it to the specified type.
- Args:
- source: a file object to read from.
- typ: a type object that acts as a schema for deserialization. It must match
- what was serialized. If a type is Any, it will be returned however numpy
- serialized it, which is what you want for a tree of numpy arrays.
- Returns:
- the deserialized value as the specified type.
- """
- return _convert_types(typ, _unflatten(np.load(source)))
- _SEP = ":"
- def _flatten(tree: Any) -> dict[str, Any]:
- """Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
- if dataclasses.is_dataclass(tree):
- # Don't use dataclasses.asdict as it is recursive so skips dropping None.
- tree = {f.name: v for f in dataclasses.fields(tree)
- if (v := getattr(tree, f.name)) is not None}
- elif isinstance(tree, (list, tuple)):
- tree = dict(enumerate(tree))
- assert isinstance(tree, dict)
- flat = {}
- for k, v in tree.items():
- k = str(k)
- assert _SEP not in k
- if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
- for a, b in _flatten(v).items():
- flat[f"{k}{_SEP}{a}"] = b
- else:
- assert v is not None
- flat[k] = v
- return flat
- def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
- """Unflatten a dict to a tree of dicts."""
- tree = {}
- for flat_key, v in flat.items():
- node = tree
- keys = flat_key.split(_SEP)
- for k in keys[:-1]:
- if k not in node:
- node[k] = {}
- node = node[k]
- node[keys[-1]] = v
- return tree
- def _convert_types(typ: type[_T], value: Any) -> _T:
- """Convert some structure into the given type. The structures must match."""
- if typ in (Any, ...):
- return value
- if typ in (int, float, str, bool):
- return typ(value)
- if typ is np.ndarray:
- assert isinstance(value, np.ndarray)
- return value
- if dataclasses.is_dataclass(typ):
- kwargs = {}
- for f in dataclasses.fields(typ):
- # Only support Optional for dataclasses, as numpy can't serialize it
- # directly (without pickle), and dataclasses are the only case where we
- # can know the full set of values and types and therefore know the
- # non-existence must mean None.
- if isinstance(f.type, (types.UnionType, type(Optional[int]))):
- constructors = [t for t in f.type.__args__ if t is not types.NoneType]
- if len(constructors) != 1:
- raise TypeError(
- "Optional works, Union with anything except None doesn't")
- if f.name not in value:
- kwargs[f.name] = None
- continue
- constructor = constructors[0]
- else:
- constructor = f.type
- if f.name in value:
- kwargs[f.name] = _convert_types(constructor, value[f.name])
- else:
- raise ValueError(f"Missing value: {f.name}")
- return typ(**kwargs)
- base_type = getattr(typ, "__origin__", None)
- if base_type is dict:
- assert len(typ.__args__) == 2
- key_type, value_type = typ.__args__
- return {_convert_types(key_type, k): _convert_types(value_type, v)
- for k, v in value.items()}
- if base_type is list:
- assert len(typ.__args__) == 1
- value_type = typ.__args__[0]
- return [_convert_types(value_type, v)
- for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
- if base_type is tuple:
- if len(typ.__args__) == 2 and typ.__args__[1] == ...:
- # An arbitrary length tuple of a single type, eg: tuple[int, ...]
- value_type = typ.__args__[0]
- return tuple(_convert_types(value_type, v)
- for _, v in sorted(value.items(), key=lambda x: int(x[0])))
- else:
- # A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
- assert len(typ.__args__) == len(value)
- return tuple(
- _convert_types(t, v)
- for t, (_, v) in zip(
- typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
- # This is probably unreachable with reasonable serializable inputs.
- try:
- return typ(value)
- except TypeError as e:
- raise TypeError(
- "_convert_types expects the type argument to be a dataclass defined "
- "with types that are valid constructors (eg tuple is fine, Tuple "
- "isn't), and accept a numpy array as the sole argument.") from e
|