Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add complex test case
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Dec 31, 2022
1 parent c3e4843 commit c2342d6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cinn/hlir/pass/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ int CommonSubexpressionElimination(Graph* graph, std::vector<GraphNode*>& store_
}
return remove_num;
}
//

void CommonSubexpressionEliminationPass(Graph* graph) {
VLOG(3) << "CommonSubexpressionEliminationPass...!";
std::unordered_map<std::string, std::vector<Node*>> expr_map;
Expand Down
47 changes: 30 additions & 17 deletions cinn/hlir/pass/common_subexpression_elimination_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
auto add_1 = program.add(A, B);
auto add_2 = program.add(B, A);
auto add = program.add(add_1, add_2);
auto t_1 = program.transpose(add, {0, 1});
auto t_2 = program.transpose(add, {0, 1});
auto t_1 = program.transpose(add, {1, 0});
auto t_2 = program.transpose(add, {1, 0});
auto max = program.reduce_max(add, {0}, true);

Target target = common::DefaultTarget();
Expand Down Expand Up @@ -93,11 +93,11 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
Placeholder B(Float(32), {32, 1}, "B", true);

Program program;
auto sub_1 = program.elementwise_sub(A, A);
auto sub_2 = program.elementwise_sub(A, A);
auto add_1 = program.reshape(B, {4, -1});
auto add_2 = program.reshape(B, {4, 8});
auto add = program.add(add_1, add_2);
auto add_1 = program.add(A, A);
auto add_2 = program.add(A, A);
auto reshape_1 = program.reshape(B, {4, -1});
auto reshape_2 = program.reshape(B, {4, 8});
auto add = program.add(reshape_1, reshape_2);

Target target = common::DefaultTarget();
program.SetInputs({A, B});
Expand Down Expand Up @@ -129,17 +129,30 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
}

TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
Placeholder A(Float(32), {32, 16}, "A");
Placeholder B(Float(32), {32, 1}, "B", true);
Placeholder A(Float(32), {1, 3, 224, 224}, "A");
Placeholder B(Float(32), {1, 1, 224, 224}, "B", true);

absl::flat_hash_map<std::string, Program::attr_t> attrs;
attrs["stride"] = std::vector<int>({2, 2});
attrs["dilation"] = std::vector<int>({1, 1});
attrs["padding"] = std::vector<int>({3, 3});
std::string src_layout = "NCHW";
attrs["data_format"] = src_layout;

Program program;
auto sub_1 = program.elementwise_sub(A, A);
auto sub_2 = program.elementwise_sub(A, A);
auto const_1 = program.fill_constant<float>({32, 16}, 1.0f, "", false, "const1");
auto const_2 = program.fill_constant<float>({32, 16}, 1.0f, "", false, "const2");
auto const_3 = program.fill_constant<float>({32, 16}, 2.0f, "", false, "const3");
auto out1 = program.add(const_1, const_3);
auto out2 = program.add(const_2, const_3);
auto add_1 = program.add(A, B);
auto weight_1 = program.fill_constant<float>({64, 3, 7, 7}, 1.0f, "", false, "w1");
auto weight_2 = program.fill_constant<float>({64, 3, 7, 7}, 1.0f, "", false, "w2");
auto bias = program.fill_constant<float>({1, 64, 112, 112}, 2.0f, "", false, "b1");
auto conv_1 = program.conv2d(add_1, weight_1, attrs);
auto add_2 = program.add(conv_1, bias);
auto relu_1 = program.relu(add_2);
auto conv_2 = program.conv2d(add_1, weight_2, attrs);
auto add_3 = program.add(conv_2, bias);
auto relu_2 = program.relu(add_3);
auto out1 = program.add(relu_1, add_2);
auto out2 = program.add(add_2, relu_2);
auto out = program.multiply(out1, out2);

Target target = common::DefaultTarget();
program.SetInputs({A, B});
Expand All @@ -157,7 +170,7 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
auto& prerun_instrs = runtime_program->GetPreRunInstructions();
auto& run_instrs = runtime_program->GetRunInstructions();
ASSERT_EQ(prerun_instrs.size(), 0);
ASSERT_EQ(run_instrs.size(), 4);
ASSERT_EQ(run_instrs.size(), 8);
scope->Var<hlir::framework::Tensor>("A");
scope->Var<hlir::framework::Tensor>("B");

Expand Down

0 comments on commit c2342d6

Please sign in to comment.