xarray_jax.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Helpers to use xarray.{Variable,DataArray,Dataset} with JAX.
  15. Allows them to be based on JAX arrays without converting to numpy arrays under
  16. the hood, so you can start with a JAX array, do some computation with it in
  17. xarray-land, get a JAX array out the other end and (for example) jax.jit
  18. through the whole thing. You can even jax.jit a function which accepts and
  19. returns xarray.Dataset, DataArray and Variable.
  20. ## Creating xarray datatypes from jax arrays, and vice-versa.
  21. You can use the xarray_jax.{Variable,DataArray,Dataset} constructors, which have
  22. the same API as the standard xarray constructors but will accept JAX arrays
  23. without converting them to numpy.
  24. It does this by wrapping the JAX array in a wrapper before passing it to
  25. xarray; you can also do this manually by calling xarray_jax.wrap on your JAX
  26. arrays before passing them to the standard xarray constructors.
  27. To get non-wrapped JAX arrays out the other end, you can use e.g.:
  28. xarray_jax.jax_vars(dataset)
  29. xarray_jax.jax_data(dataset.some_var)
  30. which will complain if the data isn't actually a JAX array. Use this if you need
  31. to make sure the computation has gone via JAX, e.g. if it's the output of code
  32. that you want to JIT or compute gradients through. If this is not the case and
  33. you want to support passing plain numpy arrays through as well as potentially
  34. JAX arrays, you can use:
  35. xarray_jax.unwrap_vars(dataset)
  36. xarray_jax.unwrap_data(dataset.some_var)
  37. which will unwrap the data if it is a wrapped JAX array, but otherwise pass
  38. it through to you without complaint.
  39. The wrapped JAX arrays aim to support all the core operations from the numpy
  40. array API that xarray expects, however there may still be some gaps; if you run
  41. into any problems around this, you may need to add a few more proxy methods onto
  42. the wrapper class below.
  43. In future once JAX and xarray support the new Python array API standard
  44. (https://data-apis.org/array-api/latest/index.html), we hope to avoid the need
  45. for wrapping the JAX arrays like this.
  46. ## jax.jit and pmap of functions taking and returning xarray datatypes
  47. We register xarray datatypes with jax.tree_util, which allows them to be treated
  48. as generic containers of jax arrays by various parts of jax including jax.jit.
  49. This allows for, e.g.:
  50. @jax.jit
  51. def foo(input: xarray.Dataset) -> xarray.Dataset:
  52. ...
  53. It will not work out-of-the-box with shape-modifying transformations like
  54. jax.pmap, or e.g. a jax.tree_util.tree_map with some transform that alters array
  55. shapes or dimension order. That's because we won't know what dimension names
  56. and/or coordinates to use when unflattening, if the results have a different
  57. shape to the data that was originally flattened.
  58. You can work around this using xarray_jax.dims_change_on_unflatten, however,
  59. and in the case of jax.pmap we provide a wrapper xarray_jax.pmap which allows
  60. it to be used with functions taking and returning xarrays.
  61. ## Treatment of coordinates
  62. We don't support passing jax arrays as coordinates when constructing a
  63. DataArray/Dataset. This is because xarray's advanced indexing and slicing is
  64. unlikely to work with jax arrays (at least when a Tracer is used during
  65. jax.jit), and also because some important datatypes used for coordinates, like
  66. timedelta64 and datetime64, are not supported by jax.
  67. For the purposes of tree_util and jax.jit, coordinates are not treated as leaves
  68. of the tree (array data 'contained' by a Dataset/DataArray), they are just a
  69. static part of the structure. That means that if a jit'ed function is called
  70. twice with Dataset inputs that use different coordinates, it will compile a
  71. separate version of the function for each. The coordinates are treated like
  72. static_argnums by jax.jit.
  73. If you want to use dynamic data for coordinates, we recommend making it a
  74. data_var instead of a coord. You won't be able to do indexing and slicing using
  75. the coordinate, but that wasn't going to work with a jax array anyway.
  76. """
  77. import collections
  78. import contextlib
  79. import contextvars
  80. from typing import Any, Callable, Hashable, Iterator, Mapping, Optional, Union, Tuple, TypeVar, cast
  81. import jax
  82. import jax.numpy as jnp
  83. import numpy as np
  84. import tree
  85. import xarray
  86. def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name
  87. """Like xarray.Variable, but can wrap JAX arrays."""
  88. return xarray.Variable(dims, wrap(data), **kwargs)
  89. _JAX_COORD_ATTR_NAME = '_jax_coord'
  90. def DataArray( # pylint:disable=invalid-name
  91. data,
  92. coords=None,
  93. dims=None,
  94. name=None,
  95. attrs=None,
  96. jax_coords=None,
  97. ) -> xarray.DataArray:
  98. """Like xarray.DataArray, but supports using JAX arrays.
  99. Args:
  100. data: As for xarray.DataArray, except jax arrays are also supported.
  101. coords: Coordinates for the array, see xarray.DataArray. These coordinates
  102. must be based on plain numpy arrays or something convertible to plain
  103. numpy arrays. Their values will form a static part of the data structure
  104. from the point of view of jax.tree_util. In particular this means these
  105. coordinates will be passed as plain numpy arrays even inside a JIT'd
  106. function, and the JIT'd function will be recompiled under the hood if the
  107. coordinates of DataArrays passed into it change.
  108. If this is not convenient for you, see also jax_coords below.
  109. dims: See xarray.DataArray.
  110. name: See xarray.DataArray.
  111. attrs: See xarray.DataArray.
  112. jax_coords: Additional coordinates, which *can* use JAX arrays. These
  113. coordinates will be treated as JAX data from the point of view of
  114. jax.tree_util, that means when JIT'ing they will be passed as tracers and
  115. computation involving them will be JIT'd.
  116. Unfortunately a side-effect of this is that they can't be used as index
  117. coordinates (because xarray's indexing logic is not JIT-able). If you
  118. specify a coordinate with the same name as a dimension here, it will not
  119. be set as an index coordinate; this behaviour is different to the default
  120. for `coords`, and it means that things like `.sel` based on the jax
  121. coordinate will not work.
  122. Note we require `jax_coords` to be explicitly specified via a different
  123. constructor argument to `coords`, rather than just looking for jax arrays
  124. within the `coords` and treating them differently. This is because it
  125. affects the way jax.tree_util treats them, which is somewhat orthogonal to
  126. whether the value is passed in as numpy or not, and generally needs to be
  127. handled consistently so is something we encourage explicit control over.
  128. Returns:
  129. An instance of xarray.DataArray. Where JAX arrays are used as data or
  130. coords, they will be wrapped with JaxArrayWrapper and can be unwrapped via
  131. `unwrap` and `unwrap_data`.
  132. """
  133. result = xarray.DataArray(
  134. wrap(data), dims=dims, name=name, attrs=attrs or {})
  135. return assign_coords(result, coords=coords, jax_coords=jax_coords)
  136. def Dataset( # pylint:disable=invalid-name
  137. data_vars,
  138. coords=None,
  139. attrs=None,
  140. jax_coords=None,
  141. ) -> xarray.Dataset:
  142. """Like xarray.Dataset, but can wrap JAX arrays.
  143. Args:
  144. data_vars: As for xarray.Dataset, except jax arrays are also supported.
  145. coords: Coordinates for the dataset, see xarray.Dataset. These coordinates
  146. must be based on plain numpy arrays or something convertible to plain
  147. numpy arrays. Their values will form a static part of the data structure
  148. from the point of view of jax.tree_util. In particular this means these
  149. coordinates will be passed as plain numpy arrays even inside a JIT'd
  150. function, and the JIT'd function will be recompiled under the hood if the
  151. coordinates of DataArrays passed into it change.
  152. If this is not convenient for you, see also jax_coords below.
  153. attrs: See xarray.Dataset.
  154. jax_coords: Additional coordinates, which *can* use JAX arrays. These
  155. coordinates will be treated as JAX data from the point of view of
  156. jax.tree_util, that means when JIT'ing they will be passed as tracers and
  157. computation involving them will be JIT'd.
  158. Unfortunately a side-effect of this is that they can't be used as index
  159. coordinates (because xarray's indexing logic is not JIT-able). If you
  160. specify a coordinate with the same name as a dimension here, it will not
  161. be set as an index coordinate; this behaviour is different to the default
  162. for `coords`, and it means that things like `.sel` based on the jax
  163. coordinate will not work.
  164. Note we require `jax_coords` to be explicitly specified via a different
  165. constructor argument to `coords`, rather than just looking for jax arrays
  166. within the `coords` and treating them differently. This is because it
  167. affects the way jax.tree_util treats them, which is somewhat orthogonal to
  168. whether the value is passed in as numpy or not, and generally needs to be
  169. handled consistently so is something we encourage explicit control over.
  170. Returns:
  171. An instance of xarray.Dataset. Where JAX arrays are used as data, they
  172. will be wrapped with JaxArrayWrapper.
  173. """
  174. wrapped_data_vars = {}
  175. for name, var_like in data_vars.items():
  176. # xarray.Dataset accepts a few different formats for data_vars:
  177. if isinstance(var_like, jax.Array):
  178. wrapped_data_vars[name] = wrap(var_like)
  179. elif isinstance(var_like, tuple):
  180. # Layout is (dims, data, ...). We wrap data.
  181. wrapped_data_vars[name] = (var_like[0], wrap(var_like[1])) + var_like[2:]
  182. else:
  183. # Could be a plain numpy array or scalar (we don't wrap), or an
  184. # xarray.Variable, DataArray etc, which we must assume is already wrapped
  185. # if necessary (e.g. if creating using xarray_jax.{Variable,DataArray}).
  186. wrapped_data_vars[name] = var_like
  187. result = xarray.Dataset(
  188. data_vars=wrapped_data_vars,
  189. attrs=attrs)
  190. return assign_coords(result, coords=coords, jax_coords=jax_coords)
  191. DatasetOrDataArray = TypeVar(
  192. 'DatasetOrDataArray', xarray.Dataset, xarray.DataArray)
  193. def assign_coords(
  194. x: DatasetOrDataArray,
  195. *,
  196. coords: Optional[Mapping[Hashable, Any]] = None,
  197. jax_coords: Optional[Mapping[Hashable, Any]] = None,
  198. ) -> DatasetOrDataArray:
  199. """Replacement for assign_coords which works in presence of jax_coords.
  200. `jax_coords` allow certain specified coordinates to have their data passed as
  201. JAX arrays (including through jax.jit boundaries). The compromise in return is
  202. that they are not created as index coordinates and cannot be used for .sel
  203. and other coordinate-based indexing operations. See docs for `jax_coords` on
  204. xarray_jax.Dataset and xarray_jax.DataArray for more information.
  205. This function can be used to set jax_coords on an existing DataArray or
  206. Dataset, and also to set a mix of jax and non-jax coordinates. It implements
  207. some workarounds to prevent xarray trying and failing to create IndexVariables
  208. from jax arrays under the hood.
  209. If you have any jax_coords with the same name as a dimension, you'll need to
  210. use this function instead of data_array.assign_coords or dataset.assign_coords
  211. in general, to avoid an xarray bug where it tries (and in our case fails) to
  212. create indexes for existing jax coords. See
  213. https://github.com/pydata/xarray/issues/7885.
  214. Args:
  215. x: An xarray Dataset or DataArray.
  216. coords: Dict of (non-JAX) coords, or None if not assigning any.
  217. jax_coords: Dict of JAX coords, or None if not assigning any. See docs for
  218. xarray_jax.Dataset / DataArray for more information on jax_coords.
  219. Returns:
  220. The Dataset or DataArray with coordinates assigned, similarly to
  221. Dataset.assign_coords / DataArray.assign_coords.
  222. """
  223. coords = {} if coords is None else dict(coords) # Copy before mutating.
  224. jax_coords = {} if jax_coords is None else dict(jax_coords)
  225. # Any existing JAX coords must be dropped and re-added via the workaround
  226. # below, since otherwise .assign_coords will trigger an xarray bug where
  227. # it tries to recreate the indexes again for the existing coordinates.
  228. # Can remove if/when https://github.com/pydata/xarray/issues/7885 fixed.
  229. existing_jax_coords = {
  230. name: coord_var for name, coord_var in x.coords.variables.items()
  231. if coord_var.attrs.get(_JAX_COORD_ATTR_NAME, False)
  232. }
  233. jax_coords = existing_jax_coords | jax_coords
  234. x = x.drop_vars(existing_jax_coords.keys())
  235. # We need to ensure that xarray doesn't try to create an index for
  236. # coordinates with the same name as a dimension, since this will fail if
  237. # given a wrapped JAX tracer.
  238. # It appears the only way to avoid this is to name them differently to any
  239. # dimension name, then rename them back afterwards.
  240. renamed_jax_coords = {}
  241. for name, coord in jax_coords.items():
  242. if isinstance(coord, xarray.DataArray):
  243. coord = coord.variable
  244. if isinstance(coord, xarray.Variable):
  245. coord = coord.copy(deep=False) # Copy before mutating attrs.
  246. else:
  247. # Must wrap as Variable with the correct dims first if this has not
  248. # already been done, otherwise xarray.Dataset will assume the dimension
  249. # name is also __NONINDEX_{n}.
  250. coord = Variable((name,), coord)
  251. # We set an attr on each jax_coord identifying it as such. These attrs on
  252. # the coord Variable gets reflected on the coord DataArray exposed too, and
  253. # when set on coordinates they generally get preserved under the default
  254. # keep_attrs setting.
  255. # These attrs are used by jax.tree_util registered flatten/unflatten to
  256. # determine which coords need to be treated as leaves of the flattened
  257. # structure vs static data.
  258. coord.attrs[_JAX_COORD_ATTR_NAME] = True
  259. renamed_jax_coords[f'__NONINDEX_{name}'] = coord
  260. x = x.assign_coords(coords=coords | renamed_jax_coords)
  261. rename_back_mapping = {f'__NONINDEX_{name}': name for name in jax_coords}
  262. if isinstance(x, xarray.Dataset):
  263. # Using 'rename' doesn't work if renaming to the same name as a dimension.
  264. return x.rename_vars(rename_back_mapping)
  265. else: # DataArray
  266. return x.rename(rename_back_mapping)
  267. def assign_jax_coords(
  268. x: DatasetOrDataArray,
  269. jax_coords: Optional[Mapping[Hashable, Any]] = None,
  270. **jax_coords_kwargs
  271. ) -> DatasetOrDataArray:
  272. """Assigns only jax_coords, with same API as xarray's assign_coords."""
  273. return assign_coords(x, jax_coords=jax_coords or jax_coords_kwargs)
  274. def wrap(value):
  275. """Wraps JAX arrays for use in xarray, passing through other values."""
  276. if isinstance(value, jax.Array):
  277. return JaxArrayWrapper(value)
  278. else:
  279. return value
  280. def unwrap(value, require_jax=False):
  281. """Unwraps wrapped JAX arrays used in xarray, passing through other values."""
  282. if isinstance(value, JaxArrayWrapper):
  283. return value.jax_array
  284. elif isinstance(value, jax.Array):
  285. return value
  286. elif require_jax:
  287. raise TypeError(f'Expected JAX array, found {type(value)}.')
  288. else:
  289. return value
  290. def _wrapped(func):
  291. """Surrounds a function with JAX array unwrapping/wrapping."""
  292. def wrapped_func(*args, **kwargs):
  293. args, kwargs = tree.map_structure(unwrap, (args, kwargs))
  294. result = func(*args, **kwargs)
  295. return tree.map_structure(wrap, result)
  296. return wrapped_func
  297. def unwrap_data(
  298. value: Union[xarray.Variable, xarray.DataArray],
  299. require_jax: bool = False
  300. ) -> Union[jax.Array, np.ndarray]:
  301. """The unwrapped (see unwrap) data of a an xarray.Variable or DataArray."""
  302. return unwrap(value.data, require_jax=require_jax)
  303. def unwrap_vars(
  304. dataset: Mapping[Hashable, xarray.DataArray],
  305. require_jax: bool = False
  306. ) -> Mapping[str, Union[jax.Array, np.ndarray]]:
  307. """The unwrapped data (see unwrap) of the variables in a dataset."""
  308. # xarray types variable names as Hashable, but in practice they're invariably
  309. # strings and we convert to str to allow for a more useful return type.
  310. return {str(name): unwrap_data(var, require_jax=require_jax)
  311. for name, var in dataset.items()}
  312. def unwrap_coords(
  313. dataset: Union[xarray.Dataset, xarray.DataArray],
  314. require_jax: bool = False
  315. ) -> Mapping[str, Union[jax.Array, np.ndarray]]:
  316. """The unwrapped data (see unwrap) of the coords in a Dataset or DataArray."""
  317. return {str(name): unwrap_data(var, require_jax=require_jax)
  318. for name, var in dataset.coords.items()}
  319. def jax_data(value: Union[xarray.Variable, xarray.DataArray]) -> jax.Array:
  320. """Like unwrap_data, but will complain if not a jax array."""
  321. # Implementing this separately so we can give a more specific return type
  322. # for it.
  323. return cast(jax.Array, unwrap_data(value, require_jax=True))
  324. def jax_vars(
  325. dataset: Mapping[Hashable, xarray.DataArray]) -> Mapping[str, jax.Array]:
  326. """Like unwrap_vars, but will complain if vars are not all jax arrays."""
  327. return cast(Mapping[str, jax.Array], unwrap_vars(dataset, require_jax=True))
  328. class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
  329. """Wraps a JAX array into a duck-typed array suitable for use with xarray.
  330. This uses an older duck-typed array protocol based on __array_ufunc__ and
  331. __array_function__ which works with numpy and xarray. This is in the process
  332. of being superseded by the Python array API standard
  333. (https://data-apis.org/array-api/latest/index.html), but JAX and xarray
  334. haven't implemented it yet. Once they have, we should be able to get rid of
  335. this wrapper and use JAX arrays directly with xarray.
  336. """
  337. def __init__(self, jax_array):
  338. self.jax_array = jax_array
  339. def __array_ufunc__(self, ufunc, method, *args, **kwargs):
  340. for x in args:
  341. if not isinstance(x, (jax.typing.ArrayLike, type(self))):
  342. return NotImplemented
  343. if method != '__call__':
  344. return NotImplemented
  345. try:
  346. # Get the corresponding jax.numpy function to the NumPy ufunc:
  347. func = getattr(jnp, ufunc.__name__)
  348. except AttributeError:
  349. return NotImplemented
  350. # There may be an 'out' kwarg requesting an in-place operation, e.g. when
  351. # this is called via __iadd__ (+=), __imul__ (*=) etc. JAX doesn't support
  352. # in-place operations so we just remove this argument and have the ufunc
  353. # return a fresh JAX array instead.
  354. kwargs.pop('out', None)
  355. return _wrapped(func)(*args, **kwargs)
  356. def __array_function__(self, func, types, args, kwargs):
  357. try:
  358. # Get the corresponding jax.np function to the NumPy function:
  359. func = getattr(jnp, func.__name__)
  360. except AttributeError:
  361. return NotImplemented
  362. return _wrapped(func)(*args, **kwargs)
  363. def __repr__(self):
  364. return f'xarray_jax.JaxArrayWrapper({repr(self.jax_array)})'
  365. # NDArrayOperatorsMixin already proxies most __dunder__ operator methods.
  366. # We need to proxy through a few more methods in a similar way:
  367. # Essential array properties:
  368. @property
  369. def shape(self):
  370. return self.jax_array.shape
  371. @property
  372. def dtype(self):
  373. return self.jax_array.dtype
  374. @property
  375. def ndim(self):
  376. return self.jax_array.ndim
  377. @property
  378. def size(self):
  379. return self.jax_array.size
  380. # Array methods not covered by NDArrayOperatorsMixin:
  381. # Allows conversion to numpy array using np.asarray etc. Warning: doing this
  382. # will fail in a jax.jit-ed function.
  383. def __array__(self, dtype=None, context=None):
  384. return np.asarray(self.jax_array, dtype=dtype)
  385. __getitem__ = _wrapped(lambda array, *args: array.__getitem__(*args))
  386. # We drop the kwargs on this as they are not supported by JAX, but xarray
  387. # uses at least one of them (the copy arg).
  388. astype = _wrapped(lambda array, *args, **kwargs: array.astype(*args))
  389. # There are many more methods which are more canonically available via (j)np
  390. # functions, e.g. .sum() available via jnp.sum, and also mean, max, min,
  391. # argmax, argmin etc. We don't attempt to proxy through all of these as
  392. # methods, since this doesn't appear to be expected from a duck-typed array
  393. # implementation. But there are a few which xarray calls as methods, so we
  394. # proxy those:
  395. transpose = _wrapped(jnp.transpose)
  396. reshape = _wrapped(jnp.reshape)
  397. all = _wrapped(jnp.all)
  398. def apply_ufunc(func, *args, require_jax=False, **apply_ufunc_kwargs):
  399. """Like xarray.apply_ufunc but for jax-specific ufuncs.
  400. Many numpy ufuncs will work fine out of the box with xarray_jax and
  401. JaxArrayWrapper, since JaxArrayWrapper quacks (mostly) like a numpy array and
  402. will convert many numpy operations to jax ops under the hood. For these
  403. situations, xarray.apply_ufunc should work fine.
  404. But sometimes you need a jax-specific ufunc which needs to be given a
  405. jax array as input or return a jax array as output. In that case you should
  406. use this helper as it will remove any JaxArrayWrapper before calling the func,
  407. and wrap the result afterwards before handing it back to xarray.
  408. Args:
  409. func: A function that works with jax arrays (e.g. using functions from
  410. jax.numpy) but otherwise meets the spec for the func argument to
  411. xarray.apply_ufunc.
  412. *args: xarray arguments to be mapped to arguments for func
  413. (see xarray.apply_ufunc).
  414. require_jax: Whether to require that inputs are based on jax arrays or allow
  415. those based on plain numpy arrays too.
  416. **apply_ufunc_kwargs: See xarray.apply_ufunc.
  417. Returns:
  418. Corresponding xarray results (see xarray.apply_ufunc).
  419. """
  420. def wrapped_func(*maybe_wrapped_args):
  421. unwrapped_args = [unwrap(a, require_jax) for a in maybe_wrapped_args]
  422. result = func(*unwrapped_args)
  423. # Result can be an array or a tuple of arrays, this handles both:
  424. return jax.tree_util.tree_map(wrap, result)
  425. return xarray.apply_ufunc(wrapped_func, *args, **apply_ufunc_kwargs)
  426. def pmap(fn: Callable[..., Any],
  427. dim: str,
  428. axis_name: Optional[str] = None,
  429. devices: ... = None,
  430. backend: ... = None) -> Callable[..., Any]:
  431. """Wraps a subset of jax.pmap functionality to handle xarray input/output.
  432. Constraints:
  433. * Any Dataset or DataArray passed to the function must have `dim` as the
  434. first dimension. This will be checked. You can ensure this if necessary
  435. by calling `.transpose(dim, ...)` beforehand.
  436. * All args and return values will be mapped over the first dimension,
  437. it will use in_axes=0, out_axes=0.
  438. * No support for static_broadcasted_argnums, donate_argnums etc.
  439. Args:
  440. fn: Function to be pmap'd which takes and returns trees which may contain
  441. xarray Dataset/DataArray. Any Dataset/DataArrays passed as input must use
  442. `dim` as the first dimension on all arrays.
  443. dim: The xarray dimension name corresponding to the first dimension that is
  444. pmapped over (pmap is called with in_axes=0, out_axes=0).
  445. axis_name: Used by jax to identify the mapped axis so that parallel
  446. collectives can be applied. Defaults to same as `dim`.
  447. devices:
  448. backend:
  449. See jax.pmap.
  450. Returns:
  451. A pmap'd version of `fn`, which takes and returns Dataset/DataArray with an
  452. extra leading dimension `dim` relative to what the original `fn` sees.
  453. """
  454. input_treedef = None
  455. output_treedef = None
  456. def fn_passed_to_pmap(*flat_args):
  457. assert input_treedef is not None
  458. # Inside the pmap the original first dimension will no longer be present:
  459. def check_and_remove_leading_dim(dims):
  460. try:
  461. index = dims.index(dim)
  462. except ValueError:
  463. index = None
  464. if index != 0:
  465. raise ValueError(f'Expected dim {dim} at index 0, found at {index}.')
  466. return dims[1:]
  467. with dims_change_on_unflatten(check_and_remove_leading_dim):
  468. args = jax.tree_util.tree_unflatten(input_treedef, flat_args)
  469. result = fn(*args)
  470. nonlocal output_treedef
  471. flat_result, output_treedef = jax.tree_util.tree_flatten(result)
  472. return flat_result
  473. pmapped_fn = jax.pmap(
  474. fn_passed_to_pmap,
  475. axis_name=axis_name or dim,
  476. in_axes=0,
  477. out_axes=0,
  478. devices=devices,
  479. backend=backend)
  480. def result_fn(*args):
  481. nonlocal input_treedef
  482. flat_args, input_treedef = jax.tree_util.tree_flatten(args)
  483. flat_result = pmapped_fn(*flat_args)
  484. assert output_treedef is not None
  485. # After the pmap an extra leading axis will be present, we need to add an
  486. # xarray dimension for this when unflattening the result:
  487. with dims_change_on_unflatten(lambda dims: (dim,) + dims):
  488. return jax.tree_util.tree_unflatten(output_treedef, flat_result)
  489. return result_fn
  490. # Register xarray datatypes with jax.tree_util.
  491. DimsChangeFn = Callable[[Tuple[Hashable, ...]], Tuple[Hashable, ...]]
  492. _DIMS_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[DimsChangeFn] = (
  493. contextvars.ContextVar('dims_change_on_unflatten_fn'))
  494. @contextlib.contextmanager
  495. def dims_change_on_unflatten(dims_change_fn: DimsChangeFn):
  496. """Can be used to change the dims used when unflattening arrays into xarrays.
  497. This is useful when some axes were added to / removed from the underlying jax
  498. arrays after they were flattened using jax.tree_util.tree_flatten, and you
  499. want to unflatten them again afterwards using the original treedef but
  500. adjusted for the added/removed dimensions.
  501. It can also be used with jax.tree_util.tree_map, when it's called with a
  502. function that adds/removes axes or otherwise changes the axis order.
  503. When dimensions are removed, any coordinates using those removed dimensions
  504. will also be removed on unflatten.
  505. This is implemented as a context manager that sets some thread-local state
  506. affecting the behaviour of our unflatten functions, because it's not possible
  507. to directly modify the treedef to change the dims/coords in it (and with
  508. tree_map, the treedef isn't exposed to you anyway).
  509. Args:
  510. dims_change_fn: Maps a tuple of dimension names for the original
  511. Variable/DataArray/Dataset that was flattened, to an updated tuple of
  512. dimensions which should be used when unflattening.
  513. Yields:
  514. To a context manager in whose scope jax.tree_util.tree_unflatten and
  515. jax.tree_util.tree_map will apply the dims_change_fn before reconstructing
  516. xarrays from jax arrays.
  517. """
  518. token = _DIMS_CHANGE_ON_UNFLATTEN_FN.set(dims_change_fn)
  519. try:
  520. yield
  521. finally:
  522. _DIMS_CHANGE_ON_UNFLATTEN_FN.reset(token)
  523. def _flatten_variable(v: xarray.Variable) -> Tuple[
  524. Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]:
  525. """Flattens a Variable for jax.tree_util."""
  526. children = (unwrap_data(v),)
  527. aux = v.dims
  528. return children, aux
  529. def _unflatten_variable(
  530. aux: Tuple[Hashable, ...],
  531. children: Tuple[jax.typing.ArrayLike]) -> xarray.Variable:
  532. """Unflattens a Variable for jax.tree_util."""
  533. dims = aux
  534. dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
  535. if dims_change_fn: dims = dims_change_fn(dims)
  536. return Variable(dims=dims, data=children[0])
  537. def _split_static_and_jax_coords(
  538. coords: xarray.core.coordinates.Coordinates) -> Tuple[
  539. Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]]:
  540. static_coord_vars = {}
  541. jax_coord_vars = {}
  542. for name, coord in coords.items():
  543. if coord.attrs.get(_JAX_COORD_ATTR_NAME, False):
  544. jax_coord_vars[name] = coord.variable
  545. else:
  546. assert not isinstance(coord, (jax.Array, JaxArrayWrapper))
  547. static_coord_vars[name] = coord.variable
  548. return static_coord_vars, jax_coord_vars
  549. def _drop_with_none_of_dims(
  550. coord_vars: Mapping[Hashable, xarray.Variable],
  551. dims: Tuple[Hashable]) -> Mapping[Hashable, xarray.Variable]:
  552. return {name: var for name, var in coord_vars.items()
  553. if set(var.dims) <= set(dims)}
  554. class _HashableCoords(collections.abc.Mapping):
  555. """Wraps a dict of xarray Variables as hashable, used for static coordinates.
  556. This needs to be hashable so that when an xarray.Dataset is passed to a
  557. jax.jit'ed function, jax can check whether it's seen an array with the
  558. same static coordinates(*) before or whether it needs to recompile the
  559. function for the new values of the static coordinates.
  560. (*) note jax_coords are not included in this; their value can be different
  561. on different calls without triggering a recompile.
  562. """
  563. def __init__(self, coord_vars: Mapping[Hashable, xarray.Variable]):
  564. self._variables = coord_vars
  565. def __repr__(self) -> str:
  566. return f'_HashableCoords({repr(self._variables)})'
  567. def __getitem__(self, key: Hashable) -> xarray.Variable:
  568. return self._variables[key]
  569. def __len__(self) -> int:
  570. return len(self._variables)
  571. def __iter__(self) -> Iterator[Hashable]:
  572. return iter(self._variables)
  573. def __hash__(self):
  574. if not hasattr(self, '_hash'):
  575. self._hash = hash(frozenset((name, var.data.tobytes())
  576. for name, var in self._variables.items()))
  577. return self._hash
  578. def __eq__(self, other):
  579. if self is other:
  580. return True
  581. elif not isinstance(other, type(self)):
  582. return NotImplemented
  583. elif self._variables is other._variables:
  584. return True
  585. else:
  586. return self._variables.keys() == other._variables.keys() and all(
  587. variable.equals(other._variables[name])
  588. for name, variable in self._variables.items())
  589. def _flatten_data_array(v: xarray.DataArray) -> Tuple[
  590. # Children (data variable, jax_coord_vars):
  591. Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
  592. # Static auxiliary data (name, static_coord_vars):
  593. Tuple[Optional[Hashable], _HashableCoords]]:
  594. """Flattens a DataArray for jax.tree_util."""
  595. static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(v.coords)
  596. children = (v.variable, jax_coord_vars)
  597. aux = (v.name, _HashableCoords(static_coord_vars))
  598. return children, aux
  599. def _unflatten_data_array(
  600. aux: Tuple[Optional[Hashable], _HashableCoords],
  601. children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
  602. ) -> xarray.DataArray:
  603. """Unflattens a DataArray for jax.tree_util."""
  604. variable, jax_coord_vars = children
  605. name, static_coord_vars = aux
  606. # Drop static coords which have dims not present in any of the data_vars.
  607. # These would generally be dims that were dropped by a dims_change_fn, but
  608. # because static coordinates don't go through dims_change_fn on unflatten, we
  609. # just drop them where this causes a problem.
  610. # Since jax_coords go through the dims_change_fn on unflatten we don't need
  611. # to do this for jax_coords.
  612. static_coord_vars = _drop_with_none_of_dims(static_coord_vars, variable.dims)
  613. return DataArray(
  614. variable, name=name, coords=static_coord_vars, jax_coords=jax_coord_vars)
  615. def _flatten_dataset(dataset: xarray.Dataset) -> Tuple[
  616. # Children (data variables, jax_coord_vars):
  617. Tuple[Mapping[Hashable, xarray.Variable],
  618. Mapping[Hashable, xarray.Variable]],
  619. # Static auxiliary data (static_coord_vars):
  620. _HashableCoords]:
  621. """Flattens a Dataset for jax.tree_util."""
  622. variables = {name: data_array.variable
  623. for name, data_array in dataset.data_vars.items()}
  624. static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(
  625. dataset.coords)
  626. children = (variables, jax_coord_vars)
  627. aux = _HashableCoords(static_coord_vars)
  628. return children, aux
  629. def _unflatten_dataset(
  630. aux: _HashableCoords,
  631. children: Tuple[Mapping[Hashable, xarray.Variable],
  632. Mapping[Hashable, xarray.Variable]],
  633. ) -> xarray.Dataset:
  634. """Unflattens a Dataset for jax.tree_util."""
  635. data_vars, jax_coord_vars = children
  636. static_coord_vars = aux
  637. dataset = xarray.Dataset(data_vars)
  638. # Drop static coords which have dims not present in any of the data_vars.
  639. # See corresponding comment in _unflatten_data_array.
  640. static_coord_vars = _drop_with_none_of_dims(static_coord_vars, dataset.dims)
  641. return assign_coords(
  642. dataset, coords=static_coord_vars, jax_coords=jax_coord_vars)
  643. jax.tree_util.register_pytree_node(
  644. xarray.Variable, _flatten_variable, _unflatten_variable)
  645. # This is a subclass of Variable but still needs registering separately.
  646. # Flatten/unflatten for IndexVariable is a bit of a corner case but we do
  647. # need to support it.
  648. jax.tree_util.register_pytree_node(
  649. xarray.IndexVariable, _flatten_variable, _unflatten_variable)
  650. jax.tree_util.register_pytree_node(
  651. xarray.DataArray, _flatten_data_array, _unflatten_data_array)
  652. jax.tree_util.register_pytree_node(
  653. xarray.Dataset, _flatten_dataset, _unflatten_dataset)