Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support set custom OP attibute in PIR mode #63176

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,29 @@ Value GetOutputValueByName(const Program &program, const std::string &name) {
return value;
}

pir::Attribute PyObject2Attribute(const py::object &obj) {
if (py::isinstance<py::bool_>(obj)) {
return BoolAttribute::get(IrContext::Instance(), obj.cast<bool>());
} else if (py::isinstance<py::int_>(obj)) {
return pir::Int64Attribute::get(IrContext::Instance(), obj.cast<int64_t>());
} else if (py::isinstance<py::float_>(obj)) {
return pir::FloatAttribute::get(IrContext::Instance(), obj.cast<float>());
} else if (py::isinstance<py::str>(obj)) {
return pir::StrAttribute::get(IrContext::Instance(),
obj.cast<std::string>());
} else if (py::isinstance<py::list>(obj)) {
auto list = obj.cast<py::list>();
std::vector<Attribute> attrs;
for (auto &item : list) {
attrs.push_back(PyObject2Attribute(item.cast<py::object>()));
}
return pir::ArrayAttribute::get(IrContext::Instance(), attrs);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The type of attr should be bool, int, float, str or list."));
}
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(
*m, "Program", py::dynamic_attr(), R"DOC(
Expand Down Expand Up @@ -521,6 +544,11 @@ void BindOperation(py::module *m) {
}
return attrs_dict;
})
.def("set_attr",
[](Operation &self, const std::string &attr_name, py::object attr) {
auto attr_value = PyObject2Attribute(attr);
self.set_attribute(attr_name, attr_value);
})
.def("set_scheduling_priority",
[](Operation &self, int64_t priority) {
self.set_attribute("scheduling_priority",
Expand Down Expand Up @@ -638,6 +666,10 @@ void BindOperation(py::module *m) {
"callstack",
[](Operation &self) -> py::list {
py::list callstack_list;
if (!self.HasAttribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName())) {
return callstack_list;
}
pir::Attribute op_callstack = self.attribute<pir::Attribute>(
paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName());
Expand All @@ -663,17 +695,11 @@ void BindOperation(py::module *m) {
},
[](Operation &self,
const std::vector<std::string> &callstack) -> void {
std::vector<pir::Attribute> op_callstack_infos;
for (auto str : callstack) {
op_callstack_infos.push_back(
pir::StrAttribute::get(pir::IrContext::Instance(), str));
}
auto op_callstack_attr = PyObject2Attribute(py::cast(callstack));

self.set_attribute(
paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
pir::ArrayAttribute::get(pir::IrContext::Instance(),
op_callstack_infos));
self.set_attribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
op_callstack_attr);
})
.def("dist_attr", [](Operation &self) {
if (self.HasAttribute(kAttrOpDistAttr)) {
Expand Down Expand Up @@ -1054,6 +1080,10 @@ struct PyInsertionPoint {
void BindInsertionPoint(pybind11::module *m) {
py::class_<PyInsertionPoint> ir_insertion_point(*m, "InsertionPoint", R"DOC(
InsertionPoint class represents the insertion point in the Builder.)DOC");
ir_insertion_point.def(
"block",
[](PyInsertionPoint &self) { return self.value.first; },
return_value_policy::reference);
}

std::list<Operation *>::const_iterator list_offset(const Block *block,
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from paddle.base.layer_helper_base import LayerHelperBase
from paddle.base.param_attr import ParamAttr
from paddle.framework import use_pir_api
from paddle.profiler.utils import in_profiler_mode
from paddle.utils import deprecated

Expand Down Expand Up @@ -79,16 +80,25 @@ def set_op_customized_attrs_post_hook(layer, inputs, outputs):
"""
if not in_dygraph_mode() and layer._op_recorder.is_valid:
start = layer._op_recorder.start
end = len(default_main_program().current_block().ops)
if use_pir_api():
current_block = (
paddle.base.libpaddle.pir.get_current_insertion_point().block()
)
else:
current_block = default_main_program().current_block()
end = len(current_block.ops)
assert start >= 0 and end >= start
ops = default_main_program().current_block().ops[start:end]
ops = current_block.ops[start:end]

layer._op_recorder.end = end
layer._op_recorder.ops = ops

for op in ops:
for attr_name, val in layer._customized_attrs.items():
op._set_attr(attr_name, val)
if use_pir_api():
op.set_attr(attr_name, val)
else:
op._set_attr(attr_name, val)

# remove pre-hook and post-hook
for hook_helper in layer._op_recorder.hooks:
Expand Down
177 changes: 118 additions & 59 deletions test/dygraph_to_static/test_op_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,32 @@

import unittest

from dygraph_to_static_utils import Dy2StTestBase, test_ast_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.framework import use_pir_api
from paddle.static import InputSpec


def walk(block, fn):
fn(block)
for op in block.ops:
for sub_block in op.blocks():
walk(sub_block, fn)


def run_on_each_op(block, fn):
def check_block(block):
for op in block.ops:
fn(op)

walk(block, check_block)


class MySub(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand All @@ -28,63 +48,76 @@ def forward(self, x, y, name=None):
return paddle.subtract(x, y, name)


class NetWithOpAttr(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super().__init__()

self.linear = paddle.nn.Linear(in_num, out_num)
self.bn = paddle.nn.BatchNorm(out_num)
self.sub = MySub()
def create_net_with_op_attr_class():
class NetWithOpAttr(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super().__init__()

def forward(self, x):
out = self.linear(x)
out = self.sub(out, x)
out = self.bn(out)
return out
self.linear = paddle.nn.Linear(in_num, out_num)
self.bn = paddle.nn.BatchNorm(out_num)
self.sub = MySub()

@paddle.jit.to_static(input_spec=[InputSpec([10, 16])], full_graph=True)
def with_cond(self, x):
if paddle.mean(x) > 0.0:
def forward(self, x):
out = self.linear(x)
else:
out = self.sub(x, x)
out = self.bn(out)
return out
out = self.sub(out, x)
out = self.bn(out)
return out

@paddle.jit.to_static(input_spec=[InputSpec([10, 16])], full_graph=True)
def with_cond(self, x):
if paddle.mean(x) > 0.0:
out = self.linear(x)
else:
out = self.sub(x, x)
out = self.bn(out)
return out

return NetWithOpAttr


class CheckOpAttr(Dy2StTestBase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x = paddle.randn([10, self.in_num])
self.expected_results()

def expected_results(self):
self.fc_attrs = {
fc_attrs = {
"int_val": 10,
"int_vals": [10, 20],
"float_val": 3.8,
"float_vals": [3.8, -0.2],
}
self.bn_attrs = {"bool_val": True, "bool_vals": [True, False]}
self.sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]}

self.infos = {
'matmul': self.fc_attrs,
'elementwise_add': self.fc_attrs,
'batch_norm': self.bn_attrs,
'tanh': self.bn_attrs,
'elementwise_sub': self.sub_attrs,
}
bn_attrs = {"bool_val": True, "bool_vals": [True, False]}
sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]}

if use_pir_api():
infos = {
'pd_op.matmul': fc_attrs,
'pd_op.add': fc_attrs,
'pd_op.batch_norm_': bn_attrs,
'pd_op.subtract': sub_attrs,
}
else:
infos = {
'matmul': fc_attrs,
'elementwise_add': fc_attrs,
'batch_norm': bn_attrs,
'tanh': bn_attrs,
'elementwise_sub': sub_attrs,
}
return fc_attrs, bn_attrs, sub_attrs, infos

@test_ast_only
@test_legacy_and_pt_and_pir
def test_set_op_attrs(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, bn_attrs, sub_attrs, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs(self.fc_attrs)
net.linear._set_op_attrs(fc_attrs)
net.bn._set_op_attrs({"bool_val": False}) # test overwrite behavior
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
net.bn._set_op_attrs(bn_attrs)
net.sub._set_op_attrs(sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
Expand All @@ -101,34 +134,58 @@ def test_set_op_attrs(self):
self.assertEqual(len(net.linear._forward_post_hooks), 0)

def check_op_attrs(self, main_program):
for cur_block in main_program.blocks:
ops = cur_block.ops
for op in ops:
if op.type not in self.infos:
continue
for attr_name, expect_vals in self.infos[op.type].items():
op_vals = op.desc.attr(attr_name)
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)
_, _, _, infos = self.expected_results()
if not use_pir_api():
for cur_block in main_program.blocks:
ops = cur_block.ops
for op in ops:
if op.type not in infos:
continue
for attr_name, expect_vals in infos[op.type].items():
op_vals = op.desc.attr(attr_name)
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)
return
global_block = main_program.global_block()

def check_op(op):
if op.name() not in infos:
return
for attr_name, expect_vals in infos[op.name()].items():
op_vals = op.attrs()[attr_name]
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)

run_on_each_op(global_block, check_op)

@test_ast_only
@test_legacy_and_pt_and_pir
def test_set_op_attrs_with_sub_block(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, bn_attrs, sub_attrs, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs(
{"int_vals": [0, 0]}
) # test overwrite behavior
net.linear._set_op_attrs(self.fc_attrs)
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
net.linear._set_op_attrs(fc_attrs)
net.bn._set_op_attrs(bn_attrs)
net.sub._set_op_attrs(sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
Expand All @@ -140,11 +197,13 @@ def test_set_op_attrs_with_sub_block(self):
self.assertEqual(len(net.linear._forward_pre_hooks), 0)
self.assertEqual(len(net.linear._forward_post_hooks), 0)

@test_legacy_and_pt_and_pir
def test_type_error(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, _, _, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# attrs should be dict
with self.assertRaises(TypeError):
net.linear._set_op_attrs([self.fc_attrs])
net.linear._set_op_attrs([fc_attrs])


if __name__ == '__main__':
Expand Down