Skip to content

Commit

Permalink
[CINN] Support recompute "shape" of dynamic shape and vector type (#6…
Browse files Browse the repository at this point in the history
…8913)

* fix recompute bug

* fix recompute bug
  • Loading branch information
chen2016013 authored Oct 28, 2024
1 parent 4933629 commit 5a7f85f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,8 @@ const phi::DDim &GetTensorDims(Type type) {
} else if (auto sparse_csr_tensr_type =
type.dyn_cast<SparseCsrTensorType>()) {
return sparse_csr_tensr_type.dims();
} else if (auto dense_array_type = type.dyn_cast<DenseTensorArrayType>()) {
return dense_array_type.dims();
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"Currently, we can only get shape for dense and selsect rows type."));
Expand Down
36 changes: 25 additions & 11 deletions python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,9 @@ def _ban_recomputation(value_node):
)
value_id_dict[value_node.id] = value_node

# todo(wanghao107) hack for dynamic shape
if is_dynamic_value_node(value_node):
weight = 1
else:
weight = _get_node_weight(
value_node, placeholder_value_nodes=inputs | outputs
)
weight = _get_node_weight(
value_node, placeholder_value_nodes=inputs | outputs
)

# Creates the weights on the "node" edge
nx_graph.add_edge(
Expand Down Expand Up @@ -942,16 +938,34 @@ def is_dynamic_value_node(value_node):
raise ValueError(f"value node not found in program: {value_node} ")


def cal_value_node_size(value_node):
# todo(wanghao107) hack for dynamic shape
def is_vector_value_node(value_node):
try:
return value_node.type().as_vec_type() is not None
except:
raise ValueError(f"value node illegal: {value_node} ")


def cal_value_node_size_impl(value_node):
if is_dynamic_value_node(value_node):
return 1
value_node_shape = [i for i in value_node.shape if i != -1]
else:
value_node_shape = value_node.shape
return (
functools.reduce(lambda x, y: x * y, value_node.shape, 1)
functools.reduce(lambda x, y: x * y, value_node_shape, 1)
* _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
)


def cal_value_node_size(value_node):
if is_vector_value_node(value_node):
value_vec = value_node.type().as_vec_type().as_list()
sum_res = 0
for child_node in value_vec:
sum_res += cal_value_node_size_impl(child_node)
return sum_res
return cal_value_node_size_impl(value_node)


def cal_value_nodes_dist_to_backward(all_ops, required_fw_value_nodes):
dist_from_bw = backward_utils.ValueDict()
# calculate value node the shortest dist to backward graph
Expand Down

0 comments on commit 5a7f85f

Please sign in to comment.