normalization.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Wrappers for Predictors which allow them to work with normalized data.
  15. The Predictor which is wrapped sees normalized inputs and targets, and makes
  16. normalized predictions. The wrapper handles translating the predictions back
  17. to the original domain.
  18. """
  19. import logging
  20. from typing import Optional, Tuple
  21. from graphcast import predictor_base
  22. from graphcast import xarray_tree
  23. import xarray
  24. def normalize(values: xarray.Dataset,
  25. scales: xarray.Dataset,
  26. locations: Optional[xarray.Dataset],
  27. ) -> xarray.Dataset:
  28. """Normalize variables using the given scales and (optionally) locations."""
  29. def normalize_array(array):
  30. if array.name is None:
  31. raise ValueError(
  32. "Can't look up normalization constants because array has no name.")
  33. if locations is not None:
  34. if array.name in locations:
  35. array = array - locations[array.name].astype(array.dtype)
  36. else:
  37. logging.warning('No normalization location found for %s', array.name)
  38. if array.name in scales:
  39. array = array / scales[array.name].astype(array.dtype)
  40. else:
  41. logging.warning('No normalization scale found for %s', array.name)
  42. return array
  43. return xarray_tree.map_structure(normalize_array, values)
  44. def unnormalize(values: xarray.Dataset,
  45. scales: xarray.Dataset,
  46. locations: Optional[xarray.Dataset],
  47. ) -> xarray.Dataset:
  48. """Unnormalize variables using the given scales and (optionally) locations."""
  49. def unnormalize_array(array):
  50. if array.name is None:
  51. raise ValueError(
  52. "Can't look up normalization constants because array has no name.")
  53. if array.name in scales:
  54. array = array * scales[array.name].astype(array.dtype)
  55. else:
  56. logging.warning('No normalization scale found for %s', array.name)
  57. if locations is not None:
  58. if array.name in locations:
  59. array = array + locations[array.name].astype(array.dtype)
  60. else:
  61. logging.warning('No normalization location found for %s', array.name)
  62. return array
  63. return xarray_tree.map_structure(unnormalize_array, values)
  64. class InputsAndResiduals(predictor_base.Predictor):
  65. """Wraps with a residual connection, normalizing inputs and target residuals.
  66. The inner predictor is given inputs that are normalized using `locations`
  67. and `scales` to roughly zero-mean unit variance.
  68. For target variables that are present in the inputs, the inner predictor is
  69. trained to predict residuals (target - last_frame_of_input) that have been
  70. normalized using `residual_scales` (and optionally `residual_locations`) to
  71. roughly unit variance / zero mean.
  72. This replaces `residual.Predictor` in the case where you want normalization
  73. that's based on the scales of the residuals.
  74. Since we return the underlying predictor's loss on the normalized residuals,
  75. if the underlying predictor is a sum of per-variable losses, the normalization
  76. will affect the relative weighting of the per-variable loss terms (hopefully
  77. in a good way).
  78. For target variables *not* present in the inputs, the inner predictor is
  79. trained to predict targets directly, that have been normalized in the same
  80. way as the inputs.
  81. The transforms applied to the targets (the residual connection and the
  82. normalization) are applied in reverse to the predictions before returning
  83. them.
  84. """
  85. def __init__(
  86. self,
  87. predictor: predictor_base.Predictor,
  88. stddev_by_level: xarray.Dataset,
  89. mean_by_level: xarray.Dataset,
  90. diffs_stddev_by_level: xarray.Dataset):
  91. self._predictor = predictor
  92. self._scales = stddev_by_level
  93. self._locations = mean_by_level
  94. self._residual_scales = diffs_stddev_by_level
  95. self._residual_locations = None
  96. def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
  97. if norm_prediction.sizes.get('time') != 1:
  98. raise ValueError(
  99. 'normalization.InputsAndResiduals only supports predicting a '
  100. 'single timestep.')
  101. if norm_prediction.name in inputs:
  102. # Residuals are assumed to be predicted as normalized (unit variance),
  103. # but the scale and location they need mapping to is that of the residuals
  104. # not of the values themselves.
  105. prediction = unnormalize(
  106. norm_prediction, self._residual_scales, self._residual_locations)
  107. # A prediction for which we have a corresponding input -- we are
  108. # predicting the residual:
  109. last_input = inputs[norm_prediction.name].isel(time=-1)
  110. prediction += last_input
  111. return prediction
  112. else:
  113. # A predicted variable which is not an input variable. We are predicting
  114. # it directly, so unnormalize it directly to the target scale/location:
  115. return unnormalize(norm_prediction, self._scales, self._locations)
  116. def _subtract_input_and_normalize_target(self, inputs, target):
  117. if target.sizes.get('time') != 1:
  118. raise ValueError(
  119. 'normalization.InputsAndResiduals only supports wrapping predictors'
  120. 'that predict a single timestep.')
  121. if target.name in inputs:
  122. target_residual = target
  123. last_input = inputs[target.name].isel(time=-1)
  124. target_residual -= last_input
  125. return normalize(
  126. target_residual, self._residual_scales, self._residual_locations)
  127. else:
  128. return normalize(target, self._scales, self._locations)
  129. def __call__(self,
  130. inputs: xarray.Dataset,
  131. targets_template: xarray.Dataset,
  132. forcings: xarray.Dataset,
  133. **kwargs
  134. ) -> xarray.Dataset:
  135. norm_inputs = normalize(inputs, self._scales, self._locations)
  136. norm_forcings = normalize(forcings, self._scales, self._locations)
  137. norm_predictions = self._predictor(
  138. norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
  139. return xarray_tree.map_structure(
  140. lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
  141. norm_predictions)
  142. def loss(self,
  143. inputs: xarray.Dataset,
  144. targets: xarray.Dataset,
  145. forcings: xarray.Dataset,
  146. **kwargs,
  147. ) -> predictor_base.LossAndDiagnostics:
  148. """Returns the loss computed on normalized inputs and targets."""
  149. norm_inputs = normalize(inputs, self._scales, self._locations)
  150. norm_forcings = normalize(forcings, self._scales, self._locations)
  151. norm_target_residuals = xarray_tree.map_structure(
  152. lambda t: self._subtract_input_and_normalize_target(inputs, t),
  153. targets)
  154. return self._predictor.loss(
  155. norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
  156. def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
  157. self,
  158. inputs: xarray.Dataset,
  159. targets: xarray.Dataset,
  160. forcings: xarray.Dataset,
  161. **kwargs,
  162. ) -> Tuple[predictor_base.LossAndDiagnostics,
  163. xarray.Dataset]:
  164. """The loss computed on normalized data, with unnormalized predictions."""
  165. norm_inputs = normalize(inputs, self._scales, self._locations)
  166. norm_forcings = normalize(forcings, self._scales, self._locations)
  167. norm_target_residuals = xarray_tree.map_structure(
  168. lambda t: self._subtract_input_and_normalize_target(inputs, t),
  169. targets)
  170. (loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
  171. norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
  172. predictions = xarray_tree.map_structure(
  173. lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
  174. norm_predictions)
  175. return (loss, scalars), predictions