Skip to content

Commit

Permalink
Merge pull request #16526 from NHZlX/refine_trt_anakin
Browse files Browse the repository at this point in the history
refine subgraph trt and anakin
  • Loading branch information
NHZlX authored Mar 29, 2019
2 parents e014950 + 7cde2d9 commit 3e6aa49
Show file tree
Hide file tree
Showing 21 changed files with 435 additions and 142 deletions.
19 changes: 5 additions & 14 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,12 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base)
pass_library(simplify_anakin_detection_pattern_pass inference)
pass_library(anakin_fillconstant_elementwisemul_fuse inference)
pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference)

# There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will
# be detected by our pass. The index here represents the number of structures in the
# pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models.
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
endforeach()

foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
endforeach()
if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference)
endif()

if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include <memory>
#include <string>

#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"

namespace paddle {
Expand All @@ -29,8 +29,8 @@ namespace ir {
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);

void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
void FillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph);

GraphPatternDetector gpd;
Expand All @@ -39,8 +39,8 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
->assert_is_op_input("elementwise_mul", "X")
->AsInput();

patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(),
pattern_name);
patterns::FillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(),
pattern_name);
pattern(x);

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -79,5 +79,5 @@ void AnakinFillconstantElementwisemulFuse::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle

REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse,
paddle::framework::ir::AnakinFillconstantElementwisemulFuse);
REGISTER_PASS(fillconstant_elementwisemul_fuse,
paddle::framework::ir::FillconstantElementwisemulFuse);
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ namespace paddle {
namespace framework {
namespace ir {

class AnakinFillconstantElementwisemulFuse : public FusePassBase {
class FillconstantElementwisemulFuse : public FusePassBase {
public:
virtual ~AnakinFillconstantElementwisemulFuse() {}
virtual ~FillconstantElementwisemulFuse() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
Expand Down
94 changes: 83 additions & 11 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1471,7 +1471,8 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
}

PDNode *patterns::AnakinDetectionPattern::operator()(
std::vector<PDNode *> conv_in, int times) {
std::vector<PDNode *> conv_in, int times, std::string priorbox_type,
bool is_reshape) {
// The times represents the repeat times of the
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
const int kNumFields = 7;
Expand All @@ -1486,37 +1487,38 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
const int kMultiClassSecondInputNmsOffset = times + 1;

std::vector<PDNode *> nodes;
std::string op_after_priorbox = is_reshape ? "reshape2" : "flatten2";

for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
->assert_is_op("density_prior_box"));
->assert_is_op(priorbox_type));
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Boxes")
->assert_is_op_input("reshape2", "X")
->assert_is_op_output(priorbox_type, "Boxes")
->assert_is_op_input(op_after_priorbox, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
->assert_is_op("reshape2"));
->assert_is_op(op_after_priorbox));

nodes.push_back(
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_output(op_after_priorbox)
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());

nodes.push_back(
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Variances")
->assert_is_op_input("reshape2", "X")
->assert_is_op_output(priorbox_type, "Variances")
->assert_is_op_input(op_after_priorbox, "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
->assert_is_op("reshape2"));
->assert_is_op(op_after_priorbox));

nodes.push_back(
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_output(op_after_priorbox)
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());
}
Expand Down Expand Up @@ -1612,7 +1614,7 @@ PDNode *patterns::AnakinDetectionPattern::operator()(
return multiclass_nms_out;
}

PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
PDNode *patterns::FillConstantElementWiseMulFuse::operator()(
PDNode *elementwise_op_input) {
auto fill_constant =
pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant");
Expand All @@ -1635,6 +1637,76 @@ PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
return elementwise_mul_out;
}

void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
const std::string &op_type,
const std::string &weight_name,
int times) {
const int kNumFields = 5;
const int kQuantizedWeightOffset = 0;
const int kQuantizedOpOffset = 1;
const int kQuantizedOpOutOffset = 2;
const int kDequantOpOffset = 3;
const int kDequantOpOutOffset = 4;
// the quant op always be one.
auto quant_op_in_scale =
pattern->NewNode(GetNodeName("quant_op_in_scale"))
->assert_is_op_input("fake_quantize_range_abs_max", "InScale")
->AsInput();
auto quant_op = pattern->NewNode(GetNodeName("quant_op"))
->assert_is_op("fake_quantize_range_abs_max");

auto quant_op_out_scale =
pattern->NewNode(GetNodeName("quant_op_out_scale"))
->assert_is_op_output("fake_quantize_range_abs_max", "OutScale")
->assert_is_op_input("fake_dequantize_max_abs", "Scale")
->AsIntermediate();

auto quant_op_out =
pattern->NewNode(GetNodeName("quant_op_out"))
->assert_is_op_output("fake_quantize_range_abs_max", "Out")
->assert_is_op_input(op_type)
->AsIntermediate();

// there are 'times' quantized and dequant op
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
->assert_is_op_input(op_type, weight_name)
->AsInput());
nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
->assert_is_op(op_type));

nodes.push_back(
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
->assert_is_op_output(op_type)
->assert_is_op_input("fake_dequantize_max_abs", "X")
->AsIntermediate());

nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
->assert_is_op("fake_dequantize_max_abs"));
nodes.push_back(
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
->assert_is_op_output("fake_dequantize_max_abs", "Out")
->AsOutput());
}

quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
quant_op_out->LinksFrom({quant_op});
for (int i = 0; i < times; i++) {
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOffset]});
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
{nodes[i * kNumFields + kDequantOpOffset]});
}
}

} // namespace ir
} // namespace framework
} // namespace paddle
25 changes: 21 additions & 4 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ struct AnakinDetectionPattern : public PatternBase {
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {}

PDNode* operator()(std::vector<PDNode*> conv_inputs, int times);
PDNode* operator()(std::vector<PDNode*> conv_inputs, int times,
std::string priorbox_type, bool is_reshape);

std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
Expand All @@ -859,9 +860,9 @@ struct AnakinDetectionPattern : public PatternBase {
}
};

struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
AnakinFillConstantElementWiseMulFuse(PDPattern* pattern,
const std::string& name_scope)
struct FillConstantElementWiseMulFuse : public PatternBase {
FillConstantElementWiseMulFuse(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"anakin_fillconstant_elementwisemul_fuse") {}

Expand All @@ -874,6 +875,22 @@ struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
PATTERN_DECL_NODE(elementwise_mul_out);
};

struct QuantDequantOpFuse : public PatternBase {
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}

void operator()(PDNode* quant_op_input, const std::string& op_name,
const std::string& weight_name, int times = 1);

std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}

PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};

} // namespace patterns

// Link two ir::Nodes from each other.
Expand Down
Loading

0 comments on commit 3e6aa49

Please sign in to comment.