123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Tests for icosahedral_mesh."""
- from absl.testing import absltest
- from absl.testing import parameterized
- import chex
- from graphcast import icosahedral_mesh
- import numpy as np
- def _get_mesh_spec(splits: int):
- """Returns size of the final icosahedral mesh resulting from the splitting."""
- num_vertices = 12
- num_faces = 20
- for _ in range(splits):
- # Each previous face adds three new vertices, but each vertex is shared
- # by two faces.
- num_vertices += num_faces * 3 // 2
- num_faces *= 4
- return num_vertices, num_faces
- class IcosahedralMeshTest(parameterized.TestCase):
- def test_icosahedron(self):
- mesh = icosahedral_mesh.get_icosahedron()
- _assert_valid_mesh(
- mesh, num_expected_vertices=12, num_expected_faces=20)
- @parameterized.parameters(list(range(5)))
- def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
- meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
- splits=splits)
- prev_vertices = None
- for mesh_i, mesh in enumerate(meshes):
- # Check that `mesh` is valid.
- num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
- _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
- # Check that the first N vertices from this mesh match all of the
- # vertices from the previous mesh.
- if prev_vertices is not None:
- leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
- np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
- # Increase the expected/previous values for the next iteration.
- if mesh_i < len(meshes) - 1:
- prev_vertices = mesh.vertices
- @parameterized.parameters(list(range(4)))
- def test_merge_meshes(self, splits):
- mesh_hierarchy = (
- icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
- splits=splits))
- mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
- expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
- np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
- np.testing.assert_array_equal(mesh.faces, expected_faces)
- def test_faces_to_edges(self):
- faces = np.array([[0, 1, 2],
- [3, 4, 5]])
- # This also documents the order of the edges returned by the method.
- expected_edges = np.array(
- [[0, 1],
- [3, 4],
- [1, 2],
- [4, 5],
- [2, 0],
- [5, 3]])
- expected_senders = expected_edges[:, 0]
- expected_receivers = expected_edges[:, 1]
- senders, receivers = icosahedral_mesh.faces_to_edges(faces)
- np.testing.assert_array_equal(senders, expected_senders)
- np.testing.assert_array_equal(receivers, expected_receivers)
- def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
- vertices = mesh.vertices
- faces = mesh.faces
- chex.assert_shape(vertices, [num_expected_vertices, 3])
- chex.assert_shape(faces, [num_expected_faces, 3])
- # Vertices norm should be 1.
- vertices_norm = np.linalg.norm(vertices, axis=-1)
- np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
- _assert_positive_face_orientation(vertices, faces)
- def _assert_positive_face_orientation(vertices, faces):
- # Obtain a unit vector that points, in the direction of the face.
- face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
- vertices[faces[:, 2]] - vertices[faces[:, 1]])
- face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
- # And a unit vector pointing from the origin to the center of the face.
- face_centers = vertices[faces].mean(1)
- face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
- # Positive orientation means those two vectors should be parallel
- # (dot product, 1), and not anti-parallel (dot product, -1).
- dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
- # Check that the face normal is parallel to the vector that joins the center
- # of the face to the center of the sphere. Note we need a small tolerance
- # because some discretizations are not exactly uniform, so it will not be
- # exactly parallel.
- np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
- if __name__ == "__main__":
- absltest.main()
|