Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#17 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
support LiftToAnchorPattern for reduce tree pattern
  • Loading branch information
feifei-111 authored May 8, 2024
2 parents e10f83a + 0f8b3f5 commit f84c7f4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ pir::Attribute ArrayAttributeToIntArrayAttribute(
}

const auto& handler_reduce_max_op =
[&](::pir::Operation* op,
const ::pir::Builder& builder) -> ::pir::Operation* {
[](::pir::Operation* op,
const ::pir::Builder& builder) -> ::pir::Operation* {
VLOG(6) << "transform " << op->name() << " from cinn_op to pd_op";
auto cinn_op = op->dyn_cast<cinn::dialect::ReduceMaxOp>();
auto attr = cinn_op.attributes();
Expand Down
29 changes: 25 additions & 4 deletions paddle/cinn/operator_fusion/graph_transformer/matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ struct CanFuseReduceTreeAndTrivialMatcher {
}
};

struct LiftToAnchorPatternMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
bool not_reduce_tree =
!StmtPatternGraphMatcher<ReduceTreePattern<T>>()(graph, node) &&
!StmtPatternGraphMatcher<ReduceTreePlusTrivialPattern<T>>()(graph,
node);
bool reduce_tree_with_single_reduce =
StmtPatternGraphMatcher<ReduceTreePattern<T>>()(graph, node) &&
std::get<ReduceTreePattern<T>>(node->stmt_pattern()).childs().size() ==
0;
return not_reduce_tree || reduce_tree_with_single_reduce;
}
};

struct RecomputeNodeMatcher {
template <typename T>
bool operator()(const PatternGraph<T>& graph, const PatternNodePtr<T>& node) {
Expand All @@ -103,8 +118,11 @@ struct HasUpstreamAnchorMatcher {
const PatternNodePtr<T>& upstream,
const PatternNodePtr<T>& downstream) {
return graph.policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasUpstreamAnchor(upstream, downstream);
.template GetPolicy<GeneralTopoPolicy>()
->CanFuse(upstream, downstream) &&
graph.policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasUpstreamAnchor(upstream, downstream);
}
};

Expand All @@ -114,8 +132,11 @@ struct HasDownstreamAnchorMatcher {
const PatternNodePtr<T>& upstream,
const PatternNodePtr<T>& downstream) {
return graph.policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasDownstreamAnchor(upstream, downstream);
.template GetPolicy<GeneralTopoPolicy>()
->CanFuse(upstream, downstream) &&
graph.policy_manager()
.template GetPolicy<AnchorSearchPolicy>()
->HasDownstreamAnchor(upstream, downstream);
}
};

Expand Down
11 changes: 4 additions & 7 deletions paddle/cinn/operator_fusion/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,14 @@ void PatternGraph<T>::ReduceTree_Trivial_Fusion() {

template <typename T>
void PatternGraph<T>::LiftToAnchorPattern() {
GraphTransformer<
NodePattern,
T,
And<Not<StmtPatternGraphMatcher<ReduceTreePattern<T>>>,
Not<StmtPatternGraphMatcher<ReduceTreePlusTrivialPattern<T>>>>,
LiftToAnchorPatternOperation>(this);
GraphTransformer<NodePattern,
T,
LiftToAnchorPatternMatcher,
LiftToAnchorPatternOperation>(this);
}

template <typename T>
void PatternGraph<T>::AnchorPatternFusion() {
// TODO(@wuzhanfei)
GraphTransformer<ReverseTopoNodePairPattern,
T,
HasUpstreamAnchorMatcher,
Expand Down

0 comments on commit f84c7f4

Please sign in to comment.