model_utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Utilities for building models."""
  15. from typing import Mapping, Optional, Tuple
  16. import numpy as np
  17. from scipy.spatial import transform
  18. import xarray
  19. def get_graph_spatial_features(
  20. *, node_lat: np.ndarray, node_lon: np.ndarray,
  21. senders: np.ndarray, receivers: np.ndarray,
  22. add_node_positions: bool,
  23. add_node_latitude: bool,
  24. add_node_longitude: bool,
  25. add_relative_positions: bool,
  26. relative_longitude_local_coordinates: bool,
  27. relative_latitude_local_coordinates: bool,
  28. sine_cosine_encoding: bool = False,
  29. encoding_num_freqs: int = 10,
  30. encoding_multiplicative_factor: float = 1.2,
  31. ) -> Tuple[np.ndarray, np.ndarray]:
  32. """Computes spatial features for the nodes.
  33. Args:
  34. node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
  35. node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
  36. senders: Sender indices of shape [num_edges]
  37. receivers: Receiver indices of shape [num_edges]
  38. add_node_positions: Add unit norm absolute positions.
  39. add_node_latitude: Add a feature for latitude (cos(90 - lat))
  40. Note even if this is set to False, the model may be able to infer the
  41. longitude from relative features, unless
  42. `relative_latitude_local_coordinates` is also True, or if there is any
  43. bias on the relative edge sizes for different longitudes.
  44. add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
  45. Note even if this is set to False, the model may be able to infer the
  46. longitude from relative features, unless
  47. `relative_longitude_local_coordinates` is also True, or if there is any
  48. bias on the relative edge sizes for different longitudes.
  49. add_relative_positions: Whether to relative positions in R3 to the edges.
  50. relative_longitude_local_coordinates: If True, relative positions are
  51. computed in a local space where the receiver is at 0 longitude.
  52. relative_latitude_local_coordinates: If True, relative positions are
  53. computed in a local space where the receiver is at 0 latitude.
  54. sine_cosine_encoding: If True, we will transform the node/edge features
  55. with sine and cosine functions, similar to NERF.
  56. encoding_num_freqs: frequency parameter
  57. encoding_multiplicative_factor: used for calculating the frequency.
  58. Returns:
  59. Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
  60. with node and edge features.
  61. """
  62. num_nodes = node_lat.shape[0]
  63. num_edges = senders.shape[0]
  64. dtype = node_lat.dtype
  65. node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
  66. # Computing some node features.
  67. node_features = []
  68. if add_node_positions:
  69. # Already in [-1, 1.] range.
  70. node_features.extend(spherical_to_cartesian(node_phi, node_theta))
  71. if add_node_latitude:
  72. # Using the cos of theta.
  73. # From 1. (north pole) to -1 (south pole).
  74. node_features.append(np.cos(node_theta))
  75. if add_node_longitude:
  76. # Using the cos and sin, which is already normalized.
  77. node_features.append(np.cos(node_phi))
  78. node_features.append(np.sin(node_phi))
  79. if not node_features:
  80. node_features = np.zeros([num_nodes, 0], dtype=dtype)
  81. else:
  82. node_features = np.stack(node_features, axis=-1)
  83. # Computing some edge features.
  84. edge_features = []
  85. if add_relative_positions:
  86. relative_position = get_relative_position_in_receiver_local_coordinates(
  87. node_phi=node_phi,
  88. node_theta=node_theta,
  89. senders=senders,
  90. receivers=receivers,
  91. latitude_local_coordinates=relative_latitude_local_coordinates,
  92. longitude_local_coordinates=relative_longitude_local_coordinates
  93. )
  94. # Note this is L2 distance in 3d space, rather than geodesic distance.
  95. relative_edge_distances = np.linalg.norm(
  96. relative_position, axis=-1, keepdims=True)
  97. # Normalize to the maximum edge distance. Note that we expect to always
  98. # have an edge that goes in the opposite direction of any given edge
  99. # so the distribution of relative positions should be symmetric around
  100. # zero. So by scaling by the maximum length, we expect all relative
  101. # positions to fall in the [-1., 1.] interval, and all relative distances
  102. # to fall in the [0., 1.] interval.
  103. max_edge_distance = relative_edge_distances.max()
  104. edge_features.append(relative_edge_distances / max_edge_distance)
  105. edge_features.append(relative_position / max_edge_distance)
  106. if not edge_features:
  107. edge_features = np.zeros([num_edges, 0], dtype=dtype)
  108. else:
  109. edge_features = np.concatenate(edge_features, axis=-1)
  110. if sine_cosine_encoding:
  111. def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
  112. freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
  113. phases = freqs * x[..., None]
  114. x_sin = np.sin(phases)
  115. x_cos = np.cos(phases)
  116. x_cat = np.concatenate([x_sin, x_cos], axis=-1)
  117. return x_cat.reshape([x.shape[0], -1])
  118. node_features = sine_cosine_transform(node_features)
  119. edge_features = sine_cosine_transform(edge_features)
  120. return node_features, edge_features
  121. def lat_lon_to_leading_axes(
  122. grid_xarray: xarray.DataArray) -> xarray.DataArray:
  123. """Reorders xarray so lat/lon axes come first."""
  124. # leading + ["lat", "lon"] + trailing
  125. # to
  126. # ["lat", "lon"] + leading + trailing
  127. return grid_xarray.transpose("lat", "lon", ...)
  128. def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
  129. """Reorders xarray so batch/time/level axes come first (if present)."""
  130. # ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
  131. # to
  132. # [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
  133. input_dims = list(grid_xarray.dims)
  134. output_dims = list(input_dims)
  135. for leading_key in ["level", "time", "batch"]: # reverse order for insert
  136. if leading_key in input_dims:
  137. output_dims.remove(leading_key)
  138. output_dims.insert(0, leading_key)
  139. return grid_xarray.transpose(*output_dims)
  140. def lat_lon_deg_to_spherical(node_lat: np.ndarray,
  141. node_lon: np.ndarray,
  142. ) -> Tuple[np.ndarray, np.ndarray]:
  143. phi = np.deg2rad(node_lon)
  144. theta = np.deg2rad(90 - node_lat)
  145. return phi, theta
  146. def spherical_to_lat_lon(phi: np.ndarray,
  147. theta: np.ndarray,
  148. ) -> Tuple[np.ndarray, np.ndarray]:
  149. lon = np.mod(np.rad2deg(phi), 360)
  150. lat = 90 - np.rad2deg(theta)
  151. return lat, lon
  152. def cartesian_to_spherical(x: np.ndarray,
  153. y: np.ndarray,
  154. z: np.ndarray,
  155. ) -> Tuple[np.ndarray, np.ndarray]:
  156. phi = np.arctan2(y, x)
  157. with np.errstate(invalid="ignore"): # circumventing b/253179568
  158. theta = np.arccos(z) # Assuming unit radius.
  159. return phi, theta
  160. def spherical_to_cartesian(
  161. phi: np.ndarray, theta: np.ndarray
  162. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  163. # Assuming unit radius.
  164. return (np.cos(phi)*np.sin(theta),
  165. np.sin(phi)*np.sin(theta),
  166. np.cos(theta))
  167. def get_relative_position_in_receiver_local_coordinates(
  168. node_phi: np.ndarray,
  169. node_theta: np.ndarray,
  170. senders: np.ndarray,
  171. receivers: np.ndarray,
  172. latitude_local_coordinates: bool,
  173. longitude_local_coordinates: bool
  174. ) -> np.ndarray:
  175. """Returns relative position features for the edges.
  176. The relative positions will be computed in a rotated space for a local
  177. coordinate system as defined by the receiver. The relative positions are
  178. simply obtained by subtracting sender position minues receiver position in
  179. that local coordinate system after the rotation in R^3.
  180. Args:
  181. node_phi: [num_nodes] with polar angles.
  182. node_theta: [num_nodes] with azimuthal angles.
  183. senders: [num_edges] with indices.
  184. receivers: [num_edges] with indices.
  185. latitude_local_coordinates: Whether to rotate edges such that in the
  186. positions are computed such that the receiver is always at latitude 0.
  187. longitude_local_coordinates: Whether to rotate edges such that in the
  188. positions are computed such that the receiver is always at longitude 0.
  189. Returns:
  190. Array of relative positions in R3 [num_edges, 3]
  191. """
  192. node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
  193. # No rotation in this case.
  194. if not (latitude_local_coordinates or longitude_local_coordinates):
  195. return node_pos[senders] - node_pos[receivers]
  196. # Get rotation matrices for the local space space for every node.
  197. rotation_matrices = get_rotation_matrices_to_local_coordinates(
  198. reference_phi=node_phi,
  199. reference_theta=node_theta,
  200. rotate_latitude=latitude_local_coordinates,
  201. rotate_longitude=longitude_local_coordinates)
  202. # Each edge will be rotated according to the rotation matrix of its receiver
  203. # node.
  204. edge_rotation_matrices = rotation_matrices[receivers]
  205. # Rotate all nodes to the rotated space of the corresponding edge.
  206. # Note for receivers we can also do the matmul first and the gather second:
  207. # ```
  208. # receiver_pos_in_rotated_space = rotate_with_matrices(
  209. # rotation_matrices, node_pos)[receivers]
  210. # ```
  211. # which is more efficient, however, we do gather first to keep it more
  212. # symmetric with the sender computation.
  213. receiver_pos_in_rotated_space = rotate_with_matrices(
  214. edge_rotation_matrices, node_pos[receivers])
  215. sender_pos_in_in_rotated_space = rotate_with_matrices(
  216. edge_rotation_matrices, node_pos[senders])
  217. # Note, here, that because the rotated space is chosen according to the
  218. # receiver, if:
  219. # * latitude_local_coordinates = True: latitude for the receivers will be
  220. # 0, that is the z coordinate will always be 0.
  221. # * longitude_local_coordinates = True: longitude for the receivers will be
  222. # 0, that is the y coordinate will be 0.
  223. # Now we can just subtract.
  224. # Note we are rotating to a local coordinate system, where the y-z axes are
  225. # parallel to a tangent plane to the sphere, but still remain in a 3d space.
  226. # Note that if both `latitude_local_coordinates` and
  227. # `longitude_local_coordinates` are True, and edges are short,
  228. # then the difference in x coordinate between sender and receiver
  229. # should be small, so we could consider dropping the new x coordinate if
  230. # we wanted to the tangent plane, however in doing so
  231. # we would lose information about the curvature of the mesh, which may be
  232. # important for very coarse meshes.
  233. return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
  234. def get_rotation_matrices_to_local_coordinates(
  235. reference_phi: np.ndarray,
  236. reference_theta: np.ndarray,
  237. rotate_latitude: bool,
  238. rotate_longitude: bool) -> np.ndarray:
  239. """Returns a rotation matrix to rotate to a point based on a reference vector.
  240. The rotation matrix is build such that, a vector in the
  241. same coordinate system at the reference point that points towards the pole
  242. before the rotation, continues to point towards the pole after the rotation.
  243. Args:
  244. reference_phi: [leading_axis] Polar angles of the reference.
  245. reference_theta: [leading_axis] Azimuthal angles of the reference.
  246. rotate_latitude: Whether to produce a rotation matrix that would rotate
  247. R^3 vectors to zero latitude.
  248. rotate_longitude: Whether to produce a rotation matrix that would rotate
  249. R^3 vectors to zero longitude.
  250. Returns:
  251. Matrices of shape [leading_axis] such that when applied to the reference
  252. position with `rotate_with_matrices(rotation_matrices, reference_pos)`
  253. * phi goes to 0. if "rotate_longitude" is True.
  254. * theta goes to np.pi / 2 if "rotate_latitude" is True.
  255. The rotation consists of:
  256. * rotate_latitude = False, rotate_longitude = True:
  257. Latitude preserving rotation.
  258. * rotate_latitude = True, rotate_longitude = True:
  259. Latitude preserving rotation, followed by longitude preserving
  260. rotation.
  261. * rotate_latitude = True, rotate_longitude = False:
  262. Latitude preserving rotation, followed by longitude preserving
  263. rotation, and the inverse of the latitude preserving rotation. Note
  264. this is computationally different from rotating the longitude only
  265. and is. We do it like this, so the polar geodesic curve, continues
  266. to be aligned with one of the axis after the rotation.
  267. """
  268. if rotate_longitude and rotate_latitude:
  269. # We first rotate around the z axis "minus the azimuthal angle", to get the
  270. # point with zero longitude
  271. azimuthal_rotation = - reference_phi
  272. # One then we will do a polar rotation (which can be done along the y
  273. # axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
  274. # to get the point with zero latitude.
  275. polar_rotation = - reference_theta + np.pi/2
  276. return transform.Rotation.from_euler(
  277. "zy", np.stack([azimuthal_rotation, polar_rotation],
  278. axis=1)).as_matrix()
  279. elif rotate_longitude:
  280. # Just like the previous case, but applying only the azimuthal rotation.
  281. azimuthal_rotation = - reference_phi
  282. return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
  283. elif rotate_latitude:
  284. # Just like the first case, but after doing the polar rotation, undoing
  285. # the azimuthal rotation.
  286. azimuthal_rotation = - reference_phi
  287. polar_rotation = - reference_theta + np.pi/2
  288. return transform.Rotation.from_euler(
  289. "zyz", np.stack(
  290. [azimuthal_rotation, polar_rotation, -azimuthal_rotation]
  291. , axis=1)).as_matrix()
  292. else:
  293. raise ValueError(
  294. "At least one of longitude and latitude should be rotated.")
  295. def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
  296. ) -> np.ndarray:
  297. return np.einsum("bji,bi->bj", rotation_matrices, positions)
  298. def get_bipartite_graph_spatial_features(
  299. *,
  300. senders_node_lat: np.ndarray,
  301. senders_node_lon: np.ndarray,
  302. senders: np.ndarray,
  303. receivers_node_lat: np.ndarray,
  304. receivers_node_lon: np.ndarray,
  305. receivers: np.ndarray,
  306. add_node_positions: bool,
  307. add_node_latitude: bool,
  308. add_node_longitude: bool,
  309. add_relative_positions: bool,
  310. edge_normalization_factor: Optional[float] = None,
  311. relative_longitude_local_coordinates: bool,
  312. relative_latitude_local_coordinates: bool,
  313. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  314. """Computes spatial features for the nodes.
  315. This function is almost identical to `get_graph_spatial_features`. The only
  316. difference is that sender nodes and receiver nodes can be in different arrays.
  317. This is necessary to enable combination with typed Graph.
  318. Args:
  319. senders_node_lat: Latitudes in the [-90, 90] interval of shape
  320. [num_sender_nodes]
  321. senders_node_lon: Longitudes in the [0, 360] interval of shape
  322. [num_sender_nodes]
  323. senders: Sender indices of shape [num_edges], indices in [0,
  324. num_sender_nodes)
  325. receivers_node_lat: Latitudes in the [-90, 90] interval of shape
  326. [num_receiver_nodes]
  327. receivers_node_lon: Longitudes in the [0, 360] interval of shape
  328. [num_receiver_nodes]
  329. receivers: Receiver indices of shape [num_edges], indices in [0,
  330. num_receiver_nodes)
  331. add_node_positions: Add unit norm absolute positions.
  332. add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
  333. this is set to False, the model may be able to infer the longitude from
  334. relative features, unless `relative_latitude_local_coordinates` is also
  335. True, or if there is any bias on the relative edge sizes for different
  336. longitudes.
  337. add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
  338. even if this is set to False, the model may be able to infer the longitude
  339. from relative features, unless `relative_longitude_local_coordinates` is
  340. also True, or if there is any bias on the relative edge sizes for
  341. different longitudes.
  342. add_relative_positions: Whether to relative positions in R3 to the edges.
  343. edge_normalization_factor: Allows explicitly controlling edge normalization.
  344. If None, defaults to max edge length. This supports using pre-trained
  345. model weights with a different graph structure to what it was trained on.
  346. relative_longitude_local_coordinates: If True, relative positions are
  347. computed in a local space where the receiver is at 0 longitude.
  348. relative_latitude_local_coordinates: If True, relative positions are
  349. computed in a local space where the receiver is at 0 latitude.
  350. Returns:
  351. Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
  352. with node and edge features.
  353. """
  354. num_senders = senders_node_lat.shape[0]
  355. num_receivers = receivers_node_lat.shape[0]
  356. num_edges = senders.shape[0]
  357. dtype = senders_node_lat.dtype
  358. assert receivers_node_lat.dtype == dtype
  359. senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
  360. senders_node_lat, senders_node_lon)
  361. receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
  362. receivers_node_lat, receivers_node_lon)
  363. # Computing some node features.
  364. senders_node_features = []
  365. receivers_node_features = []
  366. if add_node_positions:
  367. # Already in [-1, 1.] range.
  368. senders_node_features.extend(
  369. spherical_to_cartesian(senders_node_phi, senders_node_theta))
  370. receivers_node_features.extend(
  371. spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
  372. if add_node_latitude:
  373. # Using the cos of theta.
  374. # From 1. (north pole) to -1 (south pole).
  375. senders_node_features.append(np.cos(senders_node_theta))
  376. receivers_node_features.append(np.cos(receivers_node_theta))
  377. if add_node_longitude:
  378. # Using the cos and sin, which is already normalized.
  379. senders_node_features.append(np.cos(senders_node_phi))
  380. senders_node_features.append(np.sin(senders_node_phi))
  381. receivers_node_features.append(np.cos(receivers_node_phi))
  382. receivers_node_features.append(np.sin(receivers_node_phi))
  383. if not senders_node_features:
  384. senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
  385. receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
  386. else:
  387. senders_node_features = np.stack(senders_node_features, axis=-1)
  388. receivers_node_features = np.stack(receivers_node_features, axis=-1)
  389. # Computing some edge features.
  390. edge_features = []
  391. if add_relative_positions:
  392. relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
  393. senders_node_phi=senders_node_phi,
  394. senders_node_theta=senders_node_theta,
  395. receivers_node_phi=receivers_node_phi,
  396. receivers_node_theta=receivers_node_theta,
  397. senders=senders,
  398. receivers=receivers,
  399. latitude_local_coordinates=relative_latitude_local_coordinates,
  400. longitude_local_coordinates=relative_longitude_local_coordinates)
  401. # Note this is L2 distance in 3d space, rather than geodesic distance.
  402. relative_edge_distances = np.linalg.norm(
  403. relative_position, axis=-1, keepdims=True)
  404. if edge_normalization_factor is None:
  405. # Normalize to the maximum edge distance. Note that we expect to always
  406. # have an edge that goes in the opposite direction of any given edge
  407. # so the distribution of relative positions should be symmetric around
  408. # zero. So by scaling by the maximum length, we expect all relative
  409. # positions to fall in the [-1., 1.] interval, and all relative distances
  410. # to fall in the [0., 1.] interval.
  411. edge_normalization_factor = relative_edge_distances.max()
  412. edge_features.append(relative_edge_distances / edge_normalization_factor)
  413. edge_features.append(relative_position / edge_normalization_factor)
  414. if not edge_features:
  415. edge_features = np.zeros([num_edges, 0], dtype=dtype)
  416. else:
  417. edge_features = np.concatenate(edge_features, axis=-1)
  418. return senders_node_features, receivers_node_features, edge_features
  419. def get_bipartite_relative_position_in_receiver_local_coordinates(
  420. senders_node_phi: np.ndarray,
  421. senders_node_theta: np.ndarray,
  422. senders: np.ndarray,
  423. receivers_node_phi: np.ndarray,
  424. receivers_node_theta: np.ndarray,
  425. receivers: np.ndarray,
  426. latitude_local_coordinates: bool,
  427. longitude_local_coordinates: bool) -> np.ndarray:
  428. """Returns relative position features for the edges.
  429. This function is equivalent to
  430. `get_relative_position_in_receiver_local_coordinates`, but adapted to work
  431. with bipartite typed graphs.
  432. The relative positions will be computed in a rotated space for a local
  433. coordinate system as defined by the receiver. The relative positions are
  434. simply obtained by subtracting sender position minues receiver position in
  435. that local coordinate system after the rotation in R^3.
  436. Args:
  437. senders_node_phi: [num_sender_nodes] with polar angles.
  438. senders_node_theta: [num_sender_nodes] with azimuthal angles.
  439. senders: [num_edges] with indices into sender nodes.
  440. receivers_node_phi: [num_sender_nodes] with polar angles.
  441. receivers_node_theta: [num_sender_nodes] with azimuthal angles.
  442. receivers: [num_edges] with indices into receiver nodes.
  443. latitude_local_coordinates: Whether to rotate edges such that in the
  444. positions are computed such that the receiver is always at latitude 0.
  445. longitude_local_coordinates: Whether to rotate edges such that in the
  446. positions are computed such that the receiver is always at longitude 0.
  447. Returns:
  448. Array of relative positions in R3 [num_edges, 3]
  449. """
  450. senders_node_pos = np.stack(
  451. spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
  452. receivers_node_pos = np.stack(
  453. spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
  454. # No rotation in this case.
  455. if not (latitude_local_coordinates or longitude_local_coordinates):
  456. return senders_node_pos[senders] - receivers_node_pos[receivers]
  457. # Get rotation matrices for the local space space for every receiver node.
  458. receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
  459. reference_phi=receivers_node_phi,
  460. reference_theta=receivers_node_theta,
  461. rotate_latitude=latitude_local_coordinates,
  462. rotate_longitude=longitude_local_coordinates)
  463. # Each edge will be rotated according to the rotation matrix of its receiver
  464. # node.
  465. edge_rotation_matrices = receiver_rotation_matrices[receivers]
  466. # Rotate all nodes to the rotated space of the corresponding edge.
  467. # Note for receivers we can also do the matmul first and the gather second:
  468. # ```
  469. # receiver_pos_in_rotated_space = rotate_with_matrices(
  470. # rotation_matrices, node_pos)[receivers]
  471. # ```
  472. # which is more efficient, however, we do gather first to keep it more
  473. # symmetric with the sender computation.
  474. receiver_pos_in_rotated_space = rotate_with_matrices(
  475. edge_rotation_matrices, receivers_node_pos[receivers])
  476. sender_pos_in_in_rotated_space = rotate_with_matrices(
  477. edge_rotation_matrices, senders_node_pos[senders])
  478. # Note, here, that because the rotated space is chosen according to the
  479. # receiver, if:
  480. # * latitude_local_coordinates = True: latitude for the receivers will be
  481. # 0, that is the z coordinate will always be 0.
  482. # * longitude_local_coordinates = True: longitude for the receivers will be
  483. # 0, that is the y coordinate will be 0.
  484. # Now we can just subtract.
  485. # Note we are rotating to a local coordinate system, where the y-z axes are
  486. # parallel to a tangent plane to the sphere, but still remain in a 3d space.
  487. # Note that if both `latitude_local_coordinates` and
  488. # `longitude_local_coordinates` are True, and edges are short,
  489. # then the difference in x coordinate between sender and receiver
  490. # should be small, so we could consider dropping the new x coordinate if
  491. # we wanted to the tangent plane, however in doing so
  492. # we would lose information about the curvature of the mesh, which may be
  493. # important for very coarse meshes.
  494. return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
  495. def variable_to_stacked(
  496. variable: xarray.Variable,
  497. sizes: Mapping[str, int],
  498. preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
  499. ) -> xarray.Variable:
  500. """Converts an xarray.Variable to preserved_dims + ("channels",).
  501. Any dimensions other than those included in preserved_dims get stacked into a
  502. final "channels" dimension. If any of the preserved_dims are missing then they
  503. are added, with the data broadcast/tiled to match the sizes specified in
  504. `sizes`.
  505. Args:
  506. variable: An xarray.Variable.
  507. sizes: Mapping including sizes for any dimensions which are not present in
  508. `variable` but are needed for the output. This may be needed for example
  509. for a static variable with only ("lat", "lon") dims, or if you want to
  510. encode just the latitude coordinates (a variable with dims ("lat",)).
  511. preserved_dims: dimensions of variable to not be folded in channels.
  512. Returns:
  513. An xarray.Variable with dimensions preserved_dims + ("channels",).
  514. """
  515. stack_to_channels_dims = [
  516. d for d in variable.dims if d not in preserved_dims]
  517. if stack_to_channels_dims:
  518. variable = variable.stack(channels=stack_to_channels_dims)
  519. dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
  520. dims["channels"] = variable.sizes.get("channels", 1)
  521. return variable.set_dims(dims)
  522. def dataset_to_stacked(
  523. dataset: xarray.Dataset,
  524. sizes: Optional[Mapping[str, int]] = None,
  525. preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
  526. ) -> xarray.DataArray:
  527. """Converts an xarray.Dataset to a single stacked array.
  528. This takes each consistuent data_var, converts it into BHWC layout
  529. using `variable_to_stacked`, then concats them all along the channels axis.
  530. Args:
  531. dataset: An xarray.Dataset.
  532. sizes: Mapping including sizes for any dimensions which are not present in
  533. the `dataset` but are needed for the output. See variable_to_stacked.
  534. preserved_dims: dimensions from the dataset that should not be folded in
  535. the predictions channels.
  536. Returns:
  537. An xarray.DataArray with dimensions preserved_dims + ("channels",).
  538. Existing coordinates for preserved_dims axes will be preserved, however
  539. there will be no coordinates for "channels".
  540. """
  541. data_vars = [
  542. variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
  543. preserved_dims)
  544. for name in sorted(dataset.data_vars.keys())
  545. ]
  546. coords = {
  547. dim: coord
  548. for dim, coord in dataset.coords.items()
  549. if dim in preserved_dims
  550. }
  551. return xarray.DataArray(
  552. data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
  553. def stacked_to_dataset(
  554. stacked_array: xarray.Variable,
  555. template_dataset: xarray.Dataset,
  556. preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
  557. ) -> xarray.Dataset:
  558. """The inverse of dataset_to_stacked.
  559. Requires a template dataset to demonstrate the variables/shapes/coordinates
  560. required.
  561. All variables must have preserved_dims dimensions.
  562. Args:
  563. stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
  564. would if it was asked to encode `template_dataset`.
  565. template_dataset: A template Dataset (or other mapping of DataArrays)
  566. demonstrating the shape of output required (variables, shapes,
  567. coordinates etc).
  568. preserved_dims: dimensions from the target_template that were not folded in
  569. the predictions channels. The preserved_dims need to be a subset of the
  570. dims of all the variables of template_dataset.
  571. Returns:
  572. An xarray.Dataset (or other mapping of DataArrays) with the same shape and
  573. type as template_dataset.
  574. """
  575. unstack_from_channels_sizes = {}
  576. var_names = sorted(template_dataset.keys())
  577. for name in var_names:
  578. template_var = template_dataset[name]
  579. if not all(dim in template_var.dims for dim in preserved_dims):
  580. raise ValueError(
  581. f"stacked_to_dataset requires all Variables to have {preserved_dims} "
  582. f"dimensions, but found only {template_var.dims}.")
  583. unstack_from_channels_sizes[name] = {
  584. dim: size for dim, size in template_var.sizes.items()
  585. if dim not in preserved_dims}
  586. channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
  587. for name, unstack_sizes in unstack_from_channels_sizes.items()}
  588. total_expected_channels = sum(channels.values())
  589. found_channels = stacked_array.sizes["channels"]
  590. if total_expected_channels != found_channels:
  591. raise ValueError(
  592. f"Expected {total_expected_channels} channels but found "
  593. f"{found_channels}, when trying to convert a stacked array of shape "
  594. f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
  595. data_vars = {}
  596. index = 0
  597. for name in var_names:
  598. template_var = template_dataset[name]
  599. var = stacked_array.isel({"channels": slice(index, index + channels[name])})
  600. index += channels[name]
  601. var = var.unstack({"channels": unstack_from_channels_sizes[name]})
  602. var = var.transpose(*template_var.dims)
  603. data_vars[name] = xarray.DataArray(
  604. data=var,
  605. coords=template_var.coords,
  606. # This might not always be the same as the name it's keyed under; it
  607. # will refer to the original variable name, whereas the key might be
  608. # some alias e.g. temperature_850 under which it should be logged:
  609. name=template_var.name,
  610. )
  611. return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count