|
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utilities for building models."""
- from typing import Mapping, Optional, Tuple
- import numpy as np
- from scipy.spatial import transform
- import xarray
- def get_graph_spatial_features(
- *, node_lat: np.ndarray, node_lon: np.ndarray,
- senders: np.ndarray, receivers: np.ndarray,
- add_node_positions: bool,
- add_node_latitude: bool,
- add_node_longitude: bool,
- add_relative_positions: bool,
- relative_longitude_local_coordinates: bool,
- relative_latitude_local_coordinates: bool,
- sine_cosine_encoding: bool = False,
- encoding_num_freqs: int = 10,
- encoding_multiplicative_factor: float = 1.2,
- ) -> Tuple[np.ndarray, np.ndarray]:
- """Computes spatial features for the nodes.
- Args:
- node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
- node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
- senders: Sender indices of shape [num_edges]
- receivers: Receiver indices of shape [num_edges]
- add_node_positions: Add unit norm absolute positions.
- add_node_latitude: Add a feature for latitude (cos(90 - lat))
- Note even if this is set to False, the model may be able to infer the
- longitude from relative features, unless
- `relative_latitude_local_coordinates` is also True, or if there is any
- bias on the relative edge sizes for different longitudes.
- add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
- Note even if this is set to False, the model may be able to infer the
- longitude from relative features, unless
- `relative_longitude_local_coordinates` is also True, or if there is any
- bias on the relative edge sizes for different longitudes.
- add_relative_positions: Whether to relative positions in R3 to the edges.
- relative_longitude_local_coordinates: If True, relative positions are
- computed in a local space where the receiver is at 0 longitude.
- relative_latitude_local_coordinates: If True, relative positions are
- computed in a local space where the receiver is at 0 latitude.
- sine_cosine_encoding: If True, we will transform the node/edge features
- with sine and cosine functions, similar to NERF.
- encoding_num_freqs: frequency parameter
- encoding_multiplicative_factor: used for calculating the frequency.
- Returns:
- Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
- with node and edge features.
- """
- num_nodes = node_lat.shape[0]
- num_edges = senders.shape[0]
- dtype = node_lat.dtype
- node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
- # Computing some node features.
- node_features = []
- if add_node_positions:
- # Already in [-1, 1.] range.
- node_features.extend(spherical_to_cartesian(node_phi, node_theta))
- if add_node_latitude:
- # Using the cos of theta.
- # From 1. (north pole) to -1 (south pole).
- node_features.append(np.cos(node_theta))
- if add_node_longitude:
- # Using the cos and sin, which is already normalized.
- node_features.append(np.cos(node_phi))
- node_features.append(np.sin(node_phi))
- if not node_features:
- node_features = np.zeros([num_nodes, 0], dtype=dtype)
- else:
- node_features = np.stack(node_features, axis=-1)
- # Computing some edge features.
- edge_features = []
- if add_relative_positions:
- relative_position = get_relative_position_in_receiver_local_coordinates(
- node_phi=node_phi,
- node_theta=node_theta,
- senders=senders,
- receivers=receivers,
- latitude_local_coordinates=relative_latitude_local_coordinates,
- longitude_local_coordinates=relative_longitude_local_coordinates
- )
- # Note this is L2 distance in 3d space, rather than geodesic distance.
- relative_edge_distances = np.linalg.norm(
- relative_position, axis=-1, keepdims=True)
- # Normalize to the maximum edge distance. Note that we expect to always
- # have an edge that goes in the opposite direction of any given edge
- # so the distribution of relative positions should be symmetric around
- # zero. So by scaling by the maximum length, we expect all relative
- # positions to fall in the [-1., 1.] interval, and all relative distances
- # to fall in the [0., 1.] interval.
- max_edge_distance = relative_edge_distances.max()
- edge_features.append(relative_edge_distances / max_edge_distance)
- edge_features.append(relative_position / max_edge_distance)
- if not edge_features:
- edge_features = np.zeros([num_edges, 0], dtype=dtype)
- else:
- edge_features = np.concatenate(edge_features, axis=-1)
- if sine_cosine_encoding:
- def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
- freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
- phases = freqs * x[..., None]
- x_sin = np.sin(phases)
- x_cos = np.cos(phases)
- x_cat = np.concatenate([x_sin, x_cos], axis=-1)
- return x_cat.reshape([x.shape[0], -1])
- node_features = sine_cosine_transform(node_features)
- edge_features = sine_cosine_transform(edge_features)
- return node_features, edge_features
- def lat_lon_to_leading_axes(
- grid_xarray: xarray.DataArray) -> xarray.DataArray:
- """Reorders xarray so lat/lon axes come first."""
- # leading + ["lat", "lon"] + trailing
- # to
- # ["lat", "lon"] + leading + trailing
- return grid_xarray.transpose("lat", "lon", ...)
- def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
- """Reorders xarray so batch/time/level axes come first (if present)."""
- # ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
- # to
- # [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
- input_dims = list(grid_xarray.dims)
- output_dims = list(input_dims)
- for leading_key in ["level", "time", "batch"]: # reverse order for insert
- if leading_key in input_dims:
- output_dims.remove(leading_key)
- output_dims.insert(0, leading_key)
- return grid_xarray.transpose(*output_dims)
- def lat_lon_deg_to_spherical(node_lat: np.ndarray,
- node_lon: np.ndarray,
- ) -> Tuple[np.ndarray, np.ndarray]:
- phi = np.deg2rad(node_lon)
- theta = np.deg2rad(90 - node_lat)
- return phi, theta
- def spherical_to_lat_lon(phi: np.ndarray,
- theta: np.ndarray,
- ) -> Tuple[np.ndarray, np.ndarray]:
- lon = np.mod(np.rad2deg(phi), 360)
- lat = 90 - np.rad2deg(theta)
- return lat, lon
- def cartesian_to_spherical(x: np.ndarray,
- y: np.ndarray,
- z: np.ndarray,
- ) -> Tuple[np.ndarray, np.ndarray]:
- phi = np.arctan2(y, x)
- with np.errstate(invalid="ignore"): # circumventing b/253179568
- theta = np.arccos(z) # Assuming unit radius.
- return phi, theta
- def spherical_to_cartesian(
- phi: np.ndarray, theta: np.ndarray
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- # Assuming unit radius.
- return (np.cos(phi)*np.sin(theta),
- np.sin(phi)*np.sin(theta),
- np.cos(theta))
- def get_relative_position_in_receiver_local_coordinates(
- node_phi: np.ndarray,
- node_theta: np.ndarray,
- senders: np.ndarray,
- receivers: np.ndarray,
- latitude_local_coordinates: bool,
- longitude_local_coordinates: bool
- ) -> np.ndarray:
- """Returns relative position features for the edges.
- The relative positions will be computed in a rotated space for a local
- coordinate system as defined by the receiver. The relative positions are
- simply obtained by subtracting sender position minues receiver position in
- that local coordinate system after the rotation in R^3.
- Args:
- node_phi: [num_nodes] with polar angles.
- node_theta: [num_nodes] with azimuthal angles.
- senders: [num_edges] with indices.
- receivers: [num_edges] with indices.
- latitude_local_coordinates: Whether to rotate edges such that in the
- positions are computed such that the receiver is always at latitude 0.
- longitude_local_coordinates: Whether to rotate edges such that in the
- positions are computed such that the receiver is always at longitude 0.
- Returns:
- Array of relative positions in R3 [num_edges, 3]
- """
- node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
- # No rotation in this case.
- if not (latitude_local_coordinates or longitude_local_coordinates):
- return node_pos[senders] - node_pos[receivers]
- # Get rotation matrices for the local space space for every node.
- rotation_matrices = get_rotation_matrices_to_local_coordinates(
- reference_phi=node_phi,
- reference_theta=node_theta,
- rotate_latitude=latitude_local_coordinates,
- rotate_longitude=longitude_local_coordinates)
- # Each edge will be rotated according to the rotation matrix of its receiver
- # node.
- edge_rotation_matrices = rotation_matrices[receivers]
- # Rotate all nodes to the rotated space of the corresponding edge.
- # Note for receivers we can also do the matmul first and the gather second:
- # ```
- # receiver_pos_in_rotated_space = rotate_with_matrices(
- # rotation_matrices, node_pos)[receivers]
- # ```
- # which is more efficient, however, we do gather first to keep it more
- # symmetric with the sender computation.
- receiver_pos_in_rotated_space = rotate_with_matrices(
- edge_rotation_matrices, node_pos[receivers])
- sender_pos_in_in_rotated_space = rotate_with_matrices(
- edge_rotation_matrices, node_pos[senders])
- # Note, here, that because the rotated space is chosen according to the
- # receiver, if:
- # * latitude_local_coordinates = True: latitude for the receivers will be
- # 0, that is the z coordinate will always be 0.
- # * longitude_local_coordinates = True: longitude for the receivers will be
- # 0, that is the y coordinate will be 0.
- # Now we can just subtract.
- # Note we are rotating to a local coordinate system, where the y-z axes are
- # parallel to a tangent plane to the sphere, but still remain in a 3d space.
- # Note that if both `latitude_local_coordinates` and
- # `longitude_local_coordinates` are True, and edges are short,
- # then the difference in x coordinate between sender and receiver
- # should be small, so we could consider dropping the new x coordinate if
- # we wanted to the tangent plane, however in doing so
- # we would lose information about the curvature of the mesh, which may be
- # important for very coarse meshes.
- return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
- def get_rotation_matrices_to_local_coordinates(
- reference_phi: np.ndarray,
- reference_theta: np.ndarray,
- rotate_latitude: bool,
- rotate_longitude: bool) -> np.ndarray:
- """Returns a rotation matrix to rotate to a point based on a reference vector.
- The rotation matrix is build such that, a vector in the
- same coordinate system at the reference point that points towards the pole
- before the rotation, continues to point towards the pole after the rotation.
- Args:
- reference_phi: [leading_axis] Polar angles of the reference.
- reference_theta: [leading_axis] Azimuthal angles of the reference.
- rotate_latitude: Whether to produce a rotation matrix that would rotate
- R^3 vectors to zero latitude.
- rotate_longitude: Whether to produce a rotation matrix that would rotate
- R^3 vectors to zero longitude.
- Returns:
- Matrices of shape [leading_axis] such that when applied to the reference
- position with `rotate_with_matrices(rotation_matrices, reference_pos)`
- * phi goes to 0. if "rotate_longitude" is True.
- * theta goes to np.pi / 2 if "rotate_latitude" is True.
- The rotation consists of:
- * rotate_latitude = False, rotate_longitude = True:
- Latitude preserving rotation.
- * rotate_latitude = True, rotate_longitude = True:
- Latitude preserving rotation, followed by longitude preserving
- rotation.
- * rotate_latitude = True, rotate_longitude = False:
- Latitude preserving rotation, followed by longitude preserving
- rotation, and the inverse of the latitude preserving rotation. Note
- this is computationally different from rotating the longitude only
- and is. We do it like this, so the polar geodesic curve, continues
- to be aligned with one of the axis after the rotation.
- """
- if rotate_longitude and rotate_latitude:
- # We first rotate around the z axis "minus the azimuthal angle", to get the
- # point with zero longitude
- azimuthal_rotation = - reference_phi
- # One then we will do a polar rotation (which can be done along the y
- # axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
- # to get the point with zero latitude.
- polar_rotation = - reference_theta + np.pi/2
- return transform.Rotation.from_euler(
- "zy", np.stack([azimuthal_rotation, polar_rotation],
- axis=1)).as_matrix()
- elif rotate_longitude:
- # Just like the previous case, but applying only the azimuthal rotation.
- azimuthal_rotation = - reference_phi
- return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
- elif rotate_latitude:
- # Just like the first case, but after doing the polar rotation, undoing
- # the azimuthal rotation.
- azimuthal_rotation = - reference_phi
- polar_rotation = - reference_theta + np.pi/2
- return transform.Rotation.from_euler(
- "zyz", np.stack(
- [azimuthal_rotation, polar_rotation, -azimuthal_rotation]
- , axis=1)).as_matrix()
- else:
- raise ValueError(
- "At least one of longitude and latitude should be rotated.")
- def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
- ) -> np.ndarray:
- return np.einsum("bji,bi->bj", rotation_matrices, positions)
- def get_bipartite_graph_spatial_features(
- *,
- senders_node_lat: np.ndarray,
- senders_node_lon: np.ndarray,
- senders: np.ndarray,
- receivers_node_lat: np.ndarray,
- receivers_node_lon: np.ndarray,
- receivers: np.ndarray,
- add_node_positions: bool,
- add_node_latitude: bool,
- add_node_longitude: bool,
- add_relative_positions: bool,
- edge_normalization_factor: Optional[float] = None,
- relative_longitude_local_coordinates: bool,
- relative_latitude_local_coordinates: bool,
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
- """Computes spatial features for the nodes.
- This function is almost identical to `get_graph_spatial_features`. The only
- difference is that sender nodes and receiver nodes can be in different arrays.
- This is necessary to enable combination with typed Graph.
- Args:
- senders_node_lat: Latitudes in the [-90, 90] interval of shape
- [num_sender_nodes]
- senders_node_lon: Longitudes in the [0, 360] interval of shape
- [num_sender_nodes]
- senders: Sender indices of shape [num_edges], indices in [0,
- num_sender_nodes)
- receivers_node_lat: Latitudes in the [-90, 90] interval of shape
- [num_receiver_nodes]
- receivers_node_lon: Longitudes in the [0, 360] interval of shape
- [num_receiver_nodes]
- receivers: Receiver indices of shape [num_edges], indices in [0,
- num_receiver_nodes)
- add_node_positions: Add unit norm absolute positions.
- add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
- this is set to False, the model may be able to infer the longitude from
- relative features, unless `relative_latitude_local_coordinates` is also
- True, or if there is any bias on the relative edge sizes for different
- longitudes.
- add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
- even if this is set to False, the model may be able to infer the longitude
- from relative features, unless `relative_longitude_local_coordinates` is
- also True, or if there is any bias on the relative edge sizes for
- different longitudes.
- add_relative_positions: Whether to relative positions in R3 to the edges.
- edge_normalization_factor: Allows explicitly controlling edge normalization.
- If None, defaults to max edge length. This supports using pre-trained
- model weights with a different graph structure to what it was trained on.
- relative_longitude_local_coordinates: If True, relative positions are
- computed in a local space where the receiver is at 0 longitude.
- relative_latitude_local_coordinates: If True, relative positions are
- computed in a local space where the receiver is at 0 latitude.
- Returns:
- Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
- with node and edge features.
- """
- num_senders = senders_node_lat.shape[0]
- num_receivers = receivers_node_lat.shape[0]
- num_edges = senders.shape[0]
- dtype = senders_node_lat.dtype
- assert receivers_node_lat.dtype == dtype
- senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
- senders_node_lat, senders_node_lon)
- receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
- receivers_node_lat, receivers_node_lon)
- # Computing some node features.
- senders_node_features = []
- receivers_node_features = []
- if add_node_positions:
- # Already in [-1, 1.] range.
- senders_node_features.extend(
- spherical_to_cartesian(senders_node_phi, senders_node_theta))
- receivers_node_features.extend(
- spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
- if add_node_latitude:
- # Using the cos of theta.
- # From 1. (north pole) to -1 (south pole).
- senders_node_features.append(np.cos(senders_node_theta))
- receivers_node_features.append(np.cos(receivers_node_theta))
- if add_node_longitude:
- # Using the cos and sin, which is already normalized.
- senders_node_features.append(np.cos(senders_node_phi))
- senders_node_features.append(np.sin(senders_node_phi))
- receivers_node_features.append(np.cos(receivers_node_phi))
- receivers_node_features.append(np.sin(receivers_node_phi))
- if not senders_node_features:
- senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
- receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
- else:
- senders_node_features = np.stack(senders_node_features, axis=-1)
- receivers_node_features = np.stack(receivers_node_features, axis=-1)
- # Computing some edge features.
- edge_features = []
- if add_relative_positions:
- relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
- senders_node_phi=senders_node_phi,
- senders_node_theta=senders_node_theta,
- receivers_node_phi=receivers_node_phi,
- receivers_node_theta=receivers_node_theta,
- senders=senders,
- receivers=receivers,
- latitude_local_coordinates=relative_latitude_local_coordinates,
- longitude_local_coordinates=relative_longitude_local_coordinates)
- # Note this is L2 distance in 3d space, rather than geodesic distance.
- relative_edge_distances = np.linalg.norm(
- relative_position, axis=-1, keepdims=True)
- if edge_normalization_factor is None:
- # Normalize to the maximum edge distance. Note that we expect to always
- # have an edge that goes in the opposite direction of any given edge
- # so the distribution of relative positions should be symmetric around
- # zero. So by scaling by the maximum length, we expect all relative
- # positions to fall in the [-1., 1.] interval, and all relative distances
- # to fall in the [0., 1.] interval.
- edge_normalization_factor = relative_edge_distances.max()
- edge_features.append(relative_edge_distances / edge_normalization_factor)
- edge_features.append(relative_position / edge_normalization_factor)
- if not edge_features:
- edge_features = np.zeros([num_edges, 0], dtype=dtype)
- else:
- edge_features = np.concatenate(edge_features, axis=-1)
- return senders_node_features, receivers_node_features, edge_features
- def get_bipartite_relative_position_in_receiver_local_coordinates(
- senders_node_phi: np.ndarray,
- senders_node_theta: np.ndarray,
- senders: np.ndarray,
- receivers_node_phi: np.ndarray,
- receivers_node_theta: np.ndarray,
- receivers: np.ndarray,
- latitude_local_coordinates: bool,
- longitude_local_coordinates: bool) -> np.ndarray:
- """Returns relative position features for the edges.
- This function is equivalent to
- `get_relative_position_in_receiver_local_coordinates`, but adapted to work
- with bipartite typed graphs.
- The relative positions will be computed in a rotated space for a local
- coordinate system as defined by the receiver. The relative positions are
- simply obtained by subtracting sender position minues receiver position in
- that local coordinate system after the rotation in R^3.
- Args:
- senders_node_phi: [num_sender_nodes] with polar angles.
- senders_node_theta: [num_sender_nodes] with azimuthal angles.
- senders: [num_edges] with indices into sender nodes.
- receivers_node_phi: [num_sender_nodes] with polar angles.
- receivers_node_theta: [num_sender_nodes] with azimuthal angles.
- receivers: [num_edges] with indices into receiver nodes.
- latitude_local_coordinates: Whether to rotate edges such that in the
- positions are computed such that the receiver is always at latitude 0.
- longitude_local_coordinates: Whether to rotate edges such that in the
- positions are computed such that the receiver is always at longitude 0.
- Returns:
- Array of relative positions in R3 [num_edges, 3]
- """
- senders_node_pos = np.stack(
- spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
- receivers_node_pos = np.stack(
- spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
- # No rotation in this case.
- if not (latitude_local_coordinates or longitude_local_coordinates):
- return senders_node_pos[senders] - receivers_node_pos[receivers]
- # Get rotation matrices for the local space space for every receiver node.
- receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
- reference_phi=receivers_node_phi,
- reference_theta=receivers_node_theta,
- rotate_latitude=latitude_local_coordinates,
- rotate_longitude=longitude_local_coordinates)
- # Each edge will be rotated according to the rotation matrix of its receiver
- # node.
- edge_rotation_matrices = receiver_rotation_matrices[receivers]
- # Rotate all nodes to the rotated space of the corresponding edge.
- # Note for receivers we can also do the matmul first and the gather second:
- # ```
- # receiver_pos_in_rotated_space = rotate_with_matrices(
- # rotation_matrices, node_pos)[receivers]
- # ```
- # which is more efficient, however, we do gather first to keep it more
- # symmetric with the sender computation.
- receiver_pos_in_rotated_space = rotate_with_matrices(
- edge_rotation_matrices, receivers_node_pos[receivers])
- sender_pos_in_in_rotated_space = rotate_with_matrices(
- edge_rotation_matrices, senders_node_pos[senders])
- # Note, here, that because the rotated space is chosen according to the
- # receiver, if:
- # * latitude_local_coordinates = True: latitude for the receivers will be
- # 0, that is the z coordinate will always be 0.
- # * longitude_local_coordinates = True: longitude for the receivers will be
- # 0, that is the y coordinate will be 0.
- # Now we can just subtract.
- # Note we are rotating to a local coordinate system, where the y-z axes are
- # parallel to a tangent plane to the sphere, but still remain in a 3d space.
- # Note that if both `latitude_local_coordinates` and
- # `longitude_local_coordinates` are True, and edges are short,
- # then the difference in x coordinate between sender and receiver
- # should be small, so we could consider dropping the new x coordinate if
- # we wanted to the tangent plane, however in doing so
- # we would lose information about the curvature of the mesh, which may be
- # important for very coarse meshes.
- return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
- def variable_to_stacked(
- variable: xarray.Variable,
- sizes: Mapping[str, int],
- preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
- ) -> xarray.Variable:
- """Converts an xarray.Variable to preserved_dims + ("channels",).
- Any dimensions other than those included in preserved_dims get stacked into a
- final "channels" dimension. If any of the preserved_dims are missing then they
- are added, with the data broadcast/tiled to match the sizes specified in
- `sizes`.
- Args:
- variable: An xarray.Variable.
- sizes: Mapping including sizes for any dimensions which are not present in
- `variable` but are needed for the output. This may be needed for example
- for a static variable with only ("lat", "lon") dims, or if you want to
- encode just the latitude coordinates (a variable with dims ("lat",)).
- preserved_dims: dimensions of variable to not be folded in channels.
- Returns:
- An xarray.Variable with dimensions preserved_dims + ("channels",).
- """
- stack_to_channels_dims = [
- d for d in variable.dims if d not in preserved_dims]
- if stack_to_channels_dims:
- variable = variable.stack(channels=stack_to_channels_dims)
- dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
- dims["channels"] = variable.sizes.get("channels", 1)
- return variable.set_dims(dims)
- def dataset_to_stacked(
- dataset: xarray.Dataset,
- sizes: Optional[Mapping[str, int]] = None,
- preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
- ) -> xarray.DataArray:
- """Converts an xarray.Dataset to a single stacked array.
- This takes each consistuent data_var, converts it into BHWC layout
- using `variable_to_stacked`, then concats them all along the channels axis.
- Args:
- dataset: An xarray.Dataset.
- sizes: Mapping including sizes for any dimensions which are not present in
- the `dataset` but are needed for the output. See variable_to_stacked.
- preserved_dims: dimensions from the dataset that should not be folded in
- the predictions channels.
- Returns:
- An xarray.DataArray with dimensions preserved_dims + ("channels",).
- Existing coordinates for preserved_dims axes will be preserved, however
- there will be no coordinates for "channels".
- """
- data_vars = [
- variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
- preserved_dims)
- for name in sorted(dataset.data_vars.keys())
- ]
- coords = {
- dim: coord
- for dim, coord in dataset.coords.items()
- if dim in preserved_dims
- }
- return xarray.DataArray(
- data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
- def stacked_to_dataset(
- stacked_array: xarray.Variable,
- template_dataset: xarray.Dataset,
- preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
- ) -> xarray.Dataset:
- """The inverse of dataset_to_stacked.
- Requires a template dataset to demonstrate the variables/shapes/coordinates
- required.
- All variables must have preserved_dims dimensions.
- Args:
- stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
- would if it was asked to encode `template_dataset`.
- template_dataset: A template Dataset (or other mapping of DataArrays)
- demonstrating the shape of output required (variables, shapes,
- coordinates etc).
- preserved_dims: dimensions from the target_template that were not folded in
- the predictions channels. The preserved_dims need to be a subset of the
- dims of all the variables of template_dataset.
- Returns:
- An xarray.Dataset (or other mapping of DataArrays) with the same shape and
- type as template_dataset.
- """
- unstack_from_channels_sizes = {}
- var_names = sorted(template_dataset.keys())
- for name in var_names:
- template_var = template_dataset[name]
- if not all(dim in template_var.dims for dim in preserved_dims):
- raise ValueError(
- f"stacked_to_dataset requires all Variables to have {preserved_dims} "
- f"dimensions, but found only {template_var.dims}.")
- unstack_from_channels_sizes[name] = {
- dim: size for dim, size in template_var.sizes.items()
- if dim not in preserved_dims}
- channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
- for name, unstack_sizes in unstack_from_channels_sizes.items()}
- total_expected_channels = sum(channels.values())
- found_channels = stacked_array.sizes["channels"]
- if total_expected_channels != found_channels:
- raise ValueError(
- f"Expected {total_expected_channels} channels but found "
- f"{found_channels}, when trying to convert a stacked array of shape "
- f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
- data_vars = {}
- index = 0
- for name in var_names:
- template_var = template_dataset[name]
- var = stacked_array.isel({"channels": slice(index, index + channels[name])})
- index += channels[name]
- var = var.unstack({"channels": unstack_from_channels_sizes[name]})
- var = var.transpose(*template_var.dims)
- data_vars[name] = xarray.DataArray(
- data=var,
- coords=template_var.coords,
- # This might not always be the same as the name it's keyed under; it
- # will refer to the original variable name, whereas the key might be
- # some alias e.g. temperature_850 under which it should be logged:
- name=template_var.name,
- )
- return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count
|