-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
9 changed files
with
355 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""ConfigDict for Multi-Head Attention.""" | ||
|
||
from ml_collections import config_dict | ||
import tensorflow as tf | ||
from tensorflow_gnn.models.hgt import layers | ||
|
||
|
||
def graph_update_get_config_dict() -> config_dict.ConfigDict: | ||
"""Returns ConfigDict for graph_update_from_config_dict() with defaults.""" | ||
cfg = config_dict.ConfigDict() | ||
# LINT.IfChange(graph_update_get_config_dict) | ||
cfg.num_heads = config_dict.placeholder(int) # Sets type to Optional[int]. | ||
cfg.per_head_channels = config_dict.placeholder(int) | ||
cfg.receiver_tag = config_dict.placeholder(int) | ||
cfg.use_weighted_skip = config_dict.placeholder(bool) | ||
cfg.dropout_rate = config_dict.placeholder(float) | ||
cfg.use_layer_norm = config_dict.placeholder(bool) | ||
cfg.use_bias = config_dict.placeholder(bool) | ||
cfg.activation = config_dict.placeholder(str) | ||
# LINT.ThenChange(./layers.py:HGTGraphUpdate_args) | ||
cfg.lock() | ||
return cfg | ||
|
||
|
||
def graph_update_from_config_dict( | ||
cfg: config_dict.ConfigDict) -> tf.keras.layers.Layer: | ||
"""Returns a HGTGraphUpdate initialized from `cfg`. | ||
Args: | ||
cfg: A `ConfigDict` with the fields defined by | ||
`graph_update_get_config_dict()`. All fields with non-`None` values are | ||
used as keyword arguments for initializing and returning a | ||
`HGTGraphUpdate` object. For the required arguments of | ||
`HGTGraphUpdate.__init__`, users must set a value in | ||
`cfg` before passing it here. | ||
Returns: | ||
A new `HGTGraphUpdate` object. | ||
Raises: | ||
TypeError: if `cfg` fails to supply a required argument for | ||
`HGTGraphUpdate.__init__`. | ||
""" | ||
kwargs = {k: v for k, v in cfg.items() if v is not None} | ||
return layers.HGTGraphUpdate(**kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for config_dict.""" | ||
|
||
import json | ||
from typing import Mapping | ||
|
||
import tensorflow as tf | ||
import tensorflow_gnn as tfgnn | ||
from tensorflow_gnn.models.hgt import config_dict as hgt_config_dict | ||
from tensorflow_gnn.models.hgt import layers | ||
|
||
|
||
class ConfigDictTest(tf.test.TestCase): | ||
|
||
def test_graph_update_defaults(self): | ||
num_heads = 1 | ||
per_head_channels = 1 | ||
receiver_tag = tfgnn.SOURCE | ||
use_weighted_skip = True | ||
dropout_rate = 0.1 | ||
use_layer_norm = True | ||
use_bias = True | ||
activation = "gelu" | ||
|
||
cfg = hgt_config_dict.graph_update_get_config_dict() | ||
cfg.num_heads = num_heads | ||
cfg.per_head_channels = per_head_channels | ||
cfg.receiver_tag = receiver_tag | ||
cfg.use_weighted_skip = use_weighted_skip | ||
cfg.dropout_rate = dropout_rate | ||
cfg.use_layer_norm = use_layer_norm | ||
cfg.use_bias = use_bias | ||
cfg.activation = activation | ||
|
||
if tf.__version__.startswith("2.9."): | ||
self.skipTest(f"HGTGraphUpdate requires TF 2.10+, got {tf.__version__}") | ||
|
||
actual = hgt_config_dict.graph_update_from_config_dict(cfg) | ||
expected = layers.HGTGraphUpdate( | ||
num_heads=num_heads, | ||
per_head_channels=per_head_channels, | ||
receiver_tag=receiver_tag, | ||
use_weighted_skip=use_weighted_skip, | ||
dropout_rate=dropout_rate, | ||
use_layer_norm=use_layer_norm, | ||
use_bias=use_bias, | ||
activation=activation) | ||
self.assertEqual(to_model_config(actual), to_model_config(expected)) | ||
|
||
|
||
# TODO(b/265776928): De-duplicate the multiple copies of this test helper. | ||
def to_model_config(layer: tf.keras.layers.Layer): | ||
"""Returns a parsed model config for `layer`, without `"name"` fields.""" | ||
# Need a full model to serialize *recursively*. | ||
model = tf.keras.Sequential([layer]) | ||
# Subobjects are only built in the first call. | ||
_ = model(_make_test_graph_loop()) | ||
model_config = json.loads(model.to_json()) | ||
# The names of layers are uniquified and impede the hparam comparison. | ||
return _remove_names(model_config) | ||
|
||
|
||
def _remove_names(obj): | ||
"""Returns parsed JSON `obj` without dict entries called "name".""" | ||
if isinstance(obj, Mapping): | ||
return {k: _remove_names(v) for k, v in obj.items() if k != "name"} | ||
elif isinstance(obj, (list, tuple)): | ||
return type(obj)([_remove_names(v) for v in obj]) | ||
else: | ||
return obj | ||
|
||
|
||
def _make_test_graph_loop(): | ||
"""Returns a scalar GraphTensor with one node and one egde.""" | ||
return tfgnn.GraphTensor.from_pieces( | ||
node_sets={ | ||
"nodes": tfgnn.NodeSet.from_fields( | ||
sizes=tf.constant([1]), | ||
features={tfgnn.HIDDEN_STATE: tf.constant([[1.]])})}, | ||
edge_sets={ | ||
"edges": tfgnn.EdgeSet.from_fields( | ||
sizes=tf.constant([1]), | ||
adjacency=tfgnn.Adjacency.from_indices( | ||
("nodes", tf.constant([0])), | ||
("nodes", tf.constant([0]))))}) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Hyperparameter search spaces for Vizier studies. | ||
This file defines search spaces for hyperparameter tuning of the | ||
MultiHeadAttention model architecture with https://github.com/google/vizier. | ||
End-to-end models built with MultiHeadAttention can use this to configure and | ||
launch a Vizier study and the training runs for its trials. It's up to them how | ||
to forward Vizier params to the training script and its use of | ||
MultiHeadAttention. The parameter names set here for Vizier match the keyword | ||
arguments in the Python modeling code. | ||
For each search space definition, this file has a function | ||
``` | ||
add_params_<name>(search_space) | ||
``` | ||
that modifies `search_space` in-place by adding parameters and returns `None`. | ||
""" | ||
|
||
from vizier.service import pyvizier as vz | ||
|
||
|
||
def add_params_regularization(search_space: vz.SearchSpace, | ||
*, prefix: str = "")-> None: | ||
"""Adds params for a study of regularization strength. | ||
Args: | ||
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding | ||
`state_dropout_rate`, `edge_dropout_rate` and `l2_regularization`. | ||
prefix: a prefix added to param names. | ||
""" | ||
# The params in `root` apply to all trials in the Vizier study. | ||
# go/pyvizier also lets you add conditional params. | ||
root = search_space.root | ||
root.add_discrete_param( | ||
prefix + "dropout_rate", [.1, .2, .3], | ||
scale_type=vz.ScaleType.LINEAR) | ||
|
||
|
||
def add_params_attention(search_space: vz.SearchSpace, | ||
*, prefix: str = "")-> None: | ||
"""Adds params for a study of attention configurations. | ||
Args: | ||
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding | ||
`num_heads`. | ||
prefix: a prefix added to param names. | ||
""" | ||
# The params in `root` apply to all trials in the Vizier study. | ||
# go/pyvizier also lets you add conditional params. | ||
root = search_space.root | ||
root.add_discrete_param( | ||
prefix + "num_heads", [4, 8, 16, 32], | ||
scale_type=vz.ScaleType.LINEAR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2023 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for hparams_vizier.""" | ||
|
||
from absl.testing import absltest | ||
from tensorflow_gnn.models.hgt import hparams_vizier | ||
|
||
from vizier.service import pyvizier as vz | ||
|
||
|
||
class HparamsVizierTest(absltest.TestCase): | ||
|
||
def test_regularization(self): | ||
problem = vz.ProblemStatement() | ||
hparams_vizier.add_params_regularization( | ||
problem.search_space, prefix="foo." | ||
) | ||
self.assertCountEqual( | ||
[p.name for p in problem.search_space.parameters], ["foo.dropout_rate"] | ||
) | ||
|
||
def test_hgt_attention(self): | ||
problem = vz.ProblemStatement() | ||
hparams_vizier.add_params_attention( | ||
problem.search_space, prefix="foo." | ||
) | ||
self.assertCountEqual( | ||
[p.name for p in problem.search_space.parameters], ["foo.num_heads"] | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.