Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Make daisy_chain_variables and hparam instead of flag and unset to al…
Browse files Browse the repository at this point in the history
…low LSTM to train in distributed mode.

PiperOrigin-RevId: 177193238
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Nov 29, 2017
1 parent a3d0ffe commit 5adacd0
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 11 deletions.
5 changes: 5 additions & 0 deletions tensor2tensor/layers/common_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ def basic_params1():
scheduled_sampling_prob=0.0,
scheduled_sampling_warmup_steps=50000,
scheduled_sampling_gold_mixin_prob=0.5,
# This setting controls whether to copy variables around in a daisy chain
# (if true) or leave their placement to Tensorflow. It only affects multi
# device training and mostly should be turned on for performance. One
# exception are recurrent models: with dynamic loops it must be off.
daisy_chain_variables=True,
# This is the actual batch size, *not* tokens per batch (i.e. for
# language models this is the number of sentences in the batch)
tpu_batch_size_per_shard=24,
Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/models/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def model_fn_body(self, features):
def lstm_seq2seq():
"""hparams for LSTM."""
hparams = common_hparams.basic_params1()
hparams.daisy_chain_variables = False
hparams.batch_size = 1024
hparams.hidden_size = 128
hparams.num_hidden_layers = 2
Expand Down
9 changes: 5 additions & 4 deletions tensor2tensor/models/neural_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import tensorflow as tf


def neural_gpu(inputs, hparams, name=None):
def neural_gpu_body(inputs, hparams, name=None):
"""The core Neural GPU."""
with tf.variable_scope(name, "neural_gpu"):

Expand Down Expand Up @@ -59,7 +59,7 @@ def step(state, inp): # pylint: disable=missing-docstring
class NeuralGPU(t2t_model.T2TModel):

def model_fn_body(self, features):
return neural_gpu(features["inputs"], self._hparams)
return neural_gpu_body(features["inputs"], self._hparams)


def diagonal_neural_gpu(inputs, hparams, name=None):
Expand Down Expand Up @@ -97,10 +97,11 @@ def model_fn_body(self, features):
return diagonal_neural_gpu(features["inputs"], self._hparams)


@registry.register_hparams("neuralgpu_1")
def neural_gpu_params1():
@registry.register_hparams
def neural_gpu():
"""Set of hyperparameters."""
hparams = common_hparams.basic_params1()
hparams.daisy_chain_variables = False
hparams.batch_size = 1024
hparams.num_hidden_layers = 1
hparams.hidden_size = 256
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def decode_from_dataset(estimator,
mode=tf.estimator.ModeKeys.PREDICT,
hparams=hparams,
data_dir=hparams.data_dir,
num_datashards=devices.data_parallelism().n,
num_datashards=devices.data_parallelism(hparams).n,
fixed_problem=problem_idx,
batch_size=decode_hp.batch_size,
dataset_split=dataset_split,
Expand Down
5 changes: 3 additions & 2 deletions tensor2tensor/utils/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def ps_devices(all_workers=False):
return [""]


def data_parallelism(all_workers=False):
def data_parallelism(hparams, all_workers=False):
"""Over which devices do we split each training batch.
In old-fashioned async mode, we split the batch over all GPUs on the
Expand All @@ -95,6 +95,7 @@ def data_parallelism(all_workers=False):
between datashards.
Args:
hparams: model hyperparameters (an HParams object).
all_workers: whether the devices are all async workers or just this one.
Returns:
Expand Down Expand Up @@ -148,4 +149,4 @@ def _replica_device_setter(worker_device):
datashard_devices,
reuse=True,
caching_devices=caching_devices,
daisy_chain_variables=FLAGS.daisy_chain_variables)
daisy_chain_variables=hparams.daisy_chain_variables)
2 changes: 1 addition & 1 deletion tensor2tensor/utils/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def model_fn(model,
decode_hp = decode_hparams

# TODO(rsepassi): This still depends on FLAGS. Rm eventually.
dp = devices.data_parallelism()
dp = devices.data_parallelism(hparams)

tf.get_variable_scope().set_initializer(_get_variable_initializer(hparams))
is_training = mode == tf.estimator.ModeKeys.TRAIN
Expand Down
4 changes: 1 addition & 3 deletions tensor2tensor/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@
flags.DEFINE_bool("locally_shard_to_cpu", False,
"Use CPU as a sharding device running locally. This allows "
"to test sharded model construction on a machine with 1 GPU.")
flags.DEFINE_bool("daisy_chain_variables", True,
"copy variables around in a daisy chain")
flags.DEFINE_bool("sync", False, "Sync compute on PS.")
flags.DEFINE_string("worker_job", "/job:localhost", "name of worker job")
flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.")
Expand Down Expand Up @@ -219,7 +217,7 @@ def create_experiment_components(data_dir, model_name, hparams, run_config):

# hparams batch_size is used as minibatch size instead of tokens in batch
batch_size = (hparams.use_fixed_batch_size and hparams.batch_size) or None
num_datashards = devices.data_parallelism().n
num_datashards = devices.data_parallelism(hparams).n
train_input_fn = input_fn_builder.build_input_fn(
mode=tf.estimator.ModeKeys.TRAIN,
hparams=hparams,
Expand Down

0 comments on commit 5adacd0

Please sign in to comment.