xarray_tree.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Utilities for working with trees of xarray.DataArray (including Datasets).
  15. Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
  16. it won't work as a leaf node since it implements Mapping, but also won't work
  17. as an internal node since tree doesn't know how to re-create it properly.
  18. To fix this, we reimplement a subset of `map_structure`, exposing its
  19. constituent DataArrays as leaf nodes. This means it can be mapped over as a
  20. generic container of DataArrays, while still preserving the result as a Dataset
  21. where possible.
  22. This is useful because in a few places we need to handle a general
  23. Mapping[str, DataArray] (where the coordinates might not be compatible across
  24. the constituent DataArrays) but also the special case of a Dataset nicely.
  25. For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
  26. some of the child DataArrays, they will be omitted from the returned dataset. If
  27. any values other than DataArrays or None are returned, then we don't attempt to
  28. return a Dataset and just return a plain dict of the results. Similarly if
  29. DataArrays are returned but with non-matching coordinates, it will just return a
  30. plain dict of DataArrays.
  31. Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
  32. but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
  33. as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
  34. latter exposes DataArrays as leaf nodes.
  35. """
  36. from typing import Any, Callable
  37. import xarray
  38. def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
  39. """Maps func through given structures with xarrays. See tree.map_structure."""
  40. if not callable(func):
  41. raise TypeError(f'func must be callable, got: {func}')
  42. if not structures:
  43. raise ValueError('Must provide at least one structure')
  44. first = structures[0]
  45. if isinstance(first, xarray.Dataset):
  46. data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
  47. if all(isinstance(a, (type(None), xarray.DataArray))
  48. for a in data.values()):
  49. data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
  50. try:
  51. return xarray.merge(data_arrays, join='exact')
  52. except ValueError: # Exact join not possible.
  53. pass
  54. return data
  55. if isinstance(first, dict):
  56. return {k: map_structure(func, *[s[k] for s in structures])
  57. for k in first.keys()}
  58. if isinstance(first, (list, tuple, set)):
  59. return type(first)(map_structure(func, *s) for s in zip(*structures))
  60. return func(*structures)