Skip to content

Commit

Permalink
tensor_array slice in PIR (#60503)
Browse files Browse the repository at this point in the history
* use slice_array, now will meet error of destory opresult still in use

* disable the pir test until the bug fixed
  • Loading branch information
zoooo0820 authored Jan 5, 2024
1 parent a11aabd commit 116c892
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
21 changes: 15 additions & 6 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,19 @@ def slice_is_same_to_original(start, end, step):


def parse_index(x, indices):
advanced_index = [None] * 2 * len(x.shape) # content is (dim, index)
from .framework import in_pir_mode

if in_pir_mode():
is_tensor_array = x.is_dense_tensor_array_type()
else:
is_tensor_array = (
hasattr(x, "desc")
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)

advanced_index = (
[] if is_tensor_array else [None] * 2 * len(x.shape)
) # content is (dim, index)
# for set_value / slice / strided_slice OP
decrease_axes = []
axes = []
Expand All @@ -267,11 +279,6 @@ def parse_index(x, indices):
indices = replace_ellipsis(x, indices)
indices, none_axes = replace_none(indices)

is_tensor_array = (
hasattr(x, "desc")
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)

estimated_dim = 0
dim = 0
for i, slice_item in enumerate(indices):
Expand Down Expand Up @@ -740,6 +747,8 @@ def get_tensor_with_basic_indexing(
if isinstance(end, (list, tuple)):
if paddle.utils._contain_var(end):
end = paddle.utils.get_int_tensor_list(end)
if x.is_dense_tensor_array_type():
return paddle._pir_ops.slice_array_dense(x, st)
out = paddle._C_ops.slice(
x,
axes,
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def init_dygraph_func(self):
test_list_pop_in_while_loop,
]

# TODO(zhangbo): Refine BuildOpFrom for op with sub_block
def train(self, to_static=False):
with base.dygraph.guard():
if to_static:
Expand Down

0 comments on commit 116c892

Please sign in to comment.