Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix special arg
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Dec 21, 2022
1 parent f234bc9 commit c3e4843
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
37 changes: 30 additions & 7 deletions cinn/hlir/pass/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using common::GraphEdge;
using common::GraphNode;

using InputToNodeMap = std::unordered_map<std::string, std::unordered_set<Node*>>;
using shape_dict_t = absl::flat_hash_map<std::string, framework::shape_t>;

std::unordered_set<std::string> unordered_ops = {
"elementwise_add",
Expand All @@ -54,7 +55,7 @@ std::unordered_set<std::string> unordered_ops = {
"reduce_min",
};

bool IsSameSubexpression(Node* op1, Node* op2, const absl::flat_hash_map<std::string, framework::shape_t>& shape_dict) {
bool IsSameSubexpression(Node* op1, Node* op2, shape_dict_t& shape_dict) {
auto op1_in_edges = op1->inlinks_in_order(true);
auto op2_in_edges = op2->inlinks_in_order(true);
auto op1_inputs_size = op1_in_edges.size();
Expand Down Expand Up @@ -94,12 +95,34 @@ bool IsSameSubexpression(Node* op1, Node* op2, const absl::flat_hash_map<std::st
}
}
}
return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) {
if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) {
if (op1->op()->name == "reshape") {
auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as<NodeData>();
auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as<NodeData>();
return shape_dict[op1_sink_node->id()] == shape_dict[op2_sink_node->id()];
} else {
auto* op1_sink_node = op1->outlinks_in_order(true)[0]->sink()->safe_as<NodeData>();
auto* op2_sink_node = op2->outlinks_in_order(true)[0]->sink()->safe_as<NodeData>();
if (shape_dict[op1_sink_node->id()].size() != shape_dict[op2_sink_node->id()].size()) {
return false;
}
return true;
});
return std::all_of(op1->attrs.attr_store.begin(), op1->attrs.attr_store.end(), [&](auto attr) {
if (!op2->attrs.attr_store.count(attr.first) || op2->attrs.attr_store[attr.first] != attr.second) {
if (attr.first == "axis" || attr.first == "dim") {
auto op1_axis = absl::get<int>(attr.second);
auto op2_axis = absl::get<int>(op2->attrs.attr_store[attr.first]);
if (op1_axis < 0) {
op1_axis += shape_dict[op1_sink_node->id()].size();
}
if (op2_axis < 0) {
op2_axis += shape_dict[op1_sink_node->id()].size();
}
return op2_axis == op1_axis;
}
return false;
}
return true;
});
}
}

void RemoveNode(framework::Graph* graph, Node* node) {
Expand Down Expand Up @@ -136,8 +159,8 @@ void ReplaceNode(NodeData* src_new, NodeData* src_old, Node* trt) {

int CommonSubexpressionElimination(Graph* graph, std::vector<GraphNode*>& store_nodes, InputToNodeMap in2node) {
std::unordered_map<std::string, std::vector<Node*>> expr_map;
auto& shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, framework::shape_t>>("infershape");
int remove_num = 0;
auto shape_dict = graph->GetAttrs<absl::flat_hash_map<std::string, framework::shape_t>>("infershape");
int remove_num = 0;
for (auto& graph_node : store_nodes) {
auto node = graph_node->safe_as<Node>();
if (node) {
Expand Down
17 changes: 8 additions & 9 deletions cinn/hlir/pass/common_subexpression_elimination_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
auto add_1 = program.add(A, B);
auto add_2 = program.add(B, A);
auto add = program.add(add_1, add_2);
auto max_1 = program.reduce_max(add, {-1}, false);
auto max_2 = program.reduce_max(add, {1}, false);
auto t_1 = program.transpose(add, {0, 1});
auto t_2 = program.transpose(add, {0, 1});
auto max = program.reduce_max(add, {0}, true);

Target target = common::DefaultTarget();
program.SetInputs({A, B});
program.Validate();
LOG(INFO) << "Program:\n" << program;
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
LOG(INFO) << "graph:\n" << graph->Visualize();
// LOG(INFO) << "graph:\n" << graph->Visualize();

hlir::framework::ApplyPass(graph.get(), "InferShape");
hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass");
Expand All @@ -75,7 +75,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
ASSERT_EQ(prerun_instrs.size(), 0);
ASSERT_EQ(run_instrs.size(), 2);
ASSERT_EQ(run_instrs.size(), 4);

scope->Var<hlir::framework::Tensor>("A");
scope->Var<hlir::framework::Tensor>("B");
Expand All @@ -95,16 +95,16 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
Program program;
auto sub_1 = program.elementwise_sub(A, A);
auto sub_2 = program.elementwise_sub(A, A);
auto add_1 = program.add(B, sub_1);
auto add_2 = program.add(sub_2, B);
auto add_1 = program.reshape(B, {4, -1});
auto add_2 = program.reshape(B, {4, 8});
auto add = program.add(add_1, add_2);

Target target = common::DefaultTarget();
program.SetInputs({A, B});
program.Validate();
LOG(INFO) << "Program:\n" << program;
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
LOG(INFO) << "graph:\n" << graph->Visualize();
// LOG(INFO) << "graph:\n" << graph->Visualize();

hlir::framework::ApplyPass(graph.get(), "InferShape");
hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass");
Expand Down Expand Up @@ -146,7 +146,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
program.Validate();
LOG(INFO) << "Program:\n" << program;
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
LOG(INFO) << "graph:\n" << graph->Visualize();
// LOG(INFO) << "graph:\n" << graph->Visualize();

hlir::framework::ApplyPass(graph.get(), "InferShape");
hlir::framework::ApplyPass(graph.get(), "CommonSubexpressionEliminationPass");
Expand All @@ -167,7 +167,6 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
SetRandData<float>(B1, target);

runtime_program->Execute();
LOG(INFO) << "graph:\n" << graph->Visualize();
}

} // namespace frontend
Expand Down

0 comments on commit c3e4843

Please sign in to comment.