Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#1 from Fridge003/cinn_tmp
Browse files Browse the repository at this point in the history
spliting
  • Loading branch information
feifei-111 authored Mar 20, 2024
2 parents 890c560 + 05aeb8f commit dfee88f
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 112 deletions.
16 changes: 7 additions & 9 deletions paddle/cinn/frontend/cluster_ops/cluster_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/frontend/group_pattern/cluster_policy.h"
#include "paddle/cinn/frontend/cluster_ops/cluster_policy.h"

namespace cinn::frontend {

std::shared_ptr<ClusteringPolicy> MakeLoopAlignableClusteringPolicy(
const pir::ShapeConstraintIRAnalysis* shape_analysis) {
return std::make_shared<LoopAlignableClusteringPolicy>(shape_analysis);
}

class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
public:
explicit LoopAlignableClusteringPolicy(
Expand Down Expand Up @@ -233,8 +226,13 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
return GetRank(reduce_op->result(result_idx)) == shardable_axes.size();
}
}

const pir::ShapeConstraintIRAnalysis* shape_analysis_;
};

std::shared_ptr<ClusteringPolicy> MakeLoopAlignableClusteringPolicy(
const pir::ShapeConstraintIRAnalysis* shape_analysis) {
return std::make_shared<LoopAlignableClusteringPolicy>(shape_analysis);
}


} // namespace cinn::frontend
8 changes: 2 additions & 6 deletions paddle/cinn/frontend/cluster_ops/cluster_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

#pragma once

#include "paddle/cinn/frontend/group_pattern.h"
#include "paddle/cinn/frontend/cluster_ops/group_pattern.h"
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"

