typed_graph_net.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """A library of typed Graph Neural Networks."""
  15. from typing import Callable, Mapping, Optional, Union
  16. from graphcast import typed_graph
  17. import jax.numpy as jnp
  18. import jax.tree_util as tree
  19. import jraph
  20. # All features will be an ArrayTree.
  21. NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = (
  22. jraph.ArrayTree)
  23. # Signature:
  24. # (node features, outgoing edge features, incoming edge features,
  25. # globals) -> updated node features
  26. GNUpdateNodeFn = Callable[
  27. [NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures],
  28. Globals],
  29. NodeFeatures]
  30. GNUpdateGlobalFn = Callable[
  31. [Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals],
  32. Globals]
  33. def GraphNetwork( # pylint: disable=invalid-name
  34. update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn],
  35. update_node_fn: Mapping[str, GNUpdateNodeFn],
  36. update_global_fn: Optional[GNUpdateGlobalFn] = None,
  37. aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
  38. .segment_sum,
  39. aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph
  40. .segment_sum,
  41. aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph
  42. .segment_sum,
  43. ):
  44. """Returns a method that applies a configured GraphNetwork.
  45. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
  46. extended to Typed Graphs with multiple edge sets and node sets and extended to
  47. allow aggregating not only edges received by the nodes, but also edges sent by
  48. the nodes.
  49. Example usage::
  50. gn = GraphNetwork(update_edge_function,
  51. update_node_function, **kwargs)
  52. # Conduct multiple rounds of message passing with the same parameters:
  53. for _ in range(num_message_passing_steps):
  54. graph = gn(graph)
  55. Args:
  56. update_edge_fn: mapping of functions used to update a subset of the edge
  57. types, indexed by edge type name.
  58. update_node_fn: mapping of functions used to update a subset of the node
  59. types, indexed by node type name.
  60. update_global_fn: function used to update the globals or None to deactivate
  61. globals updates.
  62. aggregate_edges_for_nodes_fn: function used to aggregate messages to each
  63. node.
  64. aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
  65. globals.
  66. aggregate_edges_for_globals_fn: function used to aggregate the edges for the
  67. globals.
  68. Returns:
  69. A method that applies the configured GraphNetwork.
  70. """
  71. def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
  72. """Applies a configured GraphNetwork to a graph.
  73. This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
  74. extended to Typed Graphs with multiple edge sets and node sets and extended
  75. to allow aggregating not only edges received by the nodes, but also edges
  76. sent by the nodes.
  77. Args:
  78. graph: a `TypedGraph` containing the graph.
  79. Returns:
  80. Updated `TypedGraph`.
  81. """
  82. updated_graph = graph
  83. # Edge update.
  84. updated_edges = dict(updated_graph.edges)
  85. for edge_set_name, edge_fn in update_edge_fn.items():
  86. edge_set_key = graph.edge_key_by_name(edge_set_name)
  87. updated_edges[edge_set_key] = _edge_update(
  88. updated_graph, edge_fn, edge_set_key)
  89. updated_graph = updated_graph._replace(edges=updated_edges)
  90. # Node update.
  91. updated_nodes = dict(updated_graph.nodes)
  92. for node_set_key, node_fn in update_node_fn.items():
  93. updated_nodes[node_set_key] = _node_update(
  94. updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
  95. updated_graph = updated_graph._replace(nodes=updated_nodes)
  96. # Global update.
  97. if update_global_fn:
  98. updated_context = _global_update(
  99. updated_graph, update_global_fn,
  100. aggregate_edges_for_globals_fn,
  101. aggregate_nodes_for_globals_fn)
  102. updated_graph = updated_graph._replace(context=updated_context)
  103. return updated_graph
  104. return _apply_graph_net
  105. def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name
  106. """Updates an edge set of a given key."""
  107. sender_nodes = graph.nodes[edge_set_key.node_sets[0]]
  108. receiver_nodes = graph.nodes[edge_set_key.node_sets[1]]
  109. edge_set = graph.edges[edge_set_key]
  110. senders = edge_set.indices.senders # pytype: disable=attribute-error
  111. receivers = edge_set.indices.receivers # pytype: disable=attribute-error
  112. sent_attributes = tree.tree_map(
  113. lambda n: n[senders], sender_nodes.features)
  114. received_attributes = tree.tree_map(
  115. lambda n: n[receivers], receiver_nodes.features)
  116. n_edge = edge_set.n_edge
  117. sum_n_edge = senders.shape[0]
  118. global_features = tree.tree_map(
  119. lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
  120. graph.context.features)
  121. new_features = edge_fn(
  122. edge_set.features, sent_attributes, received_attributes,
  123. global_features)
  124. return edge_set._replace(features=new_features)
  125. def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name
  126. """Updates an edge set of a given key."""
  127. node_set = graph.nodes[node_set_key]
  128. sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
  129. sent_features = {}
  130. for edge_set_key, edge_set in graph.edges.items():
  131. sender_node_set_key = edge_set_key.node_sets[0]
  132. if sender_node_set_key == node_set_key:
  133. assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
  134. senders = edge_set.indices.senders
  135. sent_features[edge_set_key.name] = tree.tree_map(
  136. lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
  137. received_features = {}
  138. for edge_set_key, edge_set in graph.edges.items():
  139. receiver_node_set_key = edge_set_key.node_sets[1]
  140. if receiver_node_set_key == node_set_key:
  141. assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
  142. receivers = edge_set.indices.receivers
  143. received_features[edge_set_key.name] = tree.tree_map(
  144. lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
  145. n_node = node_set.n_node
  146. global_features = tree.tree_map(
  147. lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
  148. graph.context.features)
  149. new_features = node_fn(
  150. node_set.features, sent_features, received_features, global_features)
  151. return node_set._replace(features=new_features)
  152. def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name
  153. """Updates an edge set of a given key."""
  154. n_graph = graph.context.n_graph.shape[0]
  155. graph_idx = jnp.arange(n_graph)
  156. edge_features = {}
  157. for edge_set_key, edge_set in graph.edges.items():
  158. assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
  159. sum_n_edge = edge_set.indices.senders.shape[0]
  160. edge_gr_idx = jnp.repeat(
  161. graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge)
  162. edge_features[edge_set_key.name] = tree.tree_map(
  163. lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
  164. edge_set.features)
  165. node_features = {}
  166. for node_set_key, node_set in graph.nodes.items():
  167. sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
  168. node_gr_idx = jnp.repeat(
  169. graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node)
  170. node_features[node_set_key] = tree.tree_map(
  171. lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
  172. node_set.features)
  173. new_features = global_fn(node_features, edge_features, graph.context.features)
  174. return graph.context._replace(features=new_features)
  175. InteractionUpdateNodeFn = Callable[
  176. [jraph.NodeFeatures,
  177. Mapping[str, SenderFeatures],
  178. Mapping[str, ReceiverFeatures]],
  179. jraph.NodeFeatures]
  180. InteractionUpdateNodeFnNoSentEdges = Callable[
  181. [jraph.NodeFeatures,
  182. Mapping[str, ReceiverFeatures]],
  183. jraph.NodeFeatures]
  184. def InteractionNetwork( # pylint: disable=invalid-name
  185. update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn],
  186. update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn,
  187. InteractionUpdateNodeFnNoSentEdges]],
  188. aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
  189. .segment_sum,
  190. include_sent_messages_in_node_update: bool = False):
  191. """Returns a method that applies a configured InteractionNetwork.
  192. An interaction network computes interactions on the edges based on the
  193. previous edges features, and on the features of the nodes sending into those
  194. edges. It then updates the nodes based on the incoming updated edges.
  195. See https://arxiv.org/abs/1612.00222 for more details.
  196. This implementation extends the behavior to `TypedGraphs` adding an option
  197. to include edge features for which a node is a sender in the arguments to
  198. the node update function.
  199. Args:
  200. update_edge_fn: mapping of functions used to update a subset of the edge
  201. types, indexed by edge type name.
  202. update_node_fn: mapping of functions used to update a subset of the node
  203. types, indexed by node type name.
  204. aggregate_edges_for_nodes_fn: function used to aggregate messages to each
  205. node.
  206. include_sent_messages_in_node_update: pass edge features for which a node is
  207. a sender to the node update function.
  208. """
  209. # An InteractionNetwork is a GraphNetwork without globals features,
  210. # so we implement the InteractionNetwork as a configured GraphNetwork.
  211. # An InteractionNetwork edge function does not have global feature inputs,
  212. # so we filter the passed global argument in the GraphNetwork.
  213. wrapped_update_edge_fn = tree.tree_map(
  214. lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn)
  215. # Similarly, we wrap the update_node_fn to ensure only the expected
  216. # arguments are passed to the Interaction net.
  217. if include_sent_messages_in_node_update:
  218. wrapped_update_node_fn = tree.tree_map(
  219. lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn)
  220. else:
  221. wrapped_update_node_fn = tree.tree_map(
  222. lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn)
  223. return GraphNetwork(
  224. update_edge_fn=wrapped_update_edge_fn,
  225. update_node_fn=wrapped_update_node_fn,
  226. aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
  227. def GraphMapFeatures( # pylint: disable=invalid-name
  228. embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None,
  229. embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None,
  230. embed_global_fn: Optional[jraph.EmbedGlobalFn] = None):
  231. """Returns function which embeds the components of a graph independently.
  232. Args:
  233. embed_edge_fn: mapping of functions used to embed each edge type,
  234. indexed by edge type name.
  235. embed_node_fn: mapping of functions used to embed each node type,
  236. indexed by node type name.
  237. embed_global_fn: function used to embed the globals.
  238. """
  239. def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
  240. updated_edges = dict(graph.edges)
  241. if embed_edge_fn:
  242. for edge_set_name, embed_fn in embed_edge_fn.items():
  243. edge_set_key = graph.edge_key_by_name(edge_set_name)
  244. edge_set = graph.edges[edge_set_key]
  245. updated_edges[edge_set_key] = edge_set._replace(
  246. features=embed_fn(edge_set.features))
  247. updated_nodes = dict(graph.nodes)
  248. if embed_node_fn:
  249. for node_set_key, embed_fn in embed_node_fn.items():
  250. node_set = graph.nodes[node_set_key]
  251. updated_nodes[node_set_key] = node_set._replace(
  252. features=embed_fn(node_set.features))
  253. updated_context = graph.context
  254. if embed_global_fn:
  255. updated_context = updated_context._replace(
  256. features=embed_global_fn(updated_context.features))
  257. return graph._replace(edges=updated_edges, nodes=updated_nodes,
  258. context=updated_context)
  259. return _embed