Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#20 from zyfncg/drr_attr
Browse files Browse the repository at this point in the history
[DRR] Support compute attribute in drr pattern
  • Loading branch information
yuanlehome authored Aug 23, 2023
2 parents 16f066a + cd8af21 commit 524f15f
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 10 deletions.
34 changes: 30 additions & 4 deletions paddle/fluid/ir/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

#pragma once

#include <any>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <variant>

#include "paddle/fluid/ir/drr/api/match_context.h"

Expand All @@ -34,16 +36,33 @@ class PatternGraph;
class SourcePatternGraph;
class ResultPatternGraph;

class Attribute {
class NormalAttribute {
public:
explicit Attribute(const std::string& name) : attr_name_(name) {}
explicit NormalAttribute(const std::string& name) : attr_name_(name) {}

const std::string& name() const { return attr_name_; }

private:
std::string attr_name_;
};

using AttrComputeFunc = std::function<std::any(const MatchContext&)>;

class ComputeAttribute {
public:
explicit ComputeAttribute(const AttrComputeFunc& attr_compute_func)
: attr_compute_func_(attr_compute_func) {}

const AttrComputeFunc& attr_compute_func() const {
return attr_compute_func_;
}

private:
AttrComputeFunc attr_compute_func_;
};

using Attribute = std::variant<NormalAttribute, ComputeAttribute>;

class TensorShape {
public:
explicit TensorShape(const std::string& tensor_name)
Expand Down Expand Up @@ -245,7 +264,12 @@ class ResultPattern {
return ctx_->ResultTensorPattern(name);
}

Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); }
Attribute Attr(const std::string& attr_name) const {
return NormalAttribute(attr_name);
}
Attribute Attr(const AttrComputeFunc& attr_compute_func) const {
return ComputeAttribute(attr_compute_func);
}

private:
friend class SourcePattern;
Expand All @@ -269,7 +293,9 @@ class SourcePattern {
return ctx_->SourceTensorPattern(name);
}

Attribute Attr(const std::string& attr_name) { return Attribute(attr_name); }
Attribute Attr(const std::string& attr_name) const {
return NormalAttribute(attr_name);
}

void RequireEqual(const TensorShape& first, const TensorShape& second) {
ctx_->RequireEqual(first, second);
Expand Down
20 changes: 18 additions & 2 deletions paddle/fluid/ir/drr/ir_operation_creator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ ir::AttributeMap CreateAttributeMap(const OpCall& op_call,
const MatchContextImpl& src_match_ctx) {
ir::AttributeMap attr_map;
for (const auto& kv : op_call.attributes()) {
attr_map[kv.first] = src_match_ctx.GetIrAttr(kv.second.name());
std::visit(
[&](auto&& arg) {
if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
NormalAttribute>) {
attr_map[kv.first] = src_match_ctx.GetIrAttr(arg.name());
}
},
kv.second);
}
return attr_map;
}
Expand All @@ -48,7 +55,16 @@ template <typename T>
T GetAttr(const std::string& attr_name,
const OpCall& op_call,
const MatchContextImpl& src_match_ctx) {
return src_match_ctx.Attr<T>(op_call.attributes().at(attr_name).name());
const auto& attr = op_call.attributes().at(attr_name);
if (std::holds_alternative<NormalAttribute>(attr)) {
return src_match_ctx.Attr<T>(std::get<NormalAttribute>(attr).name());
} else if (std::holds_alternative<ComputeAttribute>(attr)) {
MatchContext ctx(std::make_shared<MatchContextImpl>(src_match_ctx));
return std::any_cast<T>(
std::get<ComputeAttribute>(attr).attr_compute_func()(ctx));
} else {
IR_THROW("Unknown attrbute type for : %s.", attr_name);
}
}

Operation* CreateOperation(const OpCall& op_call,
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/ir/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,22 @@ class MatchContextImpl final {
operation_map_.emplace(op_call, op);
const auto& attrs = op_call->attributes();
for (const auto& kv : attrs) {
BindIrAttr(kv.second.name(), op->get()->attribute(kv.first));
std::visit(
[&](auto&& arg) {
if constexpr (std::is_same_v<std::decay_t<decltype(arg)>,
NormalAttribute>) {
BindIrAttr(arg.name(), op->get()->attribute(kv.first));
}
},
kv.second);
}
}

private:
void BindIrAttr(const std::string& attr_name, ir::Attribute attr) {
attr_map_.emplace(attr_name, attr);
}

private:
std::unordered_map<std::string, std::shared_ptr<IrValue>> tensor_map_;
std::unordered_map<const OpCall*, std::shared_ptr<IrOperation>>
operation_map_;
Expand Down
13 changes: 11 additions & 2 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,18 @@ class RemoveRedundentTransposePattern

// Result patterns: 要替换的子图
ir::drr::ResultPattern res = pat.ResultPattern();
// todo 先简单用perm2替换
const auto &new_perm_attr =
res.Attr([](const ir::drr::MatchContext &match_ctx) -> std::any {
const auto &perm1 = match_ctx.Attr<std::vector<int>>("perm_1");
const auto &perm2 = match_ctx.Attr<std::vector<int>>("perm_2");
std::vector<int> new_perm;
for (int v : perm2) {
new_perm.emplace_back(perm1[v]);
}
return new_perm;
});
const auto &tranpose_continuous =
res.Op("pd.transpose", {{"perm", pat.Attr("perm_2")}});
res.Op("pd.transpose", {{"perm", new_perm_attr}});

res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose"));
}
Expand Down

0 comments on commit 524f15f

Please sign in to comment.