1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utilities for working with trees of xarray.DataArray (including Datasets).
- Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
- it won't work as a leaf node since it implements Mapping, but also won't work
- as an internal node since tree doesn't know how to re-create it properly.
- To fix this, we reimplement a subset of `map_structure`, exposing its
- constituent DataArrays as leaf nodes. This means it can be mapped over as a
- generic container of DataArrays, while still preserving the result as a Dataset
- where possible.
- This is useful because in a few places we need to handle a general
- Mapping[str, DataArray] (where the coordinates might not be compatible across
- the constituent DataArrays) but also the special case of a Dataset nicely.
- For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
- some of the child DataArrays, they will be omitted from the returned dataset. If
- any values other than DataArrays or None are returned, then we don't attempt to
- return a Dataset and just return a plain dict of the results. Similarly if
- DataArrays are returned but with non-matching coordinates, it will just return a
- plain dict of DataArrays.
- Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
- but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
- as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
- latter exposes DataArrays as leaf nodes.
- """
- from typing import Any, Callable
- import xarray
- def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
- """Maps func through given structures with xarrays. See tree.map_structure."""
- if not callable(func):
- raise TypeError(f'func must be callable, got: {func}')
- if not structures:
- raise ValueError('Must provide at least one structure')
- first = structures[0]
- if isinstance(first, xarray.Dataset):
- data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
- if all(isinstance(a, (type(None), xarray.DataArray))
- for a in data.values()):
- data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
- try:
- return xarray.merge(data_arrays, join='exact')
- except ValueError: # Exact join not possible.
- pass
- return data
- if isinstance(first, dict):
- return {k: map_structure(func, *[s[k] for s in structures])
- for k in first.keys()}
- if isinstance(first, (list, tuple, set)):
- return type(first)(map_structure(func, *s) for s in zip(*structures))
- return func(*structures)
|