123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Utils for rolling out models."""
- from typing import Iterator
- from absl import logging
- import chex
- import dask
- from graphcast import xarray_tree
- import jax
- import numpy as np
- import typing_extensions
- import xarray
- class PredictorFn(typing_extensions.Protocol):
- """Functional version of base.Predictor.__call__ with explicit rng."""
- def __call__(
- self, rng: chex.PRNGKey, inputs: xarray.Dataset,
- targets_template: xarray.Dataset,
- forcings: xarray.Dataset,
- **optional_kwargs,
- ) -> xarray.Dataset:
- ...
- def chunked_prediction(
- predictor_fn: PredictorFn,
- rng: chex.PRNGKey,
- inputs: xarray.Dataset,
- targets_template: xarray.Dataset,
- forcings: xarray.Dataset,
- num_steps_per_chunk: int = 1,
- verbose: bool = False,
- ) -> xarray.Dataset:
- """Outputs a long trajectory by iteratively concatenating chunked predictions.
- Args:
- predictor_fn: Function to use to make predictions for each chunk.
- rng: Random key.
- inputs: Inputs for the model.
- targets_template: Template for the target prediction, requires targets
- equispaced in time.
- forcings: Optional forcing for the model.
- num_steps_per_chunk: How many of the steps in `targets_template` to predict
- at each call of `predictor_fn`. It must evenly divide the number of
- steps in `targets_template`.
- verbose: Whether to log the current chunk being predicted.
- Returns:
- Predictions for the targets template.
- """
- chunks_list = []
- for prediction_chunk in chunked_prediction_generator(
- predictor_fn=predictor_fn,
- rng=rng,
- inputs=inputs,
- targets_template=targets_template,
- forcings=forcings,
- num_steps_per_chunk=num_steps_per_chunk,
- verbose=verbose):
- chunks_list.append(jax.device_get(prediction_chunk))
- return xarray.concat(chunks_list, dim="time")
- def chunked_prediction_generator(
- predictor_fn: PredictorFn,
- rng: chex.PRNGKey,
- inputs: xarray.Dataset,
- targets_template: xarray.Dataset,
- forcings: xarray.Dataset,
- num_steps_per_chunk: int = 1,
- verbose: bool = False,
- ) -> Iterator[xarray.Dataset]:
- """Outputs a long trajectory by yielding chunked predictions.
- Args:
- predictor_fn: Function to use to make predictions for each chunk.
- rng: Random key.
- inputs: Inputs for the model.
- targets_template: Template for the target prediction, requires targets
- equispaced in time.
- forcings: Optional forcing for the model.
- num_steps_per_chunk: How many of the steps in `targets_template` to predict
- at each call of `predictor_fn`. It must evenly divide the number of
- steps in `targets_template`.
- verbose: Whether to log the current chunk being predicted.
- Yields:
- The predictions for each chunked step of the chunked rollout, such as
- if all predictions are concatenated in time this would match the targets
- template in structure.
- """
- # Create copies to avoid mutating inputs.
- inputs = xarray.Dataset(inputs)
- targets_template = xarray.Dataset(targets_template)
- forcings = xarray.Dataset(forcings)
- if "datetime" in inputs.coords:
- del inputs.coords["datetime"]
- if "datetime" in targets_template.coords:
- output_datetime = targets_template.coords["datetime"]
- del targets_template.coords["datetime"]
- else:
- output_datetime = None
- if "datetime" in forcings.coords:
- del forcings.coords["datetime"]
- num_target_steps = targets_template.dims["time"]
- num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
- if remainder != 0:
- raise ValueError(
- f"The number of steps per chunk {num_steps_per_chunk} must "
- f"evenly divide the number of target steps {num_target_steps} ")
- if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
- raise ValueError("The targets time coordinates must be evenly spaced")
- # Our template targets will always have a time axis corresponding for the
- # timedeltas for the first chunk.
- targets_chunk_time = targets_template.time.isel(
- time=slice(0, num_steps_per_chunk))
- current_inputs = inputs
- for chunk_index in range(num_chunks):
- if verbose:
- logging.info("Chunk %d/%d", chunk_index, num_chunks)
- logging.flush()
- # Select targets for the time period that we are predicting for this chunk.
- target_offset = num_steps_per_chunk * chunk_index
- target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
- current_targets_template = targets_template.isel(time=target_slice)
- # Replace the timedelta, by the one corresponding to the first chunk, so we
- # don't recompile at every iteration, keeping the
- actual_target_time = current_targets_template.coords["time"]
- current_targets_template = current_targets_template.assign_coords(
- time=targets_chunk_time).compute()
- current_forcings = forcings.isel(time=target_slice)
- current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
- current_forcings = current_forcings.compute()
- # Make predictions for the chunk.
- rng, this_rng = jax.random.split(rng)
- predictions = predictor_fn(
- rng=this_rng,
- inputs=current_inputs,
- targets_template=current_targets_template,
- forcings=current_forcings)
- next_frame = xarray.merge([predictions, current_forcings])
- current_inputs = _get_next_inputs(current_inputs, next_frame)
- # At this point we can assign the actual targets time coordinates.
- predictions = predictions.assign_coords(time=actual_target_time)
- if output_datetime is not None:
- predictions.coords["datetime"] = output_datetime.isel(
- time=target_slice)
- yield predictions
- del predictions
- def _get_next_inputs(
- prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
- ) -> xarray.Dataset:
- """Computes next inputs, from previous inputs and predictions."""
- # Make sure are are predicting all inputs with a time axis.
- non_predicted_or_forced_inputs = list(
- set(prev_inputs.keys()) - set(next_frame.keys()))
- if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
- raise ValueError(
- "Found an input with a time index that is not predicted or forced.")
- # Keys we need to copy from predictions to inputs.
- next_inputs_keys = list(
- set(next_frame.keys()).intersection(set(prev_inputs.keys())))
- next_inputs = next_frame[next_inputs_keys]
- # Apply concatenate next frame with inputs, crop what we don't need and
- # shift timedelta coordinates, so we don't recompile at every iteration.
- num_inputs = prev_inputs.dims["time"]
- return (
- xarray.concat(
- [prev_inputs, next_inputs], dim="time", data_vars="different")
- .tail(time=num_inputs)
- .assign_coords(time=prev_inputs.coords["time"]))
- def extend_targets_template(
- targets_template: xarray.Dataset,
- required_num_steps: int) -> xarray.Dataset:
- """Extends `targets_template` to `required_num_steps` with lazy arrays.
- It uses lazy dask arrays of zeros, so it does not require instantiating the
- array in memory.
- Args:
- targets_template: Input template to extend.
- required_num_steps: Number of steps required in the returned template.
- Returns:
- `xarray.Dataset` identical in variables and timestep to `targets_template`
- full of `dask.array.zeros` such that the time axis has `required_num_steps`.
- """
- # Extend the "time" and "datetime" coordinates
- time = targets_template.coords["time"]
- # Assert the first target time corresponds to the timestep.
- timestep = time[0].data
- if time.shape[0] > 1:
- assert np.all(timestep == time[1:] - time[:-1])
- extended_time = (np.arange(required_num_steps) + 1) * timestep
- if "datetime" in targets_template.coords:
- datetime = targets_template.coords["datetime"]
- extended_datetime = (datetime[0].data - timestep) + extended_time
- else:
- extended_datetime = None
- # Replace the values with empty dask arrays extending the time coordinates.
- datetime = targets_template.coords["time"]
- def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
- dims = data_array.dims
- shape = list(data_array.shape)
- shape[dims.index("time")] = required_num_steps
- dask_data = dask.array.zeros(
- shape=tuple(shape),
- chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
- dtype=data_array.dtype)
- coords = dict(data_array.coords)
- coords["time"] = extended_time
- if extended_datetime is not None:
- coords["datetime"] = ("time", extended_datetime)
- return xarray.DataArray(
- dims=dims,
- data=dask_data,
- coords=coords)
- return xarray_tree.map_structure(extend_time, targets_template)
|