-
Notifications
You must be signed in to change notification settings - Fork 187
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MultiHeadAttention model option to xm_launcher
PiperOrigin-RevId: 493797031
- Loading branch information
1 parent
dd0d1a2
commit ed9fa45
Showing
8 changed files
with
341 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""ConfigDict for Multi-Head Attention.""" | ||
|
||
from ml_collections import config_dict | ||
import tensorflow as tf | ||
from tensorflow_gnn.models.multi_head_attention import layers | ||
|
||
|
||
def graph_update_get_config_dict() -> config_dict.ConfigDict: | ||
"""Returns ConfigDict for graph_update_from_config_dict() with defaults.""" | ||
# Keep in sync with default args of | ||
# MultiHeadAttentionMPNNGraphUpdate.__init__. | ||
cfg = config_dict.ConfigDict() | ||
cfg.units = config_dict.placeholder(int) # Sets type to Optional[int]. | ||
cfg.message_dim = config_dict.placeholder(int) | ||
cfg.num_heads = config_dict.placeholder(int) | ||
cfg.receiver_tag = config_dict.placeholder(int) | ||
cfg.l2_regularization = 0.0 | ||
cfg.edge_dropout_rate = 0.0 | ||
cfg.state_dropout_rate = 0.0 | ||
cfg.conv_activation = "relu" | ||
cfg.activation = "relu" | ||
cfg.lock() | ||
return cfg | ||
|
||
|
||
def graph_update_from_config_dict( | ||
cfg: config_dict.ConfigDict) -> tf.keras.layers.Layer: | ||
"""Returns a MultiHeadAttentionMPNNGraphUpdate initialized from `cfg`. | ||
Args: | ||
cfg: A `ConfigDict` with the fields defined by | ||
`graph_update_get_config_dict()`. All fields with non-`None` values are | ||
used as keyword arguments for initializing and returning a | ||
`MultiHeadAttentionMPNNGraphUpdate` object. For the required arguments of | ||
`MultiHeadAttentionMPNNGraphUpdate.__init__`, users must set a value in | ||
`cfg` before passing it here. | ||
Returns: | ||
A new `MultiHeadAttentionMPNNGraphUpdate` object. | ||
Raises: | ||
TypeError: if `cfg` fails to supply a required argument for | ||
`MultiHeadAttentionMPNNGraphUpdate.__init__`. | ||
""" | ||
kwargs = {k: v for k, v in cfg.items() if v is not None} | ||
return layers.MultiHeadAttentionMPNNGraphUpdate(**kwargs) |
87 changes: 87 additions & 0 deletions
87
tensorflow_gnn/models/multi_head_attention/config_dict_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for config_dict.""" | ||
|
||
import json | ||
from typing import Mapping | ||
|
||
import tensorflow as tf | ||
import tensorflow_gnn as tfgnn | ||
from tensorflow_gnn.models.multi_head_attention import config_dict as multi_head_attention_config_dict | ||
from tensorflow_gnn.models.multi_head_attention import layers | ||
|
||
|
||
class ConfigDictTest(tf.test.TestCase): | ||
|
||
def test_graph_update_defaults(self): | ||
units = 32 | ||
message_dim = 16 | ||
num_heads = 4 | ||
receiver_tag = tfgnn.SOURCE | ||
|
||
cfg = multi_head_attention_config_dict.graph_update_get_config_dict() | ||
cfg.units = units | ||
cfg.message_dim = message_dim | ||
cfg.num_heads = num_heads | ||
cfg.receiver_tag = receiver_tag | ||
actual = multi_head_attention_config_dict.graph_update_from_config_dict(cfg) | ||
|
||
expected = layers.MultiHeadAttentionMPNNGraphUpdate( | ||
units=units, | ||
message_dim=message_dim, | ||
num_heads=num_heads, | ||
receiver_tag=receiver_tag) | ||
|
||
self.assertEqual(to_model_config(actual), to_model_config(expected)) | ||
|
||
|
||
def to_model_config(layer: tf.keras.layers.Layer): | ||
"""Returns a parsed model config for `layer`, without `"name"` fields.""" | ||
# Need a full model to serialize *recursively*. | ||
model = tf.keras.Sequential([layer]) | ||
# Subobjects are only build in the first call. | ||
_ = model(_make_test_graph_loop()) | ||
model_config = json.loads(model.to_json()) | ||
# The names of layers are uniquified and impede the hparam comparison. | ||
return _remove_names(model_config) | ||
|
||
|
||
def _remove_names(obj): | ||
"""Returns parsed JSON `obj` without dict entries called "name".""" | ||
if isinstance(obj, Mapping): | ||
return {k: _remove_names(v) for k, v in obj.items() if k != "name"} | ||
elif isinstance(obj, (list, tuple)): | ||
return type(obj)([_remove_names(v) for v in obj]) | ||
else: | ||
return obj | ||
|
||
|
||
def _make_test_graph_loop(): | ||
"""Returns a scalar GraphTensor with one node and one egde.""" | ||
return tfgnn.GraphTensor.from_pieces( | ||
node_sets={ | ||
"nodes": tfgnn.NodeSet.from_fields( | ||
sizes=tf.constant([1]), | ||
features={tfgnn.HIDDEN_STATE: tf.constant([[1.]])})}, | ||
edge_sets={ | ||
"edges": tfgnn.EdgeSet.from_fields( | ||
sizes=tf.constant([1]), | ||
adjacency=tfgnn.Adjacency.from_indices( | ||
("nodes", tf.constant([0])), | ||
("nodes", tf.constant([0]))))}) | ||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
74 changes: 74 additions & 0 deletions
74
tensorflow_gnn/models/multi_head_attention/hparams_vizier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Hyperparameter search spaces for Vizier studies. | ||
This file defines search spaces for hyperparameter tuning of the | ||
MultiHeadAttention model architecture with https://github.com/google/vizier. | ||
End-to-end models built with MultiHeadAttention can use this to configure and | ||
launch a Vizier study and the training runs for its trials. It's up to them how | ||
to forward Vizier params to the training script and its use of | ||
MultiHeadAttention. The parameter names set here for Vizier match the keyword | ||
arguments in the Python modeling code. | ||
For each search space definition, this file has a function | ||
``` | ||
add_params_<name>(search_space) | ||
``` | ||
that modifies `search_space` in-place by adding parameters and returns `None`. | ||
""" | ||
|
||
from vizier.service import pyvizier as vz | ||
|
||
|
||
def add_params_regularization(search_space: vz.SearchSpace, | ||
*, prefix: str = "")-> None: | ||
"""Adds params for a study of regularization strength. | ||
Args: | ||
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding | ||
`state_dropout_rate`, `edge_dropout_rate` and `l2_regularization`. | ||
prefix: a prefix added to param names. | ||
""" | ||
# The params in `root` apply to all trials in the Vizier study. | ||
# go/pyvizier also lets you add conditional params. | ||
root = search_space.root | ||
root.add_discrete_param( | ||
prefix + "state_dropout_rate", [.1, .2, .3], | ||
scale_type=vz.ScaleType.LINEAR) | ||
root.add_discrete_param( | ||
prefix + "edge_dropout_rate", [.1, .2, .3], | ||
scale_type=vz.ScaleType.LINEAR) | ||
root.add_float_param( | ||
prefix + "l2_regularization", 1e-6, 1e-4, | ||
scale_type=vz.ScaleType.LOG) | ||
|
||
|
||
def add_params_attention(search_space: vz.SearchSpace, | ||
*, prefix: str = "")-> None: | ||
"""Adds params for a study of attention configurations. | ||
Args: | ||
search_space: a `pyvizier.SearchSpace` that is changed in-place by adding | ||
`num_heads`. | ||
prefix: a prefix added to param names. | ||
""" | ||
# The params in `root` apply to all trials in the Vizier study. | ||
# go/pyvizier also lets you add conditional params. | ||
root = search_space.root | ||
root.add_discrete_param( | ||
prefix + "num_heads", [2, 4, 8], | ||
scale_type=vz.ScaleType.LINEAR) |
44 changes: 44 additions & 0 deletions
44
tensorflow_gnn/models/multi_head_attention/hparams_vizier_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright 2022 The TensorFlow GNN Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for hparams_vizier.""" | ||
|
||
from absl.testing import absltest | ||
from tensorflow_gnn.models.multi_head_attention import hparams_vizier | ||
|
||
from vizier.service import pyvizier as vz | ||
|
||
|
||
class HparamsVizierTest(absltest.TestCase): | ||
|
||
def test_regularization(self): | ||
problem = vz.ProblemStatement() | ||
hparams_vizier.add_params_regularization( | ||
problem.search_space, prefix="foo.") | ||
self.assertCountEqual([p.name for p in problem.search_space.parameters], [ | ||
"foo.state_dropout_rate", "foo.edge_dropout_rate", | ||
"foo.l2_regularization" | ||
]) | ||
|
||
def test_attention(self): | ||
problem = vz.ProblemStatement() | ||
hparams_vizier.add_params_attention(problem.search_space, | ||
prefix="foo.") | ||
self.assertCountEqual( | ||
[p.name for p in problem.search_space.parameters], | ||
["foo.num_heads"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.