From 6fcfd74ada07552a85fad2b43c69ea472a1a4b5d Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Sun, 8 Oct 2023 08:53:31 +0000 Subject: [PATCH 1/6] embeddingbag --- python/paddle/nn/functional/input.py | 95 ++++++++++++++++ python/paddle/nn/layer/common.py | 156 +++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index e38797a1115ae..1be327af5448b 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -18,6 +18,7 @@ from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode +from ...tensor.math import max, mean, sum __all__ = [] @@ -252,3 +253,97 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): }, ) return tmp + + +def embedding_bag( + x, weight, padding_idx=None, sparse=False, name=None, mode="mean" +): + r""" + Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings.. + + Args: + x(Tensor): A Tensor with type int32/int64, which contains the id information. The value of the input id should + satisfy :math:`0<= id < weight.shape[0]` . + weight (Tensor): The weight. A Tensor with shape of lookup table parameter. It should have two elements which + indicates the size of the dictionary of embeddings and the size of each embedding vector respectively. + sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only + affects the performance of the backwards gradient update. It is recommended to set + True because sparse update is faster. But some optimizers does not support sparse update, + such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`. + In these cases, sparse must be False. Default: False. + padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]). + If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted + to :math:`weight.shape[0] + padding\_idx` . It will output all-zero padding data whenever lookup + encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. + If set None, it makes no effect to output. Default: None. + name(str|None, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + mode(str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. + "sum" computes the sum . "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean" + + Returns: + Tensor: Embedding Tensor mapped by x. The data type is the same as :attr:`weight`. + + Examples: + + .. code-block:: python + + >>> import paddle + >>> import paddle.nn as nn + + >>> x0 = paddle.arange(3, 6).reshape((3, 1)).astype(paddle.int64) + >>> w0 = paddle.full(shape=(10, 3), fill_value=2).astype(paddle.float32) + + >>> x = paddle.to_tensor(x0, stop_gradient=False) + >>> print(x.numpy()) + [[3] + [4] + [5]] + >>> print(x.shape) + [3, 1] + + >>> w = paddle.to_tensor(w0, stop_gradient=False) + >>> print(w.numpy()) + [[2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.] + [2. 2. 2.]] + >>> print(w.shape) + [10, 3] + + >>> emb = nn.functional.embedding_bag( + ... x=x, weight=w, sparse=True, name="embedding", mode="mean") + >>> print(emb.numpy()) + [[[2. 2. 2.]] + [[2. 2. 2.]] + [[2. 2. 2.]]] + >>> print(emb.shape) + [3, 1, 3] + + """ + + if mode not in ("mean", "sum", "max"): + raise ValueError( + f"mode should be one of'mean','sum','max', but got {mode}" + ) + + out = embedding( + x, + weight=weight, + padding_idx=padding_idx, + sparse=sparse, + name=name, + ) + if mode == "sum": + return sum(out, axis=1) + elif mode == "mean": + return mean(out, axis=1) + elif mode == "max": + return max(out, axis=1) diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index e041ed505b260..a88b11881d6f9 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1523,6 +1523,162 @@ def extra_repr(self): return main_str.format(**self.__dict__) + +class EmbeddingBag(Layer): + r""" + + Computes sums or means of ‘bags’ of embeddings, without instantiating the intermediate embeddings. + + Parameters: + num_embeddings (int): Just one element which indicate the size of the dictionary of embeddings. + embedding_dim (int): Just one element which indicate the size of each embedding vector respectively. + padding_idx(int|long|None, optional): padding_idx needs to be in the interval [-num_embeddings, num_embeddings). + If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted + to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup + encounters :math:`padding\_idx` in id. And the padding data will not be updated while training. + If set None, it makes no effect to output. Default: None. + sparse(bool, optional): The flag indicating whether to use sparse update. This parameter only + affects the performance of the backwards gradient update. It is recommended to set + True because sparse update is faster. But some optimizer does not support sparse update, + such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`. + In these case, sparse must be False. Default: False. + weight_attr(ParamAttr, optional): To specify the weight parameter property. Default: None, which means the + default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition, + user-defined or pre-trained word vectors can be loaded with the :attr:`param_attr` parameter. + The local word vector needs to be transformed into numpy format, and the shape of local word + vector should be consistent with :attr:`num_embeddings` . Then :ref:`api_paddle_nn_initializer_Assign` + is used to load custom or pre-trained word vectors. See code example for details. + name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + mode(str, optional): "sum", "mean" or "max". Specifies the way to reduce the bag. + "sum" computes the sum . "mean" computes the average of the values in the bag, "max" computes the max value over each bag. Default: "mean" + + Attribute: + **weight** (Parameter): the learnable weights of this layer. + + Returns: + None + + Examples: + + .. code-block:: python + + >>> import paddle + + >>> x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False) + >>> embedding = paddle.nn.EmbeddingBag(4, 3, sparse=True) + + >>> w0 = paddle.to_tensor([[0., 0., 0.], + ... [1., 1., 1.], + ... [2., 2., 2.], + ... [3., 3., 3.]], dtype="float32") + >>> embedding.weight.set_value(w0) + >>> print(embedding.weight) + Parameter containing: + Tensor(shape=[4, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[0., 0., 0.], + [1., 1., 1.], + [2., 2., 2.], + [3., 3., 3.]]) + + >>> adam = paddle.optimizer.Adam(parameters=[embedding.weight], learning_rate=0.01) + >>> adam.clear_grad() + + >>> out = embedding(x) + >>> print(out) + Tensor(shape=[3, 1, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[[0., 0., 0.]], + [[1., 1., 1.]], + [[3., 3., 3.]]]) + + >>> out.backward() + >>> adam.step() + + """ + + def __init__( + self, + num_embeddings, + embedding_dim, + padding_idx=None, + sparse=False, + weight_attr=None, + name=None, + mode="mean", + ): + super().__init__() + self._num_embeddings = num_embeddings + self._embedding_dim = embedding_dim + self._sparse = sparse + self._is_distributed = False + self._padding_idx = padding_idx + self._mode = mode + + if self._mode not int ("mean", "sum", "max"): + raise ValueError("mode must be one of'mean','sum','max'") + + if self._num_embeddings <= 0: + raise ValueError("num_embeddings must be gather than 0") + + if self._embedding_dim <= 0: + raise ValueError("embedding_dim must be gather than 0") + + padding_idx = ( + -1 + if padding_idx is None + else padding_idx + if padding_idx >= 0 + else (num_embeddings + padding_idx) + ) + + if padding_idx >= num_embeddings or padding_idx < -num_embeddings: + raise ValueError( + f"padding_idx must be within [-{num_embeddings}, {num_embeddings})" + ) + + self._dtype = self._helper.get_default_dtype() + self._size = [self._num_embeddings, self._embedding_dim] + + self._weight_attr = weight_attr + self._remote_prefetch = False + self._name = name + self.weight = self.create_parameter( + attr=self._weight_attr, + shape=self._size, + dtype=self._dtype, + is_bias=False, + ) + + if in_dynamic_mode() and padding_idx != -1: + with paddle.no_grad(): + self.weight[padding_idx] = 0.0 + + def forward(self, x): + out = F.embedding( + x, + weight=self.weight, + padding_idx=self._padding_idx, + sparse=self._sparse, + name=self._name, + ) + if self._mode == "sum": + return paddle.sum(out, axis=1) + elif self._mode == "mean": + return paddle.mean(out, axis=1) + elif self._mode == "max": + return paddle.max(out, axis=1) + + def extra_repr(self): + main_str = '{_num_embeddings}, {_embedding_dim}' + if self._padding_idx is not None: + main_str += ', padding_idx={_padding_idx}' + main_str += ', sparse={_sparse}' + if self._name is not None: + main_str += ', name={_name}' + main_str += ', mode={_mode}' + return main_str.format(**self.__dict__) + + class Unfold(Layer): """ Returns a col buffer of sliding local blocks of input x, also known From 988e5e6fb847d55b1a7744b9af2cce19c1dd9a12 Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Sun, 8 Oct 2023 08:57:11 +0000 Subject: [PATCH 2/6] fix --- python/paddle/nn/functional/__init__.py | 1 + python/paddle/nn/layer/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/nn/functional/__init__.py b/python/paddle/nn/functional/__init__.py index 87f2eabba1f59..a164303a59f83 100644 --- a/python/paddle/nn/functional/__init__.py +++ b/python/paddle/nn/functional/__init__.py @@ -133,6 +133,7 @@ from .vision import channel_shuffle # noqa: F401 from .input import one_hot # noqa: F401 from .input import embedding # noqa: F401 +from .input import embedding_bag # noqa: F401 from .extension import gather_tree # noqa: F401 from .extension import temporal_shift # noqa: F401 diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index f83b8454456ff..0714324059df7 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -35,6 +35,7 @@ from .common import Pad3D # noqa: F401 from .common import CosineSimilarity # noqa: F401 from .common import Embedding # noqa: F401 +from .common import EmbeddingBag # noqa: F401 from .common import Linear # noqa: F401 from .common import Identity # noqa: F401 from .common import Flatten # noqa: F401 From ff12ca81a788123c8cb4c97b02690ff6672aa938 Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Sun, 8 Oct 2023 09:28:49 +0000 Subject: [PATCH 3/6] ut --- python/paddle/nn/layer/common.py | 3 +- test/legacy_test/test_layers.py | 53 +++++++ .../test_nn_functional_embeddingbag.py | 140 ++++++++++++++++++ 3 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 test/legacy_test/test_nn_functional_embeddingbag.py diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index a88b11881d6f9..9317581f5a950 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1523,7 +1523,6 @@ def extra_repr(self): return main_str.format(**self.__dict__) - class EmbeddingBag(Layer): r""" @@ -1614,7 +1613,7 @@ def __init__( self._padding_idx = padding_idx self._mode = mode - if self._mode not int ("mean", "sum", "max"): + if self._mode not in ("mean", "sum", "max"): raise ValueError("mode must be one of'mean','sum','max'") if self._num_embeddings <= 0: diff --git a/test/legacy_test/test_layers.py b/test/legacy_test/test_layers.py index 5d8087e7138d3..7089d60bbd3c8 100644 --- a/test/legacy_test/test_layers.py +++ b/test/legacy_test/test_layers.py @@ -671,6 +671,59 @@ def test_embeding(self): emb1.weight.numpy(), emb2.weight.numpy() ) + def test_embeding_bag(self): + inp_word = np.array([[[1]]]).astype('int64') + dict_size = 20 + with self.static_graph(): + data_t = paddle.static.data( + name='word', shape=[-1, 1], dtype='int64' + ) + data_t.desc.set_need_check_feed(False) + emb2 = paddle.nn.EmbeddingBag( + dict_size, 32, weight_attr='emb.w', sparse=False, mode="mean" + ) + emb_rlt = emb2(data_t) + static_rlt2 = self.get_static_graph_result( + feed={'word': inp_word}, fetch_list=[emb_rlt] + )[0] + with self.dynamic_graph(): + emb2 = paddle.nn.EmbeddingBag( + dict_size, 32, weight_attr='emb.w', sparse=False, mode="mean" + ) + dy_rlt = emb2(to_variable(inp_word)) + dy_rlt_value = dy_rlt.numpy() + + np.testing.assert_allclose(static_rlt2[0], dy_rlt_value[0]) + + with self.dynamic_graph(): + custom_weight = np.random.randn(dict_size, 32).astype("float32") + weight_attr = base.ParamAttr( + initializer=paddle.nn.initializer.Assign(custom_weight) + ) + emb1 = paddle.nn.EmbeddingBag( + dict_size, 32, sparse=False, mode="mean" + ) + emb2 = paddle.nn.EmbeddingBag( + dict_size, + 32, + weight_attr=weight_attr, + sparse=False, + mode="mean", + ) + rep1 = emb1(to_variable(inp_word)) + rep2 = emb2(to_variable(inp_word)) + self.assertFalse(np.array_equal(emb1.weight.numpy(), custom_weight)) + np.testing.assert_array_equal(emb2.weight.numpy(), custom_weight) + self.assertFalse(np.array_equal(rep1.numpy(), rep2.numpy())) + emb2.weight.set_value(emb1.weight.numpy()) + rep2 = emb2(to_variable(inp_word)) + np.testing.assert_array_equal(rep1.numpy(), rep2.numpy()) + + emb2.weight = emb1.weight + np.testing.assert_array_equal( + emb1.weight.numpy(), emb2.weight.numpy() + ) + def test_one_hot(self): with self.dynamic_graph(): label = base.dygraph.to_variable(np.array([[1], [1], [3], [0]])) diff --git a/test/legacy_test/test_nn_functional_embeddingbag.py b/test/legacy_test/test_nn_functional_embeddingbag.py new file mode 100644 index 0000000000000..bb5e875634ebd --- /dev/null +++ b/test/legacy_test/test_nn_functional_embeddingbag.py @@ -0,0 +1,140 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle import base +from paddle.nn import functional + + +class EmbeddingDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + + def test_1(self): + x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) + paddle.disable_static(paddle.CPUPlace()) + x = paddle.to_tensor(x_data, stop_gradient=False) + + embedding_bag = paddle.nn.embedding_bag( + 10, 3, sparse=True, padding_idx=9 + ) + + w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) + embedding_bag.weight.set_value(w0) + + adam = paddle.optimizer.Adam( + parameters=[embedding_bag.weight], learning_rate=0.01 + ) + adam.clear_grad() + + out = embedding_bag(x) + out.backward() + adam.step() + + def test_2(self): + x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) + y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) + paddle.disable_static(paddle.CPUPlace()) + x = paddle.to_tensor(x_data, stop_gradient=False) + y = paddle.to_tensor(y_data, stop_gradient=False) + + with self.assertRaises(ValueError): + embedding_bag = paddle.nn.embedding_bag( + 10, 3, padding_idx=11, sparse=True + ) + + with self.assertRaises(ValueError): + embedding_bag = paddle.nn.embedding_bag(-1, 3, sparse=True) + + with self.assertRaises(ValueError): + embedding_bag = paddle.nn.embedding_bag(10, -3, sparse=True) + + +class EmbeddingStatic(unittest.TestCase): + def test_1(self): + prog = base.Program() + with base.program_guard(prog): + + def test_bad_x(): + initializer = paddle.nn.initializer.Assign( + np.random.random(size=(128, 100)) + ) + + param_attr = base.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True, + ) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32" + ) + + label = paddle.static.data( + name="label", + shape=[-1, 4], + dtype="int64", + ) + + emb = functional.embedding_bag( + x=label, weight=weight, sparse=True, name="embedding_bag" + ) + + test_bad_x() + + def test_2(self): + prog = base.Program() + with base.program_guard(prog): + + def test_bad_x(): + initializer = paddle.nn.initializer.Assign( + np.random.random(size=(128, 100)) + ) + + param_attr = base.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True, + ) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32" + ) + + label = paddle.static.data( + name="label", + shape=[-1, 4], + dtype="int32", + ) + + emb = functional.embedding_bag( + x=label, + weight=weight, + padding_idx=129, + sparse=True, + name="embedding_bag", + ) + + with self.assertRaises(ValueError): + test_bad_x() + + +if __name__ == '__main__': + unittest.main() From c42aaae46da790746a51328e3b39e9223bff6f2e Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Mon, 9 Oct 2023 02:08:18 +0000 Subject: [PATCH 4/6] fix --- python/paddle/nn/functional/input.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index 1be327af5448b..6e2dee304047f 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -18,7 +18,8 @@ from ...base.layer_helper import LayerHelper from ...common_ops_import import Variable from ...framework import in_dynamic_mode, in_dynamic_or_pir_mode -from ...tensor.math import max, mean, sum +from ...tensor.math import max, sum +from ...tensor.stat import mean __all__ = [] From 9b3ad241f59cd4932e274e42ad084b918573d6d8 Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Mon, 9 Oct 2023 03:24:05 +0000 Subject: [PATCH 5/6] fix --- python/paddle/nn/__init__.py | 1 + .../test_nn_functional_embeddingbag.py | 19 +++++++++---------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index dbef7079c1bf3..75bb129ec7a57 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -57,6 +57,7 @@ from .layer.common import Pad3D # noqa: F401 from .layer.common import CosineSimilarity # noqa: F401 from .layer.common import Embedding # noqa: F401 +from .layer.common import EmbeddingBag # noqa: F401 from .layer.common import Linear # noqa: F401 from .layer.common import Identity # noqa: F401 from .layer.common import Flatten # noqa: F401 diff --git a/test/legacy_test/test_nn_functional_embeddingbag.py b/test/legacy_test/test_nn_functional_embeddingbag.py index bb5e875634ebd..f811c1b12f3a0 100644 --- a/test/legacy_test/test_nn_functional_embeddingbag.py +++ b/test/legacy_test/test_nn_functional_embeddingbag.py @@ -18,19 +18,15 @@ import paddle from paddle import base -from paddle.nn import functional class EmbeddingDygraph(unittest.TestCase): - def setUp(self): - paddle.disable_static() - def test_1(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) paddle.disable_static(paddle.CPUPlace()) x = paddle.to_tensor(x_data, stop_gradient=False) - embedding_bag = paddle.nn.embedding_bag( + embedding_bag = paddle.nn.EmbeddingBag( 10, 3, sparse=True, padding_idx=9 ) @@ -54,18 +50,21 @@ def test_2(self): y = paddle.to_tensor(y_data, stop_gradient=False) with self.assertRaises(ValueError): - embedding_bag = paddle.nn.embedding_bag( + embedding_bag = paddle.nn.EmbeddingBag( 10, 3, padding_idx=11, sparse=True ) with self.assertRaises(ValueError): - embedding_bag = paddle.nn.embedding_bag(-1, 3, sparse=True) + embedding_bag = paddle.nn.EmbeddingBag(-1, 3, sparse=True) with self.assertRaises(ValueError): - embedding_bag = paddle.nn.embedding_bag(10, -3, sparse=True) + embedding_bag = paddle.nn.EmbeddingBag(10, -3, sparse=True) class EmbeddingStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + def test_1(self): prog = base.Program() with base.program_guard(prog): @@ -92,7 +91,7 @@ def test_bad_x(): dtype="int64", ) - emb = functional.embedding_bag( + emb = paddle.nn.functional.embedding_bag( x=label, weight=weight, sparse=True, name="embedding_bag" ) @@ -124,7 +123,7 @@ def test_bad_x(): dtype="int32", ) - emb = functional.embedding_bag( + emb = paddle.nn.functional.embedding_bag( x=label, weight=weight, padding_idx=129, From 2910f98faf650a0211703622b74cdda53a6be12d Mon Sep 17 00:00:00 2001 From: lisamhy <763334840@qq.com> Date: Mon, 9 Oct 2023 09:14:39 +0000 Subject: [PATCH 6/6] fix --- python/paddle/nn/functional/input.py | 2 +- python/paddle/nn/layer/common.py | 8 +- .../test_nn_functional_embeddingbag.py | 157 ++++++++++++++++++ 3 files changed, 162 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index 6e2dee304047f..b9c3b34f193bf 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -326,7 +326,7 @@ def embedding_bag( [[2. 2. 2.]] [[2. 2. 2.]]] >>> print(emb.shape) - [3, 1, 3] + [3, 3] """ diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 9317581f5a950..e47e2d9b5583e 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1585,10 +1585,10 @@ class EmbeddingBag(Layer): >>> out = embedding(x) >>> print(out) - Tensor(shape=[3, 1, 3], dtype=float32, place=Place(cpu), stop_gradient=False, - [[[0., 0., 0.]], - [[1., 1., 1.]], - [[3., 3., 3.]]]) + Tensor(shape=[3, 3], dtype=float32, place=Place(cpu), stop_gradient=False, + [[0., 0., 0.], + [1., 1., 1.], + [3., 3., 3.]]) >>> out.backward() >>> adam.step() diff --git a/test/legacy_test/test_nn_functional_embeddingbag.py b/test/legacy_test/test_nn_functional_embeddingbag.py index f811c1b12f3a0..91ff2c4c50535 100644 --- a/test/legacy_test/test_nn_functional_embeddingbag.py +++ b/test/legacy_test/test_nn_functional_embeddingbag.py @@ -29,6 +29,7 @@ def test_1(self): embedding_bag = paddle.nn.EmbeddingBag( 10, 3, sparse=True, padding_idx=9 ) + embedding_bag.extra_repr() w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) embedding_bag.weight.set_value(w0) @@ -60,6 +61,53 @@ def test_2(self): with self.assertRaises(ValueError): embedding_bag = paddle.nn.EmbeddingBag(10, -3, sparse=True) + with self.assertRaises(ValueError): + embedding_bag = paddle.nn.EmbeddingBag( + 10, 3, sparse=True, mode="min" + ) + + def test_3(self): + x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) + paddle.disable_static(paddle.CPUPlace()) + x = paddle.to_tensor(x_data, stop_gradient=False) + + embedding_bag = paddle.nn.EmbeddingBag( + 10, 3, sparse=True, padding_idx=9, mode="sum" + ) + + w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) + embedding_bag.weight.set_value(w0) + + adam = paddle.optimizer.Adam( + parameters=[embedding_bag.weight], learning_rate=0.01 + ) + adam.clear_grad() + + out = embedding_bag(x) + out.backward() + adam.step() + + def test_4(self): + x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) + paddle.disable_static(paddle.CPUPlace()) + x = paddle.to_tensor(x_data, stop_gradient=False) + + embedding_bag = paddle.nn.EmbeddingBag( + 10, 3, sparse=True, padding_idx=9, mode="max" + ) + + w0 = np.full(shape=(10, 3), fill_value=2).astype(np.float32) + embedding_bag.weight.set_value(w0) + + adam = paddle.optimizer.Adam( + parameters=[embedding_bag.weight], learning_rate=0.01 + ) + adam.clear_grad() + + out = embedding_bag(x) + out.backward() + adam.step() + class EmbeddingStatic(unittest.TestCase): def setUp(self): @@ -134,6 +182,115 @@ def test_bad_x(): with self.assertRaises(ValueError): test_bad_x() + def test_3(self): + prog = base.Program() + with base.program_guard(prog): + + def test_bad_x(): + initializer = paddle.nn.initializer.Assign( + np.random.random(size=(128, 100)) + ) + + param_attr = base.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True, + ) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32" + ) + + label = paddle.static.data( + name="label", + shape=[-1, 4], + dtype="int64", + ) + + emb = paddle.nn.functional.embedding_bag( + x=label, + weight=weight, + sparse=True, + name="embedding_bag", + mode="min", + ) + + with self.assertRaises(ValueError): + test_bad_x() + + def test_4(self): + prog = base.Program() + with base.program_guard(prog): + + def test_bad_x(): + initializer = paddle.nn.initializer.Assign( + np.random.random(size=(128, 100)) + ) + + param_attr = base.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True, + ) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32" + ) + + label = paddle.static.data( + name="label", + shape=[-1, 4], + dtype="int64", + ) + + emb = paddle.nn.functional.embedding_bag( + x=label, + weight=weight, + sparse=True, + name="embedding_bag", + mode="sum", + ) + + test_bad_x() + + def test_5(self): + prog = base.Program() + with base.program_guard(prog): + + def test_bad_x(): + initializer = paddle.nn.initializer.Assign( + np.random.random(size=(128, 100)) + ) + + param_attr = base.ParamAttr( + name="emb_weight", + learning_rate=0.5, + initializer=initializer, + trainable=True, + ) + + weight = prog.global_block().create_parameter( + (128, 100), attr=param_attr, dtype="float32" + ) + + label = paddle.static.data( + name="label", + shape=[-1, 4], + dtype="int64", + ) + + emb = paddle.nn.functional.embedding_bag( + x=label, + weight=weight, + sparse=True, + name="embedding_bag", + mode="max", + ) + + test_bad_x() + if __name__ == '__main__': unittest.main()