Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
This change breaks previous checkpoints. Make Transformer fast on TPU.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 176784764
  • Loading branch information
nshazeer authored and Ryan Sepassi committed Nov 29, 2017
1 parent b104292 commit b3cad0c
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 85 deletions.
28 changes: 15 additions & 13 deletions tensor2tensor/layers/common_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def combine_first_two_dimensions(x):

@expert_utils.add_name_scope()
def split_heads(x, num_heads):
"""Split channels (dimension 2) into multiple heads (becomes dimension 1).
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
Args:
x: a Tensor with shape [batch, length, channels]
Expand All @@ -815,7 +815,7 @@ def split_heads(x, num_heads):

@expert_utils.add_name_scope()
def split_heads_2d(x, num_heads):
"""Split channels (dimension 3) into multiple heads (becomes dimension 1).
"""Split channels (dimension 4) into multiple heads (becomes dimension 1).
Args:
x: a Tensor with shape [batch, height, width, channels]
Expand Down Expand Up @@ -2191,10 +2191,10 @@ def compute_qkv(query_antecedent,
"""
if memory_antecedent is None and q_filter_width == kv_filter_width == 1:
# self attention with single position q, k, and v
combined = tf.layers.dense(
combined = common_layers.conv1d(
query_antecedent,
total_key_depth * 2 + total_value_depth,
use_bias=False,
1,
name="qkv_transform")
q, k, v = tf.split(
combined, [total_key_depth, total_key_depth, total_value_depth], axis=2)
Expand Down Expand Up @@ -2250,19 +2250,22 @@ def compute_qkv_2d(query_antecedent, memory_antecedent, total_key_depth,
"""
# self attention with single position q, k, and v
if memory_antecedent is None:
combined = tf.layers.dense(
query_antecedent, total_key_depth * 2 + total_value_depth,
use_bias=False, name="qkv_transform")
combined = tf.layers.conv2d(
query_antecedent,
total_key_depth * 2 + total_value_depth, (1, 1),
name="qkv_transform")
q, k, v = tf.split(
combined, [total_key_depth, total_key_depth, total_value_depth],
axis=-1)
return q, k, v

# Encoder decoder attention
q = tf.layers.dense(
query_antecedent, total_key_depth, use_bias=False, name="q_transform")
combined = tf.layers.dense(
memory_antecedent, total_key_depth + total_value_depth, use_bias=False,
q = common_layers.conv1d(
query_antecedent, total_key_depth, 1, name="q_transform")
combined = common_layers.conv1d(
memory_antecedent,
total_key_depth + total_value_depth,
1,
name="kv_transform")
k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2)

Expand Down Expand Up @@ -2407,8 +2410,7 @@ def multihead_attention(query_antecedent,
x = dilated_self_attention_1d(q, k, v, block_length, block_width,
gap_size, num_memory_blocks)
x = combine_heads(x)
x = tf.layers.dense(
x, output_depth, use_bias=False, name="output_transform")
x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
if additional_returned_value is not None:
return x, additional_returned_value
return x
Expand Down
3 changes: 0 additions & 3 deletions tensor2tensor/layers/common_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,6 @@ def basic_params1():
# This is the actual batch size, *not* tokens per batch (i.e. for
# language models this is the number of sentences in the batch)
tpu_batch_size_per_shard=24,
# Set by tpu_trainer to let the model know whether we are on TPU.
# Switching on/off tpu should not invalidate checkpoints.
use_tpu=False,
# Things not compatible with eager mode use this flag to implement
# alternative functionality. We expect this to go away soon.
use_eager_mode=False,
Expand Down
25 changes: 4 additions & 21 deletions tensor2tensor/layers/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,15 +1231,6 @@ def relu_density_logit(x, reduce_dims):
return scaled


def conv_hidden_relu_simple(inputs, hidden_size, output_size, dropout=0.0):
h = tf.layers.dense(
inputs, hidden_size, use_bias=False, activation=tf.nn.relu, name="conv1")
if dropout != 0.0:
h = tf.nn.dropout(h, 1.0 - dropout)
o = tf.layers.dense(h, output_size, use_bias=False, name="conv2")
return o


def conv_hidden_relu(inputs,
hidden_size,
output_size,
Expand All @@ -1250,9 +1241,6 @@ def conv_hidden_relu(inputs,
"""Hidden layer with RELU activation followed by linear projection."""
name = kwargs.pop("name") if "name" in kwargs else None
with tf.variable_scope(name, "conv_hidden_relu", [inputs]):
if kernel_size == (1, 1) and second_kernel_size == (1, 1):
return conv_hidden_relu_simple(
inputs, hidden_size, output_size, dropout=dropout)
if inputs.get_shape().ndims == 3:
is_3d = True
inputs = tf.expand_dims(inputs, 2)
Expand Down Expand Up @@ -1501,15 +1489,10 @@ def padded_cross_entropy(logits,
confidence = 1.0 - label_smoothing
vocab_size = shape_list(logits)[-1]
with tf.name_scope("padded_cross_entropy", [logits, labels]):
if len(logits.get_shape().as_list()) == 2:
# Deal with the case where we did not insert extra dimensions due to
# TPU issues. No pad-to-same-length happens in this case.
# TODO(noam): remove this logic once TPU can handle extra dimensions.
labels = tf.reshape(labels, [-1])
else:
logits, labels = pad_with_zeros(logits, labels)
xent = smoothing_cross_entropy(logits, labels, vocab_size, confidence)
weights = weights_fn(labels)
pad_logits, pad_labels = pad_with_zeros(logits, labels)
xent = smoothing_cross_entropy(pad_logits, pad_labels, vocab_size,
confidence)
weights = weights_fn(pad_labels)
if not reduce_sum:
return xent * weights, weights
return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
Expand Down
26 changes: 6 additions & 20 deletions tensor2tensor/layers/modalities.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@
import tensorflow as tf


# TODO(noam): remove this function after TPUs do gather faster.
def tpu_gather(params, indices):
vocab_size = params.get_shape().as_list()[0]
indices_flat = tf.reshape(indices, [-1])
out = tf.matmul(tf.one_hot(indices_flat, vocab_size), params)
out = eu.reshape_like(out, tf.expand_dims(indices, -1))
return out


@registry.register_symbol_modality("default")
class SymbolModality(modality.Modality):
"""Modality for sets of discrete symbols.
Expand Down Expand Up @@ -105,8 +96,7 @@ def bottom_simple(self, x, name, reuse):
# Squeeze out the channels dimension.
x = tf.squeeze(x, axis=3)
var = self._get_weights()
ret = (tpu_gather(var, x) if self._model_hparams.use_tpu
else tf.gather(var, x))
ret = tf.gather(var, x)
if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
ret *= self._body_input_depth**0.5
ret *= tf.expand_dims(tf.to_float(tf.not_equal(x, 0)), -1)
Expand Down Expand Up @@ -154,18 +144,14 @@ def top(self, body_output, _):
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
# insert channels dimension
body_output = tf.expand_dims(body_output, 3)
return common_layers.FactoredTensor(body_output, var)
logits = common_layers.FactoredTensor(body_output, var)
else:
body_output = tf.reshape(body_output, [-1, body_output_shape[-1]])
logits = tf.matmul(body_output, var, transpose_b=True)
if (self._model_hparams.use_tpu and
self._model_hparams.mode == tf.estimator.ModeKeys.TRAIN):
# TPU does not react kindly to extra dimensions.
# TODO(noam): remove this once TPU is more forgiving of extra dims.
return logits
else:
return tf.reshape(
logits, body_output_shape[:-1] + [1, self._vocab_size])

out_shape = body_output_shape[:-1] + [1, self._vocab_size]
logits = tf.reshape(logits, out_shape)
return logits


@registry.register_symbol_modality("ctc")
Expand Down
3 changes: 0 additions & 3 deletions tensor2tensor/layers/modalities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def testSymbolModalityInputs(self):
symbol_modality_skip_top=0,
shared_embedding_and_softmax_weights=0,
prepend_mode="none",
use_tpu=False,
use_eager_mode=False)
x = -1 + np.random.random_integers(
vocab_size, size=(batch_size, length, 1, 1))
Expand Down Expand Up @@ -74,7 +73,6 @@ def testSymbolModalityTargets(self):
factored_logits=0,
mode=tf.estimator.ModeKeys.TRAIN,
prepend_mode="none",
use_tpu=False,
use_eager_mode=False)
body_output = -1 + np.random.random_integers(
100, size=(batch_size, length, height, hidden_size))
Expand Down Expand Up @@ -112,7 +110,6 @@ def testSymbolModalityTargetsFactored(self):
factored_logits=1,
mode=tf.estimator.ModeKeys.TRAIN,
prepend_mode="none",
use_tpu=False,
use_eager_mode=False)
body_output = -1 + np.random.random_integers(
100, size=(batch_size, length, height, hidden_size))
Expand Down
26 changes: 2 additions & 24 deletions tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,8 @@ def decode(self,
hparams,
cache=cache)

