Skip to content

Commit

Permalink
[Lang] MatrixNdarray refactor part10: Remove redundant MatrixInitStmt…
Browse files Browse the repository at this point in the history
… generated from scalarization (#6171)

Related issue = #5873,
#5819

This PR is working "Part ④" in
#5873.
  • Loading branch information
jim19930609 authored Sep 30, 2022
1 parent b491cee commit 2a3ac5c
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 39 deletions.
4 changes: 4 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,10 @@ class MatrixInitStmt : public Stmt {
TI_STMT_REG_FIELDS;
}

bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, values);
TI_DEFINE_ACCEPT_AND_CLONE
};
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ void compile_to_offloads(IRNode *ir,

if (config.real_matrix && config.real_matrix_scalarize) {
irpass::scalarize(ir);

// Remove redundant MatrixInitStmt inserted during scalarization
irpass::die(ir);
print("Scalarized");
}

Expand Down
7 changes: 0 additions & 7 deletions taichi/transforms/die.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,6 @@ class DIE : public IRVisitor {
}
stmt->all_blocks_accept(this, true);
}

void visit(MatrixInitStmt *stmt) override {
register_usage(stmt);
for (auto &elts : stmt->values) {
elts->accept(this);
}
}
};

namespace irpass {
Expand Down
5 changes: 0 additions & 5 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,6 @@ void scalarize(IRNode *root) {
TI_AUTO_PROF;
Scalarize scalarize_pass(root);
ScalarizePointers scalarize_pointers_pass(root);

/* TODO(zhanlue): Remove redundant MatrixInitStmt
Scalarize pass will generate temporary MatrixInitStmts, which are only used
as rvalues. Remove these MatrixInitStmts since it's no longer needed.
*/
}

} // namespace irpass
Expand Down
55 changes: 28 additions & 27 deletions tests/cpp/transforms/scalarize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,24 @@ TEST(Scalarize, ScalarizeGlobalStore) {
irpass::die(block.get());

EXPECT_EQ(block->size(), 2 /*const*/ + 1 /*argload*/ + 1 /*external_ptr*/ +
1 /*matrix_init*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*store*/);
4 /*const*/ + 4 /*matrix_ptr*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[7]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[4]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[5]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[6]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[8]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[9]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[10]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[7]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[8]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[9]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[13]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[10]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[11]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[12]->is<GlobalStoreStmt>(), true);

EXPECT_EQ(block->statements[14]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[15]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[16]->is<GlobalStoreStmt>(), true);
EXPECT_EQ(block->statements[13]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[14]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[15]->is<GlobalStoreStmt>(), true);
}

TEST(Scalarize, ScalarizeGlobalLoad) {
Expand All @@ -86,6 +85,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
/*
TensorType<4 x i32>* %1 = ExternalPtrStmt()
TensorType<4 x i32> %2 = LoadStmt(%1)
StoreStmt(%1, %2)
*/
Type *tensor_type = type_factory.get_tensor_type(
{2, 2}, type_factory.get_primitive_type(PrimitiveTypeID::i32));
Expand All @@ -96,14 +96,17 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
argload_stmt, indices); // fake ExternalPtrStmt
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

block->push_back<GlobalLoadStmt>(src_stmt);
auto load_stmt = block->push_back<GlobalLoadStmt>(src_stmt);

// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::die(block.get());

EXPECT_EQ(block->size(), 1 /*argload*/ + 1 /*external_ptr*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*load*/ +
1 /*matrix_init*/);
4 /*matrix_ptr*/ + 4 /*load*/ + 4 /*const*/ +
4 /*matrix_ptr*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[2]->is<ConstStmt>(), true);
Expand All @@ -121,8 +124,6 @@ TEST(Scalarize, ScalarizeGlobalLoad) {
EXPECT_EQ(block->statements[11]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[12]->is<MatrixPtrStmt>(), true);
EXPECT_EQ(block->statements[13]->is<GlobalLoadStmt>(), true);

EXPECT_EQ(block->statements[14]->is<MatrixInitStmt>(), true);
}

TEST(Scalarize, ScalarizeLocalStore) {
Expand Down Expand Up @@ -157,13 +158,13 @@ TEST(Scalarize, ScalarizeLocalStore) {
block->push_back<MatrixInitStmt>(std::move(matrix_init_vals));
matrix_init_stmt->ret_type = tensor_type;

// LocalStoreStmt survives irpass::die()
block->push_back<LocalStoreStmt>(dest_stmt, matrix_init_stmt);

irpass::scalarize(block.get());
irpass::die(block.get());

EXPECT_EQ(block->size(),
2 /*const*/ + 1 /*matrix_init*/ + 4 /*alloca*/ + 4 /*store*/);
EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[0]->is<AllocaStmt>(), true);
Expand All @@ -173,12 +174,11 @@ TEST(Scalarize, ScalarizeLocalStore) {

EXPECT_EQ(block->statements[4]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[5]->is<ConstStmt>(), true);
EXPECT_EQ(block->statements[6]->is<MatrixInitStmt>(), true);

EXPECT_EQ(block->statements[6]->is<LocalStoreStmt>(), true);
EXPECT_EQ(block->statements[7]->is<LocalStoreStmt>(), true);
EXPECT_EQ(block->statements[8]->is<LocalStoreStmt>(), true);
EXPECT_EQ(block->statements[9]->is<LocalStoreStmt>(), true);
EXPECT_EQ(block->statements[10]->is<LocalStoreStmt>(), true);
}

TEST(Scalarize, ScalarizeLocalLoad) {
Expand All @@ -204,12 +204,15 @@ TEST(Scalarize, ScalarizeLocalLoad) {
Stmt *src_stmt = block->push_back<AllocaStmt>(tensor_type);
src_stmt->ret_type = type_factory.get_pointer_type(tensor_type);

block->push_back<LocalLoadStmt>(src_stmt);
auto load_stmt = block->push_back<LocalLoadStmt>(src_stmt);

// Without this GlobalStoreStmt, nothing survives irpass::die()
block->push_back<GlobalStoreStmt>(src_stmt, load_stmt);

irpass::scalarize(block.get());
irpass::die(block.get());

EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 1 /*matrix_init*/);
EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/);

// Check for scalarized statements
EXPECT_EQ(block->statements[0]->is<AllocaStmt>(), true);
Expand All @@ -221,8 +224,6 @@ TEST(Scalarize, ScalarizeLocalLoad) {
EXPECT_EQ(block->statements[5]->is<LocalLoadStmt>(), true);
EXPECT_EQ(block->statements[6]->is<LocalLoadStmt>(), true);
EXPECT_EQ(block->statements[7]->is<LocalLoadStmt>(), true);

EXPECT_EQ(block->statements[8]->is<MatrixInitStmt>(), true);
}

} // namespace taichi::lang

0 comments on commit 2a3ac5c

Please sign in to comment.