123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Tests for xarray_jax."""
- from absl.testing import absltest
- import chex
- from graphcast import xarray_jax
- import jax
- import jax.numpy as jnp
- import numpy as np
- import xarray
- class XarrayJaxTest(absltest.TestCase):
- def test_jax_array_wrapper_with_numpy_api(self):
- # This is just a side benefit of making things work with xarray, but the
- # JaxArrayWrapper does allow you to manipulate JAX arrays using the
- # standard numpy API, without converting them to numpy in the process:
- ones = jnp.ones((3, 4), dtype=np.float32)
- x = xarray_jax.JaxArrayWrapper(ones)
- x = np.abs((x + 2) * (x - 3))
- x = x[:-1, 1:3]
- x = np.concatenate([x, x + 1], axis=0)
- x = np.transpose(x, (1, 0))
- x = np.reshape(x, (-1,))
- x = x.astype(np.int32)
- self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
- # An explicit conversion gets us out of JAX-land however:
- self.assertIsInstance(np.asarray(x), np.ndarray)
- def test_jax_xarray_variable(self):
- def ops_via_xarray(inputs):
- x = xarray_jax.Variable(('lat', 'lon'), inputs)
- # We'll apply a sequence of operations just to test that the end result is
- # still a JAX array, i.e. we haven't converted to numpy at any point.
- x = np.abs((x + 2) * (x - 3))
- x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
- x = xarray.Variable.concat([x, x + 1], dim='lat')
- x = x.transpose('lon', 'lat')
- x = x.stack(channels=('lon', 'lat'))
- x = x.sum()
- return xarray_jax.jax_data(x)
- # Check it doesn't leave jax-land when passed concrete values:
- ones = jnp.ones((3, 4), dtype=np.float32)
- result = ops_via_xarray(ones)
- self.assertIsInstance(result, jax.Array)
- # And that you can JIT it and compute gradients through it. These will
- # involve passing jax tracers through the xarray computation:
- jax.jit(ops_via_xarray)(ones)
- jax.grad(ops_via_xarray)(ones)
- def test_jax_xarray_data_array(self):
- def ops_via_xarray(inputs):
- x = xarray_jax.DataArray(dims=('lat', 'lon'),
- data=inputs,
- coords={'lat': np.arange(3) * 10,
- 'lon': np.arange(4) * 10})
- x = np.abs((x + 2) * (x - 3))
- x = x.sel({'lat': slice(0, 20)})
- y = xarray_jax.DataArray(dims=('lat', 'lon'),
- data=ones,
- coords={'lat': np.arange(3, 6) * 10,
- 'lon': np.arange(4) * 10})
- x = xarray.concat([x, y], dim='lat')
- x = x.transpose('lon', 'lat')
- x = x.stack(channels=('lon', 'lat'))
- x = x.unstack()
- x = x.sum()
- return xarray_jax.jax_data(x)
- ones = jnp.ones((3, 4), dtype=np.float32)
- result = ops_via_xarray(ones)
- self.assertIsInstance(result, jax.Array)
- jax.jit(ops_via_xarray)(ones)
- jax.grad(ops_via_xarray)(ones)
- def test_jax_xarray_dataset(self):
- def ops_via_xarray(foo, bar):
- x = xarray_jax.Dataset(
- data_vars={'foo': (('lat', 'lon'), foo),
- 'bar': (('time', 'lat', 'lon'), bar)},
- coords={
- 'time': np.arange(2),
- 'lat': np.arange(3) * 10,
- 'lon': np.arange(4) * 10})
- x = np.abs((x + 2) * (x - 3))
- x = x.sel({'lat': slice(0, 20)})
- y = xarray_jax.Dataset(
- data_vars={'foo': (('lat', 'lon'), foo),
- 'bar': (('time', 'lat', 'lon'), bar)},
- coords={
- 'time': np.arange(2),
- 'lat': np.arange(3, 6) * 10,
- 'lon': np.arange(4) * 10})
- x = xarray.concat([x, y], dim='lat')
- x = x.transpose('lon', 'lat', 'time')
- x = x.stack(channels=('lon', 'lat'))
- x = (x.foo + x.bar).sum()
- return xarray_jax.jax_data(x)
- foo = jnp.ones((3, 4), dtype=np.float32)
- bar = jnp.ones((2, 3, 4), dtype=np.float32)
- result = ops_via_xarray(foo, bar)
- self.assertIsInstance(result, jax.Array)
- jax.jit(ops_via_xarray)(foo, bar)
- jax.grad(ops_via_xarray)(foo, bar)
- def test_jit_function_with_xarray_variable_arguments_and_return(self):
- function = jax.jit(lambda v: v + 1)
- with self.subTest('jax input'):
- inputs = xarray_jax.Variable(
- ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
- _ = function(inputs)
- # We test running the jitted function a second time, to exercise logic in
- # jax which checks if the structure of the inputs (including dimension
- # names and coordinates) is the same as it was for the previous call and
- # so whether it needs to re-trace-and-compile a new version of the
- # function or not. This can run into problems if the 'aux' structure
- # returned by the registered flatten function is not hashable/comparable.
- outputs = function(inputs)
- self.assertEqual(outputs.dims, inputs.dims)
- with self.subTest('numpy input'):
- inputs = xarray.Variable(
- ('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
- _ = function(inputs)
- outputs = function(inputs)
- self.assertEqual(outputs.dims, inputs.dims)
- def test_jit_problem_if_convert_to_plain_numpy_array(self):
- inputs = xarray_jax.DataArray(
- data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
- with self.assertRaises(jax.errors.TracerArrayConversionError):
- # Calling .values on a DataArray converts its values to numpy:
- jax.jit(lambda data_array: data_array.values)(inputs)
- def test_grad_function_with_xarray_variable_arguments(self):
- x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
- # For grad we still need a JAX scalar as the output:
- jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
- def test_jit_function_with_xarray_data_array_arguments_and_return(self):
- inputs = xarray_jax.DataArray(
- data=jnp.ones((3, 4), dtype=np.float32),
- dims=('lat', 'lon'),
- coords={'lat': np.arange(3),
- 'lon': np.arange(4) * 10})
- fn = jax.jit(lambda v: v + 1)
- _ = fn(inputs)
- outputs = fn(inputs)
- self.assertEqual(outputs.dims, inputs.dims)
- chex.assert_trees_all_equal(outputs.coords, inputs.coords)
- def test_jit_function_with_data_array_and_jax_coords(self):
- inputs = xarray_jax.DataArray(
- data=jnp.ones((3, 4), dtype=np.float32),
- dims=('lat', 'lon'),
- coords={'lat': np.arange(3)},
- jax_coords={'lon': jnp.arange(4) * 10})
- # Verify the jax_coord 'lon' retains jax data, and has not been created
- # as an index coordinate:
- self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', inputs.indexes)
- @jax.jit
- def fn(v):
- # The non-JAX coord is passed with numpy array data and an index:
- self.assertIsInstance(v.coords['lat'].data, np.ndarray)
- self.assertIn('lat', v.indexes)
- # The jax_coord is passed with JAX array data:
- self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', v.indexes)
- # Use the jax coord in the computation:
- v = v + v.coords['lon']
- # Return with an updated jax coord:
- return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
- _ = fn(inputs)
- outputs = fn(inputs)
- # Verify the jax_coord 'lon' has jax data in the output too:
- self.assertIsInstance(
- outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', outputs.indexes)
- self.assertEqual(outputs.dims, inputs.dims)
- chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
- # Check our computations with the coordinate values worked:
- chex.assert_trees_all_equal(
- outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
- chex.assert_trees_all_equal(
- outputs.data, (inputs + inputs.coords['lon']).data)
- def test_jit_function_with_xarray_dataset_arguments_and_return(self):
- foo = jnp.ones((3, 4), dtype=np.float32)
- bar = jnp.ones((2, 3, 4), dtype=np.float32)
- inputs = xarray_jax.Dataset(
- data_vars={'foo': (('lat', 'lon'), foo),
- 'bar': (('time', 'lat', 'lon'), bar)},
- coords={
- 'time': np.arange(2),
- 'lat': np.arange(3) * 10,
- 'lon': np.arange(4) * 10})
- fn = jax.jit(lambda v: v + 1)
- _ = fn(inputs)
- outputs = fn(inputs)
- self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
- self.assertEqual(inputs.foo.dims, outputs.foo.dims)
- self.assertEqual(inputs.bar.dims, outputs.bar.dims)
- chex.assert_trees_all_equal(outputs.coords, inputs.coords)
- def test_jit_function_with_dataset_and_jax_coords(self):
- foo = jnp.ones((3, 4), dtype=np.float32)
- bar = jnp.ones((2, 3, 4), dtype=np.float32)
- inputs = xarray_jax.Dataset(
- data_vars={'foo': (('lat', 'lon'), foo),
- 'bar': (('time', 'lat', 'lon'), bar)},
- coords={
- 'time': np.arange(2),
- 'lat': np.arange(3) * 10,
- },
- jax_coords={'lon': jnp.arange(4) * 10}
- )
- # Verify the jax_coord 'lon' retains jax data, and has not been created
- # as an index coordinate:
- self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', inputs.indexes)
- @jax.jit
- def fn(v):
- # The non-JAX coords are passed with numpy array data and an index:
- self.assertIsInstance(v.coords['lat'].data, np.ndarray)
- self.assertIn('lat', v.indexes)
- # The jax_coord is passed with JAX array data:
- self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', v.indexes)
- # Use the jax coord in the computation:
- v = v + v.coords['lon']
- # Return with an updated jax coord:
- return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
- _ = fn(inputs)
- outputs = fn(inputs)
- # Verify the jax_coord 'lon' has jax data in the output too:
- self.assertIsInstance(
- outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('lon', outputs.indexes)
- self.assertEqual(outputs.dims, inputs.dims)
- chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
- # Check our computations with the coordinate values worked:
- chex.assert_trees_all_equal(
- (outputs.coords['lon']).data,
- (inputs.coords['lon']+1).data,
- )
- outputs_dict = {key: outputs[key].data for key in outputs}
- inputs_and_inputs_coords_dict = {
- key: (inputs + inputs.coords['lon'])[key].data
- for key in inputs + inputs.coords['lon']
- }
- chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
- def test_flatten_unflatten_variable(self):
- variable = xarray_jax.Variable(
- ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
- children, aux = xarray_jax._flatten_variable(variable)
- # Check auxiliary info is hashable/comparable (important for jax.jit):
- hash(aux)
- self.assertEqual(aux, aux)
- roundtrip = xarray_jax._unflatten_variable(aux, children)
- self.assertTrue(variable.equals(roundtrip))
- def test_flatten_unflatten_data_array(self):
- data_array = xarray_jax.DataArray(
- data=jnp.ones((3, 4), dtype=np.float32),
- dims=('lat', 'lon'),
- coords={'lat': np.arange(3)},
- jax_coords={'lon': np.arange(4) * 10},
- )
- children, aux = xarray_jax._flatten_data_array(data_array)
- # Check auxiliary info is hashable/comparable (important for jax.jit):
- hash(aux)
- self.assertEqual(aux, aux)
- roundtrip = xarray_jax._unflatten_data_array(aux, children)
- self.assertTrue(data_array.equals(roundtrip))
- def test_flatten_unflatten_dataset(self):
- foo = jnp.ones((3, 4), dtype=np.float32)
- bar = jnp.ones((2, 3, 4), dtype=np.float32)
- dataset = xarray_jax.Dataset(
- data_vars={'foo': (('lat', 'lon'), foo),
- 'bar': (('time', 'lat', 'lon'), bar)},
- coords={
- 'time': np.arange(2),
- 'lat': np.arange(3) * 10},
- jax_coords={
- 'lon': np.arange(4) * 10})
- children, aux = xarray_jax._flatten_dataset(dataset)
- # Check auxiliary info is hashable/comparable (important for jax.jit):
- hash(aux)
- self.assertEqual(aux, aux)
- roundtrip = xarray_jax._unflatten_dataset(aux, children)
- self.assertTrue(dataset.equals(roundtrip))
- def test_flatten_unflatten_added_dim(self):
- data_array = xarray_jax.DataArray(
- data=jnp.ones((3, 4), dtype=np.float32),
- dims=('lat', 'lon'),
- coords={'lat': np.arange(3),
- 'lon': np.arange(4) * 10})
- leaves, treedef = jax.tree_util.tree_flatten(data_array)
- leaves = [jnp.expand_dims(x, 0) for x in leaves]
- with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
- with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
- self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
- xarray.testing.assert_identical(
- jax.device_get(data_array),
- jax.device_get(with_new_dim.isel(new=0)))
- def test_map_added_dim(self):
- data_array = xarray_jax.DataArray(
- data=jnp.ones((3, 4), dtype=np.float32),
- dims=('lat', 'lon'),
- coords={'lat': np.arange(3),
- 'lon': np.arange(4) * 10})
- with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
- with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
- data_array)
- self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
- xarray.testing.assert_identical(
- jax.device_get(data_array),
- jax.device_get(with_new_dim.isel(new=0)))
- def test_map_remove_dim(self):
- foo = jnp.ones((1, 3, 4), dtype=np.float32)
- bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
- dataset = xarray_jax.Dataset(
- data_vars={'foo': (('batch', 'lat', 'lon'), foo),
- 'bar': (('batch', 'time', 'lat', 'lon'), bar)},
- coords={
- 'batch': np.array([123]),
- 'time': np.arange(2),
- 'lat': np.arange(3) * 10,
- 'lon': np.arange(4) * 10})
- with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
- with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
- dataset)
- self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
- self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
- self.assertNotIn('batch', with_removed_dim.dims)
- self.assertNotIn('batch', with_removed_dim.coords)
- xarray.testing.assert_identical(
- jax.device_get(dataset.isel(batch=0, drop=True)),
- jax.device_get(with_removed_dim))
- def test_pmap(self):
- devices = jax.local_device_count()
- foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
- bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
- dataset = xarray_jax.Dataset({
- 'foo': (('device', 'lat', 'lon'), foo),
- 'bar': (('device', 'time', 'lat', 'lon'), bar)})
- def func(d):
- self.assertNotIn('device', d.dims)
- return d + 1
- func = xarray_jax.pmap(func, dim='device')
- result = func(dataset)
- xarray.testing.assert_identical(
- jax.device_get(dataset + 1),
- jax.device_get(result))
- # Can call it again with a different argument structure (it will recompile
- # under the hood but should work):
- dataset = dataset.drop_vars('foo')
- result = func(dataset)
- xarray.testing.assert_identical(
- jax.device_get(dataset + 1),
- jax.device_get(result))
- def test_pmap_with_jax_coords(self):
- devices = jax.local_device_count()
- foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
- bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
- time = jnp.zeros((devices, 2), dtype=np.float32)
- dataset = xarray_jax.Dataset(
- {'foo': (('device', 'lat', 'lon'), foo),
- 'bar': (('device', 'time', 'lat', 'lon'), bar)},
- coords={
- 'lat': np.arange(3),
- 'lon': np.arange(4),
- },
- jax_coords={
- # Currently any jax_coords need a leading device dimension to use
- # with pmap, same as for data_vars.
- # TODO(matthjw): have pmap automatically broadcast to all devices
- # where the device dimension not present.
- 'time': xarray_jax.Variable(('device', 'time'), time),
- }
- )
- def func(d):
- self.assertNotIn('device', d.dims)
- self.assertNotIn('device', d.coords['time'].dims)
- # The jax_coord 'time' should be passed in backed by a JAX array, but
- # not as an index coordinate.
- self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
- self.assertNotIn('time', d.indexes)
- return d + 1
- func = xarray_jax.pmap(func, dim='device')
- result = func(dataset)
- xarray.testing.assert_identical(
- jax.device_get(dataset + 1),
- jax.device_get(result))
- # Can call it again with a different argument structure (it will recompile
- # under the hood but should work):
- dataset = dataset.drop_vars('foo')
- result = func(dataset)
- xarray.testing.assert_identical(
- jax.device_get(dataset + 1),
- jax.device_get(result))
- def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
- devices = jax.local_device_count()
- data_array = xarray_jax.DataArray(
- data=jnp.ones((devices, 3, 4), dtype=np.float32),
- dims=('device', 'lat', 'lon'))
- plain_array = jnp.ones((devices, 2), dtype=np.float32)
- inputs = {'foo': data_array,
- 'bar': plain_array}
- def func(x):
- return x['foo'] + 1, x['bar'] + 1
- func = xarray_jax.pmap(func, dim='device')
- result_foo, result_bar = func(inputs)
- xarray.testing.assert_identical(
- jax.device_get(inputs['foo'] + 1),
- jax.device_get(result_foo))
- np.testing.assert_array_equal(
- jax.device_get(inputs['bar'] + 1),
- jax.device_get(result_bar))
- def test_pmap_complains_when_dim_not_first(self):
- devices = jax.local_device_count()
- data_array = xarray_jax.DataArray(
- data=jnp.ones((3, devices, 4), dtype=np.float32),
- dims=('lat', 'device', 'lon'))
- func = xarray_jax.pmap(lambda x: x+1, dim='device')
- with self.assertRaisesRegex(
- ValueError, 'Expected dim device at index 0, found at 1'):
- func(data_array)
- def test_apply_ufunc(self):
- inputs = xarray_jax.DataArray(
- data=jnp.asarray([[1, 2], [3, 4]]),
- dims=('x', 'y'),
- coords={'x': [0, 1],
- 'y': [2, 3]})
- result = xarray_jax.apply_ufunc(
- lambda x: jnp.sum(x, axis=-1),
- inputs,
- input_core_dims=[['x']])
- expected_result = xarray_jax.DataArray(
- data=[4, 6],
- dims=('y',),
- coords={'y': [2, 3]})
- xarray.testing.assert_identical(expected_result, jax.device_get(result))
- def test_apply_ufunc_multiple_return_values(self):
- def ufunc(array):
- return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
- inputs = xarray_jax.DataArray(
- data=jnp.asarray([[1, 4], [3, 2]]),
- dims=('x', 'y'),
- coords={'x': [0, 1],
- 'y': [2, 3]})
- result = xarray_jax.apply_ufunc(
- ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
- expected = (
- # Mins:
- xarray_jax.DataArray(
- data=[1, 2],
- dims=('y',),
- coords={'y': [2, 3]}
- ),
- # Maxes:
- xarray_jax.DataArray(
- data=[3, 4],
- dims=('y',),
- coords={'y': [2, 3]}
- )
- )
- xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
- xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
- if __name__ == '__main__':
- absltest.main()
|