From 01b8c31da30a7e1109451df2b4b4698946c6c35c Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Mon, 20 Nov 2017 13:31:27 -0800 Subject: [PATCH] CHECKPOINT BREAKING: make T2TModel a subclass of Layer so it can be called; all variables are now in model-name scope. PiperOrigin-RevId: 176407831 --- tensor2tensor/models/bluenet_test.py | 3 +- tensor2tensor/models/bytenet_test.py | 3 +- tensor2tensor/models/gene_expression_test.py | 5 +- tensor2tensor/models/lstm_test.py | 6 +- tensor2tensor/models/multimodel_test.py | 3 +- tensor2tensor/models/neural_gpu_test.py | 3 +- tensor2tensor/models/resnet_test.py | 3 +- tensor2tensor/models/slicenet_test.py | 3 +- tensor2tensor/models/transformer.py | 8 +- .../models/transformer_revnet_test.py | 3 +- tensor2tensor/models/transformer_test.py | 26 +++--- tensor2tensor/models/transformer_vae.py | 6 +- tensor2tensor/models/xception_test.py | 3 +- tensor2tensor/tpu/tpu_trainer_lib.py | 4 +- tensor2tensor/utils/model_builder.py | 2 +- tensor2tensor/utils/registry.py | 27 +++--- tensor2tensor/utils/t2t_model.py | 82 +++++++++++++------ tensor2tensor/utils/trainer_utils_test.py | 56 ++++++++++++- 18 files changed, 155 insertions(+), 91 deletions(-) diff --git a/tensor2tensor/models/bluenet_test.py b/tensor2tensor/models/bluenet_test.py index daf87529e..15f1f46e6 100644 --- a/tensor2tensor/models/bluenet_test.py +++ b/tensor2tensor/models/bluenet_test.py @@ -45,8 +45,7 @@ def testBlueNet(self): } model = bluenet.BlueNet( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size)) diff --git a/tensor2tensor/models/bytenet_test.py b/tensor2tensor/models/bytenet_test.py index f96d3b999..8a19ae905 100644 --- a/tensor2tensor/models/bytenet_test.py +++ b/tensor2tensor/models/bytenet_test.py @@ -44,8 +44,7 @@ def testByteNet(self): } model = bytenet.ByteNet( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size)) diff --git a/tensor2tensor/models/gene_expression_test.py b/tensor2tensor/models/gene_expression_test.py index ea02572d0..94cf20ff3 100644 --- a/tensor2tensor/models/gene_expression_test.py +++ b/tensor2tensor/models/gene_expression_test.py @@ -55,9 +55,8 @@ def _testModel(self, hparams, model_cls): "targets": tf.constant(targets, dtype=tf.float32), } p_hparams, = hparams.problems - sharded_logits, _ = model_cls(hparams, tf.estimator.ModeKeys.TRAIN, - p_hparams).model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model_cls( + hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)(features) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) diff --git a/tensor2tensor/models/lstm_test.py b/tensor2tensor/models/lstm_test.py index b8be74f23..863518fa1 100644 --- a/tensor2tensor/models/lstm_test.py +++ b/tensor2tensor/models/lstm_test.py @@ -44,8 +44,7 @@ def testLSTMSeq2Seq(self): } model = lstm.LSTMSeq2seq(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) @@ -67,8 +66,7 @@ def testLSTMSeq2SeqAttention(self): } model = lstm.LSTMSeq2seqAttention( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size)) diff --git a/tensor2tensor/models/multimodel_test.py b/tensor2tensor/models/multimodel_test.py index 3aff41029..86f92ced6 100644 --- a/tensor2tensor/models/multimodel_test.py +++ b/tensor2tensor/models/multimodel_test.py @@ -48,8 +48,7 @@ def testMultiModel(self): } model = multimodel.MultiModel( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 1, 1, 1, 10)) diff --git a/tensor2tensor/models/neural_gpu_test.py b/tensor2tensor/models/neural_gpu_test.py index 75149ddd5..99b7f1062 100644 --- a/tensor2tensor/models/neural_gpu_test.py +++ b/tensor2tensor/models/neural_gpu_test.py @@ -52,8 +52,7 @@ def testNeuralGPU(self): } model = neural_gpu.NeuralGPU(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - shadred_logits, _ = model.model_fn(features) - logits = tf.concat(shadred_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (batch_size, target_length, 1, 1, diff --git a/tensor2tensor/models/resnet_test.py b/tensor2tensor/models/resnet_test.py index 9db4cb85f..d911dcbd7 100644 --- a/tensor2tensor/models/resnet_test.py +++ b/tensor2tensor/models/resnet_test.py @@ -56,8 +56,7 @@ def _testResnet(self, img_size, output_size): "targets": tf.constant(y, dtype=tf.int32), } model = resnet.Resnet50(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size)) diff --git a/tensor2tensor/models/slicenet_test.py b/tensor2tensor/models/slicenet_test.py index faf028737..7efdf7a33 100644 --- a/tensor2tensor/models/slicenet_test.py +++ b/tensor2tensor/models/slicenet_test.py @@ -49,8 +49,7 @@ def testSliceNet(self): } model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, (3, 1, 1, 1, 10)) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 588b6154c..8745dc00b 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -158,7 +158,8 @@ def _greedy_infer(self, features, decode_length): Raises: NotImplementedError: If there are multiple data shards. """ - decoded_ids, _ = self._fast_decode(features, decode_length) + with tf.variable_scope(self.name): + decoded_ids, _ = self._fast_decode(features, decode_length) return decoded_ids, None, None def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha): @@ -175,8 +176,9 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha): Returns: samples: an integer `Tensor`. Top samples from the beam search """ - decoded_ids, scores = self._fast_decode(features, decode_length, beam_size, - top_beams, alpha) + with tf.variable_scope(self.name): + decoded_ids, scores = self._fast_decode( + features, decode_length, beam_size, top_beams, alpha) return {"outputs": decoded_ids, "scores": scores} def _fast_decode(self, diff --git a/tensor2tensor/models/transformer_revnet_test.py b/tensor2tensor/models/transformer_revnet_test.py index f61b88b5b..79f8eb1e0 100644 --- a/tensor2tensor/models/transformer_revnet_test.py +++ b/tensor2tensor/models/transformer_revnet_test.py @@ -59,8 +59,7 @@ def testTransformer(self): } model = transformer_revnet.TransformerRevnet( hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) grads = tf.gradients( tf.reduce_mean(logits), [features["inputs"]] + tf.global_variables()) grads = [g for g in grads if g is not None] diff --git a/tensor2tensor/models/transformer_test.py b/tensor2tensor/models/transformer_test.py index ae254a42d..a0c21e2c0 100644 --- a/tensor2tensor/models/transformer_test.py +++ b/tensor2tensor/models/transformer_test.py @@ -51,17 +51,16 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN): targets = -1 + np.random.random_integers( VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1)) features = { - "inputs": tf.constant(inputs, dtype=tf.int32), - "targets": tf.constant(targets, dtype=tf.int32), - "target_space_id": tf.constant(1, dtype=tf.int32), + "inputs": tf.constant(inputs, dtype=tf.int32, name="inputs"), + "targets": tf.constant(targets, dtype=tf.int32, name="targets"), + "target_space_id": tf.constant(1, dtype=tf.int32) } return transformer.Transformer(hparams, mode, p_hparams), features def testTransformer(self): model, features = self.getModel(transformer.transformer_small()) - shadred_logits, _ = model.model_fn(features) - logits = tf.concat(shadred_logits, 0) + logits, _ = model(features) with self.test_session() as session: session.run(tf.global_variables_initializer()) res = session.run(logits) @@ -69,8 +68,7 @@ def testTransformer(self): def testTransformerRelative(self): model, features = self.getModel(transformer.transformer_relative_tiny()) - shadred_logits, _ = model.model_fn(features) - logits = tf.concat(shadred_logits, 0) + logits, _ = model(features) with self.test_session() as session: session.run(tf.global_variables_initializer()) res = session.run(logits) @@ -81,8 +79,8 @@ def testGreedyVsFast(self): decode_length = 2 - out_logits, _ = model.model_fn(features) - out_logits = tf.squeeze(out_logits[0], axis=[2, 3]) + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), labels=tf.reshape(features["targets"], [-1])) @@ -94,8 +92,7 @@ def testGreedyVsFast(self): for _ in range(100): apply_grad.run() - model, _ = self.getModel(transformer.transformer_small(), - mode=tf.estimator.ModeKeys.PREDICT) + model.set_mode(tf.estimator.ModeKeys.PREDICT) with tf.variable_scope(tf.get_variable_scope(), reuse=True): greedy_result, _, _ = model._slow_greedy_infer(features, decode_length) @@ -115,8 +112,8 @@ def testBeamVsFast(self): decode_length = 2 - out_logits, _ = model.model_fn(features) - out_logits = tf.squeeze(out_logits[0], axis=[2, 3]) + out_logits, _ = model(features) + out_logits = tf.squeeze(out_logits, axis=[2, 3]) loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]), labels=tf.reshape(features["targets"], [-1])) @@ -128,8 +125,7 @@ def testBeamVsFast(self): for _ in range(100): apply_grad.run() - model, _ = self.getModel(transformer.transformer_small(), - mode=tf.estimator.ModeKeys.PREDICT) + model.set_mode(tf.estimator.ModeKeys.PREDICT) with tf.variable_scope(tf.get_variable_scope(), reuse=True): beam_result = model._beam_decode_slow( diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index ad5143095..caea3ff59 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -654,9 +654,9 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, dtype=tf.int64) features["targets"] = initial_output - sharded_logits, _ = self.model_fn(features, False, force_full_predict=True) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) + logits, _ = self.__call__( + features, skip=False, force_full_predict=True) + samples = tf.argmax(logits, axis=-1) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old diff --git a/tensor2tensor/models/xception_test.py b/tensor2tensor/models/xception_test.py index e02057c10..cb4e3544e 100644 --- a/tensor2tensor/models/xception_test.py +++ b/tensor2tensor/models/xception_test.py @@ -48,8 +48,7 @@ def _testXception(self, img_size, output_size): "targets": tf.constant(y, dtype=tf.int32), } model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) - sharded_logits, _ = model.model_fn(features) - logits = tf.concat(sharded_logits, 0) + logits, _ = model(features) session.run(tf.global_variables_initializer()) res = session.run(logits) self.assertEqual(res.shape, output_size + (1, vocab_size)) diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index b2267319c..65618fc1b 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -209,7 +209,6 @@ def t2t_model_fn(model_name, EstimatorSpec or TPUEstimatorSpec """ _create_dummy_vars() - hparams = copy.deepcopy(hparams) problem = hparams.problem_instances[0] problem_hp = hparams.problems[0] @@ -224,10 +223,9 @@ def t2t_model_fn(model_name, if use_tpu else create_data_parallelism(**config.t2t_device_info)) model = registry.model(model_name)( hparams, mode, problem_hp, data_parallelism=data_parallelism) - sharded_logits, losses_dict = model.model_fn(features) + logits, losses_dict = model(features) # Set known shapes - logits = tf.concat(sharded_logits, 0) shape = logits.get_shape().as_list() if shape[0] is None: shape[0] = _get_batch_size(params, hparams, config) diff --git a/tensor2tensor/utils/model_builder.py b/tensor2tensor/utils/model_builder.py index 6bef72b0c..13ebaa91e 100644 --- a/tensor2tensor/utils/model_builder.py +++ b/tensor2tensor/utils/model_builder.py @@ -127,7 +127,7 @@ def nth_model(n): if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL: sharded_logits, losses_dict = model_class.eval_autoregressive(features) else: - sharded_logits, losses_dict = model_class.model_fn( + sharded_logits, losses_dict = model_class( features, skip=(skipping_is_on and skip_this_one)) with tf.variable_scope("losses_avg"): total_loss, ops = 0.0, [] diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py index e3f3787f6..e21702251 100644 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -90,7 +90,7 @@ def _reset(): ctr.clear() -def _default_name(obj_class): +def default_name(obj_class): """Convert a class name to the registry's default name for the class. Args: @@ -99,7 +99,6 @@ def _default_name(obj_class): Returns: The registry's default name for the class. """ - return _convert_camel_to_snake(obj_class.__name__) @@ -112,8 +111,7 @@ def default_object_name(obj): Returns: The registry's default name for the class of the object. """ - - return _default_name(obj.__class__) + return default_name(obj.__class__) def register_model(name=None): @@ -121,16 +119,17 @@ def register_model(name=None): def decorator(model_cls, registration_name=None): """Registers & returns model_cls with registration_name or default name.""" - model_name = registration_name or _default_name(model_cls) + model_name = registration_name or default_name(model_cls) if model_name in _MODELS: raise LookupError("Model %s already registered." % model_name) + model_cls.REGISTERED_NAME = property(lambda _: model_name) _MODELS[model_name] = model_cls return model_cls # Handle if decorator was used without parens if callable(name): model_cls = name - return decorator(model_cls, registration_name=_default_name(model_cls)) + return decorator(model_cls, registration_name=default_name(model_cls)) return lambda model_cls: decorator(model_cls, name) @@ -150,7 +149,7 @@ def register_hparams(name=None): def decorator(hp_fn, registration_name=None): """Registers & returns hp_fn with registration_name or default name.""" - hp_name = registration_name or _default_name(hp_fn) + hp_name = registration_name or default_name(hp_fn) if hp_name in _HPARAMS: raise LookupError("HParams set %s already registered." % hp_name) _HPARAMS[hp_name] = hp_fn @@ -159,7 +158,7 @@ def decorator(hp_fn, registration_name=None): # Handle if decorator was used without parens if callable(name): hp_fn = name - return decorator(hp_fn, registration_name=_default_name(hp_fn)) + return decorator(hp_fn, registration_name=default_name(hp_fn)) return lambda hp_fn: decorator(hp_fn, name) @@ -182,7 +181,7 @@ def register_ranged_hparams(name=None): def decorator(rhp_fn, registration_name=None): """Registers & returns hp_fn with registration_name or default name.""" - rhp_name = registration_name or _default_name(rhp_fn) + rhp_name = registration_name or default_name(rhp_fn) if rhp_name in _RANGED_HPARAMS: raise LookupError("RangedHParams set %s already registered." % rhp_name) # Check that the fn takes a single argument @@ -197,7 +196,7 @@ def decorator(rhp_fn, registration_name=None): # Handle if decorator was used without parens if callable(name): rhp_fn = name - return decorator(rhp_fn, registration_name=_default_name(rhp_fn)) + return decorator(rhp_fn, registration_name=default_name(rhp_fn)) return lambda rhp_fn: decorator(rhp_fn, name) @@ -217,7 +216,7 @@ def register_problem(name=None): def decorator(p_cls, registration_name=None): """Registers & returns p_cls with registration_name or default name.""" - p_name = registration_name or _default_name(p_cls) + p_name = registration_name or default_name(p_cls) if p_name in _PROBLEMS: raise LookupError("Problem %s already registered." % p_name) @@ -228,7 +227,7 @@ def decorator(p_cls, registration_name=None): # Handle if decorator was used without parens if callable(name): p_cls = name - return decorator(p_cls, registration_name=_default_name(p_cls)) + return decorator(p_cls, registration_name=default_name(p_cls)) return lambda p_cls: decorator(p_cls, name) @@ -313,7 +312,7 @@ def _internal_register_modality(name, mod_collection, collection_str): def decorator(mod_cls, registration_name=None): """Registers & returns mod_cls with registration_name or default name.""" - mod_name = registration_name or _default_name(mod_cls) + mod_name = registration_name or default_name(mod_cls) if mod_name in mod_collection: raise LookupError("%s modality %s already registered." % (collection_str, mod_name)) @@ -323,7 +322,7 @@ def decorator(mod_cls, registration_name=None): # Handle if decorator was used without parens if callable(name): mod_cls = name - return decorator(mod_cls, registration_name=_default_name(mod_cls)) + return decorator(mod_cls, registration_name=default_name(mod_cls)) return lambda mod_cls: decorator(mod_cls, name) diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 02c2b8a7d..186b4348f 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -34,6 +34,8 @@ import tensorflow as tf +from tensorflow.python.layers import base + def _with_timing(fn, msg): @@ -54,16 +56,17 @@ def is_class_modality(mod): return mod.name[:len(prefix)] == prefix -class T2TModel(object): +class T2TModel(base.Layer): """Abstract base class for models. Subclassess generally only need to override `build_model`. """ + REGISTERED_NAME = None # Updated on registration. def __init__(self, hparams, mode, - problem_hparams, + problem_hparams=None, problem_idx=0, data_parallelism=None, ps_devices=None, @@ -83,18 +86,20 @@ def __init__(self, Returns: a T2TModel """ + # Determine name first: use registered name if possible, class name else. + default_name = registry.default_name(type(self)) + name = self.REGISTERED_NAME or default_name + super(T2TModel, self).__init__( + trainable=mode == tf.estimator.ModeKeys.TRAIN, name=name) if data_parallelism is None: data_parallelism = eu.Parallelism([""]) if ps_devices is None: ps_devices = [""] - hparams = copy.copy(hparams) - hparams.add_hparam("mode", mode) - # When not in training mode, set all forms of dropout to zero. - if mode != tf.estimator.ModeKeys.TRAIN: - for key in hparams.values(): - if key[-len("dropout"):] == "dropout": - setattr(hparams, key, 0.0) + if problem_hparams is None: + problem_hparams = hparams.problems[0] + # If vocabularies differ, unset shared_embedding_and_softmax_weights. + hparams = copy.copy(hparams) if hparams.shared_embedding_and_softmax_weights: same_vocab_sizes = True for problem in hparams.problems: @@ -104,7 +109,8 @@ def __init__(self, if not same_vocab_sizes: tf.logging.info("Unsetting shared_embedding_and_softmax_weights.") hparams.shared_embedding_and_softmax_weights = 0 - self._hparams = hparams + self._original_hparams = hparams + self.set_mode(mode) self._decode_hparams = copy.copy(decode_hparams) self._data_parallelism = data_parallelism self._num_datashards = data_parallelism.n @@ -113,6 +119,17 @@ def __init__(self, self._problem_idx = problem_idx self._create_modalities(problem_hparams, hparams) + def set_mode(self, mode): + """Set hparams with the given mode.""" + hparams = copy.copy(self._original_hparams) + hparams.add_hparam("mode", mode) + # When not in training mode, set all forms of dropout to zero. + if mode != tf.estimator.ModeKeys.TRAIN: + for key in hparams.values(): + if key[-len("dropout"):] == "dropout": + setattr(hparams, key, 0.0) + self._hparams = hparams + def _create_modalities(self, problem_hparams, hparams): """Construct modalities in problem_hparams.""" @@ -207,8 +224,8 @@ def infer(self, samples, _, _ = self._greedy_infer(features, decode_length) else: tf.logging.info("Beam Decoding with beam size %d" % beam_size) - samples = self._beam_decode(features, decode_length, beam_size, top_beams, - alpha) + samples = self._beam_decode( + features, decode_length, beam_size, top_beams, alpha) return samples def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha): @@ -263,11 +280,10 @@ def symbols_to_logits_fn(ids): features["targets"] = ids self._coverage = None - sharded_logits, _ = self.model_fn(features, False) + logits, _ = self.__call__(features) # now self._coverage is a coverage tensor for the first datashard. # it has shape [batch_size] and contains floats between 0 and # source_length. - logits = sharded_logits[0] # Assuming we have one shard. modality = self._hparams.problems[self._problem_idx].target_modality if modality.top_is_pointwise: return tf.squeeze(logits, axis=[1, 2, 3]) @@ -384,7 +400,7 @@ def infer_step(recent_output, recent_logits, unused_loss): samples.set_shape([None, None, None, 1]) # Assuming we have one shard for logits. - logits = tf.concat([recent_logits, logits[0][:, -1:]], 1) + logits = tf.concat([recent_logits, logits[:, -1:]], 1) loss = sum([l for l in losses.values() if l is not None]) return samples, logits, loss @@ -477,13 +493,13 @@ def sample(self, features): logits: a list of `Tensor`s, one per datashard. losses: a dictionary: {loss-name (string): floating point `Scalar`}. """ - sharded_logits, losses = self.model_fn(features, False) + logits, losses = self.__call__(features) if self._hparams.sampling_method == "argmax": - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.argmax(logits, axis=-1) else: assert self._hparams.sampling_method == "random" - def _multinomial_squeeze(logits, temperature=1.0): + def multinomial_squeeze(logits, temperature=1.0): logits_shape = common_layers.shape_list(logits) reshaped_logits = ( tf.reshape(logits, [-1, logits_shape[-1]]) / temperature) @@ -491,9 +507,9 @@ def _multinomial_squeeze(logits, temperature=1.0): choices = tf.reshape(choices, logits_shape[:-1]) return choices - sharded_samples = self._data_parallelism( - _multinomial_squeeze, sharded_logits, self._hparams.sampling_temp) - return tf.concat(sharded_samples, 0), sharded_logits, losses + samples = multinomial_squeeze(logits, self._hparams.sampling_temp) + + return samples, logits, losses def _shard_features(self, features): # pylint: disable=missing-docstring sharded_features = dict() @@ -502,13 +518,12 @@ def _shard_features(self, features): # pylint: disable=missing-docstring if not v.shape.as_list(): v = tf.expand_dims(v, axis=-1) v = tf.tile(v, [self._num_datashards]) - sharded_features[k] = self._data_parallelism(tf.identity, - tf.split( - v, self._num_datashards, - 0)) + sharded_features[k] = self._data_parallelism( + tf.identity, + tf.split(v, self._num_datashards, 0)) return sharded_features - def model_fn(self, features, skip=False, force_full_predict=False): + def _model_fn(self, features, skip=False, force_full_predict=False): """Computes the entire model and produces sharded logits and losses. Args: @@ -662,6 +677,21 @@ def sampled_results(): tf.logging.info("This model_fn took %.3f sec." % (time.time() - start_time)) return sharded_logits, losses + def call(self, inputs_dict, skip=False, force_full_predict=False): + problem_hparams = self._problem_hparams + if "problem_choice" not in inputs_dict: + inputs_dict["problem_choice"] = tf.constant( + self._problem_idx, name="problem_choice") + if "input_space_id" not in inputs_dict: + inputs_dict["input_space_id"] = tf.constant( + problem_hparams.input_space_id, name="input_space_id") + if "target_space_id" not in inputs_dict: + inputs_dict["target_space_id"] = tf.constant( + problem_hparams.target_space_id, name="target_space_id") + sharded_logits, losses = self._model_fn( + inputs_dict, skip=skip, force_full_predict=force_full_predict) + return tf.concat(sharded_logits, 0), losses + def model_fn_body_sharded(self, sharded_features): """Mixture-of-experts models will override this function. diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index d8dee3986..bd7367766 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -124,9 +124,9 @@ def testSingleEvalStepRawSession(self): features = { "inputs": batch_inputs, "targets": batch_targets, - "problem_choice": 0, # We run on the first problem here. - "input_space_id": hparams.problems[0].input_space_id, - "target_space_id": hparams.problems[0].target_space_id + "problem_choice": tf.constant(0), # We run on the first problem here. + "input_space_id": tf.constant(hparams.problems[0].input_space_id), + "target_space_id": tf.constant(hparams.problems[0].target_space_id) } # Now set a mode and create the graph by invoking model_fn. @@ -153,6 +153,56 @@ def testSingleEvalStepRawSession(self): # where, for us, batch = 1, length = 3, vocab_size = 4. self.assertEqual(np_predictions.shape, (1, 3, 4)) + def testSingleTrainStepCall(self): + """Illustrate how to run a T2T model in a raw session.""" + + # Set model name, hparams, problems as would be set on command line. + model_name = "transformer" + FLAGS.hparams_set = "transformer_test" + FLAGS.problems = "tiny_algo" + data_dir = "/tmp" # Used only when a vocab file or such like is needed. + + # Create the problem object, hparams, placeholders, features dict. + encoders = registry.problem(FLAGS.problems).feature_encoders(data_dir) + hparams = trainer_utils.create_hparams(FLAGS.hparams_set, data_dir) + trainer_utils.add_problem_hparams(hparams, FLAGS.problems) + + # Now set a mode and create the model. + mode = tf.estimator.ModeKeys.TRAIN + model = registry.model(model_name)(hparams, mode) + + # Create placeholder for features and make them batch-sized. + inputs_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_inputs = tf.reshape(inputs_ph, [1, -1, 1, 1]) # Make it 4D. + targets_ph = tf.placeholder(dtype=tf.int32) # Just length dimension. + batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1]) # Make it 4D. + features = { + "inputs": batch_inputs, + "targets": batch_targets, + "target_space_id": tf.constant(hparams.problems[0].target_space_id) + } + + # Call the model. + predictions, _ = model(features) + nvars = len(tf.trainable_variables()) + model(features) # Call again and check that reuse works. + self.assertEqual(nvars, len(tf.trainable_variables())) + + # Having the graph, let's run it on some data. + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + inputs = "0 1 0" + targets = "0 1 0" + # Encode from raw string to numpy input array using problem encoders. + inputs_numpy = encoders["inputs"].encode(inputs) + targets_numpy = encoders["targets"].encode(targets) + # Feed the encoded inputs and targets and run session. + feed = {inputs_ph: inputs_numpy, targets_ph: targets_numpy} + np_predictions = sess.run(predictions, feed) + # Check that the result has the correct shape: batch x length x vocab_size + # where, for us, batch = 1, length = 3, vocab_size = 4. + self.assertEqual(np_predictions.shape, (1, 3, 1, 1, 4)) + if __name__ == "__main__": tf.test.main()