predictor_base.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Abstract base classes for an xarray-based Predictor API."""
  15. import abc
  16. from typing import Tuple
  17. from graphcast import losses
  18. from graphcast import xarray_jax
  19. import jax.numpy as jnp
  20. import xarray
  21. LossAndDiagnostics = losses.LossAndDiagnostics
  22. class Predictor(abc.ABC):
  23. """A possibly-trainable predictor of weather, exposing an xarray-based API.
  24. Typically wraps an underlying JAX model and handles translating the xarray
  25. Dataset values to and from plain JAX arrays that are convenient for input to
  26. (and output from) the underlying model.
  27. Different subclasses may exist to wrap different kinds of underlying model,
  28. e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
  29. inputs/outputs, autoregressive models.
  30. You can also implement a specific model directly as a Predictor if you want,
  31. for example if it has quite specific/unique requirements for its input/output
  32. or loss function, or if it's convenient to implement directly using xarray.
  33. """
  34. @abc.abstractmethod
  35. def __call__(self,
  36. inputs: xarray.Dataset,
  37. targets_template: xarray.Dataset,
  38. forcings: xarray.Dataset,
  39. **optional_kwargs
  40. ) -> xarray.Dataset:
  41. """Makes predictions.
  42. This is only used by the Experiment for inference / evaluation, with
  43. training going via the .loss method. So it should default to making
  44. predictions for evaluation, although you can also support making predictions
  45. for use in the loss via an is_training argument -- see
  46. LossFunctionPredictor which helps with that.
  47. Args:
  48. inputs: An xarray.Dataset of inputs.
  49. targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
  50. with the same shape as the targets, to demonstrate what kind of
  51. predictions are required. You can use this to determine which variables,
  52. levels and lead times must be predicted.
  53. You are free to raise an error if you don't support predicting what is
  54. requested.
  55. forcings: An xarray.Dataset of forcings terms. Forcings are variables
  56. that can be fed to the model, but do not need to be predicted. This is
  57. often because this variable can be computed analytically (e.g. the toa
  58. radiation of the sun is mostly a function of geometry) or are considered
  59. to be controlled for the experiment (e.g., impose a scenario of C02
  60. emission into the atmosphere). Unlike `inputs`, the `forcings` can
  61. include information "from the future", that is, information at target
  62. times specified in the `targets_template`.
  63. **optional_kwargs: Implementations may support extra optional kwargs,
  64. provided they set appropriate defaults for them.
  65. Returns:
  66. Predictions, as an xarray.Dataset or other mapping of DataArrays which
  67. is capable of being evaluated against targets with shape given by
  68. targets_template.
  69. For probabilistic predictors which can return multiple samples from a
  70. predictive distribution, these should (by convention) be returned along
  71. an additional 'sample' dimension.
  72. """
  73. def loss(self,
  74. inputs: xarray.Dataset,
  75. targets: xarray.Dataset,
  76. forcings: xarray.Dataset,
  77. **optional_kwargs,
  78. ) -> LossAndDiagnostics:
  79. """Computes a training loss, for predictors that are trainable.
  80. Why make this the Predictor's responsibility, rather than letting callers
  81. compute their own loss function using predictions obtained from
  82. Predictor.__call__?
  83. Doing it this way gives Predictors more control over their training setup.
  84. For example, some predictors may wish to train using different targets to
  85. the ones they predict at evaluation time -- perhaps different lead times and
  86. variables, perhaps training to predict transformed versions of targets
  87. where the transform needs to be inverted at evaluation time, etc.
  88. It's also necessary for generative models (VAEs, GANs, ...) where the
  89. training loss is more complex and isn't expressible as a parameter-free
  90. function of predictions and targets.
  91. Args:
  92. inputs: An xarray.Dataset.
  93. targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
  94. docs on __call__ for an explanation about the targets.
  95. forcings: xarray.Dataset of forcing terms.
  96. **optional_kwargs: Implementations may support extra optional kwargs,
  97. provided they set appropriate defaults for them.
  98. Returns:
  99. loss: A DataArray with dimensions ('batch',) containing losses for each
  100. element of the batch. These will be averaged to give the final
  101. loss, locally and across replicas.
  102. diagnostics: Mapping of additional quantities to log by name alongside the
  103. loss. These will will typically correspond to terms in the loss. They
  104. should also have dimensions ('batch',) and will be averaged over the
  105. batch before logging.
  106. You need not include the loss itself in this dict; it will be added for
  107. you.
  108. """
  109. del targets, forcings, optional_kwargs
  110. batch_size = inputs.sizes['batch']
  111. dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
  112. return dummy_loss, {}
  113. def loss_and_predictions(
  114. self,
  115. inputs: xarray.Dataset,
  116. targets: xarray.Dataset,
  117. forcings: xarray.Dataset,
  118. **optional_kwargs,
  119. ) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
  120. """Like .loss but also returns corresponding predictions.
  121. Implementing this is optional as it's not used directly by the Experiment,
  122. but it is required by autoregressive.Predictor when applying an inner
  123. Predictor autoregressively at training time; we need a loss at each step but
  124. also predictions to feed back in for the next step.
  125. Note the loss itself may not be directly regressing the predictions towards
  126. targets, the loss may be computed in terms of transformed predictions and
  127. targets (or in some other way). For this reason we can't always cleanly
  128. separate this into step 1: get predictions, step 2: compute loss from them,
  129. hence the need for this combined method.
  130. Args:
  131. inputs:
  132. targets:
  133. forcings:
  134. **optional_kwargs:
  135. As for self.loss.
  136. Returns:
  137. (loss, diagnostics)
  138. As for self.loss
  139. predictions:
  140. The predictions which the loss relates to. These should be of the same
  141. shape as what you would get from
  142. `self.__call__(inputs, targets_template=targets)`, and should be in the
  143. same 'domain' as the inputs (i.e. they shouldn't be transformed
  144. differently to how the predictor expects its inputs).
  145. """
  146. raise NotImplementedError