Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5 No.16】为 Paddle 新增 EmbeddingBag API #57923

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/paddle/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .layer.common import Pad3D # noqa: F401
from .layer.common import CosineSimilarity # noqa: F401
from .layer.common import Embedding # noqa: F401
from .layer.common import EmbeddingBag # noqa: F401
from .layer.common import Linear # noqa: F401
from .layer.common import Identity # noqa: F401
from .layer.common import Flatten # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
from .vision import channel_shuffle # noqa: F401
from .input import one_hot # noqa: F401
from .input import embedding # noqa: F401
from .input import embedding_bag # noqa: F401
from .extension import gather_tree # noqa: F401
from .extension import temporal_shift # noqa: F401

Expand Down
96 changes: 96 additions & 0 deletions python/paddle/nn/functional/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from ...base.layer_helper import LayerHelper
from ...common_ops_import import Variable
from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode
from ...tensor.math import max, sum
from ...tensor.stat import mean

__all__ = []

Expand Down Expand Up @@ -252,3 +254,97 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
},
)
return tmp


def embedding_bag(
x, weight, padding_idx=None, sparse=False, name=None, mode="mean"
):
r"""
Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings..

Args:
x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should
satisfy :math:`0<= id < weight.shape[0]` .
weight (Tensor): The weight. A Tensor with shape of lookup table parameter. It should have two elements which
indicates the size of the dictionary of embeddings and the size of each embedding vector respectively.
sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only
affects the performance of the backwards gradient update. It is recommended to set
True because sparse update is faster. But some optimizers does not support sparse update,
such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
In these cases, sparse must be False. Default: False.
padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None.
name(str|None, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
mode(str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag.
"sum" computes the sum . "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean"

Returns:
Tensor: Embedding Tensor mapped by x. The data type is the same as :attr:`weight`.

Examples:

.. code-block:: python

>>> import paddle
>>> import paddle.nn as nn

>>> x0 = paddle.arange(3, 6).reshape((3, 1)).astype(paddle.int64)
>>> w0 = paddle.full(shape=(10, 3), fill_value=2).astype(paddle.float32)

>>> x = paddle.to_tensor(x0, stop_gradient=False)
>>> print(x.numpy())
[[3]
[4]
[5]]
>>> print(x.shape)
[3, 1]

>>> w = paddle.to_tensor(w0, stop_gradient=False)
>>> print(w.numpy())
[[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]
[2. 2. 2.]]
>>> print(w.shape)
[10, 3]

>>> emb = nn.functional.embedding_bag(
... x=x, weight=w, sparse=True, name="embedding", mode="mean")
>>> print(emb.numpy())
[[[2. 2. 2.]]
[[2. 2. 2.]]
[[2. 2. 2.]]]
>>> print(emb.shape)
[3, 3]

"""

if mode not in ("mean", "sum", "max"):
raise ValueError(
f"mode should be one of'mean','sum','max', but got {mode}"
)

out = embedding(
x,
weight=weight,
padding_idx=padding_idx,
sparse=sparse,
name=name,
)
if mode == "sum":
return sum(out, axis=1)
elif mode == "mean":
return mean(out, axis=1)
elif mode == "max":
return max(out, axis=1)
1 change: 1 addition & 0 deletions python/paddle/nn/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .common import Pad3D # noqa: F401
from .common import CosineSimilarity # noqa: F401
from .common import Embedding # noqa: F401
from .common import EmbeddingBag # noqa: F401
from .common import Linear # noqa: F401
from .common import Identity # noqa: F401
from .common import Flatten # noqa: F401
Expand Down
155 changes: 155 additions & 0 deletions python/paddle/nn/layer/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,161 @@ def extra_repr(self):
return main_str.format(**self.__dict__)


class EmbeddingBag(Layer):
r"""

Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings.

Parameters:
num_embeddings (int): Just one element which indicate the size of the dictionary of embeddings.
embedding_dim (int): Just one element which indicate the size of each embedding vector respectively.
padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-num_embeddings, num_embeddings).
If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted
to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup
encounters :math:`padding\_idx` in id. And the padding data will not be updated while training.
If set None, it makes no effect to output. Default: None.
sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only
affects the performance of the backwards gradient update. It is recommended to set
True because sparse update is faster. But some optimizer does not support sparse update,
such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`.
In these case, sparse must be False. Default: False.
weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the
default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition,
user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter.
The local word vector needs to be transformed into numpy format, and the shape of local word
vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_paddle_nn_initializer_Assign`
is used to load custom or pre-trained word vectors. See code example for details.
name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
mode(str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag.
"sum" computes the sum . "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean"

Attribute:
**weight** (Parameter): the learnable weights of this layer.

Returns:
None

Examples:

.. code-block:: python

>>> import paddle

>>> x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False)
>>> embedding = paddle.nn.EmbeddingBag(4, 3, sparse=True)

>>> w0 = paddle.to_tensor([[0., 0., 0.],
... [1., 1., 1.],
... [2., 2., 2.],
... [3., 3., 3.]], dtype="float32")
>>> embedding.weight.set_value(w0)
>>> print(embedding.weight)
Parameter containing:
Tensor(shape=[4, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.],
[3., 3., 3.]])

>>> adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01)
>>> adam.clear_grad()

>>> out = embedding(x)
>>> print(out)
Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=False,
[[0., 0., 0.],
[1., 1., 1.],
[3., 3., 3.]])

>>> out.backward()
>>> adam.step()

"""