if hparams.use_tpu and hparams.mode == tf.estimator.ModeKeys.TRAIN:
# TPU does not react kindly to extra dimensions.
# TODO(noam): remove this once TPU is more forgiving of extra dims.
return decoder_output
else:
# Expand since t2t expects 4d tensors.
return tf.expand_dims(decoder_output, axis=2)
# Expand since t2t expects 4d tensors.
return tf.expand_dims(decoder_output, axis=2)

def model_fn_body(self, features):
"""Transformer main model_fn.
Expand Down Expand Up @@ -1119,20 +1114,3 @@ def transformer_clean_big():
hparams.hidden_size = 1024
hparams.filter_size = 4096
return hparams


@registry.register_hparams
def transformer_tpu_lm1b():
"""Hparams for training languagemodel_lm1b8k_concat on tpu."""
hparams = transformer_clean()
update_hparams_for_tpu(hparams)
hparams.max_length = 512
hparams.tpu_batch_size_per_shard = 8
hparams.hidden_size = 1024
hparams.filter_size = 4096
hparams.num_heads = 4
hparams.label_smoothing = 0.0
hparams.layer_prepostprocess_dropout = 0.0
hparams.attention_dropout = 0.0
hparams.relu_dropout = 0.0
return hparams
1 change: 0 additions & 1 deletion tensor2tensor/tpu/tpu_trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def t2t_model_fn(model_name,
hparams = copy.deepcopy(hparams)
problem = hparams.problem_instances[0]
problem_hp = hparams.problems[0]
hparams.use_tpu = use_tpu

features["problem_choice"] = tf.constant(0)
features["input_space_id"] = tf.constant(problem_hp.input_space_id)
Expand Down

0 comments on commit b3cad0c

Please sign in to comment.