123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """A Predictor wrapping a one-step Predictor to make autoregressive predictions.
- """
- from typing import Optional, cast
- from absl import logging
- from graphcast import predictor_base
- from graphcast import xarray_jax
- from graphcast import xarray_tree
- import haiku as hk
- import jax
- import xarray
- def _unflatten_and_expand_time(flat_variables, tree_def, time_coords):
- variables = jax.tree_util.tree_unflatten(tree_def, flat_variables)
- return variables.expand_dims(time=time_coords, axis=0)
- def _get_flat_arrays_and_single_timestep_treedef(variables):
- flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...))
- _, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True))
- return flat_arrays, treedef
- class Predictor(predictor_base.Predictor):
- """Wraps a one-step Predictor to make multi-step predictions autoregressively.
- The wrapped Predictor will be used to predict a single timestep conditional
- on the inputs passed to the outer Predictor. Its predictions are then
- passed back in as inputs at the next timestep, for as many timesteps as are
- requested in the targets_template. (When multiple timesteps of input are
- used, a rolling window of inputs is maintained with new predictions
- concatenated onto the end).
- You may ask for additional variables to be predicted as targets which aren't
- used as inputs. These will be predicted as output variables only and not fed
- back in autoregressively. All target variables must be time-dependent however.
- You may also specify static (non-time-dependent) inputs which will be passed
- in at each timestep but are not predicted.
- At present, any time-dependent inputs must also be present as targets so they
- can be passed in autoregressively.
- The loss of the wrapped one-step Predictor is averaged over all timesteps to
- give a loss for the autoregressive Predictor.
- """
- def __init__(
- self,
- predictor: predictor_base.Predictor,
- noise_level: Optional[float] = None,
- gradient_checkpointing: bool = False,
- ):
- """Initializes an autoregressive predictor wrapper.
- Args:
- predictor: A predictor to wrap in an auto-regressive way.
- noise_level: Optional value that multiplies the standard normal noise
- added to the time-dependent variables of the predictor inputs. In
- particular, no noise is added to the predictions that are fed back
- auto-regressively. Defaults to not adding noise.
- gradient_checkpointing: If True, gradient checkpointing will be
- used at each step of the computation to save on memory. Roughtly this
- should make the backwards pass two times more expensive, and the time
- per step counting the forward pass, should only increase by about 50%.
- Note this parameter will be ignored with a warning if the scan sequence
- length is 1.
- """
- self._predictor = predictor
- self._noise_level = noise_level
- self._gradient_checkpointing = gradient_checkpointing
- def _get_and_validate_constant_inputs(self, inputs, targets, forcings):
- constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore')
- constant_inputs = constant_inputs.drop_vars(
- forcings.keys(), errors='ignore')
- for name, var in constant_inputs.items():
- if 'time' in var.dims:
- raise ValueError(
- f'Time-dependent input variable {name} must either be a forcing '
- 'variable, or a target variable to allow for auto-regressive '
- 'feedback.')
- return constant_inputs
- def _validate_targets_and_forcings(self, targets, forcings):
- for name, var in targets.items():
- if 'time' not in var.dims:
- raise ValueError(f'Target variable {name} must be time-dependent.')
- for name, var in forcings.items():
- if 'time' not in var.dims:
- raise ValueError(f'Forcing variable {name} must be time-dependent.')
- overlap = forcings.keys() & targets.keys()
- if overlap:
- raise ValueError('The following were specified as both targets and '
- f'forcings, which isn\'t allowed: {overlap}')
- def _update_inputs(self, inputs, next_frame):
- num_inputs = inputs.dims['time']
- predicted_or_forced_inputs = next_frame[list(inputs.keys())]
- # Combining datasets with inputs and target time stamps aligns them.
- # Only keep the num_inputs trailing frames for use as next inputs.
- return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time')
- .tail(time=num_inputs)
- # Update the time coordinate to reset the lead times for
- # next AR iteration.
- .assign_coords(time=inputs.coords['time']))
- def __call__(self,
- inputs: xarray.Dataset,
- targets_template: xarray.Dataset,
- forcings: xarray.Dataset,
- **kwargs) -> xarray.Dataset:
- """Calls the Predictor.
- Args:
- inputs: input variable used to make predictions. Inputs can include both
- time-dependent and time independent variables. Any time-dependent
- input variables must also be present in the targets_template or the
- forcings.
- targets_template: A target template containing informations about which
- variables should be predicted and the time alignment of the predictions.
- All target variables must be time-dependent.
- The number of time frames is used to set the number of unroll of the AR
- predictor (e.g. multiple unroll of the inner predictor for one time step
- in the targets is not supported yet).
- forcings: Variables that will be fed to the model. The variables
- should not overlap with the target ones. The time coordinates of the
- forcing variables should match the target ones.
- Forcing variables which are also present in the inputs, will be used to
- supply ground-truth values for those inputs when they are passed to the
- underlying predictor at timesteps beyond the first timestep.
- **kwargs: Additional arguments passed along to the inner Predictor.
- Returns:
- predictions: the model predictions matching the target template.
- Raise:
- ValueError: if the time coordinates of the inputs and targets are not
- different by a constant time step.
- """
- constant_inputs = self._get_and_validate_constant_inputs(
- inputs, targets_template, forcings)
- self._validate_targets_and_forcings(targets_template, forcings)
- # After the above checks, the remaining inputs must be time-dependent:
- inputs = inputs.drop_vars(constant_inputs.keys())
- # A predictions template only including the next time to predict.
- target_template = targets_template.isel(time=[0])
- flat_forcings, forcings_treedef = (
- _get_flat_arrays_and_single_timestep_treedef(forcings))
- scan_variables = flat_forcings
- def one_step_prediction(inputs, scan_variables):
- flat_forcings = scan_variables
- forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
- target_template.coords['time'])
- # Add constant inputs:
- all_inputs = xarray.merge([constant_inputs, inputs])
- predictions: xarray.Dataset = self._predictor(
- all_inputs, target_template,
- forcings=forcings,
- **kwargs)
- next_frame = xarray.merge([predictions, forcings])
- next_inputs = self._update_inputs(inputs, next_frame)
- # Drop the length-1 time dimension, since scan will concat all the outputs
- # for different times along a new leading time dimension:
- predictions = predictions.squeeze('time', drop=True)
- # We return the prediction flattened into plain jax arrays, because the
- # extra leading dimension added by scan prevents the tree_util
- # registrations in xarray_jax from unflattening them back into an
- # xarray.Dataset automatically:
- flat_pred = jax.tree_util.tree_leaves(predictions)
- return next_inputs, flat_pred
- if self._gradient_checkpointing:
- scan_length = targets_template.dims['time']
- if scan_length <= 1:
- logging.warning(
- 'Skipping gradient checkpointing for sequence length of 1')
- else:
- # Just in case we take gradients (e.g. for control), although
- # in most cases this will just be for a forward pass.
- one_step_prediction = hk.remat(one_step_prediction)
- # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
- _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
- # The result of scan will have an extra leading axis on all arrays,
- # corresponding to the target times in this case. We need to be prepared for
- # it when unflattening the arrays back into a Dataset:
- scan_result_template = (
- target_template.squeeze('time', drop=True)
- .expand_dims(time=targets_template.coords['time'], axis=0))
- _, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template)
- predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds)
- return predictions
- def loss(self,
- inputs: xarray.Dataset,
- targets: xarray.Dataset,
- forcings: xarray.Dataset,
- **kwargs
- ) -> predictor_base.LossAndDiagnostics:
- """The mean of the per-timestep losses of the underlying predictor."""
- if targets.sizes['time'] == 1:
- # If there is only a single target timestep then we don't need any
- # autoregressive feedback and can delegate the loss directly to the
- # underlying single-step predictor. This means the underlying predictor
- # doesn't need to implement .loss_and_predictions.
- return self._predictor.loss(inputs, targets, forcings, **kwargs)
- constant_inputs = self._get_and_validate_constant_inputs(
- inputs, targets, forcings)
- self._validate_targets_and_forcings(targets, forcings)
- # After the above checks, the remaining inputs must be time-dependent:
- inputs = inputs.drop_vars(constant_inputs.keys())
- if self._noise_level:
- def add_noise(x):
- return x + self._noise_level * jax.random.normal(
- hk.next_rng_key(), shape=x.shape)
- # Add noise to time-dependent variables of the inputs.
- inputs = jax.tree_map(add_noise, inputs)
- # The per-timestep targets passed by scan to one_step_loss below will have
- # no leading time axis. We need a treedef without the time axis to use
- # inside one_step_loss to unflatten it back into a dataset:
- flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef(
- targets)
- scan_variables = flat_targets
- flat_forcings, forcings_treedef = (
- _get_flat_arrays_and_single_timestep_treedef(forcings))
- scan_variables = (flat_targets, flat_forcings)
- def one_step_loss(inputs, scan_variables):
- flat_target, flat_forcings = scan_variables
- forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
- targets.coords['time'][:1])
- target = _unflatten_and_expand_time(flat_target, target_treedef,
- targets.coords['time'][:1])
- # Add constant inputs:
- all_inputs = xarray.merge([constant_inputs, inputs])
- (loss, diagnostics), predictions = self._predictor.loss_and_predictions(
- all_inputs,
- target,
- forcings=forcings,
- **kwargs)
- # Unwrap to jax arrays shape (batch,):
- loss, diagnostics = xarray_tree.map_structure(
- xarray_jax.unwrap_data, (loss, diagnostics))
- predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy.
- next_frame = xarray.merge([predictions, forcings])
- next_inputs = self._update_inputs(inputs, next_frame)
- return next_inputs, (loss, diagnostics)
- if self._gradient_checkpointing:
- scan_length = targets.dims['time']
- if scan_length <= 1:
- logging.warning(
- 'Skipping gradient checkpointing for sequence length of 1')
- else:
- one_step_loss = hk.remat(one_step_loss)
- # We can pass inputs (the initial state of the loop) in directly as a
- # Dataset because the shape we pass in to scan is the same as the shape scan
- # passes to the inner function. But, for scan_variables, we must flatten the
- # targets (and unflatten them inside the inner function) because they are
- # passed to the inner function per-timestep without the original time axis.
- # The same apply to the optional forcing.
- _, (per_timestep_losses, per_timestep_diagnostics) = hk.scan(
- one_step_loss, inputs, scan_variables)
- # Re-wrap loss and diagnostics as DataArray and average them over time:
- (loss, diagnostics) = jax.tree_util.tree_map(
- lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda
- 'time', skipna=False),
- (per_timestep_losses, per_timestep_diagnostics))
- return loss, diagnostics
|