autoregressive.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """A Predictor wrapping a one-step Predictor to make autoregressive predictions.
  15. """
  16. from typing import Optional, cast
  17. from absl import logging
  18. from graphcast import predictor_base
  19. from graphcast import xarray_jax
  20. from graphcast import xarray_tree
  21. import haiku as hk
  22. import jax
  23. import xarray
  24. def _unflatten_and_expand_time(flat_variables, tree_def, time_coords):
  25. variables = jax.tree_util.tree_unflatten(tree_def, flat_variables)
  26. return variables.expand_dims(time=time_coords, axis=0)
  27. def _get_flat_arrays_and_single_timestep_treedef(variables):
  28. flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...))
  29. _, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True))
  30. return flat_arrays, treedef
  31. class Predictor(predictor_base.Predictor):
  32. """Wraps a one-step Predictor to make multi-step predictions autoregressively.
  33. The wrapped Predictor will be used to predict a single timestep conditional
  34. on the inputs passed to the outer Predictor. Its predictions are then
  35. passed back in as inputs at the next timestep, for as many timesteps as are
  36. requested in the targets_template. (When multiple timesteps of input are
  37. used, a rolling window of inputs is maintained with new predictions
  38. concatenated onto the end).
  39. You may ask for additional variables to be predicted as targets which aren't
  40. used as inputs. These will be predicted as output variables only and not fed
  41. back in autoregressively. All target variables must be time-dependent however.
  42. You may also specify static (non-time-dependent) inputs which will be passed
  43. in at each timestep but are not predicted.
  44. At present, any time-dependent inputs must also be present as targets so they
  45. can be passed in autoregressively.
  46. The loss of the wrapped one-step Predictor is averaged over all timesteps to
  47. give a loss for the autoregressive Predictor.
  48. """
  49. def __init__(
  50. self,
  51. predictor: predictor_base.Predictor,
  52. noise_level: Optional[float] = None,
  53. gradient_checkpointing: bool = False,
  54. ):
  55. """Initializes an autoregressive predictor wrapper.
  56. Args:
  57. predictor: A predictor to wrap in an auto-regressive way.
  58. noise_level: Optional value that multiplies the standard normal noise
  59. added to the time-dependent variables of the predictor inputs. In
  60. particular, no noise is added to the predictions that are fed back
  61. auto-regressively. Defaults to not adding noise.
  62. gradient_checkpointing: If True, gradient checkpointing will be
  63. used at each step of the computation to save on memory. Roughtly this
  64. should make the backwards pass two times more expensive, and the time
  65. per step counting the forward pass, should only increase by about 50%.
  66. Note this parameter will be ignored with a warning if the scan sequence
  67. length is 1.
  68. """
  69. self._predictor = predictor
  70. self._noise_level = noise_level
  71. self._gradient_checkpointing = gradient_checkpointing
  72. def _get_and_validate_constant_inputs(self, inputs, targets, forcings):
  73. constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore')
  74. constant_inputs = constant_inputs.drop_vars(
  75. forcings.keys(), errors='ignore')
  76. for name, var in constant_inputs.items():
  77. if 'time' in var.dims:
  78. raise ValueError(
  79. f'Time-dependent input variable {name} must either be a forcing '
  80. 'variable, or a target variable to allow for auto-regressive '
  81. 'feedback.')
  82. return constant_inputs
  83. def _validate_targets_and_forcings(self, targets, forcings):
  84. for name, var in targets.items():
  85. if 'time' not in var.dims:
  86. raise ValueError(f'Target variable {name} must be time-dependent.')
  87. for name, var in forcings.items():
  88. if 'time' not in var.dims:
  89. raise ValueError(f'Forcing variable {name} must be time-dependent.')
  90. overlap = forcings.keys() & targets.keys()
  91. if overlap:
  92. raise ValueError('The following were specified as both targets and '
  93. f'forcings, which isn\'t allowed: {overlap}')
  94. def _update_inputs(self, inputs, next_frame):
  95. num_inputs = inputs.dims['time']
  96. predicted_or_forced_inputs = next_frame[list(inputs.keys())]
  97. # Combining datasets with inputs and target time stamps aligns them.
  98. # Only keep the num_inputs trailing frames for use as next inputs.
  99. return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time')
  100. .tail(time=num_inputs)
  101. # Update the time coordinate to reset the lead times for
  102. # next AR iteration.
  103. .assign_coords(time=inputs.coords['time']))
  104. def __call__(self,
  105. inputs: xarray.Dataset,
  106. targets_template: xarray.Dataset,
  107. forcings: xarray.Dataset,
  108. **kwargs) -> xarray.Dataset:
  109. """Calls the Predictor.
  110. Args:
  111. inputs: input variable used to make predictions. Inputs can include both
  112. time-dependent and time independent variables. Any time-dependent
  113. input variables must also be present in the targets_template or the
  114. forcings.
  115. targets_template: A target template containing informations about which
  116. variables should be predicted and the time alignment of the predictions.
  117. All target variables must be time-dependent.
  118. The number of time frames is used to set the number of unroll of the AR
  119. predictor (e.g. multiple unroll of the inner predictor for one time step
  120. in the targets is not supported yet).
  121. forcings: Variables that will be fed to the model. The variables
  122. should not overlap with the target ones. The time coordinates of the
  123. forcing variables should match the target ones.
  124. Forcing variables which are also present in the inputs, will be used to
  125. supply ground-truth values for those inputs when they are passed to the
  126. underlying predictor at timesteps beyond the first timestep.
  127. **kwargs: Additional arguments passed along to the inner Predictor.
  128. Returns:
  129. predictions: the model predictions matching the target template.
  130. Raise:
  131. ValueError: if the time coordinates of the inputs and targets are not
  132. different by a constant time step.
  133. """
  134. constant_inputs = self._get_and_validate_constant_inputs(
  135. inputs, targets_template, forcings)
  136. self._validate_targets_and_forcings(targets_template, forcings)
  137. # After the above checks, the remaining inputs must be time-dependent:
  138. inputs = inputs.drop_vars(constant_inputs.keys())
  139. # A predictions template only including the next time to predict.
  140. target_template = targets_template.isel(time=[0])
  141. flat_forcings, forcings_treedef = (
  142. _get_flat_arrays_and_single_timestep_treedef(forcings))
  143. scan_variables = flat_forcings
  144. def one_step_prediction(inputs, scan_variables):
  145. flat_forcings = scan_variables
  146. forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
  147. target_template.coords['time'])
  148. # Add constant inputs:
  149. all_inputs = xarray.merge([constant_inputs, inputs])
  150. predictions: xarray.Dataset = self._predictor(
  151. all_inputs, target_template,
  152. forcings=forcings,
  153. **kwargs)
  154. next_frame = xarray.merge([predictions, forcings])
  155. next_inputs = self._update_inputs(inputs, next_frame)
  156. # Drop the length-1 time dimension, since scan will concat all the outputs
  157. # for different times along a new leading time dimension:
  158. predictions = predictions.squeeze('time', drop=True)
  159. # We return the prediction flattened into plain jax arrays, because the
  160. # extra leading dimension added by scan prevents the tree_util
  161. # registrations in xarray_jax from unflattening them back into an
  162. # xarray.Dataset automatically:
  163. flat_pred = jax.tree_util.tree_leaves(predictions)
  164. return next_inputs, flat_pred
  165. if self._gradient_checkpointing:
  166. scan_length = targets_template.dims['time']
  167. if scan_length <= 1:
  168. logging.warning(
  169. 'Skipping gradient checkpointing for sequence length of 1')
  170. else:
  171. # Just in case we take gradients (e.g. for control), although
  172. # in most cases this will just be for a forward pass.
  173. one_step_prediction = hk.remat(one_step_prediction)
  174. # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
  175. _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
  176. # The result of scan will have an extra leading axis on all arrays,
  177. # corresponding to the target times in this case. We need to be prepared for
  178. # it when unflattening the arrays back into a Dataset:
  179. scan_result_template = (
  180. target_template.squeeze('time', drop=True)
  181. .expand_dims(time=targets_template.coords['time'], axis=0))
  182. _, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template)
  183. predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds)
  184. return predictions
  185. def loss(self,
  186. inputs: xarray.Dataset,
  187. targets: xarray.Dataset,
  188. forcings: xarray.Dataset,
  189. **kwargs
  190. ) -> predictor_base.LossAndDiagnostics:
  191. """The mean of the per-timestep losses of the underlying predictor."""
  192. if targets.sizes['time'] == 1:
  193. # If there is only a single target timestep then we don't need any
  194. # autoregressive feedback and can delegate the loss directly to the
  195. # underlying single-step predictor. This means the underlying predictor
  196. # doesn't need to implement .loss_and_predictions.
  197. return self._predictor.loss(inputs, targets, forcings, **kwargs)
  198. constant_inputs = self._get_and_validate_constant_inputs(
  199. inputs, targets, forcings)
  200. self._validate_targets_and_forcings(targets, forcings)
  201. # After the above checks, the remaining inputs must be time-dependent:
  202. inputs = inputs.drop_vars(constant_inputs.keys())
  203. if self._noise_level:
  204. def add_noise(x):
  205. return x + self._noise_level * jax.random.normal(
  206. hk.next_rng_key(), shape=x.shape)
  207. # Add noise to time-dependent variables of the inputs.
  208. inputs = jax.tree_map(add_noise, inputs)
  209. # The per-timestep targets passed by scan to one_step_loss below will have
  210. # no leading time axis. We need a treedef without the time axis to use
  211. # inside one_step_loss to unflatten it back into a dataset:
  212. flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef(
  213. targets)
  214. scan_variables = flat_targets
  215. flat_forcings, forcings_treedef = (
  216. _get_flat_arrays_and_single_timestep_treedef(forcings))
  217. scan_variables = (flat_targets, flat_forcings)
  218. def one_step_loss(inputs, scan_variables):
  219. flat_target, flat_forcings = scan_variables
  220. forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
  221. targets.coords['time'][:1])
  222. target = _unflatten_and_expand_time(flat_target, target_treedef,
  223. targets.coords['time'][:1])
  224. # Add constant inputs:
  225. all_inputs = xarray.merge([constant_inputs, inputs])
  226. (loss, diagnostics), predictions = self._predictor.loss_and_predictions(
  227. all_inputs,
  228. target,
  229. forcings=forcings,
  230. **kwargs)
  231. # Unwrap to jax arrays shape (batch,):
  232. loss, diagnostics = xarray_tree.map_structure(
  233. xarray_jax.unwrap_data, (loss, diagnostics))
  234. predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy.
  235. next_frame = xarray.merge([predictions, forcings])
  236. next_inputs = self._update_inputs(inputs, next_frame)
  237. return next_inputs, (loss, diagnostics)
  238. if self._gradient_checkpointing:
  239. scan_length = targets.dims['time']
  240. if scan_length <= 1:
  241. logging.warning(
  242. 'Skipping gradient checkpointing for sequence length of 1')
  243. else:
  244. one_step_loss = hk.remat(one_step_loss)
  245. # We can pass inputs (the initial state of the loop) in directly as a
  246. # Dataset because the shape we pass in to scan is the same as the shape scan
  247. # passes to the inner function. But, for scan_variables, we must flatten the
  248. # targets (and unflatten them inside the inner function) because they are
  249. # passed to the inner function per-timestep without the original time axis.
  250. # The same apply to the optional forcing.
  251. _, (per_timestep_losses, per_timestep_diagnostics) = hk.scan(
  252. one_step_loss, inputs, scan_variables)
  253. # Re-wrap loss and diagnostics as DataArray and average them over time:
  254. (loss, diagnostics) = jax.tree_util.tree_map(
  255. lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda
  256. 'time', skipna=False),
  257. (per_timestep_losses, per_timestep_diagnostics))
  258. return loss, diagnostics