Skip to content

Commit

Permalink
[pir]Adding Set and Get attr method for pir passes (PaddlePaddle#60253)
Browse files Browse the repository at this point in the history
* [pir]Adding Set and Get method for pir passes

* fix codestyle

* Update constant_folding_pass.cc
  • Loading branch information
zhangyuqin1998 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent 53368f2 commit aec1a81
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 33 deletions.
15 changes: 12 additions & 3 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,9 +806,18 @@ bool AnalysisPredictor::PrepareExecutor() {

//----------------------------------------------------------------------------------------------//
// Basic pass required by the framework
gpu_pm.AddPass(
::pir::CreateParamsSyncAmongDevicesPass(place_, sub_scope_));
gpu_pm.AddPass(::pir::CreateConstantFoldingPass(place_, sub_scope_));
auto params_sync_among_devices_pass =
::pir::CreateParamsSyncAmongDevicesPass();
params_sync_among_devices_pass->SetNotOwned(pir::kPlaceAttr, &place_);
params_sync_among_devices_pass->SetNotOwned(pir::kParamScopeAttr,
sub_scope_);

auto constant_folding_pass = ::pir::CreateConstantFoldingPass();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_);

gpu_pm.AddPass(std::move(params_sync_among_devices_pass));
gpu_pm.AddPass(std::move(constant_folding_pass));
gpu_pm.AddPass(::pir::CreateDeadCodeEliminationPass());
gpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass());
//----------------------------------------------------------------------------------------------//
Expand Down
30 changes: 21 additions & 9 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,28 @@ class ConstantFoldingPattern : public pir::RewritePattern {

class ConstantFoldingPass : public pir::Pass {
public:
explicit ConstantFoldingPass(const phi::Place& place,
paddle::framework::Scope* scope)
: pir::Pass("constant_folding_pass", 1), place_(place), scope_(scope) {
PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
}
ConstantFoldingPass()
: pir::Pass("constant_folding_pass", 1),
place_(phi::CPUPlace{}),
scope_(nullptr) {}

private:
bool Initialize(pir::IrContext* context) override {
IR_ENFORCE(Has(pir::kPlaceAttr),
"Pass initialize failed."
"When using ConstantFoldingPass, place attribute is required!"
"Use Set method to set the place attribute.");
IR_ENFORCE(Has(pir::kParamScopeAttr),
"Pass initialize failed."
"When using ConstantFoldingPass, scope attribute is required!"
"Use Set method to set the scope attribute.");

place_ = Get<phi::Place>(pir::kPlaceAttr);
scope_ = &Get<paddle::framework::Scope>(pir::kParamScopeAttr);

PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));

pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
Expand Down Expand Up @@ -354,9 +367,8 @@ class ConstantFoldingPass : public pir::Pass {

namespace pir {

std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ConstantFoldingPass>(place, scope);
std::unique_ptr<Pass> CreateConstantFoldingPass() {
return std::make_unique<ConstantFoldingPass>();
}

} // namespace pir
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/constant_folding_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateConstantFoldingPass(
const phi::Place& place, paddle::framework::Scope* scope);
IR_API std::unique_ptr<Pass> CreateConstantFoldingPass();

} // namespace pir
29 changes: 21 additions & 8 deletions paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,30 @@ namespace {

class ParamsSyncAmongDevicesPass : public pir::Pass {
public:
ParamsSyncAmongDevicesPass(const phi::Place& place,
paddle::framework::Scope* scope)
: pir::Pass("params_sync_among_devices_pass", 0),
place_(place),
scope_(scope) {
ParamsSyncAmongDevicesPass()
: pir::Pass("params_sync_among_devices_pass", 0) {}

bool Initialize(pir::IrContext* context) override {
IR_ENFORCE(Has(pir::kPlaceAttr),
"Pass initialize failed."
"When using ConstantFoldingPass, place attribute is required!"
"Use Set method to set the place attribute.");
IR_ENFORCE(Has(pir::kParamScopeAttr),
"Pass initialize failed."
"When using ConstantFoldingPass, scope attribute is required!"
"Use Set method to set the scope attribute.");

place_ = Get<phi::Place>(pir::kPlaceAttr);
scope_ = &Get<paddle::framework::Scope>(pir::kParamScopeAttr);

PADDLE_ENFORCE_NOT_NULL(
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));
PADDLE_ENFORCE(
paddle::platform::is_gpu_place(place_) ||
paddle::platform::is_cpu_place(place_),
phi::errors::PreconditionNotMet(
"params_sync_among_devices_pass should run on cpu or gpu."));
return true;
}

void Run(pir::Operation* op) override {
Expand Down Expand Up @@ -94,9 +108,8 @@ class ParamsSyncAmongDevicesPass : public pir::Pass {

namespace pir {

std::unique_ptr<pir::Pass> CreateParamsSyncAmongDevicesPass(
const phi::Place& place, paddle::framework::Scope* scope) {
return std::make_unique<ParamsSyncAmongDevicesPass>(place, scope);
std::unique_ptr<pir::Pass> CreateParamsSyncAmongDevicesPass() {
return std::make_unique<ParamsSyncAmongDevicesPass>();
}

} // namespace pir
3 changes: 1 addition & 2 deletions paddle/fluid/pir/transforms/params_sync_among_devices_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateParamsSyncAmongDevicesPass(
const phi::Place& place, paddle::framework::Scope* scope);
IR_API std::unique_ptr<Pass> CreateParamsSyncAmongDevicesPass();

} // namespace pir
77 changes: 77 additions & 0 deletions paddle/pir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

#pragma once

#include <any>
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/common/enforce.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/pass/analysis_manager.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
Expand Down Expand Up @@ -68,6 +71,9 @@ struct PassInfo {

} // namespace detail

static const char kParamScopeAttr[] = "__param_scope__";
static const char kPlaceAttr[] = "__place__";

/// We can access pass only from PassManager.
class IR_API Pass {
public:
Expand All @@ -82,6 +88,74 @@ class IR_API Pass {

const detail::PassInfo& pass_info() const { return pass_info_; }

// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType& Get(const std::string& attr_name) const {
IR_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"Attribute %s not registered for pass.",
attr_name);
try {
return *std::any_cast<AttrType*>(attrs_.at(attr_name));
} catch (std::bad_any_cast&) {
auto TypeToString = [](const std::type_info& info) -> std::string {
if (std::type_index(info) == std::type_index(typeid(bool*))) {
return "bool";
} else if (std::type_index(info) == std::type_index(typeid(int*))) {
return "int";
} else if (std::type_index(info) ==
std::type_index(typeid(const int*))) {
return "const int";
} else if (std::type_index(info) ==
std::type_index(typeid(std::string*))) {
return "std::string";
}
return info.name();
};

IR_THROW("Invalid type for attritube %s, expected: %s, actual: %s.",
attr_name,
TypeToString(typeid(AttrType*)),
TypeToString(attrs_.at(attr_name).type()));
}
}

bool Has(const std::string& attr_name) const {
return attrs_.count(attr_name) > 0;
}

void Erase(const std::string& attr_name) {
if (!Has(attr_name)) {
return;
}
if (attr_dels_.find(attr_name) != attr_dels_.end()) {
attr_dels_[attr_name]();
attr_dels_.erase(attr_name);
}
attrs_.erase(attr_name);
}

// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string& attr_name, AttrType* attr) {
VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
<< name();
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(8) << "deleting " << attr_name;
delete attr;
};
}

// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string& attr_name, AttrType* attr) {
IR_ENFORCE(0 == attrs_.count(attr_name),
"Attribute %s already set in the pass.",
attr_name);
attrs_[attr_name] = attr;
}

protected:
virtual void Run(Operation* op) = 0;

Expand All @@ -108,6 +182,9 @@ class IR_API Pass {

friend class PassManager;
friend class detail::PassAdaptor;

std::unordered_map<std::string, std::any> attrs_;
std::unordered_map<std::string, std::function<void(void)>> attr_dels_;
};

class PatternRewritePass : public Pass {
Expand Down
9 changes: 7 additions & 2 deletions test/cpp/pir/pattern_rewrite/drr_attention_fuse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,13 @@ TEST(DrrTest, AttentionFuse) {

pir::PassManager pm(ctx);
pm.AddPass(pir::CreateAttentionFusePass());
paddle::framework::Scope scope;
pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->Set(pir::kParamScopeAttr,
new paddle::framework::Scope());
pm.AddPass(std::move(constant_folding_pass));
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
Expand Down
34 changes: 27 additions & 7 deletions test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,13 @@ TEST(pattern_rewrite, Patterns) {
pm.AddPass(pir::CreateConv2dBnFusePass());
pm.AddPass(pir::CreateConv2dAddActFusePass());
pm.AddPass(pir::CreateConv2dAddFusePass());
paddle::framework::Scope scope;
pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->Set(pir::kParamScopeAttr,
new paddle::framework::Scope());
pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
// pm.EnablePassTiming();
pm.EnableIRPrinting();
Expand Down Expand Up @@ -475,7 +480,12 @@ TEST(constant_folding, ConstantFolding) {
BuildConstantFoldingProgram(&program, ctx, &scope);

pir::PassManager pm(ctx);
pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope);
pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

Expand Down Expand Up @@ -537,8 +547,13 @@ TEST(constant_folding, ConstantFolding_Combine) {
BuildConcatProgram(&program, ctx);

pir::PassManager pm(ctx);
paddle::framework::Scope scope;
pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->Set(pir::kParamScopeAttr,
new paddle::framework::Scope());
pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

Expand Down Expand Up @@ -573,8 +588,13 @@ TEST(constant_folding, ConstantFolding_MultiOutput) {
BuildMultiOutputProgram(&program, ctx);

pir::PassManager pm(ctx);
paddle::framework::Scope scope;
pm.AddPass(pir::CreateConstantFoldingPass(phi::CPUPlace{}, &scope));
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->Set(pir::kParamScopeAttr,
new paddle::framework::Scope());
pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

Expand Down

0 comments on commit aec1a81

Please sign in to comment.