typed_graph.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2023 DeepMind Technologies Limited.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS-IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Data-structure for storing graphs with typed edges and nodes."""
  15. from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
  16. ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
  17. ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
  18. _T = TypeVar('_T')
  19. # All tensors have a "flat_batch_axis", which is similar to the leading
  20. # axes of graph_tuples:
  21. # * In the case of nodes this is simply a shared node and flat batch axis, with
  22. # size corresponding to the total number of nodes in the flattened batch.
  23. # * In the case of edges this is simply a shared edge and flat batch axis, with
  24. # size corresponding to the total number of edges in the flattened batch.
  25. # * In the case of globals this is simply the number of graphs in the flattened
  26. # batch.
  27. # All shapes may also have any additional leading shape "batch_shape".
  28. # Options for building batches are:
  29. # * Use a provided "flatten" method that takes a leading `batch_shape` and
  30. # it into the flat_batch_axis (this will be useful when using `tf.Dataset`
  31. # which supports batching into RaggedTensors, with leading batch shape even
  32. # if graphs have different numbers of nodes and edges), so the RaggedBatches
  33. # can then be converted into something without ragged dimensions that jax can
  34. # use.
  35. # * Directly build a "flat batch" using a provided function for batching a list
  36. # of graphs (how it is done in `jraph`).
  37. class NodeSet(NamedTuple):
  38. """Represents a set of nodes."""
  39. n_node: ArrayLike # [num_flat_graphs]
  40. features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
  41. class EdgesIndices(NamedTuple):
  42. """Represents indices to nodes adjacent to the edges."""
  43. senders: ArrayLike # [num_flat_edges]
  44. receivers: ArrayLike # [num_flat_edges]
  45. class EdgeSet(NamedTuple):
  46. """Represents a set of edges."""
  47. n_edge: ArrayLike # [num_flat_graphs]
  48. indices: EdgesIndices
  49. features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
  50. class Context(NamedTuple):
  51. # `n_graph` always contains ones but it is useful to query the leading shape
  52. # in case of graphs without any nodes or edges sets.
  53. n_graph: ArrayLike # [num_flat_graphs]
  54. features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
  55. class EdgeSetKey(NamedTuple):
  56. name: str # Name of the EdgeSet.
  57. # Sender node set name and receiver node set name connected by the edge set.
  58. node_sets: Tuple[str, str]
  59. class TypedGraph(NamedTuple):
  60. """A graph with typed nodes and edges.
  61. A typed graph is made of a context, multiple sets of nodes and multiple
  62. sets of edges connecting those nodes (as indicated by the EdgeSetKey).
  63. """
  64. context: Context
  65. nodes: Mapping[str, NodeSet]
  66. edges: Mapping[EdgeSetKey, EdgeSet]
  67. def edge_key_by_name(self, name: str) -> EdgeSetKey:
  68. found_key = [k for k in self.edges.keys() if k.name == name]
  69. if len(found_key) != 1:
  70. raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
  71. name, ', '.join(x.name for x in self.edges.keys())))
  72. return found_key[0]
  73. def edge_by_name(self, name: str) -> EdgeSet:
  74. return self.edges[self.edge_key_by_name(name)]