Skip to content

Commit

Permalink
Add Transformer api.
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
guoshengCS committed Aug 20, 2020
1 parent 87e1106 commit c50ad43
Showing 1 changed file with 156 additions and 22 deletions.
178 changes: 156 additions & 22 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MultiheadAttention(Layer):
weights to drop some attention targets. 0 for no dropout. Default 0
kdim (int, optional): The feature size in key. If None, assumed equal to
`embed_dim`. Default None.
vdim (int, optional): The feature size in key. If None, assumed equal to
vdim (int, optional): The feature size in value. If None, assumed equal to
`embed_dim`. Default None.
need_weights (bool, optional): Indicate whether to return the attention
weights. Default False.
Expand Down Expand Up @@ -410,8 +410,11 @@ class TransformerEncoderLayer(Layer):
`dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
normalization and post-precess includes dropout, residual connection.
Otherwise, no pre-process and post-precess includes dropout, residual
connection, layer normalization. Default False
param_attr(ParamAttr|tuple, optional): To specify the weight parameter property.
If it is a tuple, `param_attr[0]` would be used as `param_attr` for
MHA, and `param_attr[1]` would be used as `param_attr` for linear in FFN.
Expand Down Expand Up @@ -621,8 +624,11 @@ class TransformerDecoderLayer(Layer):
`dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
normalization and post-precess includes dropout, residual connection.
Otherwise, no pre-process and post-precess includes dropout, residual
connection, layer normalization. Default False
param_attr(ParamAttr|tuple, optional): To specify the weight parameter property.
If it is a tuple, `param_attr[0]` would be used as `param_attr` for
self attention, `param_attr[1]` would be used as `param_attr` for
Expand Down Expand Up @@ -843,13 +849,13 @@ class TransformerDecoder(Layer):
import paddle
from paddle import TransformerDecoderLayer, TransformerDecoder
# decoder input: [batch_size, trg_len, d_model]
# decoder input: [batch_size, tgt_len, d_model]
dec_input = paddle.rand((2, 4, 128))
# encoder output: [batch_size, src_len, d_model]
enc_output = paddle.rand((2, 6, 128))
# self attention mask: [batch_size, n_head, trg_len, trg_len]
# self attention mask: [batch_size, n_head, tgt_len, tgt_len]
self_attn_mask = paddle.rand((2, 2, 4, 4))
# cross attention mask: [batch_size, n_head, trg_len, src_len]
# cross attention mask: [batch_size, n_head, tgt_len, src_len]
cross_attn_mask = paddle.rand((2, 2, 4, 6))
decoder_layer = TransformerDecoderLayer(128, 2, 512)
decoder = TransformerDecoder(decoder_layer, 2)
Expand Down Expand Up @@ -923,8 +929,7 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
cache=cache[i]
if cache is not None else None)
cache=cache[i])
new_caches.append(new_cache)

