123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- # Copyright 2023 DeepMind Technologies Limited.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS-IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Tools for converting from regular grids on a sphere, to triangular meshes."""
- from graphcast import icosahedral_mesh
- import numpy as np
- import scipy
- import trimesh
- def _grid_lat_lon_to_coordinates(
- grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
- """Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
- # Convert to spherical coordinates phi and theta defined in the grid.
- # Each [num_latitude_points, num_longitude_points]
- phi_grid, theta_grid = np.meshgrid(
- np.deg2rad(grid_longitude),
- np.deg2rad(90 - grid_latitude))
- # [num_latitude_points, num_longitude_points, 3]
- # Note this assumes unit radius, since for now we model the earth as a
- # sphere of unit radius, and keep any vertical dimension as a regular grid.
- return np.stack(
- [np.cos(phi_grid)*np.sin(theta_grid),
- np.sin(phi_grid)*np.sin(theta_grid),
- np.cos(theta_grid)], axis=-1)
- def radius_query_indices(
- *,
- grid_latitude: np.ndarray,
- grid_longitude: np.ndarray,
- mesh: icosahedral_mesh.TriangularMesh,
- radius: float) -> tuple[np.ndarray, np.ndarray]:
- """Returns mesh-grid edge indices for radius query.
- Args:
- grid_latitude: Latitude values for the grid [num_lat_points]
- grid_longitude: Longitude values for the grid [num_lon_points]
- mesh: Mesh object.
- radius: Radius of connectivity in R3. for a sphere of unit radius.
- Returns:
- tuple with `grid_indices` and `mesh_indices` indicating edges between the
- grid and the mesh such that the distances in a straight line (not geodesic)
- are smaller than or equal to `radius`.
- * grid_indices: Indices of shape [num_edges], that index into a
- [num_lat_points, num_lon_points] grid, after flattening the leading axes.
- * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
- """
- # [num_grid_points=num_lat_points * num_lon_points, 3]
- grid_positions = _grid_lat_lon_to_coordinates(
- grid_latitude, grid_longitude).reshape([-1, 3])
- # [num_mesh_points, 3]
- mesh_positions = mesh.vertices
- kd_tree = scipy.spatial.cKDTree(mesh_positions)
- # [num_grid_points, num_mesh_points_per_grid_point]
- # Note `num_mesh_points_per_grid_point` is not constant, so this is a list
- # of arrays, rather than a 2d array.
- query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
- grid_edge_indices = []
- mesh_edge_indices = []
- for grid_index, mesh_neighbors in enumerate(query_indices):
- grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
- mesh_edge_indices.append(mesh_neighbors)
- # [num_edges]
- grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
- mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
- return grid_edge_indices, mesh_edge_indices
- def in_mesh_triangle_indices(
- *,
- grid_latitude: np.ndarray,
- grid_longitude: np.ndarray,
- mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
- """Returns mesh-grid edge indices for grid points contained in mesh triangles.
- Args:
- grid_latitude: Latitude values for the grid [num_lat_points]
- grid_longitude: Longitude values for the grid [num_lon_points]
- mesh: Mesh object.
- Returns:
- tuple with `grid_indices` and `mesh_indices` indicating edges between the
- grid and the mesh vertices of the triangle that contain each grid point.
- The number of edges is always num_lat_points * num_lon_points * 3
- * grid_indices: Indices of shape [num_edges], that index into a
- [num_lat_points, num_lon_points] grid, after flattening the leading axes.
- * mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
- """
- # [num_grid_points=num_lat_points * num_lon_points, 3]
- grid_positions = _grid_lat_lon_to_coordinates(
- grid_latitude, grid_longitude).reshape([-1, 3])
- mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
- # [num_grid_points] with mesh face indices for each grid point.
- _, _, query_face_indices = trimesh.proximity.closest_point(
- mesh_trimesh, grid_positions)
- # [num_grid_points, 3] with mesh node indices for each grid point.
- mesh_edge_indices = mesh.faces[query_face_indices]
- # [num_grid_points, 3] with grid node indices, where every row simply contains
- # the row (grid_point) index.
- grid_indices = np.arange(grid_positions.shape[0])
- grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
- # Flatten to get a regular list.
- # [num_edges=num_grid_points*3]
- mesh_edge_indices = mesh_edge_indices.reshape([-1])
- grid_edge_indices = grid_edge_indices.reshape([-1])
- return grid_edge_indices, mesh_edge_indices
|