Skip to content

Commit

Permalink
[PIR] Support set custom OP attibute in PIR mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Apr 2, 2024
1 parent 980f6f8 commit 59e78be
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 72 deletions.
50 changes: 40 additions & 10 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,29 @@ Value GetOutputValueByName(const Program &program, const std::string &name) {
return value;
}

pir::Attribute PyObject2Attribute(const py::object &obj) {
if (py::isinstance<py::bool_>(obj)) {
return BoolAttribute::get(IrContext::Instance(), obj.cast<bool>());
} else if (py::isinstance<py::int_>(obj)) {
return pir::Int64Attribute::get(IrContext::Instance(), obj.cast<int64_t>());
} else if (py::isinstance<py::float_>(obj)) {
return pir::FloatAttribute::get(IrContext::Instance(), obj.cast<float>());
} else if (py::isinstance<py::str>(obj)) {
return pir::StrAttribute::get(IrContext::Instance(),
obj.cast<std::string>());
} else if (py::isinstance<py::list>(obj)) {
auto list = obj.cast<py::list>();
std::vector<Attribute> attrs;
for (auto &item : list) {
attrs.push_back(PyObject2Attribute(item.cast<py::object>()));
}
return pir::ArrayAttribute::get(IrContext::Instance(), attrs);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"The type of attr should be bool, int, float, str or list."));
}
}

