graphcast.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """A predictor that runs multiple graph neural networks on mesh data.
  15. It learns to interpolate between the grid and the mesh nodes, with the loss
  16. and the rollouts ultimately computed at the grid level.
  17. It uses ideas similar to those in Keisler (2022):
  18. Reference:
  19. https://arxiv.org/pdf/2202.07575.pdf
  20. It assumes data across time and level is stacked, and operates only operates in
  21. a 2D mesh over latitudes and longitudes.
  22. """
  23. from typing import Any, Callable, Mapping, Optional
  24. import chex
  25. from graphcast import deep_typed_graph_net
  26. from graphcast import grid_mesh_connectivity
  27. from graphcast import icosahedral_mesh
  28. from graphcast import losses
  29. from graphcast import model_utils
  30. from graphcast import predictor_base
  31. from graphcast import typed_graph
  32. from graphcast import xarray_jax
  33. import jax.numpy as jnp
  34. import jraph
  35. import numpy as np
  36. import xarray
  37. Kwargs = Mapping[str, Any]
  38. GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple]
  39. # https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5
  40. PRESSURE_LEVELS_ERA5_37 = (
  41. 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300,
  42. 350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900,
  43. 925, 950, 975, 1000)
  44. # https://www.ecmwf.int/en/forecasts/datasets/set-i
  45. PRESSURE_LEVELS_HRES_25 = (
  46. 1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600,
  47. 700, 800, 850, 900, 925, 950, 1000)
  48. # https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2020MS002203
  49. PRESSURE_LEVELS_WEATHERBENCH_13 = (
  50. 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000)
  51. PRESSURE_LEVELS = {
  52. 13: PRESSURE_LEVELS_WEATHERBENCH_13,
  53. 25: PRESSURE_LEVELS_HRES_25,
  54. 37: PRESSURE_LEVELS_ERA5_37,
  55. }
  56. # The list of all possible atmospheric variables. Taken from:
  57. # https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Table9
  58. ALL_ATMOSPHERIC_VARS = (
  59. "potential_vorticity",
  60. "specific_rain_water_content",
  61. "specific_snow_water_content",
  62. "geopotential",
  63. "temperature",
  64. "u_component_of_wind",
  65. "v_component_of_wind",
  66. "specific_humidity",
  67. "vertical_velocity",
  68. "vorticity",
  69. "divergence",
  70. "relative_humidity",
  71. "ozone_mass_mixing_ratio",
  72. "specific_cloud_liquid_water_content",
  73. "specific_cloud_ice_water_content",
  74. "fraction_of_cloud_cover",
  75. )
  76. TARGET_SURFACE_VARS = (
  77. "2m_temperature",
  78. "mean_sea_level_pressure",
  79. "10m_v_component_of_wind",
  80. "10m_u_component_of_wind",
  81. "total_precipitation_6hr",
  82. )
  83. TARGET_SURFACE_NO_PRECIP_VARS = (
  84. "2m_temperature",
  85. "mean_sea_level_pressure",
  86. "10m_v_component_of_wind",
  87. "10m_u_component_of_wind",
  88. )
  89. TARGET_ATMOSPHERIC_VARS = (
  90. "temperature",
  91. "geopotential",
  92. "u_component_of_wind",
  93. "v_component_of_wind",
  94. "vertical_velocity",
  95. "specific_humidity",
  96. )
  97. TARGET_ATMOSPHERIC_NO_W_VARS = (
  98. "temperature",
  99. "geopotential",
  100. "u_component_of_wind",
  101. "v_component_of_wind",
  102. "specific_humidity",
  103. )
  104. EXTERNAL_FORCING_VARS = (
  105. "toa_incident_solar_radiation",
  106. )
  107. GENERATED_FORCING_VARS = (
  108. "year_progress_sin",
  109. "year_progress_cos",
  110. "day_progress_sin",
  111. "day_progress_cos",
  112. )
  113. FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS
  114. STATIC_VARS = (
  115. "geopotential_at_surface",
  116. "land_sea_mask",
  117. )
  118. @chex.dataclass(frozen=True, eq=True)
  119. class TaskConfig:
  120. """Defines inputs and targets on which a model is trained and/or evaluated."""
  121. input_variables: tuple[str, ...]
  122. # Target variables which the model is expected to predict.
  123. target_variables: tuple[str, ...]
  124. forcing_variables: tuple[str, ...]
  125. pressure_levels: tuple[int, ...]
  126. input_duration: str
  127. TASK = TaskConfig(
  128. input_variables=(
  129. TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
  130. STATIC_VARS),
  131. target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
  132. forcing_variables=FORCING_VARS,
  133. pressure_levels=PRESSURE_LEVELS_ERA5_37,
  134. input_duration="12h",
  135. )
  136. TASK_13 = TaskConfig(
  137. input_variables=(
  138. TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
  139. STATIC_VARS),
  140. target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
  141. forcing_variables=FORCING_VARS,
  142. pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
  143. input_duration="12h",
  144. )
  145. TASK_13_PRECIP_OUT = TaskConfig(
  146. input_variables=(
  147. TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
  148. STATIC_VARS),
  149. target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
  150. forcing_variables=FORCING_VARS,
  151. pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
  152. input_duration="12h",
  153. )
  154. @chex.dataclass(frozen=True, eq=True)
  155. class ModelConfig:
  156. """Defines the architecture of the GraphCast neural network architecture.
  157. Properties:
  158. resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0).
  159. mesh_size: How many refinements to do on the multi-mesh.
  160. gnn_msg_steps: How many Graph Network message passing steps to do.
  161. latent_size: How many latent features to include in the various MLPs.
  162. hidden_layers: How many hidden layers for each MLP.
  163. radius_query_fraction_edge_length: Scalar that will be multiplied by the
  164. length of the longest edge of the finest mesh to define the radius of
  165. connectivity to use in the Grid2Mesh graph. Reasonable values are
  166. between 0.6 and 1. 0.6 reduces the number of grid points feeding into
  167. multiple mesh nodes and therefore reduces edge count and memory use, but
  168. 1 gives better predictions.
  169. mesh2grid_edge_normalization_factor: Allows explicitly controlling edge
  170. normalization for mesh2grid edges. If None, defaults to max edge length.
  171. This supports using pre-trained model weights with a different graph
  172. structure to what it was trained on.
  173. """
  174. resolution: float
  175. mesh_size: int
  176. latent_size: int
  177. gnn_msg_steps: int
  178. hidden_layers: int
  179. radius_query_fraction_edge_length: float
  180. mesh2grid_edge_normalization_factor: Optional[float] = None
  181. @chex.dataclass(frozen=True, eq=True)
  182. class CheckPoint:
  183. params: dict[str, Any]
  184. model_config: ModelConfig
  185. task_config: TaskConfig
  186. description: str
  187. license: str
  188. class GraphCast(predictor_base.Predictor):
  189. """GraphCast Predictor.
  190. The model works on graphs that take into account:
  191. * Mesh nodes: nodes for the vertices of the mesh.
  192. * Grid nodes: nodes for the points of the grid.
  193. * Nodes: When referring to just "nodes", this means the joint set of
  194. both mesh nodes, concatenated with grid nodes.
  195. The model works with 3 graphs:
  196. * Grid2Mesh graph: Graph that contains all nodes. This graph is strictly
  197. bipartite with edges going from grid nodes to mesh nodes using a
  198. fixed radius query. The grid2mesh_gnn will operate in this graph. The output
  199. of this stage will be a latent representation for the mesh nodes, and a
  200. latent representation for the grid nodes.
  201. * Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will
  202. operate in this graph. It will update the latent state of the mesh nodes
  203. only.
  204. * Mesh2Grid graph: Graph that contains all nodes. This graph is strictly
  205. bipartite with edges going from mesh nodes to grid nodes such that each grid
  206. nodes is connected to 3 nodes of the mesh triangular face that contains
  207. the grid points. The mesh2grid_gnn will operate in this graph. It will
  208. process the updated latent state of the mesh nodes, and the latent state
  209. of the grid nodes, to produce the final output for the grid nodes.
  210. The model is built on top of `TypedGraph`s so the different types of nodes and
  211. edges can be stored and treated separately.
  212. """
  213. def __init__(self, model_config: ModelConfig, task_config: TaskConfig):
  214. """Initializes the predictor."""
  215. self._spatial_features_kwargs = dict(
  216. add_node_positions=False,
  217. add_node_latitude=True,
  218. add_node_longitude=True,
  219. add_relative_positions=True,
  220. relative_longitude_local_coordinates=True,
  221. relative_latitude_local_coordinates=True,
  222. )
  223. # Specification of the multimesh.
  224. self._meshes = (
  225. icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
  226. splits=model_config.mesh_size))
  227. # Encoder, which moves data from the grid to the mesh with a single message
  228. # passing step.
  229. self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  230. embed_nodes=True, # Embed raw features of the grid and mesh nodes.
  231. embed_edges=True, # Embed raw features of the grid2mesh edges.
  232. edge_latent_size=dict(grid2mesh=model_config.latent_size),
  233. node_latent_size=dict(
  234. mesh_nodes=model_config.latent_size,
  235. grid_nodes=model_config.latent_size),
  236. mlp_hidden_size=model_config.latent_size,
  237. mlp_num_hidden_layers=model_config.hidden_layers,
  238. num_message_passing_steps=1,
  239. use_layer_norm=True,
  240. include_sent_messages_in_node_update=False,
  241. activation="swish",
  242. f32_aggregation=True,
  243. aggregate_normalization=None,
  244. name="grid2mesh_gnn",
  245. )
  246. # Processor, which performs message passing on the multi-mesh.
  247. self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  248. embed_nodes=False, # Node features already embdded by previous layers.
  249. embed_edges=True, # Embed raw features of the multi-mesh edges.
  250. node_latent_size=dict(mesh_nodes=model_config.latent_size),
  251. edge_latent_size=dict(mesh=model_config.latent_size),
  252. mlp_hidden_size=model_config.latent_size,
  253. mlp_num_hidden_layers=model_config.hidden_layers,
  254. num_message_passing_steps=model_config.gnn_msg_steps,
  255. use_layer_norm=True,
  256. include_sent_messages_in_node_update=False,
  257. activation="swish",
  258. f32_aggregation=False,
  259. name="mesh_gnn",
  260. )
  261. num_surface_vars = len(
  262. set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS))
  263. num_atmospheric_vars = len(
  264. set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS))
  265. num_outputs = (num_surface_vars +
  266. len(task_config.pressure_levels) * num_atmospheric_vars)
  267. # Decoder, which moves data from the mesh back into the grid with a single
  268. # message passing step.
  269. self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  270. # Require a specific node dimensionaly for the grid node outputs.
  271. node_output_size=dict(grid_nodes=num_outputs),
  272. embed_nodes=False, # Node features already embdded by previous layers.
  273. embed_edges=True, # Embed raw features of the mesh2grid edges.
  274. edge_latent_size=dict(mesh2grid=model_config.latent_size),
  275. node_latent_size=dict(
  276. mesh_nodes=model_config.latent_size,
  277. grid_nodes=model_config.latent_size),
  278. mlp_hidden_size=model_config.latent_size,
  279. mlp_num_hidden_layers=model_config.hidden_layers,
  280. num_message_passing_steps=1,
  281. use_layer_norm=True,
  282. include_sent_messages_in_node_update=False,
  283. activation="swish",
  284. f32_aggregation=False,
  285. name="mesh2grid_gnn",
  286. )
  287. # Obtain the query radius in absolute units for the unit-sphere for the
  288. # grid2mesh model, by rescaling the `radius_query_fraction_edge_length`.
  289. self._query_radius = (_get_max_edge_distance(self._finest_mesh)
  290. * model_config.radius_query_fraction_edge_length)
  291. self._mesh2grid_edge_normalization_factor = (
  292. model_config.mesh2grid_edge_normalization_factor
  293. )
  294. # Other initialization is delayed until the first call (`_maybe_init`)
  295. # when we get some sample data so we know the lat/lon values.
  296. self._initialized = False
  297. # A "_init_mesh_properties":
  298. # This one could be initialized at init but we delay it for consistency too.
  299. self._num_mesh_nodes = None # num_mesh_nodes
  300. self._mesh_nodes_lat = None # [num_mesh_nodes]
  301. self._mesh_nodes_lon = None # [num_mesh_nodes]
  302. # A "_init_grid_properties":
  303. self._grid_lat = None # [num_lat_points]
  304. self._grid_lon = None # [num_lon_points]
  305. self._num_grid_nodes = None # num_lat_points * num_lon_points
  306. self._grid_nodes_lat = None # [num_grid_nodes]
  307. self._grid_nodes_lon = None # [num_grid_nodes]
  308. # A "_init_{grid2mesh,processor,mesh2grid}_graph"
  309. self._grid2mesh_graph_structure = None
  310. self._mesh_graph_structure = None
  311. self._mesh2grid_graph_structure = None
  312. @property
  313. def _finest_mesh(self):
  314. return self._meshes[-1]
  315. def __call__(self,
  316. inputs: xarray.Dataset,
  317. targets_template: xarray.Dataset,
  318. forcings: xarray.Dataset,
  319. is_training: bool = False,
  320. ) -> xarray.Dataset:
  321. self._maybe_init(inputs)
  322. # Convert all input data into flat vectors for each of the grid nodes.
  323. # xarray (batch, time, lat, lon, level, multiple vars, forcings)
  324. # -> [num_grid_nodes, batch, num_channels]
  325. grid_node_features = self._inputs_to_grid_node_features(inputs, forcings)
  326. # Transfer data for the grid to the mesh,
  327. # [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size]
  328. (latent_mesh_nodes, latent_grid_nodes
  329. ) = self._run_grid2mesh_gnn(grid_node_features)
  330. # Run message passing in the multimesh.
  331. # [num_mesh_nodes, batch, latent_size]
  332. updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes)
  333. # Transfer data frome the mesh to the grid.
  334. # [num_grid_nodes, batch, output_size]
  335. output_grid_nodes = self._run_mesh2grid_gnn(
  336. updated_latent_mesh_nodes, latent_grid_nodes)
  337. # Conver output flat vectors for the grid nodes to the format of the output.
  338. # [num_grid_nodes, batch, output_size] ->
  339. # xarray (batch, one time step, lat, lon, level, multiple vars)
  340. return self._grid_node_outputs_to_prediction(
  341. output_grid_nodes, targets_template)
  342. def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
  343. self,
  344. inputs: xarray.Dataset,
  345. targets: xarray.Dataset,
  346. forcings: xarray.Dataset,
  347. ) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]:
  348. # Forward pass.
  349. predictions = self(
  350. inputs, targets_template=targets, forcings=forcings, is_training=True)
  351. # Compute loss.
  352. loss = losses.weighted_mse_per_level(
  353. predictions, targets,
  354. per_variable_weights={
  355. # Any variables not specified here are weighted as 1.0.
  356. # A single-level variable, but an important headline variable
  357. # and also one which we have struggled to get good performance
  358. # on at short lead times, so leaving it weighted at 1.0, equal
  359. # to the multi-level variables:
  360. "2m_temperature": 1.0,
  361. # New single-level variables, which we don't weight too highly
  362. # to avoid hurting performance on other variables.
  363. "10m_u_component_of_wind": 0.1,
  364. "10m_v_component_of_wind": 0.1,
  365. "mean_sea_level_pressure": 0.1,
  366. "total_precipitation_6hr": 0.1,
  367. })
  368. return loss, predictions # pytype: disable=bad-return-type # jax-ndarray
  369. def loss( # pytype: disable=signature-mismatch # jax-ndarray
  370. self,
  371. inputs: xarray.Dataset,
  372. targets: xarray.Dataset,
  373. forcings: xarray.Dataset,
  374. ) -> predictor_base.LossAndDiagnostics:
  375. loss, _ = self.loss_and_predictions(inputs, targets, forcings)
  376. return loss # pytype: disable=bad-return-type # jax-ndarray
  377. def _maybe_init(self, sample_inputs: xarray.Dataset):
  378. """Inits everything that has a dependency on the input coordinates."""
  379. if not self._initialized:
  380. self._init_mesh_properties()
  381. self._init_grid_properties(
  382. grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon)
  383. self._grid2mesh_graph_structure = self._init_grid2mesh_graph()
  384. self._mesh_graph_structure = self._init_mesh_graph()
  385. self._mesh2grid_graph_structure = self._init_mesh2grid_graph()
  386. self._initialized = True
  387. def _init_mesh_properties(self):
  388. """Inits static properties that have to do with mesh nodes."""
  389. self._num_mesh_nodes = self._finest_mesh.vertices.shape[0]
  390. mesh_phi, mesh_theta = model_utils.cartesian_to_spherical(
  391. self._finest_mesh.vertices[:, 0],
  392. self._finest_mesh.vertices[:, 1],
  393. self._finest_mesh.vertices[:, 2])
  394. (
  395. mesh_nodes_lat,
  396. mesh_nodes_lon,
  397. ) = model_utils.spherical_to_lat_lon(
  398. phi=mesh_phi, theta=mesh_theta)
  399. # Convert to f32 to ensure the lat/lon features aren't in f64.
  400. self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32)
  401. self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32)
  402. def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray):
  403. """Inits static properties that have to do with grid nodes."""
  404. self._grid_lat = grid_lat.astype(np.float32)
  405. self._grid_lon = grid_lon.astype(np.float32)
  406. # Initialized the counters.
  407. self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0]
  408. # Initialize lat and lon for the grid.
  409. grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat)
  410. self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32)
  411. self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32)
  412. def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph:
  413. """Build Grid2Mesh graph."""
  414. # Create some edges according to distance between mesh and grid nodes.
  415. assert self._grid_lat is not None and self._grid_lon is not None
  416. (grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices(
  417. grid_latitude=self._grid_lat,
  418. grid_longitude=self._grid_lon,
  419. mesh=self._finest_mesh,
  420. radius=self._query_radius)
  421. # Edges sending info from grid to mesh.
  422. senders = grid_indices
  423. receivers = mesh_indices
  424. # Precompute structural node and edge features according to config options.
  425. # Structural features are those that depend on the fixed values of the
  426. # latitude and longitudes of the nodes.
  427. (senders_node_features, receivers_node_features,
  428. edge_features) = model_utils.get_bipartite_graph_spatial_features(
  429. senders_node_lat=self._grid_nodes_lat,
  430. senders_node_lon=self._grid_nodes_lon,
  431. receivers_node_lat=self._mesh_nodes_lat,
  432. receivers_node_lon=self._mesh_nodes_lon,
  433. senders=senders,
  434. receivers=receivers,
  435. edge_normalization_factor=None,
  436. **self._spatial_features_kwargs,
  437. )
  438. n_grid_node = np.array([self._num_grid_nodes])
  439. n_mesh_node = np.array([self._num_mesh_nodes])
  440. n_edge = np.array([mesh_indices.shape[0]])
  441. grid_node_set = typed_graph.NodeSet(
  442. n_node=n_grid_node, features=senders_node_features)
  443. mesh_node_set = typed_graph.NodeSet(
  444. n_node=n_mesh_node, features=receivers_node_features)
  445. edge_set = typed_graph.EdgeSet(
  446. n_edge=n_edge,
  447. indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
  448. features=edge_features)
  449. nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
  450. edges = {
  451. typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")):
  452. edge_set
  453. }
  454. grid2mesh_graph = typed_graph.TypedGraph(
  455. context=typed_graph.Context(n_graph=np.array([1]), features=()),
  456. nodes=nodes,
  457. edges=edges)
  458. return grid2mesh_graph
  459. def _init_mesh_graph(self) -> typed_graph.TypedGraph:
  460. """Build Mesh graph."""
  461. merged_mesh = icosahedral_mesh.merge_meshes(self._meshes)
  462. # Work simply on the mesh edges.
  463. senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces)
  464. # Precompute structural node and edge features according to config options.
  465. # Structural features are those that depend on the fixed values of the
  466. # latitude and longitudes of the nodes.
  467. assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
  468. node_features, edge_features = model_utils.get_graph_spatial_features(
  469. node_lat=self._mesh_nodes_lat,
  470. node_lon=self._mesh_nodes_lon,
  471. senders=senders,
  472. receivers=receivers,
  473. **self._spatial_features_kwargs,
  474. )
  475. n_mesh_node = np.array([self._num_mesh_nodes])
  476. n_edge = np.array([senders.shape[0]])
  477. assert n_mesh_node == len(node_features)
  478. mesh_node_set = typed_graph.NodeSet(
  479. n_node=n_mesh_node, features=node_features)
  480. edge_set = typed_graph.EdgeSet(
  481. n_edge=n_edge,
  482. indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
  483. features=edge_features)
  484. nodes = {"mesh_nodes": mesh_node_set}
  485. edges = {
  486. typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set
  487. }
  488. mesh_graph = typed_graph.TypedGraph(
  489. context=typed_graph.Context(n_graph=np.array([1]), features=()),
  490. nodes=nodes,
  491. edges=edges)
  492. return mesh_graph
  493. def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph:
  494. """Build Mesh2Grid graph."""
  495. # Create some edges according to how the grid nodes are contained by
  496. # mesh triangles.
  497. (grid_indices,
  498. mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(
  499. grid_latitude=self._grid_lat,
  500. grid_longitude=self._grid_lon,
  501. mesh=self._finest_mesh)
  502. # Edges sending info from mesh to grid.
  503. senders = mesh_indices
  504. receivers = grid_indices
  505. # Precompute structural node and edge features according to config options.
  506. assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
  507. (senders_node_features, receivers_node_features,
  508. edge_features) = model_utils.get_bipartite_graph_spatial_features(
  509. senders_node_lat=self._mesh_nodes_lat,
  510. senders_node_lon=self._mesh_nodes_lon,
  511. receivers_node_lat=self._grid_nodes_lat,
  512. receivers_node_lon=self._grid_nodes_lon,
  513. senders=senders,
  514. receivers=receivers,
  515. edge_normalization_factor=self._mesh2grid_edge_normalization_factor,
  516. **self._spatial_features_kwargs,
  517. )
  518. n_grid_node = np.array([self._num_grid_nodes])
  519. n_mesh_node = np.array([self._num_mesh_nodes])
  520. n_edge = np.array([senders.shape[0]])
  521. grid_node_set = typed_graph.NodeSet(
  522. n_node=n_grid_node, features=receivers_node_features)
  523. mesh_node_set = typed_graph.NodeSet(
  524. n_node=n_mesh_node, features=senders_node_features)
  525. edge_set = typed_graph.EdgeSet(
  526. n_edge=n_edge,
  527. indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
  528. features=edge_features)
  529. nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
  530. edges = {
  531. typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")):
  532. edge_set
  533. }
  534. mesh2grid_graph = typed_graph.TypedGraph(
  535. context=typed_graph.Context(n_graph=np.array([1]), features=()),
  536. nodes=nodes,
  537. edges=edges)
  538. return mesh2grid_graph
  539. def _run_grid2mesh_gnn(self, grid_node_features: chex.Array,
  540. ) -> tuple[chex.Array, chex.Array]:
  541. """Runs the grid2mesh_gnn, extracting latent mesh and grid nodes."""
  542. # Concatenate node structural features with input features.
  543. batch_size = grid_node_features.shape[1]
  544. grid2mesh_graph = self._grid2mesh_graph_structure
  545. assert grid2mesh_graph is not None
  546. grid_nodes = grid2mesh_graph.nodes["grid_nodes"]
  547. mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"]
  548. new_grid_nodes = grid_nodes._replace(
  549. features=jnp.concatenate([
  550. grid_node_features,
  551. _add_batch_second_axis(
  552. grid_nodes.features.astype(grid_node_features.dtype),
  553. batch_size)
  554. ],
  555. axis=-1))
  556. # To make sure capacity of the embedded is identical for the grid nodes and
  557. # the mesh nodes, we also append some dummy zero input features for the
  558. # mesh nodes.
  559. dummy_mesh_node_features = jnp.zeros(
  560. (self._num_mesh_nodes,) + grid_node_features.shape[1:],
  561. dtype=grid_node_features.dtype)
  562. new_mesh_nodes = mesh_nodes._replace(
  563. features=jnp.concatenate([
  564. dummy_mesh_node_features,
  565. _add_batch_second_axis(
  566. mesh_nodes.features.astype(dummy_mesh_node_features.dtype),
  567. batch_size)
  568. ],
  569. axis=-1))
  570. # Broadcast edge structural features to the required batch size.
  571. grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh")
  572. edges = grid2mesh_graph.edges[grid2mesh_edges_key]
  573. new_edges = edges._replace(
  574. features=_add_batch_second_axis(
  575. edges.features.astype(dummy_mesh_node_features.dtype), batch_size))
  576. input_graph = self._grid2mesh_graph_structure._replace(
  577. edges={grid2mesh_edges_key: new_edges},
  578. nodes={
  579. "grid_nodes": new_grid_nodes,
  580. "mesh_nodes": new_mesh_nodes
  581. })
  582. # Run the GNN.
  583. grid2mesh_out = self._grid2mesh_gnn(input_graph)
  584. latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features
  585. latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features
  586. return latent_mesh_nodes, latent_grid_nodes
  587. def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array:
  588. """Runs the mesh_gnn, extracting updated latent mesh nodes."""
  589. # Add the structural edge features of this graph. Note we don't need
  590. # to add the structural node features, because these are already part of
  591. # the latent state, via the original Grid2Mesh gnn, however, we need
  592. # the edge ones, because it is the first time we are seeing this particular
  593. # set of edges.
  594. batch_size = latent_mesh_nodes.shape[1]
  595. mesh_graph = self._mesh_graph_structure
  596. assert mesh_graph is not None
  597. mesh_edges_key = mesh_graph.edge_key_by_name("mesh")
  598. edges = mesh_graph.edges[mesh_edges_key]
  599. # We are assuming here that the mesh gnn uses a single set of edge keys
  600. # named "mesh" for the edges and that it uses a single set of nodes named
  601. # "mesh_nodes"
  602. msg = ("The setup currently requires to only have one kind of edge in the"
  603. " mesh GNN.")
  604. assert len(mesh_graph.edges) == 1, msg
  605. new_edges = edges._replace(
  606. features=_add_batch_second_axis(
  607. edges.features.astype(latent_mesh_nodes.dtype), batch_size))
  608. nodes = mesh_graph.nodes["mesh_nodes"]
  609. nodes = nodes._replace(features=latent_mesh_nodes)
  610. input_graph = mesh_graph._replace(
  611. edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes})
  612. # Run the GNN.
  613. return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features
  614. def _run_mesh2grid_gnn(self,
  615. updated_latent_mesh_nodes: chex.Array,
  616. latent_grid_nodes: chex.Array,
  617. ) -> chex.Array:
  618. """Runs the mesh2grid_gnn, extracting the output grid nodes."""
  619. # Add the structural edge features of this graph. Note we don't need
  620. # to add the structural node features, because these are already part of
  621. # the latent state, via the original Grid2Mesh gnn, however, we need
  622. # the edge ones, because it is the first time we are seeing this particular
  623. # set of edges.
  624. batch_size = updated_latent_mesh_nodes.shape[1]
  625. mesh2grid_graph = self._mesh2grid_graph_structure
  626. assert mesh2grid_graph is not None
  627. mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"]
  628. grid_nodes = mesh2grid_graph.nodes["grid_nodes"]
  629. new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes)
  630. new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes)
  631. mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid")
  632. edges = mesh2grid_graph.edges[mesh2grid_key]
  633. new_edges = edges._replace(
  634. features=_add_batch_second_axis(
  635. edges.features.astype(latent_grid_nodes.dtype), batch_size))
  636. input_graph = mesh2grid_graph._replace(
  637. edges={mesh2grid_key: new_edges},
  638. nodes={
  639. "mesh_nodes": new_mesh_nodes,
  640. "grid_nodes": new_grid_nodes
  641. })
  642. # Run the GNN.
  643. output_graph = self._mesh2grid_gnn(input_graph)
  644. output_grid_nodes = output_graph.nodes["grid_nodes"].features
  645. return output_grid_nodes
  646. def _inputs_to_grid_node_features(
  647. self,
  648. inputs: xarray.Dataset,
  649. forcings: xarray.Dataset,
  650. ) -> chex.Array:
  651. """xarrays -> [num_grid_nodes, batch, num_channels]."""
  652. # xarray `Dataset` (batch, time, lat, lon, level, multiple vars)
  653. # to xarray `DataArray` (batch, lat, lon, channels)
  654. stacked_inputs = model_utils.dataset_to_stacked(inputs)
  655. stacked_forcings = model_utils.dataset_to_stacked(forcings)
  656. stacked_inputs = xarray.concat(
  657. [stacked_inputs, stacked_forcings], dim="channels")
  658. # xarray `DataArray` (batch, lat, lon, channels)
  659. # to single numpy array with shape [lat_lon_node, batch, channels]
  660. grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(
  661. stacked_inputs)
  662. return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
  663. (-1,) + grid_xarray_lat_lon_leading.data.shape[2:])
  664. def _grid_node_outputs_to_prediction(
  665. self,
  666. grid_node_outputs: chex.Array,
  667. targets_template: xarray.Dataset,
  668. ) -> xarray.Dataset:
  669. """[num_grid_nodes, batch, num_outputs] -> xarray."""
  670. # numpy array with shape [lat_lon_node, batch, channels]
  671. # to xarray `DataArray` (batch, lat, lon, channels)
  672. assert self._grid_lat is not None and self._grid_lon is not None
  673. grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0])
  674. grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
  675. grid_shape + grid_node_outputs.shape[1:])
  676. dims = ("lat", "lon", "batch", "channels")
  677. grid_xarray_lat_lon_leading = xarray_jax.DataArray(
  678. data=grid_outputs_lat_lon_leading,
  679. dims=dims)
  680. grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading)
  681. # xarray `DataArray` (batch, lat, lon, channels)
  682. # to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars)
  683. return model_utils.stacked_to_dataset(
  684. grid_xarray.variable, targets_template)
  685. def _add_batch_second_axis(data, batch_size):
  686. # data [leading_dim, trailing_dim]
  687. assert data.ndim == 2
  688. ones = jnp.ones([batch_size, 1], dtype=data.dtype)
  689. return data[:, None] * ones # [leading_dim, batch, trailing_dim]
  690. def _get_max_edge_distance(mesh):
  691. senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces)
  692. edge_distances = np.linalg.norm(
  693. mesh.vertices[senders] - mesh.vertices[receivers], axis=-1)
  694. return edge_distances.max()