losses.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Loss functions (and terms for use in loss functions) used for weather."""
  15. from typing import Mapping
  16. from graphcast import xarray_tree
  17. import numpy as np
  18. from typing_extensions import Protocol
  19. import xarray
  20. LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
  21. class LossFunction(Protocol):
  22. """A loss function.
  23. This is a protocol so it's fine to use a plain function which 'quacks like'
  24. this. This is just to document the interface.
  25. """
  26. def __call__(self,
  27. predictions: xarray.Dataset,
  28. targets: xarray.Dataset,
  29. **optional_kwargs) -> LossAndDiagnostics:
  30. """Computes a loss function.
  31. Args:
  32. predictions: Dataset of predictions.
  33. targets: Dataset of targets.
  34. **optional_kwargs: Implementations may support extra optional kwargs.
  35. Returns:
  36. loss: A DataArray with dimensions ('batch',) containing losses for each
  37. element of the batch. These will be averaged to give the final
  38. loss, locally and across replicas.
  39. diagnostics: Mapping of additional quantities to log by name alongside the
  40. loss. These will will typically correspond to terms in the loss. They
  41. should also have dimensions ('batch',) and will be averaged over the
  42. batch before logging.
  43. """
  44. def weighted_mse_per_level(
  45. predictions: xarray.Dataset,
  46. targets: xarray.Dataset,
  47. per_variable_weights: Mapping[str, float],
  48. ) -> LossAndDiagnostics:
  49. """Latitude- and pressure-level-weighted MSE loss."""
  50. def loss(prediction, target):
  51. loss = (prediction - target)**2
  52. loss *= normalized_latitude_weights(target).astype(loss.dtype)
  53. if 'level' in target.dims:
  54. loss *= normalized_level_weights(target).astype(loss.dtype)
  55. return _mean_preserving_batch(loss)
  56. losses = xarray_tree.map_structure(loss, predictions, targets)
  57. return sum_per_variable_losses(losses, per_variable_weights)
  58. def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
  59. return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
  60. def sum_per_variable_losses(
  61. per_variable_losses: Mapping[str, xarray.DataArray],
  62. weights: Mapping[str, float],
  63. ) -> LossAndDiagnostics:
  64. """Weighted sum of per-variable losses."""
  65. if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
  66. raise ValueError(
  67. 'Passing a weight that does not correspond to any variable '
  68. f'{set(weights.keys())-set(per_variable_losses.keys())}')
  69. weighted_per_variable_losses = {
  70. name: loss * weights.get(name, 1)
  71. for name, loss in per_variable_losses.items()
  72. }
  73. total = xarray.concat(
  74. weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
  75. 'variable', skipna=False)
  76. return total, per_variable_losses
  77. def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
  78. """Weights proportional to pressure at each level."""
  79. level = data.coords['level']
  80. return level / level.mean(skipna=False)
  81. def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
  82. """Weights based on latitude, roughly proportional to grid cell area.
  83. This method supports two use cases only (both for equispaced values):
  84. * Latitude values such that the closest value to the pole is at latitude
  85. (90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
  86. For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
  87. In this case each point with `lat` value represents a sphere slice between
  88. `lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
  89. proportional to:
  90. `sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
  91. we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
  92. that cancels during normalization.
  93. * Latitude values that fall exactly at the poles.
  94. For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
  95. In this case each point with `lat` value also represents
  96. a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
  97. except for the points at the poles, that represent a slice between
  98. `90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
  99. The areas of the first type of point are still proportional to:
  100. * sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
  101. but for the points at the poles now is:
  102. * sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
  103. and we will be using these weights, depending on whether we are looking at
  104. pole cells, or non-pole cells (omitting the common factor of 2 which will be
  105. absorbed by the normalization).
  106. It can be shown via a limit, or simple geometry, that in the small angles
  107. regime, the proportion of area per pole-point is equal to 1/8th
  108. the proportion of area covered by each of the nearest non-pole point, and we
  109. test for this in the test.
  110. Args:
  111. data: `DataArray` with latitude coordinates.
  112. Returns:
  113. Unit mean latitude weights.
  114. """
  115. latitude = data.coords['lat']
  116. if np.any(np.isclose(np.abs(latitude), 90.)):
  117. weights = _weight_for_latitude_vector_with_poles(latitude)
  118. else:
  119. weights = _weight_for_latitude_vector_without_poles(latitude)
  120. return weights / weights.mean(skipna=False)
  121. def _weight_for_latitude_vector_without_poles(latitude):
  122. """Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
  123. delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
  124. if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
  125. not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
  126. raise ValueError(
  127. f'Latitude vector {latitude} does not start/end at '
  128. '+- (90 - delta_latitude/2) degrees.')
  129. return np.cos(np.deg2rad(latitude))
  130. def _weight_for_latitude_vector_with_poles(latitude):
  131. """Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
  132. delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
  133. if (not np.isclose(np.max(latitude), 90.) or
  134. not np.isclose(np.min(latitude), -90.)):
  135. raise ValueError(
  136. f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
  137. weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
  138. # The two checks above enough to guarantee that latitudes are sorted, so
  139. # the extremes are the poles
  140. weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
  141. return weights
  142. def _check_uniform_spacing_and_get_delta(vector):
  143. diff = np.diff(vector)
  144. if not np.all(np.isclose(diff[0], diff)):
  145. raise ValueError(f'Vector {diff} is not uniformly spaced.')
  146. return diff[0]