if self.norm is not None:
Expand All @@ -947,13 +952,96 @@ def gen_cache(self, memory):
Returns:
list: It is a list, and each element in the list is a tuple produced \
by `TransformerDecoderLayer.gen_cache`. See `TransformerDecoderLayer.gen_cache` \
by `TransformerDecoderLayer.gen_cache(memory)`. See `TransformerDecoderLayer.gen_cache` \
for more details.
"""
return [layer.gen_cache(memory) for layer in self.layers]


class Transformer(Layer):
"""
A Transformer model composed of an instance of `TransformerEncoder` and an
instance of `TransformerDecoder`. While the embedding layer and output layer
are not included.
Please refer to `Attention is all you need <http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf>`_ ,
and see `TransformerEncoder` and `TransformerDecoder` for more details.
Users can configurate the model architecture with corresponding parameters.
Note the usage of `normalize_before` representing where to apply layer
normalization (in pre-process or post-precess of multi-head attention or FFN),
and some transformer like models are different on this, such as
`BERT <https://arxiv.org/abs/1810.04805>`_ and `GPT2 <https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf>`_ .
The default architecture here places layer normalization in pre-process and
applies another layer normalization on the output of last encoder/decoder layer.
Parameters:
d_model (int): The expected feature size in the encoder/decoder input
and output.
nhead (int): The number of heads in multi-head attention(MHA).
num_encoder_layers (int): The number of layers in encoder.
num_encoder_layers (int): The number of layers in decoder.
dim_feedforward (int): The hidden layer size in the feedforward network(FFN).
dropout (float, optional): The dropout probability used in pre-process
and post-precess of MHA and FFN sub-layer. Default 0.1
activation (str, optional): The activation function in the feedforward
network. Default relu.
attn_dropout (float, optional): The dropout probability used
in MHA to drop some attention target. If None, use the value of
`dropout`. Default None
act_dropout (float, optional): The dropout probability used after FFN
activition. If None, use the value of `dropout`. Default None
normalize_before (bool, optional): Indicate whether to put layer normalization
into preprocessing of MHA and FFN sub-layers. If True, pre-process is layer
normalization and post-precess includes dropout, residual connection.
Otherwise, no pre-process and post-precess includes dropout, residual
connection, layer normalization. Default False
param_attr(ParamAttr|tuple, optional): To specify the weight parameter property.
If it is a tuple, `param_attr[0]` would be used as `param_attr` for
self attention, `param_attr[1]` would be used as `param_attr` for
cross attention, and `param_attr[2]` would be used as `param_attr`
for linear in FFN. Otherwise, the three sub-layers all uses it as
`param_attr` to create parameters. Default: None, which means the
default weight parameter property is used. See usage for details
in :ref:`api_fluid_ParamAttr` .
bias_attr (ParamAttr|tuple, optional): To specify the bias parameter property.
If it is a tuple, `bias_attr[0]` would be used as `bias_attr` for
self attention, `bias_attr[1]` would be used as `bias_attr` for
cross attention, and `bias_attr[2]` would be used as `bias_attr`
for linear in FFN. Otherwise, the three sub-layers all uses it as
`bias_attr` to create parameters. Default: None, which means the
default bias parameter property is used. See usage for details
in :ref:`api_fluid_ParamAttr` .
custom_encoder (Layer): If custom encoder is provided, use it as the encoder.
Default None
custom_decoder (Layer): If custom decoder is provided, use it as the decoder.
Default None
Examples:
.. code-block:: python
import paddle
from paddle import Transformer
# src: [batch_size, tgt_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# tgt: [batch_size, src_len, d_model]
dec_input = paddle.rand((2, 6, 128))
# src_mask: [batch_size, n_head, src_len, src_len]
enc_self_attn_mask = paddle.rand((2, 2, 4, 4))
# tgt_mask: [batch_size, n_head, tgt_len, tgt_len]
dec_self_attn_mask = paddle.rand((2, 2, 6, 6))
# memory_mask: [batch_size, n_head, tgt_len, src_len]
cross_attn_mask = paddle.rand((2, 2, 6, 4))
transformer = Transformer(128, 2, 4, 4, 512)
output = transformer(dec_input,
enc_output,
enc_self_attn_mask,
dec_self_attn_mask,
cross_attn_mask) # [2, 6, 128]
"""

def __init__(self,
d_model=512,
nhead=8,
Expand All @@ -962,6 +1050,11 @@ def __init__(self,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
attn_dropout=None,
act_dropout=None,
normalize_before=False,
param_attr=None,
bias_attr=None,
custom_encoder=None,
custom_decoder=None):
super(Transformer, self).__init__()
Expand All @@ -970,7 +1063,9 @@ def __init__(self,
self.encoder = custom_encoder
else:
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation)
d_model, nhead, dim_feedforward, dropout, activation,
attn_dropout, act_dropout, normalize_before, param_attr,
bias_attr)
encoder_norm = LayerNorm(d_model)
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers,
encoder_norm)
Expand All @@ -979,27 +1074,66 @@ def __init__(self,
self.decoder = custom_decoder
else:
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation)
d_model, nhead, dim_feedforward, dropout, activation,
attn_dropout, act_dropout, normalize_before, param_attr,
bias_attr)
decoder_norm = LayerNorm(d_model)
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers,
decoder_norm)

self._reset_parameters()

self.d_model = d_model
self.nhead = nhead

def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
"""
Applies a Transformer model on the inputs.
Parameters:
src (Variable): The input of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
tgt (Variable): The input of Transformer decoder. It is a tensor
with shape `[batch_size, target_length, d_model]`. The data type
should be float32 or float64.
memory (Variable): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
tgt_mask (Variable, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
memory_mask (Variable, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the
unwanted positions have `-INF` values and the others have 0 values.
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
Returns:
Variable: It is a tensor that has the same shape and data type \
as `tgt`, representing the output of Transformer decoder.
"""
memory = self.encoder(src, mask=src_mask)
output = self.decoder(
tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
return output


class TransformerCell(Layer):
class TransformerDecoderCell(Layer):
"""
TransformerCell wraps a Transformer decoder producing logits from `inputs`
composed by ids and position.
TransformerDecoderCell wraps a Transformer decoder producing logits from
`inputs` composed by ids and position.
Parameters:
decoder(callable): A TransformerDecoder instance. Or a wrapper of it that
includes a embedding layer accepting ids and positions instead of embeddings
and includes a output layer transforming decoder output features to logits.
embedding_fn(function, optional): A callable that accepts ids and position
embedding_fn(callable, optional): A callable that accepts ids and position
as arguments and return embeddings as input of `decoder`. It can be
None if `decoder` includes a embedding layer. Default None.
output_fn(callable, optional): A callable applid on `decoder` output to
Expand Down Expand Up @@ -1045,7 +1179,7 @@ def forward(self, word, position):
is_test=True)
enc_output = paddle.rand((2, 4, 128))
# cross attention bias: [batch_size, n_head, trg_len, src_len]
# cross attention bias: [batch_size, n_head, tgt_len, src_len]
trg_src_attn_bias = paddle.rand((2, 2, 1, 4))
# inputs for beam search on Transformer
caches = transformer_cell.get_initial_states(enc_output)
Expand All @@ -1062,7 +1196,7 @@ def forward(self, word, position):
"""

def __init__(self, decoder, embedding_fn=None, output_fn=None):
super(TransformerCell, self).__init__()
super(TransformerDecoderCell, self).__init__()
self.decoder = decoder
self.embedding_fn = embedding_fn
self.output_fn = output_fn
Expand Down Expand Up @@ -1212,7 +1346,7 @@ def forward(self, word, position):
is_test=True)
enc_output = paddle.rand((2, 4, 128))
# cross attention bias: [batch_size, n_head, trg_len, src_len]
# cross attention bias: [batch_size, n_head, tgt_len, src_len]
trg_src_attn_bias = paddle.rand((2, 2, 1, 4))
# inputs for beam search on Transformer
caches = transformer_cell.get_initial_states(enc_output)
Expand Down

1 comment on commit c50ad43

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 Commit ID: c50ad43 contains failed CI.

Please sign in to comment.