Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
CHECKPOINT BREAKING: make T2TModel a subclass of Layer so it can be c…
Browse files Browse the repository at this point in the history
…alled; all variables are now in model-name scope.

PiperOrigin-RevId: 176407831
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Nov 29, 2017
1 parent c0ce3dd commit 01b8c31
Show file tree
Hide file tree
Showing 18 changed files with 155 additions and 91 deletions.
3 changes: 1 addition & 2 deletions tensor2tensor/models/bluenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def testBlueNet(self):
}
model = bluenet.BlueNet(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 5, 1, 1, vocab_size))
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/bytenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def testByteNet(self):
}
model = bytenet.ByteNet(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 50, 1, 1, vocab_size))
Expand Down
5 changes: 2 additions & 3 deletions tensor2tensor/models/gene_expression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def _testModel(self, hparams, model_cls):
"targets": tf.constant(targets, dtype=tf.float32),
}
p_hparams, = hparams.problems
sharded_logits, _ = model_cls(hparams, tf.estimator.ModeKeys.TRAIN,
p_hparams).model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model_cls(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)(features)

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
Expand Down
6 changes: 2 additions & 4 deletions tensor2tensor/models/lstm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def testLSTMSeq2Seq(self):
}
model = lstm.LSTMSeq2seq(hparams, tf.estimator.ModeKeys.TRAIN,
p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
Expand All @@ -67,8 +66,7 @@ def testLSTMSeq2SeqAttention(self):
}
model = lstm.LSTMSeq2seqAttention(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 6, 1, 1, vocab_size))
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/multimodel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def testMultiModel(self):
}
model = multimodel.MultiModel(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 1, 1, 1, 10))
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/neural_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def testNeuralGPU(self):
}
model = neural_gpu.NeuralGPU(hparams, tf.estimator.ModeKeys.TRAIN,
p_hparams)
shadred_logits, _ = model.model_fn(features)
logits = tf.concat(shadred_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (batch_size, target_length, 1, 1,
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/resnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def _testResnet(self, img_size, output_size):
"targets": tf.constant(y, dtype=tf.int32),
}
model = resnet.Resnet50(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (batch_size,) + output_size + (1, vocab_size))
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/slicenet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def testSliceNet(self):
}
model = slicenet.SliceNet(hparams, tf.estimator.ModeKeys.TRAIN,
p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (3, 1, 1, 1, 10))
Expand Down
8 changes: 5 additions & 3 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _greedy_infer(self, features, decode_length):
Raises:
NotImplementedError: If there are multiple data shards.
"""
decoded_ids, _ = self._fast_decode(features, decode_length)
with tf.variable_scope(self.name):
decoded_ids, _ = self._fast_decode(features, decode_length)
return decoded_ids, None, None

def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
Expand All @@ -175,8 +176,9 @@ def _beam_decode(self, features, decode_length, beam_size, top_beams, alpha):
Returns:
samples: an integer `Tensor`. Top samples from the beam search
"""
decoded_ids, scores = self._fast_decode(features, decode_length, beam_size,
top_beams, alpha)
with tf.variable_scope(self.name):
decoded_ids, scores = self._fast_decode(
features, decode_length, beam_size, top_beams, alpha)
return {"outputs": decoded_ids, "scores": scores}

def _fast_decode(self,
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/transformer_revnet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def testTransformer(self):
}
model = transformer_revnet.TransformerRevnet(
hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
grads = tf.gradients(
tf.reduce_mean(logits), [features["inputs"]] + tf.global_variables())
grads = [g for g in grads if g is not None]
Expand Down
26 changes: 11 additions & 15 deletions tensor2tensor/models/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,24 @@ def getModel(self, hparams, mode=tf.estimator.ModeKeys.TRAIN):
targets = -1 + np.random.random_integers(
VOCAB_SIZE, size=(BATCH_SIZE, TARGET_LENGTH, 1, 1))
features = {
"inputs": tf.constant(inputs, dtype=tf.int32),
"targets": tf.constant(targets, dtype=tf.int32),
"target_space_id": tf.constant(1, dtype=tf.int32),
"inputs": tf.constant(inputs, dtype=tf.int32, name="inputs"),
"targets": tf.constant(targets, dtype=tf.int32, name="targets"),
"target_space_id": tf.constant(1, dtype=tf.int32)
}

return transformer.Transformer(hparams, mode, p_hparams), features

def testTransformer(self):
model, features = self.getModel(transformer.transformer_small())
shadred_logits, _ = model.model_fn(features)
logits = tf.concat(shadred_logits, 0)
logits, _ = model(features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))

