casting.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Wrappers that take care of casting."""
  15. import contextlib
  16. from typing import Any, Mapping, Tuple
  17. import chex
  18. from graphcast import predictor_base
  19. import haiku as hk
  20. import jax
  21. import jax.numpy as jnp
  22. import numpy as np
  23. import xarray
  24. PyTree = Any
  25. class Bfloat16Cast(predictor_base.Predictor):
  26. """Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
  27. def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
  28. """Inits the wrapper.
  29. Args:
  30. predictor: predictor being wrapped.
  31. enabled: disables the wrapper if False, for simpler hyperparameter scans.
  32. """
  33. self._enabled = enabled
  34. self._predictor = predictor
  35. def __call__(self,
  36. inputs: xarray.Dataset,
  37. targets_template: xarray.Dataset,
  38. forcings: xarray.Dataset,
  39. **kwargs
  40. ) -> xarray.Dataset:
  41. if not self._enabled:
  42. return self._predictor(inputs, targets_template, forcings, **kwargs)
  43. with bfloat16_variable_view():
  44. predictions = self._predictor(
  45. *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
  46. **kwargs,)
  47. predictions_dtype = infer_floating_dtype(predictions)
  48. if predictions_dtype != jnp.bfloat16:
  49. raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
  50. targets_dtype = infer_floating_dtype(targets_template)
  51. return tree_map_cast(
  52. predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
  53. def loss(self,
  54. inputs: xarray.Dataset,
  55. targets: xarray.Dataset,
  56. forcings: xarray.Dataset,
  57. **kwargs,
  58. ) -> predictor_base.LossAndDiagnostics:
  59. if not self._enabled:
  60. return self._predictor.loss(inputs, targets, forcings, **kwargs)
  61. with bfloat16_variable_view():
  62. loss, scalars = self._predictor.loss(
  63. *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
  64. if loss.dtype != jnp.bfloat16:
  65. raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
  66. targets_dtype = infer_floating_dtype(targets)
  67. # Note that casting back the loss to e.g. float32 should not affect data
  68. # types of the backwards pass, because the first thing the backwards pass
  69. # should do is to go backwards the casting op and cast back to bfloat16
  70. # (and xprofs seem to confirm this).
  71. return tree_map_cast((loss, scalars),
  72. input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
  73. def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
  74. self,
  75. inputs: xarray.Dataset,
  76. targets: xarray.Dataset,
  77. forcings: xarray.Dataset,
  78. **kwargs,
  79. ) -> Tuple[predictor_base.LossAndDiagnostics,
  80. xarray.Dataset]:
  81. if not self._enabled:
  82. return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
  83. **kwargs)
  84. with bfloat16_variable_view():
  85. (loss, scalars), predictions = self._predictor.loss_and_predictions(
  86. *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
  87. if loss.dtype != jnp.bfloat16:
  88. raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
  89. predictions_dtype = infer_floating_dtype(predictions)
  90. if predictions_dtype != jnp.bfloat16:
  91. raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
  92. targets_dtype = infer_floating_dtype(targets)
  93. return tree_map_cast(((loss, scalars), predictions),
  94. input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
  95. def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
  96. """Infers a floating dtype from an input mapping of data."""
  97. dtypes = {
  98. v.dtype
  99. for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
  100. if len(dtypes) != 1:
  101. dtypes_and_shapes = {
  102. k: (v.dtype, v.shape)
  103. for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
  104. raise ValueError(
  105. f'Did not found exactly one floating dtype {dtypes} in input variables:'
  106. f'{dtypes_and_shapes}')
  107. return list(dtypes)[0]
  108. def _all_inputs_to_bfloat16(
  109. inputs: xarray.Dataset,
  110. targets: xarray.Dataset,
  111. forcings: xarray.Dataset,
  112. ) -> Tuple[xarray.Dataset,
  113. xarray.Dataset,
  114. xarray.Dataset]:
  115. return (inputs.astype(jnp.bfloat16),
  116. jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
  117. forcings.astype(jnp.bfloat16))
  118. def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
  119. ) -> PyTree:
  120. def cast_fn(x):
  121. if x.dtype == input_dtype:
  122. return x.astype(output_dtype)
  123. return jax.tree_map(cast_fn, inputs)
  124. @contextlib.contextmanager
  125. def bfloat16_variable_view(enabled: bool = True):
  126. """Context for Haiku modules with float32 params, but bfloat16 activations.
  127. It works as follows:
  128. * Every time a variable is requested to be created/set as np.bfloat16,
  129. it will create an underlying float32 variable, instead.
  130. * Every time a variable a variable is requested as bfloat16, it will check the
  131. variable is of float32 type, and cast the variable to bfloat16.
  132. Note the gradients are still computed and accumulated as float32, because
  133. the params returned by init are float32, so the gradient function with
  134. respect to the params will already include an implicit casting to float32.
  135. Args:
  136. enabled: Only enables bfloat16 behavior if True.
  137. Yields:
  138. None
  139. """
  140. if enabled:
  141. with hk.custom_creator(
  142. _bfloat16_creator, state=True), hk.custom_getter(
  143. _bfloat16_getter, state=True), hk.custom_setter(
  144. _bfloat16_setter):
  145. yield
  146. else:
  147. yield
  148. def _bfloat16_creator(next_creator, shape, dtype, init, context):
  149. """Creates float32 variables when bfloat16 is requested."""
  150. if context.original_dtype == jnp.bfloat16:
  151. dtype = jnp.float32
  152. return next_creator(shape, dtype, init)
  153. def _bfloat16_getter(next_getter, value, context):
  154. """Casts float32 to bfloat16 when bfloat16 was originally requested."""
  155. if context.original_dtype == jnp.bfloat16:
  156. assert value.dtype == jnp.float32
  157. value = value.astype(jnp.bfloat16)
  158. return next_getter(value)
  159. def _bfloat16_setter(next_setter, value, context):
  160. """Casts bfloat16 to float32 when bfloat16 was originally set."""
  161. if context.original_dtype == jnp.bfloat16:
  162. value = value.astype(jnp.float32)
  163. return next_setter(value)