namespace cinn::frontend {
Expand Down Expand Up @@ -43,9 +44,4 @@ class ClusteringPolicy {

std::shared_ptr<ClusteringPolicy> MakeLoopAlignableClusteringPolicy(
const pir::ShapeConstraintIRAnalysis* shape_analysis);

GroupPattern GenerateGroupPatternFromOpList(
const std::vector<const pir::Operation*>& ops,
const std::shared_ptr<ShardableAxesProvider>& shardable_axes_provider);

} // namespace cinn::frontend
111 changes: 15 additions & 96 deletions paddle/cinn/frontend/cluster_ops/clustering_engine.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/frontend/cluster_ops/cluster_engine.h"

class ClusteringEngine {
public:
Expand Down Expand Up @@ -476,102 +491,6 @@ class ClusteringEngine {
"clustering_policy_->CanActAsSink() returns false all the time.";
}

using ShardableAxes4ValueT =
std::function<std::optional<const ShardableAxes*>(pir::Value)>;
ShardableAxes4ValueT MakeInferedShardableAxes4Value(
const std::vector<const StmtPattern*>& stmt_ptrs) {
const OpSetPtr ops = [&] {
auto ops = std::make_shared<OpSet>();
for (const auto* stmt_ptr : stmt_ptrs) {
VisitStmtOp(*stmt_ptr, [&](const auto* op) { ops->insert(op); });
}
return ops;
}();
auto value2shardable_axes = shardable_axes_inferer_.InferShardableAxes(ops);
return [map = std::move(value2shardable_axes)](
pir::Value value) -> std::optional<const ShardableAxes*> {
const auto& iter = map.find(value);
if (iter == map.end()) return std::nullopt;
return &iter->second;
};
}

common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns) {
using StmtPtrs = std::vector<const StmtPattern*>;
using Op2OwnerStmtPtrs =
std::unordered_map<const pir::Operation*, StmtPtrs>;
auto op2owner_stmt_ptr = std::make_shared<Op2OwnerStmtPtrs>();
for (const auto& stmt : stmt_patterns) {
VisitStmtOp(stmt, [&](const pir::Operation* op) {
(*op2owner_stmt_ptr)[op].push_back(&stmt);
});
}
using NodeVisitor = std::function<void(const StmtPattern*)>;
auto VisitInput = [=](const StmtPattern* stmt, const NodeVisitor& DoEach) {
VisitStmtOp(*stmt, [&](const auto* op) {
op_topo.VisitInputOp(op, [&](const auto* input_op) {
const auto& owners_iter = op2owner_stmt_ptr->find(input_op);
if (owners_iter == op2owner_stmt_ptr->end()) return;
if (owners_iter->second.size() != 1) return;
const auto* owner_stmt = *owners_iter->second.begin();
if (owner_stmt == stmt) return;
DoEach(owner_stmt);
});
});
};
auto VisitOutput = [=](const StmtPattern* stmt, const NodeVisitor& DoEach) {
const auto* sink = GetStmtSoleSinkOp(*stmt);
op_topo.VisitOutputOp(sink, [&](const pir::Operation* op) {
const auto& owners_iter = op2owner_stmt_ptr->find(op);
if (owners_iter == op2owner_stmt_ptr->end()) return;
for (const StmtPattern* stmt : owners_iter->second) {
DoEach(stmt);
}
});
};
const auto& TryPushBack = [](const auto* stmt, auto* stmts) {
if (std::find(stmts->begin(), stmts->end(), stmt) == stmts->end()) {
stmts->push_back(stmt);
}
};
using EdgeCache =
std::unordered_map<const StmtPattern*, std::vector<const StmtPattern*>>;
auto stmt2inputs = std::make_shared<EdgeCache>();
auto stmt2outputs = std::make_shared<EdgeCache>();
for (const auto& stmt : stmt_patterns) {
(void)(*stmt2inputs)[&stmt];
VisitInput(&stmt, [&](const auto* input) {
TryPushBack(input, &(*stmt2inputs)[&stmt]);
});
(void)(*stmt2outputs)[&stmt];
VisitOutput(&stmt, [&](const auto* output) {
TryPushBack(output, &(*stmt2outputs)[&stmt]);
});
}

auto VisitCachedInput = [stmt2inputs](const auto* stmt,
const NodeVisitor& DoEach) {
const auto& map = (*stmt2inputs);
const auto& iter = map.find(stmt);
if (iter == map.end()) return;
for (const auto* input : iter->second) {
DoEach(input);
}
};
auto VisitCachedOutput = [stmt2outputs](const auto* stmt,
const NodeVisitor& DoEach) {
const auto& map = (*stmt2outputs);
const auto& iter = map.find(stmt);
if (iter == map.end()) return;
for (const auto* output : iter->second) {
DoEach(output);
}
};
return common::TopoWalker<const StmtPattern*>(VisitCachedInput,
VisitCachedOutput);
}

const std::vector<const pir::Operation*> ops_;
const std::shared_ptr<ClusteringPolicy> clustering_policy_;
ShardableAxesInferer shardable_axes_inferer_;
Expand Down
120 changes: 120 additions & 0 deletions paddle/cinn/frontend/cluster_ops/clustering_engine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/cinn/frontend/cluster_ops/cluster_policy.h"
#include "paddle/cinn/frontend/cluster_ops/common_utils.h"
#include "paddle/cinn/frontend/cluster_ops/group_pattern.h"
#include "paddle/cinn/frontend/cluster_ops/pattern_utils.h"
#include "paddle/cinn/frontend/cluster_ops/shardable_axes_provider.h"


namespace cinn::frontend {

class ClusteringEngine {
public:
ClusteringEngine(const std::vector<const pir::Operation*>& ops,
const ShardableAxesInferer& shardable_axes_inferer,
const std::shared_ptr<ClusteringPolicy>& clustering_policy);

ClusteringResult ClusterOps();

private:
void SortStmtsList(
std::vector<std::vector<const StmtPattern*>>* stmt_ptrs,
const std::function<size_t(const pir::Operation*)>& OrderValue4Op);

template <typename DoEachComponentT>
void VisitConnectedComponent(
const common::BfsWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns,
const DoEachComponentT& DoEachComponent);

common::BfsWalker<const StmtPattern*> MakeAcyclicSameClusterBfsWalker(
const std::vector<StmtPattern>& stmt_patterns);

using IsAcyclicConnectedT =
std::function<bool(const StmtPattern* src, const StmtPattern* dst)>;
using ClusterRoot4StmtT =
std::function<const StmtPattern*(const StmtPattern*)>;

IsAcyclicConnectedT MakePredicatorIsAcyclicConnected(
const common::TopoWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns,
const ClusterRoot4StmtT& ClusterRoot4Stmt);

struct TopoClosure {
std::list<const StmtPattern*> sources;
std::list<const StmtPattern*> sinks;
std::unordered_set<const StmtPattern*> stmts;
};

using IsReachableT =
std::function<bool(const StmtPattern* src, const StmtPattern* dst)>;

using TopoClosure4RootStmtT =
std::function<std::optional<const TopoClosure*>(const StmtPattern*)>;

using AllTopClosureUpstreams4StmtT =
std::function<const std::set<const StmtPattern*>*(const StmtPattern*)>;

AllTopClosureUpstreams4StmtT MakeAllTopClosureUpstreams4Stmt(
const common::TopoWalker<const StmtPattern*>& entire_topo_walker,
const std::vector<StmtPattern>& stmt_patterns,
const ClusterRoot4StmtT& ClusterRoot4Stmt);

TopoClosure4RootStmtT MakeTopoClosure4RootStmt(
const common::TopoWalker<const StmtPattern*>& entire_topo_walker,
const std::vector<StmtPattern>& stmt_patterns,
const ClusterRoot4StmtT& ClusterRoot4Stmt);

std::unordered_set<const StmtPattern*> CollectSubGraphAllStmts(
const common::TopoWalker<const StmtPattern*>& entire_topo_walker,
const IsReachableT& IsReachable,
const std::list<const StmtPattern*> sources,
const std::list<const StmtPattern*> sinks);

template <typename DoEachStmtAndTopoClosureUpstreamsT>
void VisitStmtTopoClosureUpstreams(
const common::TopoWalker<const StmtPattern*>& entire_topo_walker,
const TopoClosure& topo_closure,
const DoEachStmtAndTopoClosureUpstreamsT&
DoEachStmtAndTopoClosureUpstreams);

IsReachableT MakeIsReachable(
const common::TopoWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns);

std::function<const StmtPattern*(const StmtPattern*)> MakeClusterRoot4Stmt(
const common::TopoWalker<const StmtPattern*>& topo_walker,
const std::vector<StmtPattern>& stmt_patterns);

template <typename DoEachComponentT>
void VisitClusterStmts(const common::TopoWalker<const StmtPattern*>& walker,
const std::vector<StmtPattern>& stmt_patterns,
const DoEachComponentT& DoEachComponent);

template <typename DoEachComponentT>
void VisitInferedClusterStmts(
const common::TopoWalker<const StmtPattern*>& entire_topo_walker,
const std::vector<const StmtPattern*>& stmt_ptrs,
const DoEachComponentT& DoEachComponent);

const std::vector<const pir::Operation*> ops_;
const std::shared_ptr<ClusteringPolicy> clustering_policy_;
ShardableAxesInferer shardable_axes_inferer_;
const OpTopo op_topo_;
};

} // namespace cinn::frontend
77 changes: 76 additions & 1 deletion paddle/cinn/frontend/cluster_ops/pattern_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,79 @@ void SortStmtPtrs(
return lhs_order < rhs_order;
};
std::sort(stmt_ptrs->begin(), stmt_ptrs->end(), Cmp);
}
}
common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns) {
using StmtPtrs = std::vector<const StmtPattern*>;
using Op2OwnerStmtPtrs =
std::unordered_map<const pir::Operation*, StmtPtrs>;
auto op2owner_stmt_ptr = std::make_shared<Op2OwnerStmtPtrs>();
for (const auto& stmt : stmt_patterns) {
VisitStmtOp(stmt, [&](const pir::Operation* op) {
(*op2owner_stmt_ptr)[op].push_back(&stmt);
});
}
using NodeVisitor = std::function<void(const StmtPattern*)>;
auto VisitInput = [=](const StmtPattern* stmt, const NodeVisitor& DoEach) {
VisitStmtOp(*stmt, [&](const auto* op) {
op_topo.VisitInputOp(op, [&](const auto* input_op) {
const auto& owners_iter = op2owner_stmt_ptr->find(input_op);
if (owners_iter == op2owner_stmt_ptr->end()) return;
if (owners_iter->second.size() != 1) return;
const auto* owner_stmt = *owners_iter->second.begin();
if (owner_stmt == stmt) return;
DoEach(owner_stmt);
});
});
};
auto VisitOutput = [=](const StmtPattern* stmt, const NodeVisitor& DoEach) {
const auto* sink = GetStmtSoleSinkOp(*stmt);
op_topo.VisitOutputOp(sink, [&](const pir::Operation* op) {
const auto& owners_iter = op2owner_stmt_ptr->find(op);
if (owners_iter == op2owner_stmt_ptr->end()) return;
for (const StmtPattern* stmt : owners_iter->second) {
DoEach(stmt);
}
});
};
const auto& TryPushBack = [](const auto* stmt, auto* stmts) {
if (std::find(stmts->begin(), stmts->end(), stmt) == stmts->end()) {
stmts->push_back(stmt);
}
};
using EdgeCache =
std::unordered_map<const StmtPattern*, std::vector<const StmtPattern*>>;
auto stmt2inputs = std::make_shared<EdgeCache>();
auto stmt2outputs = std::make_shared<EdgeCache>();
for (const auto& stmt : stmt_patterns) {
(void)(*stmt2inputs)[&stmt];
VisitInput(&stmt, [&](const auto* input) {
TryPushBack(input, &(*stmt2inputs)[&stmt]);
});
(void)(*stmt2outputs)[&stmt];
VisitOutput(&stmt, [&](const auto* output) {
TryPushBack(output, &(*stmt2outputs)[&stmt]);
});
}

auto VisitCachedInput = [stmt2inputs](const auto* stmt,
const NodeVisitor& DoEach) {
const auto& map = (*stmt2inputs);
const auto& iter = map.find(stmt);
if (iter == map.end()) return;
for (const auto* input : iter->second) {
DoEach(input);
}
};
auto VisitCachedOutput = [stmt2outputs](const auto* stmt,
const NodeVisitor& DoEach) {
const auto& map = (*stmt2outputs);
const auto& iter = map.find(stmt);
if (iter == map.end()) return;
for (const auto* output : iter->second) {
DoEach(output);
}
};
return common::TopoWalker<const StmtPattern*>(VisitCachedInput,
VisitCachedOutput);
}
2 changes: 2 additions & 0 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns);

0 comments on commit dfee88f

Please sign in to comment.