From 0f8b3f54667c67c978c17b155e2773eef3bb0c8e Mon Sep 17 00:00:00 2001 From: zhangbaizhou Date: Wed, 8 May 2024 09:20:11 +0000 Subject: [PATCH] support LiftToAnchorPattern for reduce tree pattern --- .../operator/transforms/cinn_to_pd_util.cc | 4 +-- .../graph_transformer/matcher.h | 29 ++++++++++++++++--- paddle/cinn/operator_fusion/pattern_graph.cc | 11 +++---- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc index 387e69ff42d633..96cbeb792a28ec 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_to_pd_util.cc @@ -37,8 +37,8 @@ pir::Attribute ArrayAttributeToIntArrayAttribute( } const auto& handler_reduce_max_op = - [&](::pir::Operation* op, - const ::pir::Builder& builder) -> ::pir::Operation* { + [](::pir::Operation* op, + const ::pir::Builder& builder) -> ::pir::Operation* { VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op"; auto cinn_op = op->dyn_cast(); auto attr = cinn_op.attributes(); diff --git a/paddle/cinn/operator_fusion/graph_transformer/matcher.h b/paddle/cinn/operator_fusion/graph_transformer/matcher.h index db43fa23ee1eeb..de62a28b18adb6 100644 --- a/paddle/cinn/operator_fusion/graph_transformer/matcher.h +++ b/paddle/cinn/operator_fusion/graph_transformer/matcher.h @@ -88,6 +88,21 @@ struct CanFuseReduceTreeAndTrivialMatcher { } }; +struct LiftToAnchorPatternMatcher { + template + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + bool not_reduce_tree = + !StmtPatternGraphMatcher>()(graph, node) && + !StmtPatternGraphMatcher>()(graph, + node); + bool reduce_tree_with_single_reduce = + StmtPatternGraphMatcher>()(graph, node) && + std::get>(node->stmt_pattern()).childs().size() == + 0; + return not_reduce_tree || reduce_tree_with_single_reduce; + } +}; + struct RecomputeNodeMatcher { template bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { @@ -103,8 +118,11 @@ struct HasUpstreamAnchorMatcher { const PatternNodePtr& upstream, const PatternNodePtr& downstream) { return graph.policy_manager() - .template GetPolicy() - ->HasUpstreamAnchor(upstream, downstream); + .template GetPolicy() + ->CanFuse(upstream, downstream) && + graph.policy_manager() + .template GetPolicy() + ->HasUpstreamAnchor(upstream, downstream); } }; @@ -114,8 +132,11 @@ struct HasDownstreamAnchorMatcher { const PatternNodePtr& upstream, const PatternNodePtr& downstream) { return graph.policy_manager() - .template GetPolicy() - ->HasDownstreamAnchor(upstream, downstream); + .template GetPolicy() + ->CanFuse(upstream, downstream) && + graph.policy_manager() + .template GetPolicy() + ->HasDownstreamAnchor(upstream, downstream); } }; diff --git a/paddle/cinn/operator_fusion/pattern_graph.cc b/paddle/cinn/operator_fusion/pattern_graph.cc index 7ad6041ec81b8b..4ddc52c5dd1493 100644 --- a/paddle/cinn/operator_fusion/pattern_graph.cc +++ b/paddle/cinn/operator_fusion/pattern_graph.cc @@ -178,17 +178,14 @@ void PatternGraph::ReduceTree_Trivial_Fusion() { template void PatternGraph::LiftToAnchorPattern() { - GraphTransformer< - NodePattern, - T, - And>>, - Not>>>, - LiftToAnchorPatternOperation>(this); + GraphTransformer(this); } template void PatternGraph::AnchorPatternFusion() { - // TODO(@wuzhanfei) GraphTransformer