Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Hackathon 5th No.49][pir] add OpResult.__eq__ - Part 6 #58791

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ void BindValue(py::module *m) {
[](Value &self, Value &op_value) {
self.ReplaceAllUsesWith(op_value);
})
.def("__eq__", &Value::operator==)
.def("__eq__",
.def("is_name", &Value::operator==)
.def("is_name",
[](Value &self, OpResult &other) {
return self.impl() == other.Value::impl();
})
Expand Down Expand Up @@ -661,8 +661,8 @@ void BindOpResult(py::module *m) {
OVERRIDE_COMPARE_OP_FOR_EACH(__gt__, greater_than);
OVERRIDE_COMPARE_OP_FOR_EACH(__ge__, greater_equal);

op_result.def("__eq__", &OpResult::operator==)
.def("__eq__",
op_result.def("is_name", &OpResult::operator==)
.def("is_name",
[](OpResult &self, Value &other) {
return self.Value::impl() == other.impl();
})
Expand Down
114 changes: 89 additions & 25 deletions python/paddle/autograd/ir_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,23 @@ def prepare_grad_outputs(grad_outputs, outputs, state):

def some_in_set(value_list, value_set):
def operand2value(values):
value_set = set()
res = []
for item in values:
if isinstance(item, paddle.pir.OpOperand):
value_set.add(item.source())
res.append(item.source())
else:
value_set.add(item)
return value_set

if operand2value(value_list) & operand2value(value_set):
return True
else:
res.append(item)
return res

def check_value(original_value: list, compare_value: list) -> bool:
for i in original_value:
for j in compare_value:
if i.is_name(j):
return True
return False

return check_value(operand2value(value_list), operand2value(value_set))


def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
'''
Expand All @@ -179,20 +183,36 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
pruned op in total_ops is uneffective_ops, else is effective_ops

'''

def append_list(value, inputs: list) -> list:
if isinstance(value, (list, set, tuple)):
for input in inputs:
for v in value:
if input.is_name(v):
break
else:
# when break is not triggered, enter the else branch
inputs.append(v)
else:
inputs.append(value)
return inputs

inputs_list = list(inputs_set)

intersection_op_flags = [True] * len(total_ops)
union_op_flags = [False] * len(total_ops)
# from input to output
if inputs_set:
if inputs_list:
for i, op in enumerate(total_ops):
if some_in_set(op.results(), inputs_set):
if some_in_set(op.results(), inputs_list):
union_op_flags[i] = True
continue

if some_in_set(op.operands_source(), inputs_set):
if some_in_set(op.operands_source(), inputs_list):
union_op_flags[i] = True
for value in op.results():
if value not in no_grad_set:
inputs_set.add(value)
append_list(value, inputs_list)
else:
intersection_op_flags[i] = False

Expand Down Expand Up @@ -245,6 +265,13 @@ def update_no_grad_set_after_prune(
from inputs to outputs add value not in the path to no_grad_set,
from outputs to inputs add value not in the path to no_grad_set,
'''

def check_exist(value, res_set: set) -> bool:
for res in res_set:
if value.is_name(res):
return True
return False

inputs_set = set(inputs)
if inputs_set:
for op in block.ops:
Expand All @@ -255,17 +282,17 @@ def update_no_grad_set_after_prune(

for op in effective_forward_ops:
for value in op.operands_source():
if value not in inputs_set:
if not check_exist(value, inputs_set):
no_grad_set.add(value)

outputs_set = set(outputs)
no_grad_set_tmp = set()
for op in reversed(effective_forward_ops):
for output in op.results():
if output not in outputs_set and not some_in_set(
[output], set(op.operands_source())
for output_res in op.results():
if not check_exist(output_res, outputs_set) and not some_in_set(
[output_res], set(op.operands_source())
):
no_grad_set_tmp.add(output)
no_grad_set_tmp.add(output_res)

for input in op.operands_source():
if input not in no_grad_set:
Expand Down Expand Up @@ -358,15 +385,32 @@ def append_backward_ops(
else continue to next op.
'''

def check_in_keys(value, state_valuegrad: collections.defaultdict) -> bool:
for opresult in state_valuegrad.keys():
if value.is_name(opresult):
return True
return False

def take_defaultdict_list_value(
value, state_valuegrad: collections.defaultdict
):
for opresult in state_valuegrad.keys():
if value.is_name(opresult):
return state_valuegrad[opresult]
return []

def make_output_with_output_grad(op):
zero_flag = [False] * op.num_results()
outputs = []
output_grads = []
for i, value in enumerate(op.results()):
new_value = [value]
if (
value in state.value_to_valuegrad
and len(state.value_to_valuegrad[value]) > 1
check_in_keys(value, state.value_to_valuegrad)
and len(
take_defaultdict_list_value(value, state.value_to_valuegrad)
)
> 1
):
# one value is input of more than one fwd_op,
# so more than one bwd_op create input_grad,
Expand Down Expand Up @@ -550,7 +594,14 @@ def update_input_grad_map(op, input_grads):
else:
if op.num_operands() == 0 and op.num_results() != 0:
for value in op.results():
if len(state.value_to_valuegrad[value]) > 1:
if (
len(
take_defaultdict_list_value(
value, state.value_to_valuegrad
)
)
> 1
):
# need add sum op
paddle.add_n(
[
Expand All @@ -577,6 +628,19 @@ def update_input_grad_map(op, input_grads):


def create_backward_prune_set(inputs, outputs, no_grad_set, state):
def append_list(value, inputs_list: list) -> list:
if isinstance(value, (list, set, tuple)):
for input in inputs_list:
for v in value:
if input.is_name(v):
break
else:
# when break is not triggered, enter the else branch
inputs_list.append(v)
else:
inputs_list.append(value)
return inputs_list

outputs_set = set()
for input_ in inputs:
if not input_.use_empty():
Expand All @@ -586,16 +650,16 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
else:
logging.warning("input privided by inputs has no use")

inputs_set = set()
inputs_list = []
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_list.append(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set()
for out_grad in inputs_set:
for out_grad in inputs_list:
if not out_grad.use_empty():
for item in out_grad.first_use().owner().operands_source():
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp)
append_list(inputs_set_tmp, inputs_list)

no_gradvar_set = set() # grad_value of value in no_grad_set
for key in state.value_to_valuegrad:
Expand All @@ -606,7 +670,7 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
for item in state.value_to_sumvaluegrad[key][0]:
no_gradvar_set.add(item)

return outputs_set, inputs_set, no_gradvar_set
return outputs_set, inputs_list, no_gradvar_set


def remove_op(block, op, state):
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,10 +396,10 @@ def __impl__(self, other_var):
),
# for logical compare
# TODO(gouzil): Open after deleting c++ logic
# (
# '__eq__',
# _binary_creator_('__eq__', paddle.tensor.equal, False, None),
# ),
(
'__eq__',
_binary_creator_('__eq__', paddle.tensor.equal, False, None),
),
(
'__ne__',
_binary_creator_('__ne__', paddle.tensor.not_equal, False, None),
Expand Down
22 changes: 11 additions & 11 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,9 @@ def test_equal_and_nequal(self):
x_np = np.array([3, 4, 10, 14, 9, 18]).astype('float32')
y_np = np.array([3, 4, 11, 15, 8, 18]).astype('float32')
# TODO(gouzil): Open after deleting c++ logic
# res_np_b = x_np == y_np
# res_np_c = paddle.equal(paddle.to_tensor(x_np), paddle.to_tensor(y_np))
# res_np_d = x_np.__eq__(y_np)
res_np_b = x_np == y_np
res_np_c = paddle.equal(paddle.to_tensor(x_np), paddle.to_tensor(y_np))
res_np_d = x_np.__eq__(y_np)
res_np_e = x_np != y_np
res_np_f = paddle.not_equal(
paddle.to_tensor(x_np), paddle.to_tensor(y_np)
Expand All @@ -286,20 +286,20 @@ def test_equal_and_nequal(self):
with program_guard:
x = paddle.static.data(name="x", shape=[-1, 1], dtype='float32')
y = paddle.static.data(name="y", shape=[-1, 1], dtype='float32')
# b = x == y
# c = x.equal(y)
# d = x.__eq__(y)
b = x == y
c = x.equal(y)
d = x.__eq__(y)
e = x != y
f = x.not_equal(y)
g = x.__ne__(y)
(e_np, f_np, g_np) = exe.run(
(b_np, c_np, d_np, e_np, f_np, g_np) = exe.run(
main_program,
feed={"x": x_np, "y": y_np},
fetch_list=[e, f, g],
fetch_list=[b, c, d, e, f, g],
)
# np.testing.assert_array_equal(res_np_b, b_np)
# np.testing.assert_array_equal(res_np_c, c_np)
# np.testing.assert_array_equal(res_np_d, d_np)
np.testing.assert_array_equal(res_np_b, b_np)
np.testing.assert_array_equal(res_np_c, c_np)
np.testing.assert_array_equal(res_np_d, d_np)
np.testing.assert_array_equal(res_np_e, e_np)
np.testing.assert_array_equal(res_np_f, f_np)
np.testing.assert_array_equal(res_np_g, g_np)
Expand Down