xarray_tree_test.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for xarray_tree."""
  15. from absl.testing import absltest
  16. from graphcast import xarray_tree
  17. import numpy as np
  18. import xarray
  19. TEST_DATASET = xarray.Dataset(
  20. data_vars={
  21. "foo": (("x", "y"), np.zeros((2, 3))),
  22. "bar": (("x",), np.zeros((2,))),
  23. },
  24. coords={
  25. "x": [1, 2],
  26. "y": [10, 20, 30],
  27. }
  28. )
  29. class XarrayTreeTest(absltest.TestCase):
  30. def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
  31. def fn(leaf):
  32. self.assertIsInstance(leaf, xarray.DataArray)
  33. result = leaf + 1
  34. # Removing the name from the returned DataArray to test that we don't rely
  35. # on it being present to restore the correct names in the result:
  36. result = result.rename(None)
  37. return result
  38. result = xarray_tree.map_structure(fn, TEST_DATASET)
  39. self.assertIsInstance(result, xarray.Dataset)
  40. self.assertSameElements({"foo", "bar"}, result.keys())
  41. def test_map_structure_on_data_arrays(self):
  42. data_arrays = dict(TEST_DATASET)
  43. result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
  44. self.assertIsInstance(result, dict)
  45. self.assertSameElements({"foo", "bar"}, result.keys())
  46. def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
  47. def fn(leaf):
  48. # Returns DataArrays that can't be exactly merged back into a Dataset
  49. # due to the coordinates not matching:
  50. if leaf.name == "foo":
  51. return xarray.DataArray(
  52. data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
  53. else:
  54. return xarray.DataArray(
  55. data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
  56. result = xarray_tree.map_structure(fn, TEST_DATASET)
  57. self.assertIsInstance(result, dict)
  58. self.assertSameElements({"foo", "bar"}, result.keys())
  59. def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
  60. def fn(leaf):
  61. return leaf if leaf.name == "foo" else None
  62. result = xarray_tree.map_structure(fn, TEST_DATASET)
  63. self.assertIsInstance(result, xarray.Dataset)
  64. self.assertSameElements({"foo"}, result.keys())
  65. def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
  66. def fn(leaf):
  67. self.assertIsInstance(leaf, xarray.DataArray)
  68. return "not a DataArray"
  69. result = xarray_tree.map_structure(fn, TEST_DATASET)
  70. self.assertEqual({"foo": "not a DataArray",
  71. "bar": "not a DataArray"}, result)
  72. def test_map_structure_two_args_different_variable_orders(self):
  73. dataset_different_order = TEST_DATASET[["bar", "foo"]]
  74. def fn(arg1, arg2):
  75. self.assertEqual(arg1.name, arg2.name)
  76. xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
  77. if __name__ == "__main__":
  78. absltest.main()