Skip to content

Commit

Permalink
Exhaustively test Tasks cross distribution strategy.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 495074565
  • Loading branch information
dzelle authored and tensorflower-gardener committed Dec 13, 2022
1 parent 2cb91a8 commit 0f78af8
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 29 deletions.
6 changes: 4 additions & 2 deletions tensorflow_gnn/runner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ pytype_strict_library(

distribute_py_test(
name = "distribute_test",
size = "medium",
size = "large",
srcs = ["distribute_test.py"],
shard_count = 4,
shard_count = 8,
tags = [
"no_oss", # TODO(b/238827505)
"nomultivm", # TODO(b/170502145)
Expand All @@ -67,6 +67,8 @@ distribute_py_test(
"//tensorflow_gnn",
"//tensorflow_gnn/models/vanilla_mpnn",
"//tensorflow_gnn/runner/tasks:classification",
"//tensorflow_gnn/runner/tasks:dgi",
"//tensorflow_gnn/runner/tasks:regression",
"//tensorflow_gnn/runner/trainers:keras_fit",
"//tensorflow_gnn/runner/utils:model_templates",
"//tensorflow_gnn/runner/utils:padding",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_gnn/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@
ParameterServerStrategy = strategies.ParameterServerStrategy
TPUStrategy = strategies.TPUStrategy

# NOTE: Tasks cross TensorFlow distribute strategies are tested end to end in
# `distribute_test.py`. If adding a new Task, please also add it to the test
# combinations found there. (See `_all_task_and_processors_combinations`
# in `distribute_test.py`.)
#
# Tasks (Unsupervised)
DeepGraphInfomax = dgi.DeepGraphInfomax
# Tasks (Classification)
Expand Down
106 changes: 79 additions & 27 deletions tensorflow_gnn/runner/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@
import tensorflow.__internal__.test as tftest
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models import vanilla_mpnn
from tensorflow_gnn.runner import interfaces
from tensorflow_gnn.runner import orchestration
from tensorflow_gnn.runner.tasks import classification
from tensorflow_gnn.runner.tasks import dgi
from tensorflow_gnn.runner.tasks import regression
from tensorflow_gnn.runner.trainers import keras_fit
from tensorflow_gnn.runner.utils import model_templates
from tensorflow_gnn.runner.utils import padding

_LABELS = tuple(range(32))

_SAMPLE_DICT = immutabledict({(tfgnn.CONTEXT, None, "label"): _LABELS})

_SCHEMA = """
context {
features {
Expand Down Expand Up @@ -63,6 +64,8 @@
}
"""

TaskAndProcessor = tuple[interfaces.Task, interfaces.GraphTensorProcessorFn]


def _all_eager_strategy_combinations():
strategies = [
Expand All @@ -88,7 +91,64 @@ def _all_eager_strategy_combinations():
return tftest.combinations.combine(distribution=strategies)


class DatasetProvider(orchestration.DatasetProvider):
def _all_task_and_processors_combinations():

def identity(gt):
return gt

def extract_binary_labels(gt):
return gt, gt.context["label"] % 2

def extract_multiclass_labels(gt):
return gt, gt.context["label"]

def extract_regression_labels(gt):
return gt, tf.ones_like(gt.context["label"], dtype=tf.float32)

task_and_processor = {
# Root node classification
classification.RootNodeBinaryClassification(node_set_name="node"):
extract_binary_labels,
classification.RootNodeMulticlassClassification(
node_set_name="node",
num_classes=len(_LABELS)): extract_multiclass_labels,
# Graph classification
classification.GraphBinaryClassification(node_set_name="node"):
extract_binary_labels,
classification.GraphMulticlassClassification(
node_set_name="node",
num_classes=len(_LABELS)): extract_multiclass_labels,
# Root node regression
regression.RootNodeMeanAbsoluteError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanAbsolutePercentageError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredLogarithmicError(node_set_name="node"):
extract_regression_labels,
regression.RootNodeMeanSquaredLogScaledError(node_set_name="node"):
extract_regression_labels,
# Graph regression
regression.GraphMeanAbsoluteError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanAbsolutePercentageError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredLogarithmicError(node_set_name="node"):
extract_regression_labels,
regression.GraphMeanSquaredLogScaledError(node_set_name="node"):
extract_regression_labels,
# Unsupervised
dgi.DeepGraphInfomax(node_set_name="node"):
identity,
}
items = list(task_and_processor.items())
return tftest.combinations.combine(task_and_processor=items)


class DatasetProvider(interfaces.DatasetProvider):

def __init__(self, examples: Sequence[bytes]):
self._examples = list(examples)
Expand All @@ -99,22 +159,24 @@ def get_dataset(self, _: tf.distribute.InputContext) -> tf.data.Dataset:

class OrchestrationTests(tf.test.TestCase, parameterized.TestCase):

@tfdistribute.combinations.generate(_all_eager_strategy_combinations())
def test_run(self, distribution: tf.distribute.Strategy):
@tfdistribute.combinations.generate(
tftest.combinations.times(
_all_eager_strategy_combinations(),
_all_task_and_processors_combinations()
)
)
def test_run(
self,
distribution: tf.distribute.Strategy,
task_and_processor: TaskAndProcessor):
schema = tfgnn.parse_schema(_SCHEMA)
gtspec = tfgnn.create_graph_spec_from_schema_pb(schema)
gt = tfgnn.write_example(tfgnn.random_graph_tensor(
gtspec,
sample_dict=_SAMPLE_DICT))
ds_provider = DatasetProvider((gt.SerializeToString(),) * 4)

def extract_labels(gt):
return gt, gt.context["label"]

def node_sets_fn(node_set, node_set_name):
del node_set_name
return node_set["features"]

node_sets_fn = lambda node_set, node_set_name: node_set["features"]
model_fn = model_templates.ModelFromInitAndUpdates(
init=tfgnn.keras.layers.MapFeatures(node_sets_fn=node_sets_fn),
updates=[vanilla_mpnn.VanillaMPNNGraphUpdate(
Expand All @@ -124,10 +186,6 @@ def node_sets_fn(node_set, node_set_name):
l2_regularization=5e-4,
dropout_rate=0.1)])

task = classification.RootNodeMulticlassClassification(
node_set_name="node",
num_classes=len(_LABELS))

model_dir = self.create_tempdir()

trainer = keras_fit.KerasTrainer(
Expand All @@ -148,6 +206,8 @@ def node_sets_fn(node_set, node_set_name):
train_padding = None
valid_padding = None

task, processor = task_and_processor

orchestration.run(
train_ds_provider=ds_provider,
train_padding=train_padding,
Expand All @@ -157,24 +217,16 @@ def node_sets_fn(node_set, node_set_name):
trainer=trainer,
task=task,
gtspec=gtspec,
drop_remainder=False,
global_batch_size=4,
feature_processors=[extract_labels],
global_batch_size=2,
feature_processors=(processor,),
valid_ds_provider=ds_provider,
valid_padding=valid_padding)

dataset = ds_provider.get_dataset(tf.distribute.InputContext())
kwargs = {"examples": next(iter(dataset.batch(2)))}

saved_model = tf.saved_model.load(os.path.join(model_dir, "export"))
output = saved_model.signatures["serving_default"](**kwargs)
actual = next(iter(output.values()))

# The model has one output
self.assertLen(output, 1)

# The expected shape is (batch size, num classes) or (2, 10)
self.assertShapeEqual(actual, tf.random.uniform((2, len(_LABELS))))
saved_model.signatures["serving_default"](**kwargs)


if __name__ == "__main__":
Expand Down

0 comments on commit 0f78af8

Please sign in to comment.