def testTransformerRelative(self):
model, features = self.getModel(transformer.transformer_relative_tiny())
shadred_logits, _ = model.model_fn(features)
logits = tf.concat(shadred_logits, 0)
logits, _ = model(features)
with self.test_session() as session:
session.run(tf.global_variables_initializer())
res = session.run(logits)
Expand All @@ -81,8 +79,8 @@ def testGreedyVsFast(self):

decode_length = 2

out_logits, _ = model.model_fn(features)
out_logits = tf.squeeze(out_logits[0], axis=[2, 3])
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
Expand All @@ -94,8 +92,7 @@ def testGreedyVsFast(self):
for _ in range(100):
apply_grad.run()

model, _ = self.getModel(transformer.transformer_small(),
mode=tf.estimator.ModeKeys.PREDICT)
model.set_mode(tf.estimator.ModeKeys.PREDICT)

with tf.variable_scope(tf.get_variable_scope(), reuse=True):
greedy_result, _, _ = model._slow_greedy_infer(features, decode_length)
Expand All @@ -115,8 +112,8 @@ def testBeamVsFast(self):

decode_length = 2

out_logits, _ = model.model_fn(features)
out_logits = tf.squeeze(out_logits[0], axis=[2, 3])
out_logits, _ = model(features)
out_logits = tf.squeeze(out_logits, axis=[2, 3])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=tf.reshape(out_logits, [-1, VOCAB_SIZE]),
labels=tf.reshape(features["targets"], [-1]))
Expand All @@ -128,8 +125,7 @@ def testBeamVsFast(self):
for _ in range(100):
apply_grad.run()

model, _ = self.getModel(transformer.transformer_small(),
mode=tf.estimator.ModeKeys.PREDICT)
model.set_mode(tf.estimator.ModeKeys.PREDICT)

with tf.variable_scope(tf.get_variable_scope(), reuse=True):
beam_result = model._beam_decode_slow(
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/models/transformer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,9 +654,9 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1,
dtype=tf.int64)

features["targets"] = initial_output
sharded_logits, _ = self.model_fn(features, False, force_full_predict=True)
sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4)
samples = tf.concat(sharded_samples, 0)
logits, _ = self.__call__(
features, skip=False, force_full_predict=True)
samples = tf.argmax(logits, axis=-1)