def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
sparse=False,
weight_attr=None,
name=None,
mode="mean",
):
super().__init__()
self._num_embeddings = num_embeddings
self._embedding_dim = embedding_dim
self._sparse = sparse
self._is_distributed = False
self._padding_idx = padding_idx
self._mode = mode

if self._mode not in ("mean", "sum", "max"):
raise ValueError("mode must be one of'mean','sum','max'")

if self._num_embeddings <= 0:
raise ValueError("num_embeddings must be gather than 0")

if self._embedding_dim <= 0:
raise ValueError("embedding_dim must be gather than 0")

padding_idx = (
-1
if padding_idx is None
else padding_idx
if padding_idx >= 0
else (num_embeddings + padding_idx)
)

if padding_idx >= num_embeddings or padding_idx < -num_embeddings:
raise ValueError(
f"padding_idx must be within [-{num_embeddings}, {num_embeddings})"
)

self._dtype = self._helper.get_default_dtype()
self._size = [self._num_embeddings, self._embedding_dim]

self._weight_attr = weight_attr
self._remote_prefetch = False
self._name = name
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False,
)

if in_dynamic_mode() and padding_idx != -1:
with paddle.no_grad():
self.weight[padding_idx] = 0.0

def forward(self, x):
out = F.embedding(
x,
weight=self.weight,
padding_idx=self._padding_idx,
sparse=self._sparse,
name=self._name,
)
if self._mode == "sum":
return paddle.sum(out, axis=1)
elif self._mode == "mean":
return paddle.mean(out, axis=1)
elif self._mode == "max":
return paddle.max(out, axis=1)

def extra_repr(self):
main_str = '{_num_embeddings}, {_embedding_dim}'
if self._padding_idx is not None:
main_str += ', padding_idx={_padding_idx}'
main_str += ', sparse={_sparse}'
if self._name is not None:
main_str += ', name={_name}'
main_str += ', mode={_mode}'
return main_str.format(**self.__dict__)


class Unfold(Layer):
"""
Returns a col buffer of sliding local blocks of input x, also known
Expand Down
53 changes: 53 additions & 0 deletions test/legacy_test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,59 @@ def test_embeding(self):
emb1.weight.numpy(), emb2.weight.numpy()
)

def test_embeding_bag(self):
inp_word = np.array([[[1]]]).astype('int64')
dict_size = 20
with self.static_graph():
data_t = paddle.static.data(
name='word', shape=[-1, 1], dtype='int64'
)
data_t.desc.set_need_check_feed(False)
emb2 = paddle.nn.EmbeddingBag(
dict_size, 32, weight_attr='emb.w', sparse=False, mode="mean"
)
emb_rlt = emb2(data_t)
static_rlt2 = self.get_static_graph_result(
feed={'word': inp_word}, fetch_list=[emb_rlt]
)[0]
with self.dynamic_graph():
emb2 = paddle.nn.EmbeddingBag(
dict_size, 32, weight_attr='emb.w', sparse=False, mode="mean"
)
dy_rlt = emb2(to_variable(inp_word))
dy_rlt_value = dy_rlt.numpy()

np.testing.assert_allclose(static_rlt2[0], dy_rlt_value[0])

with self.dynamic_graph():
custom_weight = np.random.randn(dict_size, 32).astype("float32")
weight_attr = base.ParamAttr(
initializer=paddle.nn.initializer.Assign(custom_weight)
)
emb1 = paddle.nn.EmbeddingBag(
dict_size, 32, sparse=False, mode="mean"
)
emb2 = paddle.nn.EmbeddingBag(
dict_size,
32,
weight_attr=weight_attr,
sparse=False,
mode="mean",
)
rep1 = emb1(to_variable(inp_word))
rep2 = emb2(to_variable(inp_word))
self.assertFalse(np.array_equal(emb1.weight.numpy(), custom_weight))
np.testing.assert_array_equal(emb2.weight.numpy(), custom_weight)
self.assertFalse(np.array_equal(rep1.numpy(), rep2.numpy()))
emb2.weight.set_value(emb1.weight.numpy())
rep2 = emb2(to_variable(inp_word))
np.testing.assert_array_equal(rep1.numpy(), rep2.numpy())

emb2.weight = emb1.weight
np.testing.assert_array_equal(
emb1.weight.numpy(), emb2.weight.numpy()
)

def test_one_hot(self):
with self.dynamic_graph():
label = base.dygraph.to_variable(np.array([[1], [1], [3], [0]]))
Expand Down
Loading