Skip to content

Commit

Permalink
Update Transformer apis by renaming MultiheadAttention and cal_kv acc…
Browse files Browse the repository at this point in the history
…ording to comments.

test=develop
  • Loading branch information
guoshengCS committed Aug 23, 2020
1 parent 39a623c commit 48f97e1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
6 changes: 6 additions & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@
# from .layer.rnn import RNNCell #DEFINE_ALIAS
# from .layer.rnn import GRUCell #DEFINE_ALIAS
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
from .layer.transformer import MultiHeadAttention
from .layer.transformer import TransformerEncoderLayer
from .layer.transformer import TransformerEncoder
from .layer.transformer import TransformerDecoderLayer
from .layer.transformer import TransformerDecoder
from .layer.transformer import Transformer
from .layer.distance import PairwiseDistance #DEFINE_ALIAS

from .layer import loss #DEFINE_ALIAS
Expand Down
68 changes: 33 additions & 35 deletions python/paddle/nn/layer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# TODO: define the classes of Transformer neural network
__all__ = [
'MultiheadAttention',
'MultiHeadAttention',
'TransformerEncoderLayer',
'TransformerEncoder',
'TransformerDecoderLayer',
Expand All @@ -25,8 +25,6 @@
import copy
import collections

import numpy as np

from ...fluid import layers
from ...fluid.param_attr import ParamAttr
from ...fluid.dygraph import Layer, Linear, Dropout, LayerNorm, LayerList
Expand Down Expand Up @@ -66,7 +64,7 @@ def _convert_param_attr_to_list(param_attr, n):
return param_attrs


class MultiheadAttention(Layer):
class MultiHeadAttention(Layer):
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
Expand Down Expand Up @@ -104,7 +102,7 @@ class MultiheadAttention(Layer):
query = paddle.rand((2, 4, 128))
# self attention mask: [batch_size, num_heads, query_len, query_len]
attn_mask = paddle.rand((2, 2, 4, 4))
multi_head_attn = paddle.MultiheadAttention(128, 2)
multi_head_attn = paddle.MultiHeadAttention(128, 2)
output = multi_head_attn(query, attn_mask=attn_mask) # [2, 4, 128]
"""

Expand All @@ -120,7 +118,7 @@ def __init__(self,
need_weights=False,
weight_attr=None,
bias_attr=None):
super(MultiheadAttention, self).__init__()
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
Expand Down Expand Up @@ -158,11 +156,11 @@ def _prepare_qkv(self, query, key, value, cache=None):
is a tensor with shape `[batch_size, value_length, vdim]`.
The data type should be float32 or float64. If None, use `query` as
`value`.
cache (MultiheadAttention.Cache|MultiheadAttention.StaticCache, optional):
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors
shaped `[batch_size, num_heads, length, embed_dim]` which are results
of linear projection, reshape and transpose calculations in
MultiheadAttention. If is an instance of `Cache`, `k` and `v`
MultiHeadAttention. If is an instance of `Cache`, `k` and `v`
fields reserve intermediate results of previous positions, which
mostly used for decoder self attention. If it is an instance of
`StaticCache`, `key` and `value` args would be ignored, `k` and
Expand All @@ -185,7 +183,7 @@ def _prepare_qkv(self, query, key, value, cache=None):
# for encoder-decoder attention in inference and has cached
k, v = cache.k, cache.v
else:
k, v = self.cal_kv(key, value)
k, v = self.compute_kv(key, value)

if isinstance(cache, self.Cache):
# for decoder self-attention in inference
Expand All @@ -195,7 +193,7 @@ def _prepare_qkv(self, query, key, value, cache=None):

return (q, k, v) if cache is None else (q, k, v, cache)

def cal_kv(self, key, value):
def compute_kv(self, key, value):
"""
Applies linear projection on input keys and values, then splits heads
(reshape and transpose) to get keys and values from different representation
Expand Down Expand Up @@ -230,13 +228,13 @@ def cal_kv(self, key, value):
def gen_cache(self, key, value=None, type=Cache):
"""
Generates cache for `forward` usage in inference accroding to arguments.
The generated cache is an instance of `MultiheadAttention.Cache` or an
instance of `MultiheadAttention.StaticCache`.
The generated cache is an instance of `MultiHeadAttention.Cache` or an
instance of `MultiHeadAttention.StaticCache`.
`Cache` or `StaticCache` is namedtuple with `k` and `v` as fields,
and it stores tensors shaped `[batch_size, num_heads, length, embed_dim]`
which are results of linear projection, reshape and transpose calculations
in MultiheadAttention.
in MultiHeadAttention.
If the generated cache is an instance of `Cache`, `k` and `v` fields
reserve intermediate result tensors of previous positions, and the tensors
Expand All @@ -250,8 +248,8 @@ def gen_cache(self, key, value=None, type=Cache):
The cache is generated as follows:
1. If `type` is `StaticCache`, apply `cal_kv(key, value)` and use the results
to create an instance of `StaticCache`.
1. If `type` is `StaticCache`, apply `compute_kv(key, value)` and use the
results to create an instance of `StaticCache`.
2. If `type` is `Cache` and `value` is None, generate empty tensors shaped
`[batch_size, num_heads, 0, embed_dim // num_heads]` and use the results
Expand All @@ -270,14 +268,14 @@ def gen_cache(self, key, value=None, type=Cache):
is a tensor with shape `[batch_size, value_length, vdim]`.
The data type should be float32 or float64. If None, `key` is only
for batch size reference. Default None.
type (type): It should be `MultiheadAttention.StaticCache` or
`MultiheadAttention.Cache` to indicate the cache type to generate.
type (type): It should be `MultiHeadAttention.StaticCache` or
`MultiHeadAttention.Cache` to indicate the cache type to generate.
Returns:
namedtuple: an instance of `Cache` or `StaticCache` accordingly.
"""
if type == MultiheadAttention.StaticCache: # static_kv
k, v = self.cal_kv(key, value)
if type == MultiHeadAttention.StaticCache: # static_kv
k, v = self.compute_kv(key, value)
return self.StaticCache(k, v)
elif value is None: # incremental_state
k = layers.fill_constant_batch_size_like(
Expand Down Expand Up @@ -320,11 +318,11 @@ def forward(self, query, key, value, attn_mask=None, cache=None):
have 0 values. The data type should be float32 or float64. It can
be None when nothing wanted or needed to be prevented attention to.
Default None
cache (MultiheadAttention.Cache|MultiheadAttention.StaticCache, optional):
cache (MultiHeadAttention.Cache|MultiHeadAttention.StaticCache, optional):
It is a namedtuple with `k` and `v` as fields, and stores tensors
shaped `[batch_size, num_heads, length, embed_dim]` which are results
of linear projection, reshape and transpose calculations in
MultiheadAttention. If it is an instance of `Cache`, `k` and `v`
MultiHeadAttention. If it is an instance of `Cache`, `k` and `v`
fields reserve intermediate results of previous positions, which
mostly used for decoder self attention. If it is an instance of
`StaticCache`, `key` and `value` args would be ignored, `k` and
Expand Down Expand Up @@ -464,7 +462,7 @@ def __init__(self,
weight_attrs = _convert_param_attr_to_list(weight_attr, 2)
bias_attrs = _convert_param_attr_to_list(bias_attr, 2)

self.self_attn = MultiheadAttention(
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
Expand Down Expand Up @@ -685,13 +683,13 @@ def __init__(self,
weight_attrs = _convert_param_attr_to_list(weight_attr, 3)
bias_attrs = _convert_param_attr_to_list(bias_attr, 3)

self.self_attn = MultiheadAttention(
self.self_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
weight_attr=weight_attrs[0],
bias_attr=bias_attrs[0])
self.cross_attn = MultiheadAttention(
self.cross_attn = MultiHeadAttention(
d_model,
nhead,
dropout=attn_dropout,
Expand Down Expand Up @@ -741,8 +739,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
The data type should be float32 or float64. It can be None when
nothing wanted or needed to be prevented attention to. Default None
cache (tuple, optional): It is a tuple( :code:`(incremental_cache, static_cache)` ),
`incremental_cache` is an instance of `MultiheadAttention.Cache`,
`static_cache` is an instance of `MultiheadAttention.StaticCache.
`incremental_cache` is an instance of `MultiHeadAttention.Cache`,
`static_cache` is an instance of `MultiHeadAttention.StaticCache.
See `TransformerDecoderLayer.gen_cache` for more details. It is
only used for inference and should be None for training. Default
None.
Expand All @@ -753,7 +751,7 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
Or a tuple if `cache` is not None, except for decoder layer output, \
the tuple includes the new cache which is same as input `cache` \
argument but `incremental_cache` in it has an incremental length. \
See `MultiheadAttention.gen_cache` and `MultiheadAttention.forward` \
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
residual = tgt
Expand Down Expand Up @@ -793,8 +791,8 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
def gen_cache(self, memory):
"""
Generates cache for `forward` usage. The generated cache is a tuple
composed of an instance of `MultiheadAttention.Cache` and an instance
of `MultiheadAttention.StaticCache`.
composed of an instance of `MultiHeadAttention.Cache` and an instance
of `MultiHeadAttention.StaticCache`.
Parameters:
memory (Tensor): The output of Transformer encoder. It is a tensor
Expand All @@ -803,13 +801,13 @@ def gen_cache(self, memory):
Returns:
tuple: It is a tuple( :code:`(incremental_cache, static_cache)` ). \
`incremental_cache` is an instance of `MultiheadAttention.Cache` \
produced by `self_attn.gen_cache(memory, MultiheadAttention.Cache)`, \
`incremental_cache` is an instance of `MultiHeadAttention.Cache` \
produced by `self_attn.gen_cache(memory, MultiHeadAttention.Cache)`, \
it reserves two tensors shaped `[batch_size, nhead, 0, d_model // nhead]`. \
`static_cache` is an instance of `MultiheadAttention.StaticCache` \
produced by `cross_attn.gen_cache(memory, MultiheadAttention.StaticCache)`, \
`static_cache` is an instance of `MultiHeadAttention.StaticCache` \
produced by `cross_attn.gen_cache(memory, MultiHeadAttention.StaticCache)`, \
it reserves two tensors shaped `[batch_size, nhead, source_length, d_model // nhead]`.
See `MultiheadAttention.gen_cache` and `MultiheadAttention.forward` \
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
incremental_cache = self.self_attn.gen_cache(
Expand Down Expand Up @@ -901,7 +899,7 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, cache=None):
Or a tuple if `cache` is not None, except for decoder output, \
the tuple includes the new cache which is same as input `cache` \
argument but `incremental_cache` in it has an incremental length. \
See `MultiheadAttention.gen_cache` and `MultiheadAttention.forward` \
See `MultiHeadAttention.gen_cache` and `MultiHeadAttention.forward` \
for more details.
"""
output = tgt
Expand Down

0 comments on commit 48f97e1

Please sign in to comment.