Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fuse vit attention for faster-rcnn on BML #54139

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ if(WITH_MKLDNN)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn)
pass_library(compute_propagate_scales_mkldnn_pass inference DIR mkldnn)
pass_library(self_attention_fuse_pass inference DIR mkldnn)
if(WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG)
set_target_properties(self_attention_fuse_pass
PROPERTIES COMPILE_FLAGS "-mfma ${AVX512F_FLAG}")
endif()
endif()

if(WITH_IPU)
Expand Down
75 changes: 75 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2615,6 +2615,81 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) {
return reshape2_out;
}

PDNode *patterns::SelfAttention::operator()(PDNode *in) {
in->AsInput();

std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto transpose2_0_op =
pattern->NewNode(transpose2_0_op_repr())->assert_is_op("transpose2");
auto transpose2_0_out = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("slice", "Input")
->AsIntermediate();
auto slice_0_op = pattern->NewNode(slice_0_op_repr())->assert_is_op("slice");
auto slice_0_out = pattern->NewNode(slice_0_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_ops_input(matmul_ops, "X")
->AsIntermediate();
auto slice_1_op = pattern->NewNode(slice_1_op_repr())->assert_is_op("slice");
auto slice_1_out = pattern->NewNode(slice_1_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto slice_2_op = pattern->NewNode(slice_2_op_repr())->assert_is_op("slice");
auto slice_2_out = pattern->NewNode(slice_2_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_ops_input(matmul_ops, "Y")
->AsIntermediate();
auto matmul_0_op =
pattern->NewNode(matmul_0_op_repr())->assert_is_ops(matmul_ops);
auto matmul_0_out = pattern->NewNode(matmul_0_out_repr())
->assert_is_ops_output(matmul_ops, "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto matmul_1_op =
pattern->NewNode(matmul_1_op_repr())->assert_is_ops(matmul_ops);
auto matmul_1_out = pattern->NewNode(matmul_1_out_repr())
->assert_is_ops_output(matmul_ops, "Out")
->assert_is_op_input("softmax", "X")
->AsIntermediate();
auto transpose2_1_op =
pattern->NewNode(transpose2_1_op_repr())->assert_is_op("transpose2");
auto transpose2_1_out = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_ops_input(matmul_ops, "Y")
->AsIntermediate();
auto softmax_op =
pattern->NewNode(softmax_op_repr())->assert_is_op("softmax");
auto softmax_out = pattern->NewNode(softmax_out_repr())
->assert_is_op_output("softmax", "Out")
->assert_is_ops_input(matmul_ops, "X")
->AsIntermediate();
auto transpose2_2_op =
pattern->NewNode(transpose2_2_op_repr())->assert_is_op("transpose2");
auto transpose2_2_out = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2", "Out")
->AsOutput();
transpose2_0_op->LinksFrom({in});
transpose2_0_out->LinksFrom({transpose2_0_op});
slice_0_op->LinksFrom({transpose2_0_out});
slice_0_out->LinksFrom({slice_0_op});
slice_1_op->LinksFrom({transpose2_0_out});
slice_1_out->LinksFrom({slice_1_op});
slice_2_op->LinksFrom({transpose2_0_out});
slice_2_out->LinksFrom({slice_2_op});
transpose2_1_op->LinksFrom({slice_1_out});
transpose2_1_out->LinksFrom({transpose2_1_op});
matmul_1_op->LinksFrom({slice_0_out, transpose2_1_out});
matmul_1_out->LinksFrom({matmul_1_op});
softmax_op->LinksFrom({matmul_1_out});
softmax_out->LinksFrom({softmax_op});
matmul_0_op->LinksFrom({softmax_out, slice_2_out});
matmul_0_out->LinksFrom({matmul_0_op});
transpose2_2_op->LinksFrom({matmul_0_out});
transpose2_2_out->LinksFrom({transpose2_2_op});
return transpose2_2_out;
}

PDNode *patterns::ConvElementwiseadd2Act::operator()(
PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
Expand Down
27 changes: 27 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1491,6 +1491,33 @@ struct VitAttention : public PatternBase {
PATTERN_DECL_NODE(reshape2_out);
};

// self_attention in vit
struct SelfAttention : public PatternBase {
SelfAttention(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "vit_block") {}

PDNode* operator()(PDNode* in);

PATTERN_DECL_NODE(transpose2_0_op);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_op);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_op);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(matmul_0_op);
PATTERN_DECL_NODE(matmul_0_out);
PATTERN_DECL_NODE(matmul_1_op);
PATTERN_DECL_NODE(matmul_1_out);
PATTERN_DECL_NODE(slice_0_op);
PATTERN_DECL_NODE(slice_0_out);
PATTERN_DECL_NODE(slice_1_op);
PATTERN_DECL_NODE(slice_1_out);
PATTERN_DECL_NODE(slice_2_op);
PATTERN_DECL_NODE(slice_2_out);
PATTERN_DECL_NODE(softmax_op);
PATTERN_DECL_NODE(softmax_out);
};

// Conv + ElementwiseAdd + an activation
// This pattern can further fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase {
Expand Down
150 changes: 150 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h"

#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"

#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(transpose2_0_op); \
GET_IR_NODE(transpose2_0_out); \
GET_IR_NODE(slice_0_op); \
GET_IR_NODE(slice_0_out); \
GET_IR_NODE(slice_1_op); \
GET_IR_NODE(slice_1_out); \
GET_IR_NODE(slice_2_op); \
GET_IR_NODE(slice_2_out); \
GET_IR_NODE(matmul_0_op); \
GET_IR_NODE(matmul_0_out); \
GET_IR_NODE(matmul_1_op); \
GET_IR_NODE(matmul_1_out); \
GET_IR_NODE(transpose2_1_op); \
GET_IR_NODE(transpose2_1_out); \
GET_IR_NODE(softmax_op); \
GET_IR_NODE(softmax_out); \
GET_IR_NODE(transpose2_2_op); \
GET_IR_NODE(transpose2_2_out);

namespace paddle {
namespace framework {
namespace ir {

using string::PrettyLogDetail;

void SelfAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
#if !defined(__AVX512F__) || !defined(PADDLE_WITH_MKLML) || \
!defined(PADDLE_WITH_MKLDNN)
LOG(WARNING) << "No-avx512 or MKL supported!";
return;
#endif
// do something;
GraphPatternDetector gpd;
const std::string pattern_name = "self_attention_fuse";
FusePassBase::Init(pattern_name, graph);

// pattern
PDNode* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("transpose2", "X")
->AsInput();
patterns::SelfAttention pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);

int fusion_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
// do something;
OpDesc desc(transpose2_0_op->Op()->Block());
desc.SetType("self_dp_attention");
desc.SetInput("X", {subgraph.at(x)->Name()});
desc.SetOutput("Out", {transpose2_2_out->Name()});

std::vector<int64_t> in_shape = subgraph.at(x)->Var()->GetShape();
std::vector<int64_t> shape = transpose2_0_out->Var()->GetShape();
// in shape should be [batch_size, seq_len, 3, num_heads, head_size]
if (in_shape.size() != 5 || in_shape[2] != 3 || shape.size() != 5 ||
shape[0] != 3 || shape[2] != in_shape[3]) {
LOG(WARNING) << "Self-attention shape mismatch!";
return;
}
desc.SetAttr("head_number", static_cast<int>(shape[2]));
float alpha = 1.0;
if (matmul_1_op->Op()->HasAttr("alpha"))
alpha = PADDLE_GET_CONST(float, matmul_1_op->Op()->GetAttr("alpha"));
desc.SetAttr("alpha", alpha);

// Create a new node for the fused op.
auto self_attention_node = graph->CreateOpNode(&desc);

// Link inputs and outputs.
PADDLE_ENFORCE_NE(subgraph.count(x),
0,
platform::errors::NotFound(
"Detector did not find input x of self attention."));

IR_NODE_LINK_TO(subgraph.at(x), self_attention_node); // Input
IR_NODE_LINK_TO(self_attention_node, transpose2_2_out); // Output

// Delete the unneeded nodes.
std::unordered_set<const Node*> marked_nodes({transpose2_0_op,
transpose2_0_out,
slice_0_op,
slice_0_out,
slice_1_op,
slice_1_out,
slice_2_op,
slice_2_out,
matmul_0_op,
matmul_0_out,
matmul_1_op,
matmul_1_out,
transpose2_1_op,
transpose2_1_out,
softmax_op,
softmax_out,
transpose2_2_op});

GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
AddStatis(fusion_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail(
"--- fused %d self attention (of scaled_dp_attention) with %s",
fusion_count,
pattern_name);
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(self_attention_fuse_pass,
paddle::framework::ir::SelfAttentionFusePass);
REGISTER_PASS_CAPABILITY(self_attention_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose2", 0)
.EQ("slice", 0)
.EQ("scale", 0)
.EQ("softmax", 0)
.EQ("matmul_v2", 0));
41 changes: 41 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
namespace ir {

// Fusing of self-attetion structure

class Graph;

class SelfAttentionFusePass : public FusePassBase {
public:
virtual ~SelfAttentionFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"fc_mkldnn_pass",
"fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass", //
"self_attention_fuse_pass", //
"batch_norm_act_fuse_pass", //
"softplus_activation_onednn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/operators/fused/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ register_operators(
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
yolo_box_post_op
Expand All @@ -33,6 +34,14 @@ register_operators(
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
op_library(fusion_lstm_op)
if(WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG
AND WITH_MKL)
op_library(self_dp_attention_op)
set_target_properties(self_dp_attention_op PROPERTIES COMPILE_FLAGS
"-mfma ${AVX512F_FLAG}")
endif()

if(WITH_XPU)
op_library(resnet_basic_block_op)
Expand Down
Loading