Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#37 from zyfncg/drr_pass
Browse files Browse the repository at this point in the history
Fix some bug
  • Loading branch information
yuanlehome authored Oct 10, 2023
2 parents d624336 + 4993a22 commit 8df68b8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
57 changes: 27 additions & 30 deletions paddle/fluid/pir/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ class DrrRewritePattern : public pir::RewritePattern {
const MatchContextImpl& res_match_ctx,
pir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& output_name : result_pattern_graph_->output_tensors()) {
if (source_pattern_graph_->output_tensors().count(output_name)) {
if (source_pattern_graph_->id2owend_tensor().count(output_name)) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get());
Expand Down Expand Up @@ -736,35 +736,32 @@ class DrrRewritePattern : public pir::RewritePattern {
result_pattern_graph.output_tensors());
std::vector<const OpCall*> deleted_ops;
std::unordered_set<const OpCall*> deleted_ops_set;
std::for_each(
topo_order_ops.rbegin(),
topo_order_ops.rend(),
[&deleted_ops,
&deleted_ops_set,
&backward_visited_tensor_set,
&forward_deleted_ops](const OpCall* op_call) {
bool all_comsumer_deleted = true;
for (const auto* output : op_call->outputs()) {
if (backward_visited_tensor_set.count(output->name())) {
for (const auto* consumer : output->consumers()) {
if (!deleted_ops_set.count(consumer)) {
all_comsumer_deleted = false;
}
}
} else if (output->consumers().empty()) {
continue;
} else {
all_comsumer_deleted = false;
}
}
if (all_comsumer_deleted && forward_deleted_ops.count(op_call)) {
deleted_ops_set.insert(op_call);
deleted_ops.push_back(op_call);
for (const auto* input : op_call->inputs()) {
backward_visited_tensor_set.insert(input->name());
}
}
});
std::for_each(topo_order_ops.rbegin(),
topo_order_ops.rend(),
[&deleted_ops,
&deleted_ops_set,
&backward_visited_tensor_set,
&forward_deleted_ops](const OpCall* op_call) {
bool all_comsumer_deleted = true;
bool from_backward_visited_tensor = false;
for (const auto* output : op_call->outputs()) {
if (backward_visited_tensor_set.count(output->name())) {
from_backward_visited_tensor = true;
} else if (output->consumers().empty()) {
continue;
} else {
all_comsumer_deleted = false;
}
}
if (all_comsumer_deleted && from_backward_visited_tensor &&
forward_deleted_ops.count(op_call)) {
deleted_ops_set.insert(op_call);
deleted_ops.push_back(op_call);
for (const auto* input : op_call->inputs()) {
backward_visited_tensor_set.insert(input->name());
}
}
});

// Delete Operation with topo order from output tensors.
for (const auto* op_call : deleted_ops) {
Expand Down
10 changes: 6 additions & 4 deletions paddle/fluid/pir/drr/ir_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IrShape {
int64_t at(int idx) const { return dims_.at(idx); }

private:
const phi::DDim& dims_;
const phi::DDim dims_;
};

class IrDtype {
Expand All @@ -45,19 +45,21 @@ class IrDtype {
bool operator==(IrDtype other) const { return dtype_ == other.dtype_; }

private:
pir::Type dtype_;
const pir::Type dtype_;
};

class IrValue : public TensorInterface {
public:
explicit IrValue(const pir::Value& value)
: value_(value),
shape_((value && value.type())
shape_((value && value.type() &&
value.type().dyn_cast<paddle::dialect::DenseTensorType>())
? value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
: phi::DDim{}),
dtype_((value && value.type())
dtype_((value && value.type() &&
value.type().dyn_cast<paddle::dialect::DenseTensorType>())
? value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()
Expand Down

0 comments on commit 8df68b8

Please sign in to comment.