Skip to content

Commit

Permalink
【PIR API adaptor No.23】Migrate paddle.Tensor.bincount into pir (#58692)
Browse files Browse the repository at this point in the history
* 【PIR API adaptor No.23】Migrate paddle.Tensor.bincount into pir

* fix bug

* fix bug

* Update test/legacy_test/test_bincount_op.py

Co-authored-by: Lu Qi <[email protected]>

---------

Co-authored-by: Lu Qi <[email protected]>
  • Loading branch information
GreatV and MarioLulab authored Nov 8, 2023
1 parent 6aafcf2 commit e5ef5d3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
10 changes: 8 additions & 2 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import paddle
from paddle import _C_ops
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

Expand Down Expand Up @@ -1785,10 +1786,15 @@ def bincount(x, weights=None, minlength=0, name=None):
Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True,
[0. , 2.19999981, 0.40000001, 0. , 0.50000000, 0.50000000])
"""
if x.dtype not in [paddle.int32, paddle.int64]:
if x.dtype not in [
paddle.int32,
paddle.int64,
DataType.INT32,
DataType.INT64,
]:
raise TypeError("Elements in Input(x) should all be integers")

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.bincount(x, weights, minlength)
else:
helper = LayerHelper('bincount', **locals())
Expand Down
10 changes: 6 additions & 4 deletions test/legacy_test/test_bincount_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,19 @@
import paddle.inference as paddle_infer
from paddle import base
from paddle.base.framework import in_dygraph_mode
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestBincountOpAPI(unittest.TestCase):
"""Test bincount api."""

@test_with_pir_api
def test_static_graph(self):
startup_program = base.Program()
train_program = base.Program()
with base.program_guard(train_program, startup_program):
startup_program = paddle.static.Program()
train_program = paddle.static.Program()
with paddle.static.program_guard(train_program, startup_program):
inputs = paddle.static.data(name='input', dtype='int64', shape=[7])
weights = paddle.static.data(
name='weights', dtype='int64', shape=[7]
Expand Down Expand Up @@ -152,7 +154,7 @@ def init_test_case(self):
self.Out = np.bincount(self.np_input, minlength=self.minlength)

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)


class TestCase1(TestBincountOp):
Expand Down

0 comments on commit e5ef5d3

Please sign in to comment.