data_utils_test.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for `data_utils.py`."""
  15. import datetime
  16. from absl.testing import absltest
  17. from absl.testing import parameterized
  18. from graphcast import data_utils
  19. import numpy as np
  20. import xarray
  21. class DataUtilsTest(parameterized.TestCase):
  22. def setUp(self):
  23. super().setUp()
  24. # Fix the seed for reproducibility.
  25. np.random.seed(0)
  26. def test_year_progress_is_zero_at_year_start_or_end(self):
  27. year_progress = data_utils.get_year_progress(
  28. np.array([
  29. 0,
  30. data_utils.AVG_SEC_PER_YEAR,
  31. data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
  32. ])
  33. )
  34. np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
  35. def test_year_progress_is_almost_one_before_year_ends(self):
  36. year_progress = data_utils.get_year_progress(
  37. np.array([
  38. data_utils.AVG_SEC_PER_YEAR - 1,
  39. (data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
  40. ])
  41. )
  42. with self.subTest("Year progress values are close to 1"):
  43. self.assertTrue(np.all(year_progress > 0.999))
  44. with self.subTest("Year progress values != 1"):
  45. self.assertTrue(np.all(year_progress < 1.0))
  46. def test_day_progress_computes_for_all_times_and_longitudes(self):
  47. times = np.random.randint(low=0, high=1e10, size=10)
  48. longitudes = np.arange(0, 360.0, 1.0)
  49. day_progress = data_utils.get_day_progress(times, longitudes)
  50. with self.subTest("Day progress is computed for all times and longinutes"):
  51. self.assertSequenceEqual(
  52. day_progress.shape, (len(times), len(longitudes))
  53. )
  54. @parameterized.named_parameters(
  55. dict(
  56. testcase_name="random_date_1",
  57. year=1988,
  58. month=11,
  59. day=7,
  60. hour=2,
  61. minute=45,
  62. second=34,
  63. ),
  64. dict(
  65. testcase_name="random_date_2",
  66. year=2022,
  67. month=3,
  68. day=12,
  69. hour=7,
  70. minute=1,
  71. second=0,
  72. ),
  73. )
  74. def test_day_progress_is_in_between_zero_and_one(
  75. self, year, month, day, hour, minute, second
  76. ):
  77. # Datetime from a timestamp.
  78. dt = datetime.datetime(year, month, day, hour, minute, second)
  79. # Epoch time.
  80. epoch_time = datetime.datetime(1970, 1, 1)
  81. # Seconds since epoch.
  82. seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
  83. # Longitudes with 1 degree resolution.
  84. longitudes = np.arange(0, 360.0, 1.0)
  85. day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
  86. with self.subTest("Day progress >= 0"):
  87. self.assertTrue(np.all(day_progress >= 0.0))
  88. with self.subTest("Day progress < 1"):
  89. self.assertTrue(np.all(day_progress < 1.0))
  90. def test_day_progress_is_zero_at_day_start_or_end(self):
  91. day_progress = data_utils.get_day_progress(
  92. seconds_since_epoch=np.array([
  93. 0,
  94. data_utils.SEC_PER_DAY,
  95. data_utils.SEC_PER_DAY * 42, # 42 days.
  96. ]),
  97. longitude=np.array([0.0]),
  98. )
  99. np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
  100. def test_day_progress_specific_value(self):
  101. day_progress = data_utils.get_day_progress(
  102. seconds_since_epoch=np.array([123]),
  103. longitude=np.array([0.0]),
  104. )
  105. np.testing.assert_array_almost_equal(
  106. day_progress, np.array([[0.00142361]]), decimal=6
  107. )
  108. def test_featurize_progress_valid_values_and_dimensions(self):
  109. day_progress = np.array([0.0, 0.45, 0.213])
  110. feature_dimensions = ("time",)
  111. progress_features = data_utils.featurize_progress(
  112. name="day_progress", dims=feature_dimensions, progress=day_progress
  113. )
  114. for feature in progress_features.values():
  115. with self.subTest(f"Valid dimensions for {feature}"):
  116. self.assertSequenceEqual(feature.dims, feature_dimensions)
  117. with self.subTest("Valid values for day_progress"):
  118. np.testing.assert_array_equal(
  119. day_progress, progress_features["day_progress"].values
  120. )
  121. with self.subTest("Valid values for day_progress_sin"):
  122. np.testing.assert_array_almost_equal(
  123. np.array([0.0, 0.30901699, 0.97309851]),
  124. progress_features["day_progress_sin"].values,
  125. decimal=6,
  126. )
  127. with self.subTest("Valid values for day_progress_cos"):
  128. np.testing.assert_array_almost_equal(
  129. np.array([1.0, -0.95105652, 0.23038943]),
  130. progress_features["day_progress_cos"].values,
  131. decimal=6,
  132. )
  133. def test_featurize_progress_invalid_dimensions(self):
  134. year_progress = np.array([0.0, 0.45, 0.213])
  135. feature_dimensions = ("time", "longitude")
  136. with self.assertRaises(ValueError):
  137. data_utils.featurize_progress(
  138. name="year_progress", dims=feature_dimensions, progress=year_progress
  139. )
  140. def test_add_derived_vars_variables_added(self):
  141. data = xarray.Dataset(
  142. data_vars={
  143. "var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
  144. },
  145. coords={
  146. "lon": np.array([0.0, 0.5]),
  147. "datetime": np.array([
  148. datetime.datetime(2021, 1, 1),
  149. datetime.datetime(2023, 1, 1),
  150. datetime.datetime(2023, 1, 3),
  151. ]),
  152. },
  153. )
  154. data_utils.add_derived_vars(data)
  155. all_variables = set(data.variables)
  156. with self.subTest("Original value was not removed"):
  157. self.assertIn("var1", all_variables)
  158. with self.subTest("Year progress feature was added"):
  159. self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
  160. with self.subTest("Day progress feature was added"):
  161. self.assertIn(data_utils.DAY_PROGRESS, all_variables)
  162. @parameterized.named_parameters(
  163. dict(testcase_name="missing_datetime", coord_name="lon"),
  164. dict(testcase_name="missing_lon", coord_name="datetime"),
  165. )
  166. def test_add_derived_vars_missing_coordinate_raises_value_error(
  167. self, coord_name
  168. ):
  169. with self.subTest(f"Missing {coord_name} coordinate"):
  170. data = xarray.Dataset(
  171. data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
  172. coords={
  173. coord_name: np.array([0.0, 0.5]),
  174. },
  175. )
  176. with self.assertRaises(ValueError):
  177. data_utils.add_derived_vars(data)
  178. if __name__ == "__main__":
  179. absltest.main()