grid_mesh_connectivity.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tools for converting from regular grids on a sphere, to triangular meshes."""
  15. from graphcast import icosahedral_mesh
  16. import numpy as np
  17. import scipy
  18. import trimesh
  19. def _grid_lat_lon_to_coordinates(
  20. grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
  21. """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
  22. # Convert to spherical coordinates phi and theta defined in the grid.
  23. # Each [num_latitude_points, num_longitude_points]
  24. phi_grid, theta_grid = np.meshgrid(
  25. np.deg2rad(grid_longitude),
  26. np.deg2rad(90 - grid_latitude))
  27. # [num_latitude_points, num_longitude_points, 3]
  28. # Note this assumes unit radius, since for now we model the earth as a
  29. # sphere of unit radius, and keep any vertical dimension as a regular grid.
  30. return np.stack(
  31. [np.cos(phi_grid)*np.sin(theta_grid),
  32. np.sin(phi_grid)*np.sin(theta_grid),
  33. np.cos(theta_grid)], axis=-1)
  34. def radius_query_indices(
  35. *,
  36. grid_latitude: np.ndarray,
  37. grid_longitude: np.ndarray,
  38. mesh: icosahedral_mesh.TriangularMesh,
  39. radius: float) -> tuple[np.ndarray, np.ndarray]:
  40. """Returns mesh-grid edge indices for radius query.
  41. Args:
  42. grid_latitude: Latitude values for the grid [num_lat_points]
  43. grid_longitude: Longitude values for the grid [num_lon_points]
  44. mesh: Mesh object.
  45. radius: Radius of connectivity in R3. for a sphere of unit radius.
  46. Returns:
  47. tuple with `grid_indices` and `mesh_indices` indicating edges between the
  48. grid and the mesh such that the distances in a straight line (not geodesic)
  49. are smaller than or equal to `radius`.
  50. * grid_indices: Indices of shape [num_edges], that index into a
  51. [num_lat_points, num_lon_points] grid, after flattening the leading axes.
  52. * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
  53. """
  54. # [num_grid_points=num_lat_points * num_lon_points, 3]
  55. grid_positions = _grid_lat_lon_to_coordinates(
  56. grid_latitude, grid_longitude).reshape([-1, 3])
  57. # [num_mesh_points, 3]
  58. mesh_positions = mesh.vertices
  59. kd_tree = scipy.spatial.cKDTree(mesh_positions)
  60. # [num_grid_points, num_mesh_points_per_grid_point]
  61. # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
  62. # of arrays, rather than a 2d array.
  63. query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
  64. grid_edge_indices = []
  65. mesh_edge_indices = []
  66. for grid_index, mesh_neighbors in enumerate(query_indices):
  67. grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
  68. mesh_edge_indices.append(mesh_neighbors)
  69. # [num_edges]
  70. grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
  71. mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
  72. return grid_edge_indices, mesh_edge_indices
  73. def in_mesh_triangle_indices(
  74. *,
  75. grid_latitude: np.ndarray,
  76. grid_longitude: np.ndarray,
  77. mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
  78. """Returns mesh-grid edge indices for grid points contained in mesh triangles.
  79. Args:
  80. grid_latitude: Latitude values for the grid [num_lat_points]
  81. grid_longitude: Longitude values for the grid [num_lon_points]
  82. mesh: Mesh object.
  83. Returns:
  84. tuple with `grid_indices` and `mesh_indices` indicating edges between the
  85. grid and the mesh vertices of the triangle that contain each grid point.
  86. The number of edges is always num_lat_points * num_lon_points * 3
  87. * grid_indices: Indices of shape [num_edges], that index into a
  88. [num_lat_points, num_lon_points] grid, after flattening the leading axes.
  89. * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
  90. """
  91. # [num_grid_points=num_lat_points * num_lon_points, 3]
  92. grid_positions = _grid_lat_lon_to_coordinates(
  93. grid_latitude, grid_longitude).reshape([-1, 3])
  94. mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
  95. # [num_grid_points] with mesh face indices for each grid point.
  96. _, _, query_face_indices = trimesh.proximity.closest_point(
  97. mesh_trimesh, grid_positions)
  98. # [num_grid_points, 3] with mesh node indices for each grid point.
  99. mesh_edge_indices = mesh.faces[query_face_indices]
  100. # [num_grid_points, 3] with grid node indices, where every row simply contains
  101. # the row (grid_point) index.
  102. grid_indices = np.arange(grid_positions.shape[0])
  103. grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
  104. # Flatten to get a regular list.
  105. # [num_edges=num_grid_points*3]
  106. mesh_edge_indices = mesh_edge_indices.reshape([-1])
  107. grid_edge_indices = grid_edge_indices.reshape([-1])
  108. return grid_edge_indices, mesh_edge_indices