123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Loss functions (and terms for use in loss functions) used for weather."""
- from typing import Mapping
- from graphcast import xarray_tree
- import numpy as np
- from typing_extensions import Protocol
- import xarray
- LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
- class LossFunction(Protocol):
- """A loss function.
- This is a protocol so it's fine to use a plain function which 'quacks like'
- this. This is just to document the interface.
- """
- def __call__(self,
- predictions: xarray.Dataset,
- targets: xarray.Dataset,
- **optional_kwargs) -> LossAndDiagnostics:
- """Computes a loss function.
- Args:
- predictions: Dataset of predictions.
- targets: Dataset of targets.
- **optional_kwargs: Implementations may support extra optional kwargs.
- Returns:
- loss: A DataArray with dimensions ('batch',) containing losses for each
- element of the batch. These will be averaged to give the final
- loss, locally and across replicas.
- diagnostics: Mapping of additional quantities to log by name alongside the
- loss. These will will typically correspond to terms in the loss. They
- should also have dimensions ('batch',) and will be averaged over the
- batch before logging.
- """
- def weighted_mse_per_level(
- predictions: xarray.Dataset,
- targets: xarray.Dataset,
- per_variable_weights: Mapping[str, float],
- ) -> LossAndDiagnostics:
- """Latitude- and pressure-level-weighted MSE loss."""
- def loss(prediction, target):
- loss = (prediction - target)**2
- loss *= normalized_latitude_weights(target).astype(loss.dtype)
- if 'level' in target.dims:
- loss *= normalized_level_weights(target).astype(loss.dtype)
- return _mean_preserving_batch(loss)
- losses = xarray_tree.map_structure(loss, predictions, targets)
- return sum_per_variable_losses(losses, per_variable_weights)
- def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
- return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
- def sum_per_variable_losses(
- per_variable_losses: Mapping[str, xarray.DataArray],
- weights: Mapping[str, float],
- ) -> LossAndDiagnostics:
- """Weighted sum of per-variable losses."""
- if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
- raise ValueError(
- 'Passing a weight that does not correspond to any variable '
- f'{set(weights.keys())-set(per_variable_losses.keys())}')
- weighted_per_variable_losses = {
- name: loss * weights.get(name, 1)
- for name, loss in per_variable_losses.items()
- }
- total = xarray.concat(
- weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
- 'variable', skipna=False)
- return total, per_variable_losses
- def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
- """Weights proportional to pressure at each level."""
- level = data.coords['level']
- return level / level.mean(skipna=False)
- def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
- """Weights based on latitude, roughly proportional to grid cell area.
- This method supports two use cases only (both for equispaced values):
- * Latitude values such that the closest value to the pole is at latitude
- (90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
- For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
- In this case each point with `lat` value represents a sphere slice between
- `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
- proportional to:
- `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
- we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
- that cancels during normalization.
- * Latitude values that fall exactly at the poles.
- For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
- In this case each point with `lat` value also represents
- a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
- except for the points at the poles, that represent a slice between
- `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
- The areas of the first type of point are still proportional to:
- * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
- but for the points at the poles now is:
- * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
- and we will be using these weights, depending on whether we are looking at
- pole cells, or non-pole cells (omitting the common factor of 2 which will be
- absorbed by the normalization).
- It can be shown via a limit, or simple geometry, that in the small angles
- regime, the proportion of area per pole-point is equal to 1/8th
- the proportion of area covered by each of the nearest non-pole point, and we
- test for this in the test.
- Args:
- data: `DataArray` with latitude coordinates.
- Returns:
- Unit mean latitude weights.
- """
- latitude = data.coords['lat']
- if np.any(np.isclose(np.abs(latitude), 90.)):
- weights = _weight_for_latitude_vector_with_poles(latitude)
- else:
- weights = _weight_for_latitude_vector_without_poles(latitude)
- return weights / weights.mean(skipna=False)
- def _weight_for_latitude_vector_without_poles(latitude):
- """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
- delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
- if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
- not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
- raise ValueError(
- f'Latitude vector {latitude} does not start/end at '
- '+- (90 - delta_latitude/2) degrees.')
- return np.cos(np.deg2rad(latitude))
- def _weight_for_latitude_vector_with_poles(latitude):
- """Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
- delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
- if (not np.isclose(np.max(latitude), 90.) or
- not np.isclose(np.min(latitude), -90.)):
- raise ValueError(
- f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
- weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
- # The two checks above enough to guarantee that latitudes are sorted, so
- # the extremes are the poles
- weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
- return weights
- def _check_uniform_spacing_and_get_delta(vector):
- diff = np.diff(vector)
- if not np.all(np.isclose(diff[0], diff)):
- raise ValueError(f'Vector {diff} is not uniformly spaced.')
- return diff[0]
|