void BindProgram(py::module *m) {
py::class_<Program, std::shared_ptr<Program>> program(
*m, "Program", py::dynamic_attr(), R"DOC(
Expand Down Expand Up @@ -521,6 +544,11 @@ void BindOperation(py::module *m) {
}
return attrs_dict;
})
.def("set_attr",
[](Operation &self, const std::string &attr_name, py::object attr) {
auto attr_value = PyObject2Attribute(attr);
self.set_attribute(attr_name, attr_value);
})
.def("set_scheduling_priority",
[](Operation &self, int64_t priority) {
self.set_attribute("scheduling_priority",
Expand Down Expand Up @@ -638,6 +666,10 @@ void BindOperation(py::module *m) {
"callstack",
[](Operation &self) -> py::list {
py::list callstack_list;
if (!self.HasAttribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName())) {
return callstack_list;
}
pir::Attribute op_callstack = self.attribute<pir::Attribute>(
paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName());
Expand All @@ -663,17 +695,11 @@ void BindOperation(py::module *m) {
},
[](Operation &self,
const std::vector<std::string> &callstack) -> void {
std::vector<pir::Attribute> op_callstack_infos;
for (auto str : callstack) {
op_callstack_infos.push_back(
pir::StrAttribute::get(pir::IrContext::Instance(), str));
}
auto op_callstack_attr = PyObject2Attribute(py::cast(callstack));

self.set_attribute(
paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
pir::ArrayAttribute::get(pir::IrContext::Instance(),
op_callstack_infos));
self.set_attribute(paddle::framework::OpProtoAndCheckerMaker::
OpCreationCallstackAttrName(),
op_callstack_attr);
})
.def("dist_attr", [](Operation &self) {
if (self.HasAttribute(kAttrOpDistAttr)) {
Expand Down Expand Up @@ -1054,6 +1080,10 @@ struct PyInsertionPoint {
void BindInsertionPoint(pybind11::module *m) {
py::class_<PyInsertionPoint> ir_insertion_point(*m, "InsertionPoint", R"DOC(
InsertionPoint class represents the insertion point in the Builder.)DOC");
ir_insertion_point.def(
"block",
[](PyInsertionPoint &self) { return self.value.first; },
return_value_policy::reference);
}

std::list<Operation *>::const_iterator list_offset(const Block *block,
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from paddle.base.layer_helper_base import LayerHelperBase
from paddle.base.param_attr import ParamAttr
from paddle.framework import use_pir_api
from paddle.profiler.utils import in_profiler_mode
from paddle.utils import deprecated

Expand Down Expand Up @@ -79,16 +80,25 @@ def set_op_customized_attrs_post_hook(layer, inputs, outputs):
"""
if not in_dygraph_mode() and layer._op_recorder.is_valid:
start = layer._op_recorder.start
end = len(default_main_program().current_block().ops)
if use_pir_api():
current_block = (
paddle.base.libpaddle.pir.get_current_insertion_point().block()
)
else:
current_block = default_main_program().current_block()
end = len(current_block.ops)
assert start >= 0 and end >= start
ops = default_main_program().current_block().ops[start:end]
ops = current_block.ops[start:end]

layer._op_recorder.end = end
layer._op_recorder.ops = ops

for op in ops:
for attr_name, val in layer._customized_attrs.items():
op._set_attr(attr_name, val)
if use_pir_api():
op.set_attr(attr_name, val)
else:
op._set_attr(attr_name, val)

# remove pre-hook and post-hook
for hook_helper in layer._op_recorder.hooks:
Expand Down
177 changes: 118 additions & 59 deletions test/dygraph_to_static/test_op_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,32 @@

import unittest

from dygraph_to_static_utils import Dy2StTestBase, test_ast_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.framework import use_pir_api
from paddle.static import InputSpec


def walk(block, fn):
fn(block)
for op in block.ops:
for sub_block in op.blocks():
walk(sub_block, fn)


def run_on_each_op(block, fn):
def check_block(block):
for op in block.ops:
fn(op)

walk(block, check_block)


class MySub(paddle.nn.Layer):
def __init__(self):
super().__init__()
Expand All @@ -28,63 +48,76 @@ def forward(self, x, y, name=None):
return paddle.subtract(x, y, name)


class NetWithOpAttr(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super().__init__()

self.linear = paddle.nn.Linear(in_num, out_num)
self.bn = paddle.nn.BatchNorm(out_num)
self.sub = MySub()
def create_net_with_op_attr_class():
class NetWithOpAttr(paddle.nn.Layer):
def __init__(self, in_num, out_num):
super().__init__()

def forward(self, x):
out = self.linear(x)
out = self.sub(out, x)
out = self.bn(out)
return out
self.linear = paddle.nn.Linear(in_num, out_num)
self.bn = paddle.nn.BatchNorm(out_num)
self.sub = MySub()

@paddle.jit.to_static(input_spec=[InputSpec([10, 16])], full_graph=True)
def with_cond(self, x):
if paddle.mean(x) > 0.0:
def forward(self, x):
out = self.linear(x)
else:
out = self.sub(x, x)
out = self.bn(out)
return out
out = self.sub(out, x)
out = self.bn(out)
return out

@paddle.jit.to_static(input_spec=[InputSpec([10, 16])], full_graph=True)
def with_cond(self, x):
if paddle.mean(x) > 0.0:
out = self.linear(x)
else:
out = self.sub(x, x)
out = self.bn(out)
return out

return NetWithOpAttr


class CheckOpAttr(Dy2StTestBase):
def setUp(self):
self.in_num = 16
self.out_num = 16
self.x = paddle.randn([10, self.in_num])
self.expected_results()

def expected_results(self):
self.fc_attrs = {
fc_attrs = {
"int_val": 10,
"int_vals": [10, 20],
"float_val": 3.8,
"float_vals": [3.8, -0.2],
}
self.bn_attrs = {"bool_val": True, "bool_vals": [True, False]}
self.sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]}

self.infos = {
'matmul': self.fc_attrs,
'elementwise_add': self.fc_attrs,
'batch_norm': self.bn_attrs,
'tanh': self.bn_attrs,
'elementwise_sub': self.sub_attrs,
}
bn_attrs = {"bool_val": True, "bool_vals": [True, False]}
sub_attrs = {"int_vals": [10, 20], "bool_vals": [True, False]}

if use_pir_api():
infos = {
'pd_op.matmul': fc_attrs,
'pd_op.add': fc_attrs,
'pd_op.batch_norm_': bn_attrs,
'pd_op.subtract': sub_attrs,
}
else:
infos = {
'matmul': fc_attrs,
'elementwise_add': fc_attrs,
'batch_norm': bn_attrs,
'tanh': bn_attrs,
'elementwise_sub': sub_attrs,
}
return fc_attrs, bn_attrs, sub_attrs, infos

@test_ast_only
@test_legacy_and_pt_and_pir
def test_set_op_attrs(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, bn_attrs, sub_attrs, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs(self.fc_attrs)
net.linear._set_op_attrs(fc_attrs)
net.bn._set_op_attrs({"bool_val": False}) # test overwrite behavior
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
net.bn._set_op_attrs(bn_attrs)
net.sub._set_op_attrs(sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
Expand All @@ -101,34 +134,58 @@ def test_set_op_attrs(self):
self.assertEqual(len(net.linear._forward_post_hooks), 0)

def check_op_attrs(self, main_program):
for cur_block in main_program.blocks:
ops = cur_block.ops
for op in ops:
if op.type not in self.infos:
continue
for attr_name, expect_vals in self.infos[op.type].items():
op_vals = op.desc.attr(attr_name)
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)
_, _, _, infos = self.expected_results()
if not use_pir_api():
for cur_block in main_program.blocks:
ops = cur_block.ops
for op in ops:
if op.type not in infos:
continue
for attr_name, expect_vals in infos[op.type].items():
op_vals = op.desc.attr(attr_name)
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)
return
global_block = main_program.global_block()

def check_op(op):
if op.name() not in infos:
return
for attr_name, expect_vals in infos[op.name()].items():
op_vals = op.attrs()[attr_name]
if not isinstance(expect_vals, list):
expect_vals = [expect_vals]
op_vals = [op_vals]

for op_val, expect_val in zip(op_vals, expect_vals):
if isinstance(op_val, float):
# C++ vs python: 3.799999952316284 ~= 3.8
self.assertAlmostEqual(op_val, expect_val)
else:
self.assertEqual(op_val, expect_val)

run_on_each_op(global_block, check_op)

@test_ast_only
@test_legacy_and_pt_and_pir
def test_set_op_attrs_with_sub_block(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, bn_attrs, sub_attrs, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# set attrs
net.linear._set_op_attrs(
{"int_vals": [0, 0]}
) # test overwrite behavior
net.linear._set_op_attrs(self.fc_attrs)
net.bn._set_op_attrs(self.bn_attrs)
net.sub._set_op_attrs(self.sub_attrs)
net.linear._set_op_attrs(fc_attrs)
net.bn._set_op_attrs(bn_attrs)
net.sub._set_op_attrs(sub_attrs)
# assert hooks exist.
self.assertEqual(len(net.linear._forward_pre_hooks), 1)
self.assertEqual(len(net.linear._forward_post_hooks), 1)
Expand All @@ -140,11 +197,13 @@ def test_set_op_attrs_with_sub_block(self):
self.assertEqual(len(net.linear._forward_pre_hooks), 0)
self.assertEqual(len(net.linear._forward_post_hooks), 0)

@test_legacy_and_pt_and_pir
def test_type_error(self):
net = NetWithOpAttr(self.in_num, self.out_num)
fc_attrs, _, _, _ = self.expected_results()
net = create_net_with_op_attr_class()(self.in_num, self.out_num)
# attrs should be dict
with self.assertRaises(TypeError):
net.linear._set_op_attrs([self.fc_attrs])
net.linear._set_op_attrs([fc_attrs])


if __name__ == '__main__':
Expand Down

0 comments on commit 59e78be

Please sign in to comment.