Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unification of BF16 enablement process #31034

Merged
merged 3 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 47 additions & 44 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1829,9 +1829,8 @@ PDNode *patterns::OpDequant::operator()() {
auto any_op = pattern->NewNode(any_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return (node->Op()->Type() == "matmul" ||
node->Op()->Type() == "conv2d" ||
node->Op()->Type() == "fc");
return (node->Op()->HasAttr("force_fp32_output") ||
node->Op()->HasProtoAttr("force_fp32_output"));
});
auto dequant_in = pattern->NewNode(dequant_in_repr())
->assert_is_op_input("dequantize", "Input");
Expand Down Expand Up @@ -1865,6 +1864,44 @@ PDNode *patterns::DequantScale::operator()() {
return scale_out;
}

PDNode *patterns::ScaleQuant::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
->assert_is_op_input("scale", "X");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");

auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");

scale_op->LinksFrom({scale_in}).LinksTo({quant_in});
quant_op->LinksFrom({quant_in});

return quant_op;
}

PDNode *patterns::QuantConv::operator()() {
auto quant_in = pattern->NewNode(quant_in_repr())
->AsInput()
->assert_is_op_input("quantize", "Input");
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");

auto conv_in = pattern->NewNode(conv_in_repr())
->AsInput()
->assert_is_op_input("conv2d", "Input");
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
conv_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});

quant_op->LinksFrom({quant_in}).LinksTo({conv_in});
conv_op->LinksFrom({conv_in});

return quant_op;
}

PDNode *patterns::ScaleMatmul::operator()() {
auto scale_in = pattern->NewNode(scale_in_repr())
->AsInput()
Expand Down Expand Up @@ -2191,10 +2228,11 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"});
std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose",
"elementwise_add", "elementwise_mul",
"fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "relu", "reshape2",
"softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
Expand Down Expand Up @@ -2240,33 +2278,19 @@ PDNode *patterns::LastBfloat16Ops::operator()() {
"bfloat16";
});
auto *op_out = pattern->NewNode(op_out_repr())->AsOutput();

auto *next_op = pattern->NewNode(next_op_repr())->assert_is_op();
next_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});

op->LinksTo({op_out});
next_op->LinksFrom({op_out});
return next_op;
return op_out;
}

PDNode *patterns::FirstBfloat16Ops::operator()() {
auto *prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") !=
"bfloat16";
});
auto *op_in = pattern->NewNode(op_in_repr())->AsOutput();
auto *op_in = pattern->NewNode(op_in_repr())->AsInput();

auto *op = pattern->NewNode(op_repr())->assert_is_op();
op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});

prev_op->LinksTo({op_in});
op->LinksFrom({op_in});
return op;
}
Expand All @@ -2280,27 +2304,6 @@ PDNode *patterns::DuplicatedInputs::operator()() {
return op;
}

PDNode *patterns::UnnecessaryReorders::operator()() {
auto prev_op = pattern->NewNode(prev_op_repr())->assert_is_op();
prev_op->assert_more([&](Node *node) {
return node->Op()->GetAttrIfExists<std::string>("mkldnn_data_type") ==
"bfloat16";
});

auto *quant_in = pattern->NewNode(quant_in_repr())
->assert_is_op_input("quantize", "Input");

auto *quant_op = pattern->NewNode(quant_op_repr())->assert_is_op("quantize");

auto *quant_out = pattern->NewNode(quant_out_repr())
->assert_is_op_output("quantize", "Output");

prev_op->LinksTo({quant_in});
quant_op->LinksFrom({quant_in}).LinksTo({quant_out});

return quant_out;
}

PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs",
Expand Down
40 changes: 26 additions & 14 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1135,11 +1135,36 @@ struct DequantScale : public PatternBase {

PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out);

PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};

// Scale + Quantize
struct ScaleQuant : public PatternBase {
ScaleQuant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "scale_quant") {}

PDNode* operator()();

PATTERN_DECL_NODE(scale_in);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
};

// Quantize + Conv2d
struct QuantConv : public PatternBase {
QuantConv(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_conv") {}

PDNode* operator()();

PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(conv_in);
PATTERN_DECL_NODE(conv_op);
};

// Scale + Matmul
struct ScaleMatmul : public PatternBase {
ScaleMatmul(PDPattern* pattern, const std::string& name_scope)
Expand Down Expand Up @@ -1338,15 +1363,13 @@ struct LastBfloat16Ops : public PatternBase {

PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(op_out);
PATTERN_DECL_NODE(next_op);
};

struct FirstBfloat16Ops : public PatternBase {
FirstBfloat16Ops(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "first_bfloat16_ops") {}
PDNode* operator()();

PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(op_in);
PATTERN_DECL_NODE(op);
};
Expand All @@ -1360,17 +1383,6 @@ struct DuplicatedInputs : public PatternBase {
PATTERN_DECL_NODE(op);
};

struct UnnecessaryReorders : public PatternBase {
UnnecessaryReorders(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "unnecessary_reorders") {}
PDNode* operator()();

PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
};

// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
Expand Down
Loading