if inputs_old is not None: # Restore to not confuse Estimator.
features["inputs"] = inputs_old
Expand Down
3 changes: 1 addition & 2 deletions tensor2tensor/models/xception_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def _testXception(self, img_size, output_size):
"targets": tf.constant(y, dtype=tf.int32),
}
model = xception.Xception(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
sharded_logits, _ = model.model_fn(features)
logits = tf.concat(sharded_logits, 0)
logits, _ = model(features)
session.run(tf.global_variables_initializer())
res = session.run(logits)
self.assertEqual(res.shape, output_size + (1, vocab_size))
Expand Down
4 changes: 1 addition & 3 deletions tensor2tensor/tpu/tpu_trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def t2t_model_fn(model_name,
EstimatorSpec or TPUEstimatorSpec
"""
_create_dummy_vars()

hparams = copy.deepcopy(hparams)
problem = hparams.problem_instances[0]
problem_hp = hparams.problems[0]
Expand All @@ -224,10 +223,9 @@ def t2t_model_fn(model_name,
if use_tpu else create_data_parallelism(**config.t2t_device_info))
model = registry.model(model_name)(
hparams, mode, problem_hp, data_parallelism=data_parallelism)
sharded_logits, losses_dict = model.model_fn(features)
logits, losses_dict = model(features)

# Set known shapes
logits = tf.concat(sharded_logits, 0)
shape = logits.get_shape().as_list()
if shape[0] is None:
shape[0] = _get_batch_size(params, hparams, config)
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/utils/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def nth_model(n):
if eval_run_autoregressive and mode == tf.estimator.ModeKeys.EVAL:
sharded_logits, losses_dict = model_class.eval_autoregressive(features)
else:
sharded_logits, losses_dict = model_class.model_fn(
sharded_logits, losses_dict = model_class(
features, skip=(skipping_is_on and skip_this_one))
with tf.variable_scope("losses_avg"):
total_loss, ops = 0.0, []
Expand Down
27 changes: 13 additions & 14 deletions tensor2tensor/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _reset():
ctr.clear()


def _default_name(obj_class):
def default_name(obj_class):
"""Convert a class name to the registry's default name for the class.
Args:
Expand All @@ -99,7 +99,6 @@ def _default_name(obj_class):
Returns:
The registry's default name for the class.
"""

return _convert_camel_to_snake(obj_class.__name__)


Expand All @@ -112,25 +111,25 @@ def default_object_name(obj):
Returns:
The registry's default name for the class of the object.
"""

return _default_name(obj.__class__)
return default_name(obj.__class__)


def register_model(name=None):
"""Register a model. name defaults to class name snake-cased."""

def decorator(model_cls, registration_name=None):
"""Registers & returns model_cls with registration_name or default name."""
model_name = registration_name or _default_name(model_cls)
model_name = registration_name or default_name(model_cls)
if model_name in _MODELS:
raise LookupError("Model %s already registered." % model_name)
model_cls.REGISTERED_NAME = property(lambda _: model_name)
_MODELS[model_name] = model_cls
return model_cls

# Handle if decorator was used without parens
if callable(name):
model_cls = name
return decorator(model_cls, registration_name=_default_name(model_cls))
return decorator(model_cls, registration_name=default_name(model_cls))

return lambda model_cls: decorator(model_cls, name)

Expand All @@ -150,7 +149,7 @@ def register_hparams(name=None):

def decorator(hp_fn, registration_name=None):
"""Registers & returns hp_fn with registration_name or default name."""
hp_name = registration_name or _default_name(hp_fn)
hp_name = registration_name or default_name(hp_fn)
if hp_name in _HPARAMS:
raise LookupError("HParams set %s already registered." % hp_name)
_HPARAMS[hp_name] = hp_fn
Expand All @@ -159,7 +158,7 @@ def decorator(hp_fn, registration_name=None):
# Handle if decorator was used without parens
if callable(name):
hp_fn = name
return decorator(hp_fn, registration_name=_default_name(hp_fn))
return decorator(hp_fn, registration_name=default_name(hp_fn))

return lambda hp_fn: decorator(hp_fn, name)

Expand All @@ -182,7 +181,7 @@ def register_ranged_hparams(name=None):

def decorator(rhp_fn, registration_name=None):
"""Registers & returns hp_fn with registration_name or default name."""
rhp_name = registration_name or _default_name(rhp_fn)
rhp_name = registration_name or default_name(rhp_fn)
if rhp_name in _RANGED_HPARAMS:
raise LookupError("RangedHParams set %s already registered." % rhp_name)
# Check that the fn takes a single argument
Expand All @@ -197,7 +196,7 @@ def decorator(rhp_fn, registration_name=None):
# Handle if decorator was used without parens
if callable(name):
rhp_fn = name
return decorator(rhp_fn, registration_name=_default_name(rhp_fn))
return decorator(rhp_fn, registration_name=default_name(rhp_fn))

return lambda rhp_fn: decorator(rhp_fn, name)

Expand All @@ -217,7 +216,7 @@ def register_problem(name=None):

def decorator(p_cls, registration_name=None):
"""Registers & returns p_cls with registration_name or default name."""
p_name = registration_name or _default_name(p_cls)
p_name = registration_name or default_name(p_cls)
if p_name in _PROBLEMS:
raise LookupError("Problem %s already registered." % p_name)

Expand All @@ -228,7 +227,7 @@ def decorator(p_cls, registration_name=None):
# Handle if decorator was used without parens
if callable(name):
p_cls = name
return decorator(p_cls, registration_name=_default_name(p_cls))
return decorator(p_cls, registration_name=default_name(p_cls))

return lambda p_cls: decorator(p_cls, name)

Expand Down Expand Up @@ -313,7 +312,7 @@ def _internal_register_modality(name, mod_collection, collection_str):

def decorator(mod_cls, registration_name=None):
"""Registers & returns mod_cls with registration_name or default name."""
mod_name = registration_name or _default_name(mod_cls)
mod_name = registration_name or default_name(mod_cls)
if mod_name in mod_collection:
raise LookupError("%s modality %s already registered." % (collection_str,
mod_name))
Expand All @@ -323,7 +322,7 @@ def decorator(mod_cls, registration_name=None):
# Handle if decorator was used without parens
if callable(name):
mod_cls = name
return decorator(mod_cls, registration_name=_default_name(mod_cls))
return decorator(mod_cls, registration_name=default_name(mod_cls))

return lambda mod_cls: decorator(mod_cls, name)

Expand Down
Loading

0 comments on commit 01b8c31

Please sign in to comment.