grid_mesh_connectivity_test.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for graphcast.grid_mesh_connectivity."""
  15. from absl.testing import absltest
  16. from graphcast import grid_mesh_connectivity
  17. from graphcast import icosahedral_mesh
  18. import numpy as np
  19. class GridMeshConnectivityTest(absltest.TestCase):
  20. def test_grid_lat_lon_to_coordinates(self):
  21. # Intervals of 30 degrees.
  22. grid_latitude = np.array([-45., 0., 45])
  23. grid_longitude = np.array([0., 90., 180., 270.])
  24. inv_sqrt2 = 1 / np.sqrt(2)
  25. expected_coordinates = np.array([
  26. [[inv_sqrt2, 0., -inv_sqrt2],
  27. [0., inv_sqrt2, -inv_sqrt2],
  28. [-inv_sqrt2, 0., -inv_sqrt2],
  29. [0., -inv_sqrt2, -inv_sqrt2]],
  30. [[1., 0., 0.],
  31. [0., 1., 0.],
  32. [-1., 0., 0.],
  33. [0., -1., 0.]],
  34. [[inv_sqrt2, 0., inv_sqrt2],
  35. [0., inv_sqrt2, inv_sqrt2],
  36. [-inv_sqrt2, 0., inv_sqrt2],
  37. [0., -inv_sqrt2, inv_sqrt2]],
  38. ])
  39. coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
  40. grid_latitude, grid_longitude)
  41. np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
  42. def test_radius_query_indices_smoke(self):
  43. # TODO(alvarosg): Add non-smoke test?
  44. grid_latitude = np.linspace(-75, 75, 6)
  45. grid_longitude = np.arange(12) * 30.
  46. mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
  47. splits=3)[-1]
  48. grid_mesh_connectivity.radius_query_indices(
  49. grid_latitude=grid_latitude,
  50. grid_longitude=grid_longitude,
  51. mesh=mesh, radius=0.2)
  52. def test_in_mesh_triangle_indices_smoke(self):
  53. # TODO(alvarosg): Add non-smoke test?
  54. grid_latitude = np.linspace(-75, 75, 6)
  55. grid_longitude = np.arange(12) * 30.
  56. mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
  57. splits=3)[-1]
  58. grid_mesh_connectivity.in_mesh_triangle_indices(
  59. grid_latitude=grid_latitude,
  60. grid_longitude=grid_longitude,
  61. mesh=mesh)
  62. if __name__ == "__main__":
  63. absltest.main()