xarray_jax_test.py 20 KB


  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for xarray_jax."""
  15. from absl.testing import absltest
  16. import chex
  17. from graphcast import xarray_jax
  18. import jax
  19. import jax.numpy as jnp
  20. import numpy as np
  21. import xarray
  22. class XarrayJaxTest(absltest.TestCase):
  23. def test_jax_array_wrapper_with_numpy_api(self):
  24. # This is just a side benefit of making things work with xarray, but the
  25. # JaxArrayWrapper does allow you to manipulate JAX arrays using the
  26. # standard numpy API, without converting them to numpy in the process:
  27. ones = jnp.ones((3, 4), dtype=np.float32)
  28. x = xarray_jax.JaxArrayWrapper(ones)
  29. x = np.abs((x + 2) * (x - 3))
  30. x = x[:-1, 1:3]
  31. x = np.concatenate([x, x + 1], axis=0)
  32. x = np.transpose(x, (1, 0))
  33. x = np.reshape(x, (-1,))
  34. x = x.astype(np.int32)
  35. self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
  36. # An explicit conversion gets us out of JAX-land however:
  37. self.assertIsInstance(np.asarray(x), np.ndarray)
  38. def test_jax_xarray_variable(self):
  39. def ops_via_xarray(inputs):
  40. x = xarray_jax.Variable(('lat', 'lon'), inputs)
  41. # We'll apply a sequence of operations just to test that the end result is
  42. # still a JAX array, i.e. we haven't converted to numpy at any point.
  43. x = np.abs((x + 2) * (x - 3))
  44. x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
  45. x = xarray.Variable.concat([x, x + 1], dim='lat')
  46. x = x.transpose('lon', 'lat')
  47. x = x.stack(channels=('lon', 'lat'))
  48. x = x.sum()
  49. return xarray_jax.jax_data(x)
  50. # Check it doesn't leave jax-land when passed concrete values:
  51. ones = jnp.ones((3, 4), dtype=np.float32)
  52. result = ops_via_xarray(ones)
  53. self.assertIsInstance(result, jax.Array)
  54. # And that you can JIT it and compute gradients through it. These will
  55. # involve passing jax tracers through the xarray computation:
  56. jax.jit(ops_via_xarray)(ones)
  57. jax.grad(ops_via_xarray)(ones)
  58. def test_jax_xarray_data_array(self):
  59. def ops_via_xarray(inputs):
  60. x = xarray_jax.DataArray(dims=('lat', 'lon'),
  61. data=inputs,
  62. coords={'lat': np.arange(3) * 10,
  63. 'lon': np.arange(4) * 10})
  64. x = np.abs((x + 2) * (x - 3))
  65. x = x.sel({'lat': slice(0, 20)})
  66. y = xarray_jax.DataArray(dims=('lat', 'lon'),
  67. data=ones,
  68. coords={'lat': np.arange(3, 6) * 10,
  69. 'lon': np.arange(4) * 10})
  70. x = xarray.concat([x, y], dim='lat')
  71. x = x.transpose('lon', 'lat')
  72. x = x.stack(channels=('lon', 'lat'))
  73. x = x.unstack()
  74. x = x.sum()
  75. return xarray_jax.jax_data(x)
  76. ones = jnp.ones((3, 4), dtype=np.float32)
  77. result = ops_via_xarray(ones)
  78. self.assertIsInstance(result, jax.Array)
  79. jax.jit(ops_via_xarray)(ones)
  80. jax.grad(ops_via_xarray)(ones)
  81. def test_jax_xarray_dataset(self):
  82. def ops_via_xarray(foo, bar):
  83. x = xarray_jax.Dataset(
  84. data_vars={'foo': (('lat', 'lon'), foo),
  85. 'bar': (('time', 'lat', 'lon'), bar)},
  86. coords={
  87. 'time': np.arange(2),
  88. 'lat': np.arange(3) * 10,
  89. 'lon': np.arange(4) * 10})
  90. x = np.abs((x + 2) * (x - 3))
  91. x = x.sel({'lat': slice(0, 20)})
  92. y = xarray_jax.Dataset(
  93. data_vars={'foo': (('lat', 'lon'), foo),
  94. 'bar': (('time', 'lat', 'lon'), bar)},
  95. coords={
  96. 'time': np.arange(2),
  97. 'lat': np.arange(3, 6) * 10,
  98. 'lon': np.arange(4) * 10})
  99. x = xarray.concat([x, y], dim='lat')
  100. x = x.transpose('lon', 'lat', 'time')
  101. x = x.stack(channels=('lon', 'lat'))
  102. x = (x.foo + x.bar).sum()
  103. return xarray_jax.jax_data(x)
  104. foo = jnp.ones((3, 4), dtype=np.float32)
  105. bar = jnp.ones((2, 3, 4), dtype=np.float32)
  106. result = ops_via_xarray(foo, bar)
  107. self.assertIsInstance(result, jax.Array)
  108. jax.jit(ops_via_xarray)(foo, bar)
  109. jax.grad(ops_via_xarray)(foo, bar)
  110. def test_jit_function_with_xarray_variable_arguments_and_return(self):
  111. function = jax.jit(lambda v: v + 1)
  112. with self.subTest('jax input'):
  113. inputs = xarray_jax.Variable(
  114. ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
  115. _ = function(inputs)
  116. # We test running the jitted function a second time, to exercise logic in
  117. # jax which checks if the structure of the inputs (including dimension
  118. # names and coordinates) is the same as it was for the previous call and
  119. # so whether it needs to re-trace-and-compile a new version of the
  120. # function or not. This can run into problems if the 'aux' structure
  121. # returned by the registered flatten function is not hashable/comparable.
  122. outputs = function(inputs)
  123. self.assertEqual(outputs.dims, inputs.dims)
  124. with self.subTest('numpy input'):
  125. inputs = xarray.Variable(
  126. ('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
  127. _ = function(inputs)
  128. outputs = function(inputs)
  129. self.assertEqual(outputs.dims, inputs.dims)
  130. def test_jit_problem_if_convert_to_plain_numpy_array(self):
  131. inputs = xarray_jax.DataArray(
  132. data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
  133. with self.assertRaises(jax.errors.TracerArrayConversionError):
  134. # Calling .values on a DataArray converts its values to numpy:
  135. jax.jit(lambda data_array: data_array.values)(inputs)
  136. def test_grad_function_with_xarray_variable_arguments(self):
  137. x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
  138. # For grad we still need a JAX scalar as the output:
  139. jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
  140. def test_jit_function_with_xarray_data_array_arguments_and_return(self):
  141. inputs = xarray_jax.DataArray(
  142. data=jnp.ones((3, 4), dtype=np.float32),
  143. dims=('lat', 'lon'),
  144. coords={'lat': np.arange(3),
  145. 'lon': np.arange(4) * 10})
  146. fn = jax.jit(lambda v: v + 1)
  147. _ = fn(inputs)
  148. outputs = fn(inputs)
  149. self.assertEqual(outputs.dims, inputs.dims)
  150. chex.assert_trees_all_equal(outputs.coords, inputs.coords)
  151. def test_jit_function_with_data_array_and_jax_coords(self):
  152. inputs = xarray_jax.DataArray(
  153. data=jnp.ones((3, 4), dtype=np.float32),
  154. dims=('lat', 'lon'),
  155. coords={'lat': np.arange(3)},
  156. jax_coords={'lon': jnp.arange(4) * 10})
  157. # Verify the jax_coord 'lon' retains jax data, and has not been created
  158. # as an index coordinate:
  159. self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  160. self.assertNotIn('lon', inputs.indexes)
  161. @jax.jit
  162. def fn(v):
  163. # The non-JAX coord is passed with numpy array data and an index:
  164. self.assertIsInstance(v.coords['lat'].data, np.ndarray)
  165. self.assertIn('lat', v.indexes)
  166. # The jax_coord is passed with JAX array data:
  167. self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  168. self.assertNotIn('lon', v.indexes)
  169. # Use the jax coord in the computation:
  170. v = v + v.coords['lon']
  171. # Return with an updated jax coord:
  172. return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
  173. _ = fn(inputs)
  174. outputs = fn(inputs)
  175. # Verify the jax_coord 'lon' has jax data in the output too:
  176. self.assertIsInstance(
  177. outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  178. self.assertNotIn('lon', outputs.indexes)
  179. self.assertEqual(outputs.dims, inputs.dims)
  180. chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
  181. # Check our computations with the coordinate values worked:
  182. chex.assert_trees_all_equal(
  183. outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
  184. chex.assert_trees_all_equal(
  185. outputs.data, (inputs + inputs.coords['lon']).data)
  186. def test_jit_function_with_xarray_dataset_arguments_and_return(self):
  187. foo = jnp.ones((3, 4), dtype=np.float32)
  188. bar = jnp.ones((2, 3, 4), dtype=np.float32)
  189. inputs = xarray_jax.Dataset(
  190. data_vars={'foo': (('lat', 'lon'), foo),
  191. 'bar': (('time', 'lat', 'lon'), bar)},
  192. coords={
  193. 'time': np.arange(2),
  194. 'lat': np.arange(3) * 10,
  195. 'lon': np.arange(4) * 10})
  196. fn = jax.jit(lambda v: v + 1)
  197. _ = fn(inputs)
  198. outputs = fn(inputs)
  199. self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
  200. self.assertEqual(inputs.foo.dims, outputs.foo.dims)
  201. self.assertEqual(inputs.bar.dims, outputs.bar.dims)
  202. chex.assert_trees_all_equal(outputs.coords, inputs.coords)
  203. def test_jit_function_with_dataset_and_jax_coords(self):
  204. foo = jnp.ones((3, 4), dtype=np.float32)
  205. bar = jnp.ones((2, 3, 4), dtype=np.float32)
  206. inputs = xarray_jax.Dataset(
  207. data_vars={'foo': (('lat', 'lon'), foo),
  208. 'bar': (('time', 'lat', 'lon'), bar)},
  209. coords={
  210. 'time': np.arange(2),
  211. 'lat': np.arange(3) * 10,
  212. },
  213. jax_coords={'lon': jnp.arange(4) * 10}
  214. )
  215. # Verify the jax_coord 'lon' retains jax data, and has not been created
  216. # as an index coordinate:
  217. self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  218. self.assertNotIn('lon', inputs.indexes)
  219. @jax.jit
  220. def fn(v):
  221. # The non-JAX coords are passed with numpy array data and an index:
  222. self.assertIsInstance(v.coords['lat'].data, np.ndarray)
  223. self.assertIn('lat', v.indexes)
  224. # The jax_coord is passed with JAX array data:
  225. self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  226. self.assertNotIn('lon', v.indexes)
  227. # Use the jax coord in the computation:
  228. v = v + v.coords['lon']
  229. # Return with an updated jax coord:
  230. return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
  231. _ = fn(inputs)
  232. outputs = fn(inputs)
  233. # Verify the jax_coord 'lon' has jax data in the output too:
  234. self.assertIsInstance(
  235. outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
  236. self.assertNotIn('lon', outputs.indexes)
  237. self.assertEqual(outputs.dims, inputs.dims)
  238. chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
  239. # Check our computations with the coordinate values worked:
  240. chex.assert_trees_all_equal(
  241. (outputs.coords['lon']).data,
  242. (inputs.coords['lon']+1).data,
  243. )
  244. outputs_dict = {key: outputs[key].data for key in outputs}
  245. inputs_and_inputs_coords_dict = {
  246. key: (inputs + inputs.coords['lon'])[key].data
  247. for key in inputs + inputs.coords['lon']
  248. }
  249. chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
  250. def test_flatten_unflatten_variable(self):
  251. variable = xarray_jax.Variable(
  252. ('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
  253. children, aux = xarray_jax._flatten_variable(variable)
  254. # Check auxiliary info is hashable/comparable (important for jax.jit):
  255. hash(aux)
  256. self.assertEqual(aux, aux)
  257. roundtrip = xarray_jax._unflatten_variable(aux, children)
  258. self.assertTrue(variable.equals(roundtrip))
  259. def test_flatten_unflatten_data_array(self):
  260. data_array = xarray_jax.DataArray(
  261. data=jnp.ones((3, 4), dtype=np.float32),
  262. dims=('lat', 'lon'),
  263. coords={'lat': np.arange(3)},
  264. jax_coords={'lon': np.arange(4) * 10},
  265. )
  266. children, aux = xarray_jax._flatten_data_array(data_array)
  267. # Check auxiliary info is hashable/comparable (important for jax.jit):
  268. hash(aux)
  269. self.assertEqual(aux, aux)
  270. roundtrip = xarray_jax._unflatten_data_array(aux, children)
  271. self.assertTrue(data_array.equals(roundtrip))
  272. def test_flatten_unflatten_dataset(self):
  273. foo = jnp.ones((3, 4), dtype=np.float32)
  274. bar = jnp.ones((2, 3, 4), dtype=np.float32)
  275. dataset = xarray_jax.Dataset(
  276. data_vars={'foo': (('lat', 'lon'), foo),
  277. 'bar': (('time', 'lat', 'lon'), bar)},
  278. coords={
  279. 'time': np.arange(2),
  280. 'lat': np.arange(3) * 10},
  281. jax_coords={
  282. 'lon': np.arange(4) * 10})
  283. children, aux = xarray_jax._flatten_dataset(dataset)
  284. # Check auxiliary info is hashable/comparable (important for jax.jit):
  285. hash(aux)
  286. self.assertEqual(aux, aux)
  287. roundtrip = xarray_jax._unflatten_dataset(aux, children)
  288. self.assertTrue(dataset.equals(roundtrip))
  289. def test_flatten_unflatten_added_dim(self):
  290. data_array = xarray_jax.DataArray(
  291. data=jnp.ones((3, 4), dtype=np.float32),
  292. dims=('lat', 'lon'),
  293. coords={'lat': np.arange(3),
  294. 'lon': np.arange(4) * 10})
  295. leaves, treedef = jax.tree_util.tree_flatten(data_array)
  296. leaves = [jnp.expand_dims(x, 0) for x in leaves]
  297. with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
  298. with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
  299. self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
  300. xarray.testing.assert_identical(
  301. jax.device_get(data_array),
  302. jax.device_get(with_new_dim.isel(new=0)))
  303. def test_map_added_dim(self):
  304. data_array = xarray_jax.DataArray(
  305. data=jnp.ones((3, 4), dtype=np.float32),
  306. dims=('lat', 'lon'),
  307. coords={'lat': np.arange(3),
  308. 'lon': np.arange(4) * 10})
  309. with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
  310. with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
  311. data_array)
  312. self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
  313. xarray.testing.assert_identical(
  314. jax.device_get(data_array),
  315. jax.device_get(with_new_dim.isel(new=0)))
  316. def test_map_remove_dim(self):
  317. foo = jnp.ones((1, 3, 4), dtype=np.float32)
  318. bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
  319. dataset = xarray_jax.Dataset(
  320. data_vars={'foo': (('batch', 'lat', 'lon'), foo),
  321. 'bar': (('batch', 'time', 'lat', 'lon'), bar)},
  322. coords={
  323. 'batch': np.array([123]),
  324. 'time': np.arange(2),
  325. 'lat': np.arange(3) * 10,
  326. 'lon': np.arange(4) * 10})
  327. with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
  328. with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
  329. dataset)
  330. self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
  331. self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
  332. self.assertNotIn('batch', with_removed_dim.dims)
  333. self.assertNotIn('batch', with_removed_dim.coords)
  334. xarray.testing.assert_identical(
  335. jax.device_get(dataset.isel(batch=0, drop=True)),
  336. jax.device_get(with_removed_dim))
  337. def test_pmap(self):
  338. devices = jax.local_device_count()
  339. foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
  340. bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
  341. dataset = xarray_jax.Dataset({
  342. 'foo': (('device', 'lat', 'lon'), foo),
  343. 'bar': (('device', 'time', 'lat', 'lon'), bar)})
  344. def func(d):
  345. self.assertNotIn('device', d.dims)
  346. return d + 1
  347. func = xarray_jax.pmap(func, dim='device')
  348. result = func(dataset)
  349. xarray.testing.assert_identical(
  350. jax.device_get(dataset + 1),
  351. jax.device_get(result))
  352. # Can call it again with a different argument structure (it will recompile
  353. # under the hood but should work):
  354. dataset = dataset.drop_vars('foo')
  355. result = func(dataset)
  356. xarray.testing.assert_identical(
  357. jax.device_get(dataset + 1),
  358. jax.device_get(result))
  359. def test_pmap_with_jax_coords(self):
  360. devices = jax.local_device_count()
  361. foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
  362. bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
  363. time = jnp.zeros((devices, 2), dtype=np.float32)
  364. dataset = xarray_jax.Dataset(
  365. {'foo': (('device', 'lat', 'lon'), foo),
  366. 'bar': (('device', 'time', 'lat', 'lon'), bar)},
  367. coords={
  368. 'lat': np.arange(3),
  369. 'lon': np.arange(4),
  370. },
  371. jax_coords={
  372. # Currently any jax_coords need a leading device dimension to use
  373. # with pmap, same as for data_vars.
  374. # TODO(matthjw): have pmap automatically broadcast to all devices
  375. # where the device dimension not present.
  376. 'time': xarray_jax.Variable(('device', 'time'), time),
  377. }
  378. )
  379. def func(d):
  380. self.assertNotIn('device', d.dims)
  381. self.assertNotIn('device', d.coords['time'].dims)
  382. # The jax_coord 'time' should be passed in backed by a JAX array, but
  383. # not as an index coordinate.
  384. self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
  385. self.assertNotIn('time', d.indexes)
  386. return d + 1
  387. func = xarray_jax.pmap(func, dim='device')
  388. result = func(dataset)
  389. xarray.testing.assert_identical(
  390. jax.device_get(dataset + 1),
  391. jax.device_get(result))
  392. # Can call it again with a different argument structure (it will recompile
  393. # under the hood but should work):
  394. dataset = dataset.drop_vars('foo')
  395. result = func(dataset)
  396. xarray.testing.assert_identical(
  397. jax.device_get(dataset + 1),
  398. jax.device_get(result))
  399. def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
  400. devices = jax.local_device_count()
  401. data_array = xarray_jax.DataArray(
  402. data=jnp.ones((devices, 3, 4), dtype=np.float32),
  403. dims=('device', 'lat', 'lon'))
  404. plain_array = jnp.ones((devices, 2), dtype=np.float32)
  405. inputs = {'foo': data_array,
  406. 'bar': plain_array}
  407. def func(x):
  408. return x['foo'] + 1, x['bar'] + 1
  409. func = xarray_jax.pmap(func, dim='device')
  410. result_foo, result_bar = func(inputs)
  411. xarray.testing.assert_identical(
  412. jax.device_get(inputs['foo'] + 1),
  413. jax.device_get(result_foo))
  414. np.testing.assert_array_equal(
  415. jax.device_get(inputs['bar'] + 1),
  416. jax.device_get(result_bar))
  417. def test_pmap_complains_when_dim_not_first(self):
  418. devices = jax.local_device_count()
  419. data_array = xarray_jax.DataArray(
  420. data=jnp.ones((3, devices, 4), dtype=np.float32),
  421. dims=('lat', 'device', 'lon'))
  422. func = xarray_jax.pmap(lambda x: x+1, dim='device')
  423. with self.assertRaisesRegex(
  424. ValueError, 'Expected dim device at index 0, found at 1'):
  425. func(data_array)
  426. def test_apply_ufunc(self):
  427. inputs = xarray_jax.DataArray(
  428. data=jnp.asarray([[1, 2], [3, 4]]),
  429. dims=('x', 'y'),
  430. coords={'x': [0, 1],
  431. 'y': [2, 3]})
  432. result = xarray_jax.apply_ufunc(
  433. lambda x: jnp.sum(x, axis=-1),
  434. inputs,
  435. input_core_dims=[['x']])
  436. expected_result = xarray_jax.DataArray(
  437. data=[4, 6],
  438. dims=('y',),
  439. coords={'y': [2, 3]})
  440. xarray.testing.assert_identical(expected_result, jax.device_get(result))
  441. def test_apply_ufunc_multiple_return_values(self):
  442. def ufunc(array):
  443. return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
  444. inputs = xarray_jax.DataArray(
  445. data=jnp.asarray([[1, 4], [3, 2]]),
  446. dims=('x', 'y'),
  447. coords={'x': [0, 1],
  448. 'y': [2, 3]})
  449. result = xarray_jax.apply_ufunc(
  450. ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
  451. expected = (
  452. # Mins:
  453. xarray_jax.DataArray(
  454. data=[1, 2],
  455. dims=('y',),
  456. coords={'y': [2, 3]}
  457. ),
  458. # Maxes:
  459. xarray_jax.DataArray(
  460. data=[3, 4],
  461. dims=('y',),
  462. coords={'y': [2, 3]}
  463. )
  464. )
  465. xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
  466. xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
  467. if __name__ == '__main__':
  468. absltest.main()