rollout.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Utils for rolling out models."""
  15. from typing import Iterator
  16. from absl import logging
  17. import chex
  18. import dask
  19. from graphcast import xarray_tree
  20. import jax
  21. import numpy as np
  22. import typing_extensions
  23. import xarray
  24. class PredictorFn(typing_extensions.Protocol):
  25. """Functional version of base.Predictor.__call__ with explicit rng."""
  26. def __call__(
  27. self, rng: chex.PRNGKey, inputs: xarray.Dataset,
  28. targets_template: xarray.Dataset,
  29. forcings: xarray.Dataset,
  30. **optional_kwargs,
  31. ) -> xarray.Dataset:
  32. ...
  33. def chunked_prediction(
  34. predictor_fn: PredictorFn,
  35. rng: chex.PRNGKey,
  36. inputs: xarray.Dataset,
  37. targets_template: xarray.Dataset,
  38. forcings: xarray.Dataset,
  39. num_steps_per_chunk: int = 1,
  40. verbose: bool = False,
  41. ) -> xarray.Dataset:
  42. """Outputs a long trajectory by iteratively concatenating chunked predictions.
  43. Args:
  44. predictor_fn: Function to use to make predictions for each chunk.
  45. rng: Random key.
  46. inputs: Inputs for the model.
  47. targets_template: Template for the target prediction, requires targets
  48. equispaced in time.
  49. forcings: Optional forcing for the model.
  50. num_steps_per_chunk: How many of the steps in `targets_template` to predict
  51. at each call of `predictor_fn`. It must evenly divide the number of
  52. steps in `targets_template`.
  53. verbose: Whether to log the current chunk being predicted.
  54. Returns:
  55. Predictions for the targets template.
  56. """
  57. chunks_list = []
  58. for prediction_chunk in chunked_prediction_generator(
  59. predictor_fn=predictor_fn,
  60. rng=rng,
  61. inputs=inputs,
  62. targets_template=targets_template,
  63. forcings=forcings,
  64. num_steps_per_chunk=num_steps_per_chunk,
  65. verbose=verbose):
  66. chunks_list.append(jax.device_get(prediction_chunk))
  67. return xarray.concat(chunks_list, dim="time")
  68. def chunked_prediction_generator(
  69. predictor_fn: PredictorFn,
  70. rng: chex.PRNGKey,
  71. inputs: xarray.Dataset,
  72. targets_template: xarray.Dataset,
  73. forcings: xarray.Dataset,
  74. num_steps_per_chunk: int = 1,
  75. verbose: bool = False,
  76. ) -> Iterator[xarray.Dataset]:
  77. """Outputs a long trajectory by yielding chunked predictions.
  78. Args:
  79. predictor_fn: Function to use to make predictions for each chunk.
  80. rng: Random key.
  81. inputs: Inputs for the model.
  82. targets_template: Template for the target prediction, requires targets
  83. equispaced in time.
  84. forcings: Optional forcing for the model.
  85. num_steps_per_chunk: How many of the steps in `targets_template` to predict
  86. at each call of `predictor_fn`. It must evenly divide the number of
  87. steps in `targets_template`.
  88. verbose: Whether to log the current chunk being predicted.
  89. Yields:
  90. The predictions for each chunked step of the chunked rollout, such as
  91. if all predictions are concatenated in time this would match the targets
  92. template in structure.
  93. """
  94. # Create copies to avoid mutating inputs.
  95. inputs = xarray.Dataset(inputs)
  96. targets_template = xarray.Dataset(targets_template)
  97. forcings = xarray.Dataset(forcings)
  98. if "datetime" in inputs.coords:
  99. del inputs.coords["datetime"]
  100. if "datetime" in targets_template.coords:
  101. output_datetime = targets_template.coords["datetime"]
  102. del targets_template.coords["datetime"]
  103. else:
  104. output_datetime = None
  105. if "datetime" in forcings.coords:
  106. del forcings.coords["datetime"]
  107. num_target_steps = targets_template.dims["time"]
  108. num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
  109. if remainder != 0:
  110. raise ValueError(
  111. f"The number of steps per chunk {num_steps_per_chunk} must "
  112. f"evenly divide the number of target steps {num_target_steps} ")
  113. if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
  114. raise ValueError("The targets time coordinates must be evenly spaced")
  115. # Our template targets will always have a time axis corresponding for the
  116. # timedeltas for the first chunk.
  117. targets_chunk_time = targets_template.time.isel(
  118. time=slice(0, num_steps_per_chunk))
  119. current_inputs = inputs
  120. for chunk_index in range(num_chunks):
  121. if verbose:
  122. logging.info("Chunk %d/%d", chunk_index, num_chunks)
  123. logging.flush()
  124. # Select targets for the time period that we are predicting for this chunk.
  125. target_offset = num_steps_per_chunk * chunk_index
  126. target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
  127. current_targets_template = targets_template.isel(time=target_slice)
  128. # Replace the timedelta, by the one corresponding to the first chunk, so we
  129. # don't recompile at every iteration, keeping the
  130. actual_target_time = current_targets_template.coords["time"]
  131. current_targets_template = current_targets_template.assign_coords(
  132. time=targets_chunk_time).compute()
  133. current_forcings = forcings.isel(time=target_slice)
  134. current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
  135. current_forcings = current_forcings.compute()
  136. # Make predictions for the chunk.
  137. rng, this_rng = jax.random.split(rng)
  138. predictions = predictor_fn(
  139. rng=this_rng,
  140. inputs=current_inputs,
  141. targets_template=current_targets_template,
  142. forcings=current_forcings)
  143. next_frame = xarray.merge([predictions, current_forcings])
  144. current_inputs = _get_next_inputs(current_inputs, next_frame)
  145. # At this point we can assign the actual targets time coordinates.
  146. predictions = predictions.assign_coords(time=actual_target_time)
  147. if output_datetime is not None:
  148. predictions.coords["datetime"] = output_datetime.isel(
  149. time=target_slice)
  150. yield predictions
  151. del predictions
  152. def _get_next_inputs(
  153. prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
  154. ) -> xarray.Dataset:
  155. """Computes next inputs, from previous inputs and predictions."""
  156. # Make sure are are predicting all inputs with a time axis.
  157. non_predicted_or_forced_inputs = list(
  158. set(prev_inputs.keys()) - set(next_frame.keys()))
  159. if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
  160. raise ValueError(
  161. "Found an input with a time index that is not predicted or forced.")
  162. # Keys we need to copy from predictions to inputs.
  163. next_inputs_keys = list(
  164. set(next_frame.keys()).intersection(set(prev_inputs.keys())))
  165. next_inputs = next_frame[next_inputs_keys]
  166. # Apply concatenate next frame with inputs, crop what we don't need and
  167. # shift timedelta coordinates, so we don't recompile at every iteration.
  168. num_inputs = prev_inputs.dims["time"]
  169. return (
  170. xarray.concat(
  171. [prev_inputs, next_inputs], dim="time", data_vars="different")
  172. .tail(time=num_inputs)
  173. .assign_coords(time=prev_inputs.coords["time"]))
  174. def extend_targets_template(
  175. targets_template: xarray.Dataset,
  176. required_num_steps: int) -> xarray.Dataset:
  177. """Extends `targets_template` to `required_num_steps` with lazy arrays.
  178. It uses lazy dask arrays of zeros, so it does not require instantiating the
  179. array in memory.
  180. Args:
  181. targets_template: Input template to extend.
  182. required_num_steps: Number of steps required in the returned template.
  183. Returns:
  184. `xarray.Dataset` identical in variables and timestep to `targets_template`
  185. full of `dask.array.zeros` such that the time axis has `required_num_steps`.
  186. """
  187. # Extend the "time" and "datetime" coordinates
  188. time = targets_template.coords["time"]
  189. # Assert the first target time corresponds to the timestep.
  190. timestep = time[0].data
  191. if time.shape[0] > 1:
  192. assert np.all(timestep == time[1:] - time[:-1])
  193. extended_time = (np.arange(required_num_steps) + 1) * timestep
  194. if "datetime" in targets_template.coords:
  195. datetime = targets_template.coords["datetime"]
  196. extended_datetime = (datetime[0].data - timestep) + extended_time
  197. else:
  198. extended_datetime = None
  199. # Replace the values with empty dask arrays extending the time coordinates.
  200. datetime = targets_template.coords["time"]
  201. def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
  202. dims = data_array.dims
  203. shape = list(data_array.shape)
  204. shape[dims.index("time")] = required_num_steps
  205. dask_data = dask.array.zeros(
  206. shape=tuple(shape),
  207. chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
  208. dtype=data_array.dtype)
  209. coords = dict(data_array.coords)
  210. coords["time"] = extended_time
  211. if extended_datetime is not None:
  212. coords["datetime"] = ("time", extended_datetime)
  213. return xarray.DataArray(
  214. dims=dims,
  215. data=dask_data,
  216. coords=coords)
  217. return xarray_tree.map_structure(extend_time, targets_template)