diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 3a5716877a59d8..a13283d185b1f9 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -473,8 +473,8 @@ void BindValue(py::module *m) { [](Value &self, Value &op_value) { self.ReplaceAllUsesWith(op_value); }) - .def("__eq__", &Value::operator==) - .def("__eq__", + .def("is_name", &Value::operator==) + .def("is_name", [](Value &self, OpResult &other) { return self.impl() == other.Value::impl(); }) @@ -661,8 +661,8 @@ void BindOpResult(py::module *m) { OVERRIDE_COMPARE_OP_FOR_EACH(__gt__, greater_than); OVERRIDE_COMPARE_OP_FOR_EACH(__ge__, greater_equal); - op_result.def("__eq__", &OpResult::operator==) - .def("__eq__", + op_result.def("is_name", &OpResult::operator==) + .def("is_name", [](OpResult &self, Value &other) { return self.Value::impl() == other.impl(); }) diff --git a/python/paddle/autograd/ir_backward.py b/python/paddle/autograd/ir_backward.py index 39d73cc54e9ac1..b672d7dc72c395 100644 --- a/python/paddle/autograd/ir_backward.py +++ b/python/paddle/autograd/ir_backward.py @@ -157,19 +157,23 @@ def prepare_grad_outputs(grad_outputs, outputs, state): def some_in_set(value_list, value_set): def operand2value(values): - value_set = set() + res = [] for item in values: if isinstance(item, paddle.pir.OpOperand): - value_set.add(item.source()) + res.append(item.source()) else: - value_set.add(item) - return value_set - - if operand2value(value_list) & operand2value(value_set): - return True - else: + res.append(item) + return res + + def check_value(original_value: list, compare_value: list) -> bool: + for i in original_value: + for j in compare_value: + if i.is_name(j): + return True return False + return check_value(operand2value(value_list), operand2value(value_set)) + def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ''' @@ -179,20 +183,36 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): pruned op in total_ops is uneffective_ops, else is effective_ops ''' + + def append_list(value, inputs: list) -> list: + if isinstance(value, (list, set, tuple)): + for input in inputs: + for v in value: + if input.is_name(v): + break + else: + # when break is not triggered, enter the else branch + inputs.append(v) + else: + inputs.append(value) + return inputs + + inputs_list = list(inputs_set) + intersection_op_flags = [True] * len(total_ops) union_op_flags = [False] * len(total_ops) # from input to output - if inputs_set: + if inputs_list: for i, op in enumerate(total_ops): - if some_in_set(op.results(), inputs_set): + if some_in_set(op.results(), inputs_list): union_op_flags[i] = True continue - if some_in_set(op.operands_source(), inputs_set): + if some_in_set(op.operands_source(), inputs_list): union_op_flags[i] = True for value in op.results(): if value not in no_grad_set: - inputs_set.add(value) + append_list(value, inputs_list) else: intersection_op_flags[i] = False @@ -245,6 +265,13 @@ def update_no_grad_set_after_prune( from inputs to outputs add value not in the path to no_grad_set, from outputs to inputs add value not in the path to no_grad_set, ''' + + def check_exist(value, res_set: set) -> bool: + for res in res_set: + if value.is_name(res): + return True + return False + inputs_set = set(inputs) if inputs_set: for op in block.ops: @@ -255,17 +282,17 @@ def update_no_grad_set_after_prune( for op in effective_forward_ops: for value in op.operands_source(): - if value not in inputs_set: + if not check_exist(value, inputs_set): no_grad_set.add(value) outputs_set = set(outputs) no_grad_set_tmp = set() for op in reversed(effective_forward_ops): - for output in op.results(): - if output not in outputs_set and not some_in_set( - [output], set(op.operands_source()) + for output_res in op.results(): + if not check_exist(output_res, outputs_set) and not some_in_set( + [output_res], set(op.operands_source()) ): - no_grad_set_tmp.add(output) + no_grad_set_tmp.add(output_res) for input in op.operands_source(): if input not in no_grad_set: @@ -358,6 +385,20 @@ def append_backward_ops( else continue to next op. ''' + def check_in_keys(value, state_valuegrad: collections.defaultdict) -> bool: + for opresult in state_valuegrad.keys(): + if value.is_name(opresult): + return True + return False + + def take_defaultdict_list_value( + value, state_valuegrad: collections.defaultdict + ): + for opresult in state_valuegrad.keys(): + if value.is_name(opresult): + return state_valuegrad[opresult] + return [] + def make_output_with_output_grad(op): zero_flag = [False] * op.num_results() outputs = [] @@ -365,8 +406,11 @@ def make_output_with_output_grad(op): for i, value in enumerate(op.results()): new_value = [value] if ( - value in state.value_to_valuegrad - and len(state.value_to_valuegrad[value]) > 1 + check_in_keys(value, state.value_to_valuegrad) + and len( + take_defaultdict_list_value(value, state.value_to_valuegrad) + ) + > 1 ): # one value is input of more than one fwd_op, # so more than one bwd_op create input_grad, @@ -550,7 +594,14 @@ def update_input_grad_map(op, input_grads): else: if op.num_operands() == 0 and op.num_results() != 0: for value in op.results(): - if len(state.value_to_valuegrad[value]) > 1: + if ( + len( + take_defaultdict_list_value( + value, state.value_to_valuegrad + ) + ) + > 1 + ): # need add sum op paddle.add_n( [ @@ -577,6 +628,19 @@ def update_input_grad_map(op, input_grads): def create_backward_prune_set(inputs, outputs, no_grad_set, state): + def append_list(value, inputs_list: list) -> list: + if isinstance(value, (list, set, tuple)): + for input in inputs_list: + for v in value: + if input.is_name(v): + break + else: + # when break is not triggered, enter the else branch + inputs_list.append(v) + else: + inputs_list.append(value) + return inputs_list + outputs_set = set() for input_ in inputs: if not input_.use_empty(): @@ -586,16 +650,16 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): else: logging.warning("input privided by inputs has no use") - inputs_set = set() + inputs_list = [] for output in outputs: if state.value_to_valuegrad[output] != []: - inputs_set.add(state.value_to_valuegrad[output][0][0]) + inputs_list.append(state.value_to_valuegrad[output][0][0]) inputs_set_tmp = set() - for out_grad in inputs_set: + for out_grad in inputs_list: if not out_grad.use_empty(): for item in out_grad.first_use().owner().operands_source(): inputs_set_tmp.add(item) - inputs_set.update(inputs_set_tmp) + append_list(inputs_set_tmp, inputs_list) no_gradvar_set = set() # grad_value of value in no_grad_set for key in state.value_to_valuegrad: @@ -606,7 +670,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): for item in state.value_to_sumvaluegrad[key][0]: no_gradvar_set.add(item) - return outputs_set, inputs_set, no_gradvar_set + return outputs_set, inputs_list, no_gradvar_set def remove_op(block, op, state): diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index add6ea93b96ba5..5602a3c521b740 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -396,10 +396,10 @@ def __impl__(self, other_var): ), # for logical compare # TODO(gouzil): Open after deleting c++ logic - # ( - # '__eq__', - # _binary_creator_('__eq__', paddle.tensor.equal, False, None), - # ), + ( + '__eq__', + _binary_creator_('__eq__', paddle.tensor.equal, False, None), + ), ( '__ne__', _binary_creator_('__ne__', paddle.tensor.not_equal, False, None), diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 95a6be11bf5011..eebadd864ccb6b 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -272,9 +272,9 @@ def test_equal_and_nequal(self): x_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32') y_np = np.array([3, 4, 11, 15, 8, 18]).astype('float32') # TODO(gouzil): Open after deleting c++ logic - # res_np_b = x_np == y_np - # res_np_c = paddle.equal(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) - # res_np_d = x_np.__eq__(y_np) + res_np_b = x_np == y_np + res_np_c = paddle.equal(paddle.to_tensor(x_np), paddle.to_tensor(y_np)) + res_np_d = x_np.__eq__(y_np) res_np_e = x_np != y_np res_np_f = paddle.not_equal( paddle.to_tensor(x_np), paddle.to_tensor(y_np) @@ -286,20 +286,20 @@ def test_equal_and_nequal(self): with program_guard: x = paddle.static.data(name="x", shape=[-1, 1], dtype='float32') y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32') - # b = x == y - # c = x.equal(y) - # d = x.__eq__(y) + b = x == y + c = x.equal(y) + d = x.__eq__(y) e = x != y f = x.not_equal(y) g = x.__ne__(y) - (e_np, f_np, g_np) = exe.run( + (b_np, c_np, d_np, e_np, f_np, g_np) = exe.run( main_program, feed={"x": x_np, "y": y_np}, - fetch_list=[e, f, g], + fetch_list=[b, c, d, e, f, g], ) - # np.testing.assert_array_equal(res_np_b, b_np) - # np.testing.assert_array_equal(res_np_c, c_np) - # np.testing.assert_array_equal(res_np_d, d_np) + np.testing.assert_array_equal(res_np_b, b_np) + np.testing.assert_array_equal(res_np_c, c_np) + np.testing.assert_array_equal(res_np_d, d_np) np.testing.assert_array_equal(res_np_e, e_np) np.testing.assert_array_equal(res_np_f, f_np) np.testing.assert_array_equal(res_np_g, g_np)