Skip to content

Commit

Permalink
Fix some bugs in Transformer apis.
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
guoshengCS committed Aug 20, 2020
1 parent 54e9e56 commit 8637eee
Showing 1 changed file with 43 additions and 35 deletions.
78 changes: 43 additions & 35 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import collections

import numpy as np

from ...fluid import layers
from ...fluid.param_attr import ParamAttr
from ...fluid.dygraph import Layer, Linear, Dropout, LayerNorm, LayerList
Expand Down Expand Up @@ -230,7 +231,7 @@ def cal_kv(self, key, value):
v = layers.transpose(x=v, perm=[0, 2, 1, 3])
return k, v

def gen_cache(self, key, value=None, type=Cache):
def gen_cache(self, key, value=None, type=MultiheadAttention.Cache):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiheadAttention.Cache` or an
Expand Down Expand Up @@ -384,7 +385,7 @@ def forward(self, query, key, value, attn_mask=None, cache=None):
outs.append(weights)
if cache is not None:
outs.append(cache)
return out if len(outs) else tuple(outs)
return out if len(outs) == 1 else tuple(outs)


class TransformerEncoderLayer(Layer):
Expand Down Expand Up @@ -455,6 +456,7 @@ def __init__(self,
bias_attr=None):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3

super(TransformerEncoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
Expand Down Expand Up @@ -553,15 +555,15 @@ class TransformerEncoder(Layer):
enc_input = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, n_head, src_len, src_len]
attn_mask = paddle.rand((2, 2, 4, 4))
encoder_layer = TransformerEncoderLayer(2, 64, 64, 128, 512)
encoder_layer = TransformerEncoderLayer(128, 2, 512)
encoder = TransformerEncoder(encoder_layer, 2)
enc_output = encoder(enc_input, attn_mask) # [2, 4, 128]
"""

def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = LayerList([(encoder_layer if i == 0 else
type(encoder_layer)(encoder_layer._config))
type(encoder_layer)(**encoder_layer._config))
for i in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
Expand Down Expand Up @@ -680,6 +682,7 @@ def __init__(self,
bias_attr=None):
self._config = locals()
self._config.pop("self")
self._config.pop("__class__", None) # py3

super(TransformerDecoderLayer, self).__init__()
attn_dropout = dropout if attn_dropout is None else attn_dropout
Expand Down Expand Up @@ -867,7 +870,7 @@ class TransformerDecoder(Layer):
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoder, self).__init__()
self.layers = LayerList([(decoder_layer if i == 0 else
type(decoder_layer)(decoder_layer._config))
type(decoder_layer)(**decoder_layer._config))
for i in range(num_layers)])
self.num_layers = num_layers
self.norm = norm
Expand Down Expand Up @@ -1034,8 +1037,8 @@ class Transformer(Layer):
# memory_mask: [batch_size, n_head, tgt_len, src_len]
cross_attn_mask = paddle.rand((2, 2, 6, 4))
transformer = Transformer(128, 2, 4, 4, 512)
output = transformer(dec_input,
enc_output,
output = transformer(enc_input,
dec_input,
enc_self_attn_mask,
dec_self_attn_mask,
cross_attn_mask) # [2, 6, 128]
Expand Down Expand Up @@ -1125,8 +1128,11 @@ def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):

class TransformerDecoderCell(Layer):
"""
TransformerDecoderCell wraps a Transformer decoder producing logits from
`inputs` composed by ids and position.
TransformerDecoderCell wraps a Transformer decoder combined with an embedding
layer and output layer to produce logits from symbols (ids and position here).
It is analogy to `RNNCell` and `outputs, new_states = cell(inputs, states, *kwargs)`,
where `inputs` is composed of word ids and position, `states` is `cache`,
`kwargs` includes `memory, `tgt_mask`, `memory_mask` and `static_cache`.
Parameters:
decoder(callable): A TransformerDecoder instance. Or a wrapper of it that
Expand Down Expand Up @@ -1202,11 +1208,11 @@ def __init__(self, decoder, embedding_fn=None, output_fn=None):

def forward(self,
inputs,
states=None,
enc_output=None,
trg_slf_attn_bias=None,
trg_src_attn_bias=None,
static_caches=[]):
cache=None,
memory=None,
tgt_mask=None,
memory_mask=None,
static_cache=[]):
"""
Produces logits from `inputs` composed by ids and positions.
Expand All @@ -1215,27 +1221,29 @@ def forward(self,
tensors both have int64 data type and with 2D shape
`[batch_size, sequence_length]` where `sequence_length` is 1
for inference.
states(list): It caches the multi-head attention intermediate results
cache(list): It caches the multi-head attention intermediate results
of history decoding steps. It is a list of dict where the length
of list is decoder layer number, and each dict has `k` and `v` as
keys and values are cached results. Default None
enc_output(Variable): The output of Transformer encoder. It is a tensor
with shape `[batch_size, sequence_length, d_model]`. The data type
memory (Variable): The output of Transformer encoder. It is a tensor
with shape `[batch_size, source_length, d_model]`. The data type
should be float32 or float64.
trg_slf_attn_bias(Variable, optional): A tensor used in decoder self
attention to mask out attention on unwanted target positions. It
is a tensor with shape `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. It can be None when nothing wanted or needed to
be masked out. It can be None for inference. The data type should
be float32 or float64. Default None
trg_src_attn_bias(Variable, optional): A tensor used in decoder-encoder
cross attention to mask out unwanted attention on source (encoder output).
It is a tensor with shape `[batch_size, n_head, target_length, source_length]`,
tgt_mask (Variable, optional): A tensor used in self attention
to prevents attention to some unwanted positions, usually the
the subsequent positions. It is a tensor with shape broadcasted
to `[batch_size, n_head, target_length, target_length]`,
where the unwanted positions have `-INF` values and the others
have 0 values. It can be None when nothing wanted or needed to
be masked out. The data type should be float32 or float64. Default None
static_caches(list): It stores projected results of encoder output
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
memory_mask (Variable, optional): A tensor used in decoder-encoder
cross attention to prevents attention to some unwanted positions,
usually the paddings. It is a tensor with shape broadcasted to
`[batch_size, n_head, target_length, source_length]`, where the
unwanted positions have `-INF` values and the others have 0 values.
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
static_cache(list): It stores transformed results of encoder output
to be used as keys and values in decoder-encoder cross attention
It is a list of dict where the length of list is decoder layer
number, and each dict has `static_k` and `static_v` as keys and
Expand All @@ -1250,16 +1258,16 @@ def forward(self,
concatenated into it.
"""
trg_word, trg_pos = inputs
if states and static_caches:
if cache and static_cache:
for cache, static_cache in zip(states, static_caches):
cache.update(static_cache)
if self.embedding_fn is not None:
dec_input = self.embedding_fn(trg_word, trg_pos)
outputs = self.decoder(dec_input, enc_output, None,
trg_src_attn_bias, states)
outputs = self.decoder(dec_input, memory, tgt_mask, memory_mask,
states)
else:
outputs = self.decoder(trg_word, trg_pos, enc_output, None,
trg_src_attn_bias, states)
outputs = self.decoder(trg_word, trg_pos, memory, tgt_mask,
memory_mask, states)
if self.output_fn is not None:
outputs = self.output_fn(outputs)

Expand Down

1 comment on commit 8637eee

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 Commit ID: 8637eee contains failed CI.

Please sign in to comment.