From 7a862b4987deeeaa354e7a92faef5da3eea98760 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Tue, 22 Aug 2023 12:15:43 +0000 Subject: [PATCH 1/3] use ir::get_type_name --- paddle/fluid/ir/drr/api/drr_pattern_base.h | 6 +++--- test/cpp/ir/pattern_rewrite/drr_test.cc | 15 ++++++++++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/ir/drr/api/drr_pattern_base.h b/paddle/fluid/ir/drr/api/drr_pattern_base.h index b6a742b9ffedd..f6dd0075be811 100644 --- a/paddle/fluid/ir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/ir/drr/api/drr_pattern_base.h @@ -14,14 +14,14 @@ #pragma once -#include - #include "paddle/fluid/ir/drr/api/drr_pattern_context.h" #include "paddle/fluid/ir/drr/drr_rewrite_pattern.h" +#include "paddle/ir/core/type_name.h" namespace ir { namespace drr { +template class DrrPatternBase { public: virtual ~DrrPatternBase() = default; @@ -34,7 +34,7 @@ class DrrPatternBase { DrrPatternContext drr_context; this->operator()(&drr_context); return std::make_unique( - typeid(*this).name(), drr_context, ir_context, benefit); + ir::get_type_name(), drr_context, ir_context, benefit); } }; diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index 6f610e8ae6c9f..b3803f93df0a2 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -24,7 +24,8 @@ #include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h" -class RemoveRedundentReshapePattern : public ir::drr::DrrPatternBase { +class RemoveRedundentReshapePattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { // Source patterns:待匹配的子图 @@ -44,7 +45,8 @@ class RemoveRedundentReshapePattern : public ir::drr::DrrPatternBase { } }; -class FoldBroadcastToConstantPattern : public ir::drr::DrrPatternBase { +class FoldBroadcastToConstantPattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { ir::drr::SourcePattern pat = ctx->SourcePattern(); @@ -71,7 +73,8 @@ class FoldBroadcastToConstantPattern : public ir::drr::DrrPatternBase { } }; -class RemoveRedundentTransposePattern : public ir::drr::DrrPatternBase { +class RemoveRedundentTransposePattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { // Source pattern: 待匹配的子图 @@ -93,7 +96,8 @@ class RemoveRedundentTransposePattern : public ir::drr::DrrPatternBase { } }; -class RemoveRedundentCastPattern : public ir::drr::DrrPatternBase { +class RemoveRedundentCastPattern + : public ir::drr::DrrPatternBase { void operator()(ir::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = @@ -106,7 +110,8 @@ class RemoveRedundentCastPattern : public ir::drr::DrrPatternBase { } }; -class RemoveUselessCastPattern : public ir::drr::DrrPatternBase { +class RemoveUselessCastPattern + : public ir::drr::DrrPatternBase { public: void operator()(ir::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); From a1d8c7e7680fe99ccd32716b6f6b028c2eb7287b Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 23 Aug 2023 07:10:02 +0000 Subject: [PATCH 2/3] support compute attrbute in drr pattern --- paddle/fluid/ir/drr/api/drr_pattern_context.h | 35 ++++++++++++++++--- paddle/fluid/ir/drr/ir_operation_creator.cc | 20 +++++++++-- paddle/fluid/ir/drr/match_context_impl.h | 11 ++++-- test/cpp/ir/pattern_rewrite/drr_test.cc | 13 +++++-- 4 files changed, 69 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/ir/drr/api/drr_pattern_context.h b/paddle/fluid/ir/drr/api/drr_pattern_context.h index 45be4b3963a7b..fe9aec52528db 100644 --- a/paddle/fluid/ir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/ir/drr/api/drr_pattern_context.h @@ -14,11 +14,13 @@ #pragma once +#include #include #include #include #include #include +#include #include "paddle/fluid/ir/drr/api/match_context.h" @@ -34,9 +36,9 @@ class PatternGraph; class SourcePatternGraph; class ResultPatternGraph; -class Attribute { +class NormalAttribute { public: - explicit Attribute(const std::string& name) : attr_name_(name) {} + explicit NormalAttribute(const std::string& name) : attr_name_(name) {} const std::string& name() const { return attr_name_; } @@ -44,6 +46,23 @@ class Attribute { std::string attr_name_; }; +class ComputeAttribute { + public: + explicit ComputeAttribute( + const std::function& attr_compute_func) + : attr_compute_func_(attr_compute_func) {} + + const std::function& attr_compute_func() + const { + return attr_compute_func_; + } + + private: + std::function attr_compute_func_; +}; + +using Attribute = std::variant; + class TensorShape { public: explicit TensorShape(const std::string& tensor_name) @@ -245,7 +264,13 @@ class ResultPattern { return ctx_->ResultTensorPattern(name); } - Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } + Attribute Attr( + std::function attr_compute_func) const { + return ComputeAttribute(attr_compute_func); + } private: friend class SourcePattern; @@ -269,7 +294,9 @@ class SourcePattern { return ctx_->SourceTensorPattern(name); } - Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); } + Attribute Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); + } void RequireEqual(const TensorShape& first, const TensorShape& second) { ctx_->RequireEqual(first, second); diff --git a/paddle/fluid/ir/drr/ir_operation_creator.cc b/paddle/fluid/ir/drr/ir_operation_creator.cc index 729a47cc3a691..15a9bb2df5083 100644 --- a/paddle/fluid/ir/drr/ir_operation_creator.cc +++ b/paddle/fluid/ir/drr/ir_operation_creator.cc @@ -39,7 +39,14 @@ ir::AttributeMap CreateAttributeMap(const OpCall& op_call, const MatchContextImpl& src_match_ctx) { ir::AttributeMap attr_map; for (const auto& kv : op_call.attributes()) { - attr_map[kv.first] = src_match_ctx.GetIrAttr(kv.second.name()); + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name()); + } + }, + kv.second); } return attr_map; } @@ -48,7 +55,16 @@ template T GetAttr(const std::string& attr_name, const OpCall& op_call, const MatchContextImpl& src_match_ctx) { - return src_match_ctx.Attr(op_call.attributes().at(attr_name).name()); + const auto& attr = op_call.attributes().at(attr_name); + if (std::holds_alternative(attr)) { + return src_match_ctx.Attr(std::get(attr).name()); + } else if (std::holds_alternative(attr)) { + MatchContext ctx(std::make_shared(src_match_ctx)); + return std::any_cast( + std::get(attr).attr_compute_func()(ctx)); + } else { + IR_THROW("Unknown attrbute type for : %s.", attr_name); + } } Operation* CreateOperation(const OpCall& op_call, diff --git a/paddle/fluid/ir/drr/match_context_impl.h b/paddle/fluid/ir/drr/match_context_impl.h index 6a184e6d45527..bc2ccae99e9f2 100644 --- a/paddle/fluid/ir/drr/match_context_impl.h +++ b/paddle/fluid/ir/drr/match_context_impl.h @@ -107,15 +107,22 @@ class MatchContextImpl final { operation_map_.emplace(op_call, op); const auto& attrs = op_call->attributes(); for (const auto& kv : attrs) { - BindIrAttr(kv.second.name(), op->get()->attribute(kv.first)); + std::visit( + [&](auto&& arg) { + if constexpr (std::is_same_v, + NormalAttribute>) { + BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + } + }, + kv.second); } } + private: void BindIrAttr(const std::string& attr_name, ir::Attribute attr) { attr_map_.emplace(attr_name, attr); } - private: std::unordered_map> tensor_map_; std::unordered_map> operation_map_; diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index b3803f93df0a2..19e05b9378d92 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -88,9 +88,18 @@ class RemoveRedundentTransposePattern // Result patterns: 要替换的子图 ir::drr::ResultPattern res = pat.ResultPattern(); - // todo 先简单用perm2替换 + const auto &new_perm_attr = + res.Attr([](const ir::drr::MatchContext &match_ctx) -> std::any { + const auto &perm1 = match_ctx.Attr>("perm_1"); + const auto &perm2 = match_ctx.Attr>("perm_2"); + std::vector new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); const auto &tranpose_continuous = - res.Op("pd.transpose", {{"perm", pat.Attr("perm_2")}}); + res.Op("pd.transpose", {{"perm", new_perm_attr}}); res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } From cd8af21c2addf4655a022fd6f789eb777cee726f Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 23 Aug 2023 07:27:44 +0000 Subject: [PATCH 3/3] refine code --- paddle/fluid/ir/drr/api/drr_pattern_context.h | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/ir/drr/api/drr_pattern_context.h b/paddle/fluid/ir/drr/api/drr_pattern_context.h index fe9aec52528db..85be176fcf478 100644 --- a/paddle/fluid/ir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/ir/drr/api/drr_pattern_context.h @@ -46,19 +46,19 @@ class NormalAttribute { std::string attr_name_; }; +using AttrComputeFunc = std::function; + class ComputeAttribute { public: - explicit ComputeAttribute( - const std::function& attr_compute_func) + explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func) : attr_compute_func_(attr_compute_func) {} - const std::function& attr_compute_func() - const { + const AttrComputeFunc& attr_compute_func() const { return attr_compute_func_; } private: - std::function attr_compute_func_; + AttrComputeFunc attr_compute_func_; }; using Attribute = std::variant; @@ -267,8 +267,7 @@ class ResultPattern { Attribute Attr(const std::string& attr_name) const { return NormalAttribute(attr_name); } - Attribute Attr( - std::function attr_compute_func) const { + Attribute Attr(const AttrComputeFunc& attr_compute_func) const { return ComputeAttribute(attr_compute_func); }