Skip to content

Commit

Permalink
Hooks for HGT hyperparameter tuning
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 514398136
  • Loading branch information
mihirparadkar authored and tensorflower-gardener committed Mar 6, 2023
1 parent e653541 commit 6e0222d
Show file tree
Hide file tree
Showing 9 changed files with 355 additions and 5 deletions.
58 changes: 57 additions & 1 deletion tensorflow_gnn/models/hgt/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_contrib_test", "pytype_strict_library")
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "tf_py_test")

licenses(["notice"])
Expand All @@ -18,11 +18,67 @@ pytype_strict_library(
":users",
],
deps = [
":config_dict",
":layers",
":softmax",
],
)

# A special BUILD target with a declaration of the model's hyperparameter search space.
# Unlike the model itself, this does not depend on TF/TF-GNN, but does depend on Vizier.
pytype_strict_library(
name = "hparams_vizier",
srcs = ["hparams_vizier.py"],
visibility = [
":__subpackages__",
":users",
],
deps = [
"//:expect_vizier_service_pyvizier_installed",
],
)

exports_files(
srcs = ["hparams_vizier.py"],
visibility = [
":__subpackages__",
],
)

pytype_strict_contrib_test(
name = "hparams_vizier_test",
srcs = ["hparams_vizier_test.py"],
python_version = "PY3",
deps = [
":hparams_vizier",
"//:expect_vizier_service_pyvizier_installed",
"//third_party/py/absl/testing:absltest",
],
)

pytype_strict_library(
name = "config_dict",
srcs = ["config_dict.py"],
deps = [
":layers",
"//third_party/py/ml_collections/config_dict",
"//:expect_tensorflow_installed",
],
)

tf_py_test(
name = "config_dict_test",
srcs = ["config_dict_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":config_dict",
":layers",
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
],
)

