Skip to content

Commit

Permalink
【Fix PIR Unittest No.527 BUAA】Fix some test case in PIR (PaddlePaddle…
Browse files Browse the repository at this point in the history
…#66194)

* fix test_tril_triu_op

* fix test_where_op

* fix cmake

* fix codestyle

* fix

* restore test_where_op.py
  • Loading branch information
uanu2002 authored and Dale1314 committed Jul 28, 2024
1 parent bef1891 commit 34ee7fa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
52 changes: 50 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,7 +1695,31 @@ def tril(
[5 , 0 , 0 , 0 ],
[9 , 10, 0 , 0 ]])
"""
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.tril(x, diagonal)
elif in_pir_mode():
op_type = 'tril'
assert x is not None, f'x cannot be None in {op_type}'
check_variable_and_dtype(
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'bool',
'complex64',
'complex128',
],
op_type,
)
if len(x.shape) < 2:
raise ValueError(f"x shape in {op_type} must be at least 2-D")
if not isinstance(diagonal, (int,)):
raise TypeError(f"diagonal in {op_type} must be a python Int")
return _C_ops.tril(x, diagonal)
else:
return _tril_triu_op(LayerHelper('tril', **locals()))
Expand Down Expand Up @@ -1776,7 +1800,31 @@ def triu(
[0 , 10, 11, 12]])
"""
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.triu(x, diagonal)
elif in_pir_mode():
op_type = 'triu'
assert x is not None, f'x cannot be None in {op_type}'
check_variable_and_dtype(
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'bool',
'complex64',
'complex128',
],
op_type,
)
if len(x.shape) < 2:
raise ValueError(f"x shape in {op_type} must be at least 2-D")
if not isinstance(diagonal, (int,)):
raise TypeError(f"diagonal in {op_type} must be a python Int")
return _C_ops.triu(x, diagonal)
else:
return _tril_triu_op(LayerHelper('triu', **locals()))
Expand Down
File renamed without changes.

0 comments on commit 34ee7fa

Please sign in to comment.