icosahedral_mesh_test.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for icosahedral_mesh."""
  15. from absl.testing import absltest
  16. from absl.testing import parameterized
  17. import chex
  18. from graphcast import icosahedral_mesh
  19. import numpy as np
  20. def _get_mesh_spec(splits: int):
  21. """Returns size of the final icosahedral mesh resulting from the splitting."""
  22. num_vertices = 12
  23. num_faces = 20
  24. for _ in range(splits):
  25. # Each previous face adds three new vertices, but each vertex is shared
  26. # by two faces.
  27. num_vertices += num_faces * 3 // 2
  28. num_faces *= 4
  29. return num_vertices, num_faces
  30. class IcosahedralMeshTest(parameterized.TestCase):
  31. def test_icosahedron(self):
  32. mesh = icosahedral_mesh.get_icosahedron()
  33. _assert_valid_mesh(
  34. mesh, num_expected_vertices=12, num_expected_faces=20)
  35. @parameterized.parameters(list(range(5)))
  36. def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
  37. meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
  38. splits=splits)
  39. prev_vertices = None
  40. for mesh_i, mesh in enumerate(meshes):
  41. # Check that `mesh` is valid.
  42. num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
  43. _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
  44. # Check that the first N vertices from this mesh match all of the
  45. # vertices from the previous mesh.
  46. if prev_vertices is not None:
  47. leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
  48. np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
  49. # Increase the expected/previous values for the next iteration.
  50. if mesh_i < len(meshes) - 1:
  51. prev_vertices = mesh.vertices
  52. @parameterized.parameters(list(range(4)))
  53. def test_merge_meshes(self, splits):
  54. mesh_hierarchy = (
  55. icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
  56. splits=splits))
  57. mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
  58. expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
  59. np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
  60. np.testing.assert_array_equal(mesh.faces, expected_faces)
  61. def test_faces_to_edges(self):
  62. faces = np.array([[0, 1, 2],
  63. [3, 4, 5]])
  64. # This also documents the order of the edges returned by the method.
  65. expected_edges = np.array(
  66. [[0, 1],
  67. [3, 4],
  68. [1, 2],
  69. [4, 5],
  70. [2, 0],
  71. [5, 3]])
  72. expected_senders = expected_edges[:, 0]
  73. expected_receivers = expected_edges[:, 1]
  74. senders, receivers = icosahedral_mesh.faces_to_edges(faces)
  75. np.testing.assert_array_equal(senders, expected_senders)
  76. np.testing.assert_array_equal(receivers, expected_receivers)
  77. def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
  78. vertices = mesh.vertices
  79. faces = mesh.faces
  80. chex.assert_shape(vertices, [num_expected_vertices, 3])
  81. chex.assert_shape(faces, [num_expected_faces, 3])
  82. # Vertices norm should be 1.
  83. vertices_norm = np.linalg.norm(vertices, axis=-1)
  84. np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
  85. _assert_positive_face_orientation(vertices, faces)
  86. def _assert_positive_face_orientation(vertices, faces):
  87. # Obtain a unit vector that points, in the direction of the face.
  88. face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
  89. vertices[faces[:, 2]] - vertices[faces[:, 1]])
  90. face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
  91. # And a unit vector pointing from the origin to the center of the face.
  92. face_centers = vertices[faces].mean(1)
  93. face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
  94. # Positive orientation means those two vectors should be parallel
  95. # (dot product, 1), and not anti-parallel (dot product, -1).
  96. dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
  97. # Check that the face normal is parallel to the vector that joins the center
  98. # of the face to the center of the sphere. Note we need a small tolerance
  99. # because some discretizations are not exactly uniform, so it will not be
  100. # exactly parallel.
  101. np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
  102. if __name__ == "__main__":
  103. absltest.main()