123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Wrappers for Predictors which allow them to work with normalized data.
- The Predictor which is wrapped sees normalized inputs and targets, and makes
- normalized predictions. The wrapper handles translating the predictions back
- to the original domain.
- """
- import logging
- from typing import Optional, Tuple
- from graphcast import predictor_base
- from graphcast import xarray_tree
- import xarray
- def normalize(values: xarray.Dataset,
- scales: xarray.Dataset,
- locations: Optional[xarray.Dataset],
- ) -> xarray.Dataset:
- """Normalize variables using the given scales and (optionally) locations."""
- def normalize_array(array):
- if array.name is None:
- raise ValueError(
- "Can't look up normalization constants because array has no name.")
- if locations is not None:
- if array.name in locations:
- array = array - locations[array.name].astype(array.dtype)
- else:
- logging.warning('No normalization location found for %s', array.name)
- if array.name in scales:
- array = array / scales[array.name].astype(array.dtype)
- else:
- logging.warning('No normalization scale found for %s', array.name)
- return array
- return xarray_tree.map_structure(normalize_array, values)
- def unnormalize(values: xarray.Dataset,
- scales: xarray.Dataset,
- locations: Optional[xarray.Dataset],
- ) -> xarray.Dataset:
- """Unnormalize variables using the given scales and (optionally) locations."""
- def unnormalize_array(array):
- if array.name is None:
- raise ValueError(
- "Can't look up normalization constants because array has no name.")
- if array.name in scales:
- array = array * scales[array.name].astype(array.dtype)
- else:
- logging.warning('No normalization scale found for %s', array.name)
- if locations is not None:
- if array.name in locations:
- array = array + locations[array.name].astype(array.dtype)
- else:
- logging.warning('No normalization location found for %s', array.name)
- return array
- return xarray_tree.map_structure(unnormalize_array, values)
- class InputsAndResiduals(predictor_base.Predictor):
- """Wraps with a residual connection, normalizing inputs and target residuals.
- The inner predictor is given inputs that are normalized using `locations`
- and `scales` to roughly zero-mean unit variance.
- For target variables that are present in the inputs, the inner predictor is
- trained to predict residuals (target - last_frame_of_input) that have been
- normalized using `residual_scales` (and optionally `residual_locations`) to
- roughly unit variance / zero mean.
- This replaces `residual.Predictor` in the case where you want normalization
- that's based on the scales of the residuals.
- Since we return the underlying predictor's loss on the normalized residuals,
- if the underlying predictor is a sum of per-variable losses, the normalization
- will affect the relative weighting of the per-variable loss terms (hopefully
- in a good way).
- For target variables *not* present in the inputs, the inner predictor is
- trained to predict targets directly, that have been normalized in the same
- way as the inputs.
- The transforms applied to the targets (the residual connection and the
- normalization) are applied in reverse to the predictions before returning
- them.
- """
- def __init__(
- self,
- predictor: predictor_base.Predictor,
- stddev_by_level: xarray.Dataset,
- mean_by_level: xarray.Dataset,
- diffs_stddev_by_level: xarray.Dataset):
- self._predictor = predictor
- self._scales = stddev_by_level
- self._locations = mean_by_level
- self._residual_scales = diffs_stddev_by_level
- self._residual_locations = None
- def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
- if norm_prediction.sizes.get('time') != 1:
- raise ValueError(
- 'normalization.InputsAndResiduals only supports predicting a '
- 'single timestep.')
- if norm_prediction.name in inputs:
- # Residuals are assumed to be predicted as normalized (unit variance),
- # but the scale and location they need mapping to is that of the residuals
- # not of the values themselves.
- prediction = unnormalize(
- norm_prediction, self._residual_scales, self._residual_locations)
- # A prediction for which we have a corresponding input -- we are
- # predicting the residual:
- last_input = inputs[norm_prediction.name].isel(time=-1)
- prediction = prediction + last_input
- return prediction
- else:
- # A predicted variable which is not an input variable. We are predicting
- # it directly, so unnormalize it directly to the target scale/location:
- return unnormalize(norm_prediction, self._scales, self._locations)
- def _subtract_input_and_normalize_target(self, inputs, target):
- if target.sizes.get('time') != 1:
- raise ValueError(
- 'normalization.InputsAndResiduals only supports wrapping predictors'
- 'that predict a single timestep.')
- if target.name in inputs:
- target_residual = target
- last_input = inputs[target.name].isel(time=-1)
- target_residual = target_residual - last_input
- return normalize(
- target_residual, self._residual_scales, self._residual_locations)
- else:
- return normalize(target, self._scales, self._locations)
- def __call__(self,
- inputs: xarray.Dataset,
- targets_template: xarray.Dataset,
- forcings: xarray.Dataset,
- **kwargs
- ) -> xarray.Dataset:
- norm_inputs = normalize(inputs, self._scales, self._locations)
- norm_forcings = normalize(forcings, self._scales, self._locations)
- norm_predictions = self._predictor(
- norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
- return xarray_tree.map_structure(
- lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
- norm_predictions)
- def loss(self,
- inputs: xarray.Dataset,
- targets: xarray.Dataset,
- forcings: xarray.Dataset,
- **kwargs,
- ) -> predictor_base.LossAndDiagnostics:
- """Returns the loss computed on normalized inputs and targets."""
- norm_inputs = normalize(inputs, self._scales, self._locations)
- norm_forcings = normalize(forcings, self._scales, self._locations)
- norm_target_residuals = xarray_tree.map_structure(
- lambda t: self._subtract_input_and_normalize_target(inputs, t),
- targets)
- return self._predictor.loss(
- norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
- def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
- self,
- inputs: xarray.Dataset,
- targets: xarray.Dataset,
- forcings: xarray.Dataset,
- **kwargs,
- ) -> Tuple[predictor_base.LossAndDiagnostics,
- xarray.Dataset]:
- """The loss computed on normalized data, with unnormalized predictions."""
- norm_inputs = normalize(inputs, self._scales, self._locations)
- norm_forcings = normalize(forcings, self._scales, self._locations)
- norm_target_residuals = xarray_tree.map_structure(
- lambda t: self._subtract_input_and_normalize_target(inputs, t),
- targets)
- (loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
- norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
- predictions = xarray_tree.map_structure(
- lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
- norm_predictions)
- return (loss, scalars), predictions
|