pytype_strict_library(
name = "layers",
srcs = ["layers.py"],
Expand Down
5 changes: 4 additions & 1 deletion tensorflow_gnn/models/hgt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
from tensorflow_gnn.models import hgt
```
"""

from tensorflow_gnn.models.hgt import config_dict
from tensorflow_gnn.models.hgt import layers
from tensorflow_gnn.models.hgt import softmax

global_segmented_softmax_edges_per_node = (
softmax.global_segmented_softmax_edges_per_node
)
HGTGraphUpdate = layers.HGTGraphUpdate
graph_update_get_config_dict = config_dict.graph_update_get_config_dict
graph_update_from_config_dict = config_dict.graph_update_from_config_dict

del config_dict
del layers
del softmax
59 changes: 59 additions & 0 deletions tensorflow_gnn/models/hgt/config_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ConfigDict for Multi-Head Attention."""

from ml_collections import config_dict
import tensorflow as tf
from tensorflow_gnn.models.hgt import layers


def graph_update_get_config_dict() -> config_dict.ConfigDict:
"""Returns ConfigDict for graph_update_from_config_dict() with defaults."""
cfg = config_dict.ConfigDict()
# LINT.IfChange(graph_update_get_config_dict)
cfg.num_heads = config_dict.placeholder(int) # Sets type to Optional[int].
cfg.per_head_channels = config_dict.placeholder(int)
cfg.receiver_tag = config_dict.placeholder(int)
cfg.use_weighted_skip = config_dict.placeholder(bool)
cfg.dropout_rate = config_dict.placeholder(float)
cfg.use_layer_norm = config_dict.placeholder(bool)
cfg.use_bias = config_dict.placeholder(bool)
cfg.activation = config_dict.placeholder(str)
# LINT.ThenChange(./layers.py:HGTGraphUpdate_args)
cfg.lock()
return cfg


def graph_update_from_config_dict(
cfg: config_dict.ConfigDict) -> tf.keras.layers.Layer:
"""Returns a HGTGraphUpdate initialized from `cfg`.
Args:
cfg: A `ConfigDict` with the fields defined by
`graph_update_get_config_dict()`. All fields with non-`None` values are
used as keyword arguments for initializing and returning a
`HGTGraphUpdate` object. For the required arguments of
`HGTGraphUpdate.__init__`, users must set a value in
`cfg` before passing it here.
Returns:
A new `HGTGraphUpdate` object.
Raises:
TypeError: if `cfg` fails to supply a required argument for
`HGTGraphUpdate.__init__`.
"""
kwargs = {k: v for k, v in cfg.items() if v is not None}
return layers.HGTGraphUpdate(**kwargs)
103 changes: 103 additions & 0 deletions tensorflow_gnn/models/hgt/config_dict_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for config_dict."""

import json
from typing import Mapping

import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.models.hgt import config_dict as hgt_config_dict
from tensorflow_gnn.models.hgt import layers


class ConfigDictTest(tf.test.TestCase):

def test_graph_update_defaults(self):
num_heads = 1
per_head_channels = 1
receiver_tag = tfgnn.SOURCE
use_weighted_skip = True
dropout_rate = 0.1
use_layer_norm = True
use_bias = True
activation = "gelu"

cfg = hgt_config_dict.graph_update_get_config_dict()
cfg.num_heads = num_heads
cfg.per_head_channels = per_head_channels
cfg.receiver_tag = receiver_tag
cfg.use_weighted_skip = use_weighted_skip
cfg.dropout_rate = dropout_rate
cfg.use_layer_norm = use_layer_norm
cfg.use_bias = use_bias
cfg.activation = activation

if tf.__version__.startswith("2.9."):
self.skipTest(f"HGTGraphUpdate requires TF 2.10+, got {tf.__version__}")

actual = hgt_config_dict.graph_update_from_config_dict(cfg)
expected = layers.HGTGraphUpdate(
num_heads=num_heads,
per_head_channels=per_head_channels,
receiver_tag=receiver_tag,
use_weighted_skip=use_weighted_skip,
dropout_rate=dropout_rate,
use_layer_norm=use_layer_norm,
use_bias=use_bias,
activation=activation)
self.assertEqual(to_model_config(actual), to_model_config(expected))


# TODO(b/265776928): De-duplicate the multiple copies of this test helper.
def to_model_config(layer: tf.keras.layers.Layer):
"""Returns a parsed model config for `layer`, without `"name"` fields."""
# Need a full model to serialize *recursively*.
model = tf.keras.Sequential([layer])
# Subobjects are only built in the first call.
_ = model(_make_test_graph_loop())
model_config = json.loads(model.to_json())
# The names of layers are uniquified and impede the hparam comparison.
return _remove_names(model_config)


def _remove_names(obj):
"""Returns parsed JSON `obj` without dict entries called "name"."""
if isinstance(obj, Mapping):
return {k: _remove_names(v) for k, v in obj.items() if k != "name"}
elif isinstance(obj, (list, tuple)):
return type(obj)([_remove_names(v) for v in obj])
else:
return obj


def _make_test_graph_loop():
"""Returns a scalar GraphTensor with one node and one egde."""
return tfgnn.GraphTensor.from_pieces(
node_sets={
"nodes": tfgnn.NodeSet.from_fields(
sizes=tf.constant([1]),
features={tfgnn.HIDDEN_STATE: tf.constant([[1.]])})},
edge_sets={
"edges": tfgnn.EdgeSet.from_fields(
sizes=tf.constant([1]),
adjacency=tfgnn.Adjacency.from_indices(
("nodes", tf.constant([0])),
("nodes", tf.constant([0]))))})


if __name__ == "__main__":
tf.test.main()

68 changes: 68 additions & 0 deletions tensorflow_gnn/models/hgt/hparams_vizier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparameter search spaces for Vizier studies.
This file defines search spaces for hyperparameter tuning of the
MultiHeadAttention model architecture with https://github.com/google/vizier.
End-to-end models built with MultiHeadAttention can use this to configure and
launch a Vizier study and the training runs for its trials. It's up to them how
to forward Vizier params to the training script and its use of
MultiHeadAttention. The parameter names set here for Vizier match the keyword
arguments in the Python modeling code.
For each search space definition, this file has a function
```
add_params_<name>(search_space)
```
that modifies `search_space` in-place by adding parameters and returns `None`.
"""

from vizier.service import pyvizier as vz


def add_params_regularization(search_space: vz.SearchSpace,
*, prefix: str = "")-> None:
"""Adds params for a study of regularization strength.
Args:
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding
`state_dropout_rate`, `edge_dropout_rate` and `l2_regularization`.
prefix: a prefix added to param names.
"""
# The params in `root` apply to all trials in the Vizier study.
# go/pyvizier also lets you add conditional params.
root = search_space.root
root.add_discrete_param(
prefix + "dropout_rate", [.1, .2, .3],
scale_type=vz.ScaleType.LINEAR)


def add_params_attention(search_space: vz.SearchSpace,
*, prefix: str = "")-> None:
"""Adds params for a study of attention configurations.
Args:
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding
`num_heads`.
prefix: a prefix added to param names.
"""
# The params in `root` apply to all trials in the Vizier study.
# go/pyvizier also lets you add conditional params.
root = search_space.root
root.add_discrete_param(
prefix + "num_heads", [4, 8, 16, 32],
scale_type=vz.ScaleType.LINEAR)
45 changes: 45 additions & 0 deletions tensorflow_gnn/models/hgt/hparams_vizier_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for hparams_vizier."""

from absl.testing import absltest
from tensorflow_gnn.models.hgt import hparams_vizier

from vizier.service import pyvizier as vz


class HparamsVizierTest(absltest.TestCase):

def test_regularization(self):
problem = vz.ProblemStatement()
hparams_vizier.add_params_regularization(
problem.search_space, prefix="foo."
)
self.assertCountEqual(
[p.name for p in problem.search_space.parameters], ["foo.dropout_rate"]
)

def test_hgt_attention(self):
problem = vz.ProblemStatement()
hparams_vizier.add_params_attention(
problem.search_space, prefix="foo."
)
self.assertCountEqual(
[p.name for p in problem.search_space.parameters], ["foo.num_heads"]
)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions tensorflow_gnn/models/hgt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class HGTGraphUpdate(tf.keras.layers.Layer):
"""

def __init__(
# LINT.IfChange(HGTGraphUpdate_args)
self,
*,
num_heads: int,
Expand All @@ -80,6 +81,7 @@ def __init__(
activation: Union[str, Callable[..., Any]] = 'gelu',
feature_name: str = tfgnn.HIDDEN_STATE,
**kwargs,
# LINT.ThenChange(./config_dict.py:graph_update_get_config_dict)
):
kwargs.setdefault('name', 'hgt_graph_update')
super().__init__(**kwargs)
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/runner/examples/ogbn/mag/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pytype_strict_library(
"//third_party/py/ml_collections/config_flags",
"//:expect_tensorflow_installed",
"//tensorflow_gnn",
"//tensorflow_gnn/models/hgt",
"//tensorflow_gnn/models/mt_albis",
"//tensorflow_gnn/models/multi_head_attention",
"//tensorflow_gnn/models/vanilla_mpnn",
Expand Down
Loading

0 comments on commit 6e0222d

Please sign in to comment.