diff --git a/tensorflow_gnn/models/gcn/gcn_conv.py b/tensorflow_gnn/models/gcn/gcn_conv.py index 3a7db542..527f2c0f 100644 --- a/tensorflow_gnn/models/gcn/gcn_conv.py +++ b/tensorflow_gnn/models/gcn/gcn_conv.py @@ -47,10 +47,14 @@ class GCNConv(tf.keras.layers.Layer): paper doesn't use a bias, but this defaults to True to be consistent with Keras and other implementations. add_self_loops: Whether to compute the result as if a loop from each node - to itself had been added to the edge set. + to itself had been added to the edge set. The self-loop edges are added + with an edge weight of one. normalize: Whether to normalize the node features by in-degree. kernel_initializer: initializer of type tf.keras.initializers . node_feature: Name of the node feature to transform. + edge_weight_feature_name: Can be set to the name of a feature on the edge + set that supplies a scalar weight for each edge. The GCN computation uses + it as the edge's entry in the adjacency matrix, instead of the default 1. **kwargs: additional arguments for the Layer. Call arguments: @@ -99,6 +103,7 @@ def __init__(self, kernel_initializer: bool = None, node_feature: Optional[str] = tfgnn.HIDDEN_STATE, kernel_regularizer: Optional[_RegularizerType] = None, + edge_weight_feature_name: Optional[tfgnn.FieldName] = None, **kwargs): super().__init__(**kwargs) @@ -113,6 +118,7 @@ def __init__(self, self._node_feature = node_feature self._receiver = receiver_tag self._sender = tfgnn.reverse_tag(receiver_tag) + self._edge_weight_feature_name = edge_weight_feature_name def get_config(self): filter_config = self._filter.get_config() @@ -126,6 +132,7 @@ def get_config(self): use_bias=filter_config['use_bias'], kernel_initializer=filter_config['kernel_initializer'], kernel_regularizer=filter_config['kernel_regularizer'], + edge_weight_feature_name=self._edge_weight_feature_name, **super().get_config()) def call( @@ -148,13 +155,29 @@ def call( if self._normalize: edge_set = graph.edge_sets[edge_set_name] - edge_ones = tf.ones([edge_set.total_size, 1]) - in_degree = tf.squeeze(tfgnn.pool_edges_to_node( - graph, - edge_set_name, - self._receiver, - 'sum', - feature_value=edge_ones), -1) + if self._edge_weight_feature_name is not None: + try: + edge_weights = graph.edge_sets[edge_set_name][ + self._edge_weight_feature_name] + except KeyError as e: + raise ValueError(f'{self._edge_weight_feature_name} is not given ' + f'for edge set {edge_set_name} ') from e + if edge_weights.shape.rank != 1: + # GraphTensor guarantees it is not None. + raise ValueError('Expecting vector for edge weights. Received rank ' + f'{tf.rank(edge_weights)}.') + edge_weights = tf.expand_dims( + edge_weights, axis=1) # Align with state feature. + else: + edge_weights = tf.ones([edge_set.total_size, 1]) + + in_degree = tf.squeeze( + tfgnn.pool_edges_to_node( + graph, + edge_set_name, + self._receiver, + 'sum', + feature_value=edge_weights), -1) # Degree matrix is the sum of rows of adjacency # Adding self-loops adds an identity matrix to the adjacency # This adds 1 to each diagonal element of the degree matrix @@ -176,6 +199,8 @@ def call( self._sender, feature_value=normalized_values, ) + if self._edge_weight_feature_name is not None: + source_bcast = source_bcast * edge_weights pooled = tfgnn.pool_edges_to_node( graph, edge_set_name, self._receiver, 'sum', feature_value=source_bcast) diff --git a/tensorflow_gnn/models/gcn/gcn_conv_test.py b/tensorflow_gnn/models/gcn/gcn_conv_test.py index 732565b2..2f95d4cc 100644 --- a/tensorflow_gnn/models/gcn/gcn_conv_test.py +++ b/tensorflow_gnn/models/gcn/gcn_conv_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Tests for gcn_conv.""" import enum +import math import os from absl.testing import parameterized @@ -193,9 +194,179 @@ def test_gcnconv_heterogeneous(self): lambda: conv(graph, edge_set_name=tfgnn.EDGES)) @parameterized.named_parameters( - ('', ReloadModel.SKIP), - ('Restored', ReloadModel.SAVED_MODEL), - ('RestoredKeras', ReloadModel.KERAS)) + dict( + testcase_name='noSelfLoops_noBias', + use_bias=False, + add_self_loops=False, + ),) + def test_gcnconv_with_edge_weights_ones(self, use_bias, add_self_loops): + """Tests that gcn_conv returns the correct values with edge weights.""" + graph = tfgnn.GraphTensor.from_pieces( + node_sets={ + tfgnn.NODES: + tfgnn.NodeSet.from_fields( + sizes=[2], + features={ + tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]]) + }, + ) + }, + edge_sets={ + tfgnn.EDGES: + tfgnn.EdgeSet.from_fields( + sizes=[2], + features={ + 'weights': tf.constant([1.0, 1.0], dtype=tf.float32) + }, + adjacency=tfgnn.Adjacency.from_indices( + source=(tfgnn.NODES, tf.constant([0, 1], + dtype=tf.int64)), + target=(tfgnn.NODES, tf.constant([1, 0], + dtype=tf.int64)), + )) + }) + conv_with_edge_weights = gcn_conv.GCNConv( + units=2, + use_bias=use_bias, + add_self_loops=add_self_loops, + kernel_initializer=tf.keras.initializers.Constant(tf.eye(2)), + edge_weight_feature_name='weights') + conv_without_edge_weights = gcn_conv.GCNConv( + units=2, + use_bias=use_bias, + add_self_loops=add_self_loops, + kernel_initializer=tf.keras.initializers.Constant(tf.eye(2))) + self.assertAllClose( + conv_with_edge_weights(graph, edge_set_name=tfgnn.EDGES), + conv_without_edge_weights(graph, edge_set_name=tfgnn.EDGES), + rtol=1e-06, + atol=1e-06) + + @parameterized.named_parameters( + dict( + testcase_name='noSelfLoops_noBias', + use_bias=False, + add_self_loops=False, + expected_result=tf.constant([[0., 4. / (2. * 3.)], + [9. / (2. * 3.), 0.]])), + dict( + testcase_name='selfLoops_noBias', + use_bias=False, + add_self_loops=True, + expected_result=tf.constant( + [[ + 1. / (math.sqrt(5.) * math.sqrt(5.)), + 4. / (math.sqrt(10.) * math.sqrt(5.)) + ], + [ + 9. / (math.sqrt(10.) * math.sqrt(5.)), + 1. / (math.sqrt(10.) * math.sqrt(10.)) + ]])), + ) + def test_gcnconv_with_edge_weights(self, use_bias, add_self_loops, + expected_result): + """Tests that gcn_conv returns the correct values with edge weights.""" + graph = tfgnn.GraphTensor.from_pieces( + node_sets={ + tfgnn.NODES: + tfgnn.NodeSet.from_fields( + sizes=[2], + features={ + tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]]) + }, + ) + }, + edge_sets={ + tfgnn.EDGES: + tfgnn.EdgeSet.from_fields( + sizes=[2], + features={ + 'weights': tf.constant([9.0, 4.0], dtype=tf.float32) + }, + adjacency=tfgnn.Adjacency.from_indices( + source=(tfgnn.NODES, tf.constant([0, 1], + dtype=tf.int64)), + target=(tfgnn.NODES, tf.constant([1, 0], + dtype=tf.int64)), + )) + }) + conv = gcn_conv.GCNConv( + units=2, + use_bias=use_bias, + add_self_loops=add_self_loops, + kernel_initializer=tf.keras.initializers.Constant(tf.eye(2)), + edge_weight_feature_name='weights') + + self.assertAllClose( + expected_result, + conv(graph, edge_set_name=tfgnn.EDGES), + rtol=1e-06, + atol=1e-06) + + def test_gcnconv_with_edge_weights_missing(self): + """Tests that missing given edge weights feature name in the graph tensor throws an error.""" + graph = tfgnn.GraphTensor.from_pieces( + node_sets={ + tfgnn.NODES: + tfgnn.NodeSet.from_fields( + sizes=[2], + features={ + tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]]) + }, + ) + }, + edge_sets={ + tfgnn.EDGES: + tfgnn.EdgeSet.from_fields( + sizes=[2], + adjacency=tfgnn.Adjacency.from_indices( + source=(tfgnn.NODES, tf.constant([0, 1], + dtype=tf.int64)), + target=(tfgnn.NODES, tf.constant([1, 0], + dtype=tf.int64)), + )) + }) + + conv = gcn_conv.GCNConv(units=2, edge_weight_feature_name='weights') + self.assertRaisesRegex(ValueError, + 'weights is not given for edge set edges ', + lambda: conv(graph, edge_set_name=tfgnn.EDGES)) + + def test_gcnconv_with_bad_shaped_edge_weights(self): + """Tests that given edge weights feature with bad shape throws an error.""" + graph = tfgnn.GraphTensor.from_pieces( + node_sets={ + tfgnn.NODES: + tfgnn.NodeSet.from_fields( + sizes=[2], + features={ + tfgnn.HIDDEN_STATE: tf.constant([[1., 0.], [0., 1.]]) + }, + ) + }, + edge_sets={ + tfgnn.EDGES: + tfgnn.EdgeSet.from_fields( + sizes=[2], + features={ + 'weights': tf.constant([[9.0], [4.0]], dtype=tf.float32) + }, + adjacency=tfgnn.Adjacency.from_indices( + source=(tfgnn.NODES, tf.constant([0, 1], + dtype=tf.int64)), + target=(tfgnn.NODES, tf.constant([1, 0], + dtype=tf.int64)), + )) + }) + + conv = gcn_conv.GCNConv(units=2, edge_weight_feature_name='weights') + self.assertRaisesRegex( + ValueError, 'Expecting vector for edge weights. Received rank 2.', + lambda: conv(graph, edge_set_name=tfgnn.EDGES)) + + @parameterized.named_parameters(('', ReloadModel.SKIP), + ('Restored', ReloadModel.SAVED_MODEL), + ('RestoredKeras', ReloadModel.KERAS)) def test_full_model(self, reload_model): """Tests GCNGraphUpdate in a full Model (incl. saving) with edge input.""" gt_input = tfgnn.GraphTensor.from_pieces( diff --git a/tensorflow_gnn/runner/BUILD b/tensorflow_gnn/runner/BUILD index 2e9bc8dc..36d41100 100644 --- a/tensorflow_gnn/runner/BUILD +++ b/tensorflow_gnn/runner/BUILD @@ -47,6 +47,7 @@ pytype_strict_library( "//tensorflow_gnn", "//tensorflow_gnn/runner/utils:model", "//tensorflow_gnn/runner/utils:model_export", + "//tensorflow_gnn/runner/utils:parsing", ], ) diff --git a/tensorflow_gnn/runner/orchestration.py b/tensorflow_gnn/runner/orchestration.py index 4ab8d481..eba14077 100644 --- a/tensorflow_gnn/runner/orchestration.py +++ b/tensorflow_gnn/runner/orchestration.py @@ -18,13 +18,14 @@ import functools import itertools import os -from typing import Any, Callable, Optional, Sequence, Tuple, Union +from typing import Callable, Optional, Sequence, Tuple, Union import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.runner import interfaces from tensorflow_gnn.runner.utils import model as model_utils from tensorflow_gnn.runner.utils import model_export +from tensorflow_gnn.runner.utils import parsing as parsing_utils GraphTensor = tfgnn.GraphTensor GraphTensorAndField = Tuple[GraphTensor, tfgnn.Field] @@ -172,33 +173,6 @@ def _make_preprocessing_model( num_parallel_calls=tf.data.experimental.AUTOTUNE) -def _maybe_parse(gtspec: GraphTensorSpec) -> Callable[[Any], GraphTensor]: - """Returns a callable to parse (or assert the spec of) dataset elements.""" - parse_example = tfgnn.keras.layers.ParseExample(gtspec) - # Relax the spec for potential comparisons. - relaxed = gtspec.relax(num_components=True, num_nodes=True, num_edges=True) - def fn(element): - # Use `getattr` to account for types without a `dtype` (e.g. `GraphTensor`). - if getattr(element, "dtype", None) == tf.string: - gt = parse_example(element) - elif not tfgnn.is_graph_tensor(element): - raise ValueError(f"Expected `GraphTensor` (got {element})") - else: - # Access protected member `_unbatch` to avoid any potential - # `merge_batch_to_components` work. - actual = element.spec._unbatch().relax( # pylint: disable=protected-access - num_components=True, - num_nodes=True, - num_edges=True) - if actual != relaxed: - raise ValueError( - f"Expected a `GraphTensor` of spec {relaxed} (got {actual})") - else: - gt = element - return gt - return fn - - def _per_replica_batch_size(global_batch_size: int, num_replicas: int) -> int: if global_batch_size % num_replicas != 0: raise ValueError(f"The `global_batch_size` {global_batch_size} is not " @@ -287,7 +261,7 @@ def apply_fn(ds, *, filter_fn: Optional[Callable[..., bool]] = None, size_constraints: Optional[SizeConstraints] = None): - ds = _map_over_dataset(ds, _maybe_parse(gtspec)) + ds = parsing_utils.maybe_parse_graph_tensor_dataset(ds, gtspec) if filter_fn is not None: ds = ds.filter(filter_fn) if size_constraints is not None: diff --git a/tensorflow_gnn/runner/orchestration_test.py b/tensorflow_gnn/runner/orchestration_test.py index 5cb40c60..df516f99 100644 --- a/tensorflow_gnn/runner/orchestration_test.py +++ b/tensorflow_gnn/runner/orchestration_test.py @@ -181,60 +181,6 @@ def node_sets_fn(node_set, node_set_name): # (len(examples), len(_LABELS)) self.assertAllEqual(actual.shape, (examples.shape[0], len(_LABELS))) - @parameterized.named_parameters([ - dict( - testcase_name="NoGraphTensor", - ds_provider=DatasetProvider(8191.), - expected_error=r"Expected `GraphTensor` \(got .*\)", - ), - dict( - testcase_name="MismatchedGraphTensorSpec", - ds_provider=DatasetProvider(random_graph_tensor(_SCHEMA_B)), - expected_error=r"Expected a `GraphTensor` of spec .*\ \(got .*\)", - ), - ]) - def test_run_fails( - self, - ds_provider: orchestration.DatasetProvider, - expected_error: str): - def extract_labels(gt): - return gt, gt.context["label"] - - def node_sets_fn(node_set, node_set_name): - del node_set_name - return node_set["features"] - - def model_fn(gtspec): - inputs = tf.keras.Input(type_spec=gtspec) - graph = tfgnn.keras.layers.MapFeatures(node_sets_fn=node_sets_fn)(inputs) - return tf.keras.Model(inputs, graph) - - task = classification.RootNodeMulticlassClassification( - node_set_name="node", - num_classes=8191) - - model_dir = self.create_tempdir() - - trainer = keras_fit.KerasTrainer( - strategy=tf.distribute.get_strategy(), - model_dir=model_dir, - steps_per_epoch=1, - validation_steps=1, - restore_best_weights=False) - - with self.assertRaisesRegex(ValueError, expected_error): - orchestration.run( - train_ds_provider=ds_provider, - model_fn=model_fn, - optimizer_fn=tf.keras.optimizers.Adam, - epochs=1, - trainer=trainer, - task=task, - gtspec=graph_spec(), - global_batch_size=2, - feature_processors=(extract_labels,), - valid_ds_provider=ds_provider) - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow_gnn/runner/utils/BUILD b/tensorflow_gnn/runner/utils/BUILD index 298061a6..dfba4359 100644 --- a/tensorflow_gnn/runner/utils/BUILD +++ b/tensorflow_gnn/runner/utils/BUILD @@ -66,12 +66,24 @@ pytype_strict_library( srcs_version = "PY3", visibility = ["//tensorflow_gnn/runner:__pkg__"], deps = [ + ":parsing", "//:expect_tensorflow_installed", "//tensorflow_gnn", "//tensorflow_gnn/runner:interfaces", ], ) +pytype_strict_library( + name = "parsing", + srcs = ["parsing.py"], + srcs_version = "PY3", + visibility = ["//tensorflow_gnn/runner:__pkg__"], + deps = [ + "//:expect_tensorflow_installed", + "//tensorflow_gnn", + ], +) + pytype_strict_library( name = "strategies", srcs = ["strategies.py"], @@ -115,12 +127,13 @@ py_strict_test( ], ) -pytype_strict_library( - name = "padding_test", - srcs = ["padding_test.py"], +py_strict_test( + name = "parsing_test", + srcs = ["parsing_test.py"], srcs_version = "PY3", deps = [ - ":padding", + ":parsing", + "//:expect_absl_installed", "//:expect_tensorflow_installed", "//tensorflow_gnn", ], diff --git a/tensorflow_gnn/runner/utils/padding.py b/tensorflow_gnn/runner/utils/padding.py index 629a33af..5d9046bf 100644 --- a/tensorflow_gnn/runner/utils/padding.py +++ b/tensorflow_gnn/runner/utils/padding.py @@ -20,19 +20,11 @@ import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.runner import interfaces +from tensorflow_gnn.runner.utils import parsing as parsing_utils SizeConstraints = tfgnn.SizeConstraints -def _parse_dataset( - gtspec: tfgnn.GraphTensorSpec, - dataset: tf.data.Dataset) -> tf.data.Dataset: - return dataset.map( - functools.partial(tfgnn.parse_single_example, gtspec), - deterministic=False, - num_parallel_calls=tf.data.experimental.AUTOTUNE) - - def one_node_per_component(gtspec: tfgnn.GraphTensorSpec) -> Mapping[str, int]: return {k: 1 for k in gtspec.node_sets_spec.keys()} @@ -89,11 +81,9 @@ def get_filter_fn(self, total_sizes=size_constraints) def get_size_constraints(self, target_batch_size: int) -> SizeConstraints: - dataset = _parse_dataset( - self._gtspec, - self._dataset_provider.get_dataset(tf.distribute.InputContext())) + dataset = self._dataset_provider.get_dataset(tf.distribute.InputContext()) return tfgnn.learn_fit_or_skip_size_constraints( # pytype: disable=bad-return-type - dataset, + parsing_utils.maybe_parse_graph_tensor_dataset(dataset, self._gtspec), target_batch_size, min_nodes_per_component=self._min_nodes_per_component, sample_size=self._fit_or_skip_sample_sample_size, @@ -111,10 +101,8 @@ def get_filter_fn(self, return lambda *args, **kwargs: True def get_size_constraints(self, target_batch_size: int) -> SizeConstraints: - dataset = _parse_dataset( - self._gtspec, - self._dataset_provider.get_dataset(tf.distribute.InputContext())) + dataset = self._dataset_provider.get_dataset(tf.distribute.InputContext()) return tfgnn.find_tight_size_constraints( - dataset, + parsing_utils.maybe_parse_graph_tensor_dataset(dataset, self._gtspec), min_nodes_per_component=self._min_nodes_per_component, target_batch_size=target_batch_size) diff --git a/tensorflow_gnn/runner/utils/padding_test.py b/tensorflow_gnn/runner/utils/padding_test.py deleted file mode 100644 index dd8f013c..00000000 --- a/tensorflow_gnn/runner/utils/padding_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for padding.""" -import tensorflow as tf -import tensorflow_gnn as tfgnn -from tensorflow_gnn.runner.utils import padding - -SCHEMA = """ - node_sets { - key: "node" - value { - features { - key: "features" - value { - dtype: DT_FLOAT - shape { dim { size: 4 } } - } - } - } - } - edge_sets { - key: "edge" - value { - source: "node" - target: "node" - } - }""" - - -class PaddingTest(tf.test.TestCase): - - def _assert_fields_equal(self, a: tfgnn.Fields, b: tfgnn.Fields): - self.assertCountEqual(a.keys(), b.keys()) - for k, v in a.items(): - self.assertAllEqual(v, b[k]) - - def test_parse_dataset(self): - schema = tfgnn.parse_schema(SCHEMA) - gtspec = tfgnn.create_graph_spec_from_schema_pb(schema) - expected = tfgnn.random_graph_tensor(gtspec) - example = tfgnn.write_example(expected) - dataset = tf.data.Dataset.from_tensors([example.SerializeToString()]) - - actual = next(iter(padding._parse_dataset(gtspec, dataset))) - - self.assertCountEqual(actual.node_set.keys(), expected.node_set.keys()) - self.assertCountEqual(actual.edge_set.keys(), expected.edge_set.keys()) - - self._assert_fields_equal( - actual.context.features, - expected.context.features) - - for k, v in actual.node_set.items(): - self._assert_fields_equal(v, expected.node_set[k].features) - - for k, v in actual.edge_set.items(): - self._assert_fields_equal(v, expected.edge_set[k].features) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow_gnn/runner/utils/parsing.py b/tensorflow_gnn/runner/utils/parsing.py new file mode 100644 index 00000000..1215a325 --- /dev/null +++ b/tensorflow_gnn/runner/utils/parsing.py @@ -0,0 +1,64 @@ +# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers for `GraphTensor` parsing.""" +import tensorflow as tf +import tensorflow_gnn as tfgnn + +GraphTensor = tfgnn.GraphTensor +GraphTensorSpec = tfgnn.GraphTensorSpec + + +def maybe_parse_graph_tensor_dataset( + ds: tf.data.Dataset, + gtspec: GraphTensorSpec) -> tf.data.Dataset: + """Parse (or check the compatability of) a dataset with `GraphTensorSpec`. + + * If `ds` contains `tf.string` elements, the dataset is parsed using `gtspec` + and returned. + * If `ds` contains `GraphTensor` elements, the dataset is checked + (by `tfgnn.create_schema_pb_from_graph_spec(...)`) to be compatible with + `gtspec` and returned. + * Otherwise, a `ValueError` is raised. + + Args: + ds: A `tf.data.Dataset` to parse or check. + gtspec: A `GraphTensorSpec` for parsing or checking. + + Returns: + A `tf.data.Dataset` that has been parsed by, or checked for compatibility + with, `gtspec`. + + Raises: + ValueError: If `ds` does contain `tf.string` or `GraphTensor` elements. + """ + if ds.element_spec.is_compatible_with(tf.TensorSpec((), tf.string)): + ds = ds.map( + tfgnn.keras.layers.ParseSingleExample(gtspec), + deterministic=False, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + elif ds.element_spec.is_compatible_with(tf.TensorSpec((None,), tf.string)): + ds = ds.map( + tfgnn.keras.layers.ParseExample(gtspec), + deterministic=False, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + elif not isinstance(ds.element_spec, tfgnn.GraphTensorSpec): + raise ValueError(f"Expected `GraphTensorSpec` (got {ds.element_spec})") + else: + element_spec = ds.element_spec + while element_spec.rank > 0: + element_spec = element_spec._unbatch() # pylint: disable=protected-access + schema = tfgnn.create_schema_pb_from_graph_spec(gtspec) + tfgnn.check_compatible_with_schema_pb(element_spec, schema) + return ds diff --git a/tensorflow_gnn/runner/utils/parsing_test.py b/tensorflow_gnn/runner/utils/parsing_test.py new file mode 100644 index 00000000..7d132907 --- /dev/null +++ b/tensorflow_gnn/runner/utils/parsing_test.py @@ -0,0 +1,157 @@ +# Copyright 2021 The TensorFlow GNN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for parsing.""" +import functools + +from absl.testing import parameterized +import tensorflow as tf +import tensorflow_gnn as tfgnn +from tensorflow_gnn.runner.utils import parsing as parsing_utils + +SCHEMA = """ + node_sets { + key: "node" + value { + features { + key: "features" + value { + dtype: DT_FLOAT + shape { dim { size: 4 } } + } + } + } + } + edge_sets { + key: "edge" + value { + source: "node" + target: "node" + } + } +""" + +Fields = tfgnn.Fields +GraphTensor = tfgnn.GraphTensor + +ds_from_tensor = tf.data.Dataset.from_tensors + + +def gtspec() -> tfgnn.GraphTensorSpec: + return tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA)) + + +@functools.lru_cache(None) +def random_graph_tensor() -> tfgnn.GraphTensor: + return tfgnn.random_graph_tensor(gtspec()) + + +@functools.lru_cache(None) +def random_serialized_graph_tensor() -> str: + return tfgnn.write_example(random_graph_tensor()).SerializeToString() + + +class ParsingTest(tf.test.TestCase, parameterized.TestCase): + + def _assert_fields_equal(self, a: Fields, b: Fields): + self.assertCountEqual(a.keys(), b.keys()) + for k, v in a.items(): + self.assertAllEqual(v, b[k]) + + def _assert_graph_tensors_equal(self, a: GraphTensor, b: GraphTensor): + self.assertCountEqual(a.node_sets.keys(), b.node_sets.keys()) + self.assertCountEqual(a.edge_sets.keys(), b.edge_sets.keys()) + + self._assert_fields_equal(a.context.features, b.context.features) + + for k, v in a.node_sets.items(): + self._assert_fields_equal(v.features, b.node_sets[k].features) + + for k, v in a.edge_sets.items(): + self._assert_fields_equal(v.features, b.edge_sets[k].features) + + @parameterized.named_parameters([ + dict( + testcase_name="SerializedGraphTensorElement", + ds=ds_from_tensor(random_serialized_graph_tensor()), + spec=random_graph_tensor().spec, + expected=random_graph_tensor(), + ), + dict( + testcase_name="GraphTensorElement", + ds=ds_from_tensor(random_graph_tensor()), + spec=random_graph_tensor().spec, + expected=random_graph_tensor(), + ), + dict( + testcase_name="SerializedGraphTensorElements", + ds=ds_from_tensor(random_serialized_graph_tensor()).repeat().batch(4), + spec=random_graph_tensor().spec, + expected=next( + iter( + ds_from_tensor(random_graph_tensor()).repeat().batch(4) + ) + ), + ), + dict( + testcase_name="GraphTensorElements", + ds=ds_from_tensor(random_graph_tensor()).repeat().batch(4), + spec=random_graph_tensor().spec, + expected=next( + iter( + ds_from_tensor(random_graph_tensor()).repeat().batch(4) + ) + ), + ), + ]) + def test_maybe_parse_graph_tensor_dataset( + self, + ds: tf.data.Dataset, + spec: tfgnn.GraphTensorSpec, + expected: tfgnn.GraphTensor): + ds = parsing_utils.maybe_parse_graph_tensor_dataset(ds, spec) + self._assert_graph_tensors_equal( + expected, + next(iter(ds))) + + @parameterized.named_parameters([ + dict( + testcase_name="FloatElement", + ds=ds_from_tensor(tf.constant(8191.)), + spec=random_graph_tensor().spec, + expected_failure=r"Expected `GraphTensorSpec` \(got .*\)", + ), + dict( + testcase_name="IncompatibleGraphTensorElement", + ds=ds_from_tensor( + tfgnn.homogeneous( + source=tf.constant((0, 3)), + target=tf.constant((1, 2)), + node_set_sizes=tf.constant((4,)) + ), + ), + spec=random_graph_tensor().spec, + expected_failure=r"Graph is not compatible with the graph schema.*", + ), + ]) + def test_maybe_parse_graph_tensor_dataset_fails( + self, + ds: tf.data.Dataset, + spec: tfgnn.GraphTensorSpec, + expected_failure: str): + with self.assertRaisesRegex(ValueError, expected_failure): + _ = parsing_utils.maybe_parse_graph_tensor_dataset(ds, spec) + +if __name__ == "__main__": + tf.test.main()