Skip to content

Commit

Permalink
[ir] MatrixField refactor 8/n: Rename PtrOffsetStmt to MatrixPtrStmt (#…
Browse files Browse the repository at this point in the history
…6187)

Issue: #5959

### Brief Summary

`PtrOffsetStmt` is for getting a pointer to an element of a matrix.
Let's make it follow the naming convention (`XXXPtrStmt`) for clarity.
  • Loading branch information
strongoier authored Sep 28, 2022
1 parent 827ab30 commit 6b84c99
Show file tree
Hide file tree
Showing 21 changed files with 93 additions and 93 deletions.
14 changes: 7 additions & 7 deletions taichi/analysis/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
if (!var1 || !var2)
return AliasResult::different;

// TODO: further optimize with offset inside PtrOffsetStmt
// TODO: further optimize with offset inside MatrixPtrStmt
// If at least one of var1 and var2 is local, they will be treated here.
auto retrieve_local = [&](Stmt *var) {
if (var->is<AllocaStmt>()) {
return var;
} else if (var->is<PtrOffsetStmt>() &&
var->cast<PtrOffsetStmt>()->offset_used_as_index()) {
return var->cast<PtrOffsetStmt>()->origin;
} else if (var->is<MatrixPtrStmt>() &&
var->cast<MatrixPtrStmt>()->offset_used_as_index()) {
return var->cast<MatrixPtrStmt>()->origin;
} else {
return (Stmt *)nullptr;
}
Expand All @@ -30,9 +30,9 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
Stmt *origin2 = retrieve_local(var2);
if (origin1 != nullptr && origin2 != nullptr) {
if (origin1 == origin2) {
if (var1->is<PtrOffsetStmt>() && var2->is<PtrOffsetStmt>()) {
auto diff = value_diff_ptr_index(var1->cast<PtrOffsetStmt>()->offset,
var2->cast<PtrOffsetStmt>()->offset);
if (var1->is<MatrixPtrStmt>() && var2->is<MatrixPtrStmt>()) {
auto diff = value_diff_ptr_index(var1->cast<MatrixPtrStmt>()->offset,
var2->cast<MatrixPtrStmt>()->offset);
if (diff.is_diff_certain) {
return diff.diff_range == 0 ? AliasResult::same
: AliasResult::different;
Expand Down
6 changes: 3 additions & 3 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ class IRVerifier : public BasicStmtVisitor {

void visit(LocalLoadStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<PtrOffsetStmt>());
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>());
}

void visit(LocalStoreStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->dest->is<AllocaStmt>() ||
(stmt->dest->is<PtrOffsetStmt>() &&
stmt->dest->cast<PtrOffsetStmt>()->offset_used_as_index()));
(stmt->dest->is<MatrixPtrStmt>() &&
stmt->dest->cast<MatrixPtrStmt>()->offset_used_as_index()));
}

void visit(LoopIndexStmt *stmt) override {
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,7 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) {
}
}

void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) {
void TaskCodeGenLLVM::visit(MatrixPtrStmt *stmt) {
if (stmt->offset_used_as_index()) {
auto type = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed());
llvm_val[stmt] =
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(GlobalPtrStmt *stmt) override;

void visit(PtrOffsetStmt *stmt) override;
void visit(MatrixPtrStmt *stmt) override;

void store_quant_int(llvm::Value *ptr,
llvm::Type *physical_type,
Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class TaskCodegen : public IRVisitor {
}
}

void visit(PtrOffsetStmt *stmt) override {
void visit(MatrixPtrStmt *stmt) override {
spirv::SType data_type =
ir_->get_primitive_type(stmt->element_type().ptr_removed());
spirv::SType ptr_type =
Expand Down
2 changes: 1 addition & 1 deletion taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ PER_STATEMENT(ReturnStmt)
PER_STATEMENT(ArgLoadStmt)
PER_STATEMENT(ReferenceStmt)
PER_STATEMENT(ExternalPtrStmt)
PER_STATEMENT(PtrOffsetStmt)
PER_STATEMENT(MatrixPtrStmt)
PER_STATEMENT(ConstStmt)
PER_STATEMENT(AllocaStmt)
PER_STATEMENT(UnaryOpStmt)
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,12 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
for (int i = 0; i < num_nodes; i++) {
for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
auto stmt = nodes[i]->block->statements[j].get();
if ((stmt->is<PtrOffsetStmt>() &&
stmt->as<PtrOffsetStmt>()->origin->is<AllocaStmt>()) ||
if ((stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(!after_lower_access &&
(stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
stmt->is<BlockLocalPtrStmt>() || stmt->is<ThreadLocalPtrStmt>() ||
stmt->is<GlobalTemporaryStmt>() || stmt->is<PtrOffsetStmt>()))) {
stmt->is<GlobalTemporaryStmt>() || stmt->is<MatrixPtrStmt>()))) {
// TODO: unify them
// A global pointer that may contain some data before this kernel.
nodes[start_node]->reach_gen.insert(stmt);
Expand Down Expand Up @@ -679,8 +679,8 @@ void ControlFlowGraph::live_variable_analysis(
if (stmt->is<AllocaStmt>() || stmt->is<AdStackAllocaStmt>()) {
return false;
}
if (stmt->is<PtrOffsetStmt>() &&
stmt->cast<PtrOffsetStmt>()->origin->is<AllocaStmt>()) {
if (stmt->is<MatrixPtrStmt>() &&
stmt->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>()) {
return false;
}
if (auto *gptr = stmt->cast<GlobalPtrStmt>();
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
offset_stmt = ctx->push_back<BinaryOpStmt>(BinaryOpType::mul, offset_stmt,
stride_stmt);
}
return ctx->push_back<PtrOffsetStmt>(var->stmt, offset_stmt);
return ctx->push_back<MatrixPtrStmt>(var->stmt, offset_stmt);
}

void MatrixExpression::type_check(CompileConfig *config) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes,
TI_STMT_REG_FIELDS;
}

PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) {
MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, Stmt *offset_input) {
origin = origin_input;
offset = offset_input;
if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() ||
Expand All @@ -89,7 +89,7 @@ PtrOffsetStmt::PtrOffsetStmt(Stmt *origin_input, Stmt *offset_input) {
element_type() = origin->cast<GlobalPtrStmt>()->ret_type;
} else {
TI_ERROR(
"PtrOffsetStmt must be used for AllocaStmt / GlobalTemporaryStmt "
"MatrixPtrStmt must be used for AllocaStmt / GlobalTemporaryStmt "
"(locally) or GlobalPtrStmt / MatrixOfGlobalPtrStmt / ExternalPtrStmt "
"(globally).")
}
Expand Down
16 changes: 8 additions & 8 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,18 @@ class MatrixOfGlobalPtrStmt : public Stmt {
};

/**
* An accessing tensor element operation.
* A pointer to an element of a matrix.
*/
class PtrOffsetStmt : public Stmt {
class MatrixPtrStmt : public Stmt {
public:
Stmt *origin{nullptr};
Stmt *offset{nullptr};

PtrOffsetStmt(Stmt *, Stmt *);
MatrixPtrStmt(Stmt *, Stmt *);

/* TODO(zhanlue/yi): Unify semantics of offset in PrtOffsetStmt
/* TODO(zhanlue/yi): Unify semantics of offset in MatrixPtrStmt
There is a hack in PtrOffsetStmt in terms of the semantics of "offset",
There is a hack in MatrixPtrStmt in terms of the semantics of "offset",
where "offset" can be interpreted as "number of bytes" or "index" in
different upper-level code paths
Expand All @@ -408,7 +408,7 @@ class PtrOffsetStmt : public Stmt {
if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() ||
origin->is<ExternalPtrStmt>()) {
TI_ASSERT_INFO(origin->ret_type.ptr_removed()->is<TensorType>(),
"PtrOffsetStmt can only be used for TensorType.");
"MatrixPtrStmt can only be used for TensorType.");
return true;
}
return false;
Expand Down Expand Up @@ -663,8 +663,8 @@ class LocalStoreStmt : public Stmt {

LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) {
TI_ASSERT(dest->is<AllocaStmt>() ||
(dest->is<PtrOffsetStmt>() &&
dest->cast<PtrOffsetStmt>()->offset_used_as_index()));
(dest->is<MatrixPtrStmt>() &&
dest->cast<MatrixPtrStmt>()->offset_used_as_index()));
TI_STMT_REG_FIELDS;
}

Expand Down
54 changes: 27 additions & 27 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
using BasicStmtVisitor::visit;

void visit(LocalLoadStmt *stmt) override {
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<PtrOffsetStmt>());
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<MatrixPtrStmt>());
touched_allocas_.insert(stmt->src);
}

void visit(LocalStoreStmt *stmt) override {
TI_ASSERT(stmt->dest->is<AllocaStmt>() || stmt->dest->is<PtrOffsetStmt>());
TI_ASSERT(stmt->dest->is<AllocaStmt>() || stmt->dest->is<MatrixPtrStmt>());
touched_allocas_.insert(stmt->dest);
}

Expand Down Expand Up @@ -570,7 +570,7 @@ class ADTransform : public IRVisitor {
// do nothing.
}

void visit(PtrOffsetStmt *stmt) override {
void visit(MatrixPtrStmt *stmt) override {
// do nothing.
}

Expand Down Expand Up @@ -988,9 +988,9 @@ class MakeAdjoint : public ADTransform {

GlobalPtrStmt *src = nullptr;
bool is_ptr_offset = false;
if (stmt->src->is<PtrOffsetStmt>()) {
if (stmt->src->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
src = stmt->src->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
src = stmt->src->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
src = stmt->src->as<GlobalPtrStmt>();
}
Expand All @@ -1008,8 +1008,8 @@ class MakeAdjoint : public ADTransform {
snode = snode->get_adjoint();
auto adj_ptr = insert<GlobalPtrStmt>(snode, src->indices);
if (is_ptr_offset) {
adj_ptr = insert<PtrOffsetStmt>(adj_ptr,
stmt->src->as<PtrOffsetStmt>()->offset);
adj_ptr = insert<MatrixPtrStmt>(adj_ptr,
stmt->src->as<MatrixPtrStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, adj_ptr, load(adjoint(stmt)));
}
Expand All @@ -1024,9 +1024,9 @@ class MakeAdjoint : public ADTransform {

GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
Expand All @@ -1040,8 +1040,8 @@ class MakeAdjoint : public ADTransform {
snode = snode->get_adjoint();
auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
if (is_ptr_offset) {
adjoint_ptr = insert<PtrOffsetStmt>(
adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
adjoint_ptr = insert<MatrixPtrStmt>(
adjoint_ptr, stmt->dest->as<MatrixPtrStmt>()->offset);
}
accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
stmt->parent->erase(stmt);
Expand All @@ -1051,9 +1051,9 @@ class MakeAdjoint : public ADTransform {
// erase and replace with global load adjoint
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
Expand All @@ -1068,8 +1068,8 @@ class MakeAdjoint : public ADTransform {
snode = snode->get_adjoint();
auto adjoint_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
if (is_ptr_offset) {
adjoint_ptr = insert<PtrOffsetStmt>(
adjoint_ptr, stmt->dest->as<PtrOffsetStmt>()->offset);
adjoint_ptr = insert<MatrixPtrStmt>(
adjoint_ptr, stmt->dest->as<MatrixPtrStmt>()->offset);
}
accumulate(stmt->val, insert<GlobalLoadStmt>(adjoint_ptr));
stmt->parent->erase(stmt);
Expand Down Expand Up @@ -1315,9 +1315,9 @@ class MakeDual : public ADTransform {
// issue global store to dual
GlobalPtrStmt *src = nullptr;
bool is_ptr_offset = false;
if (stmt->src->is<PtrOffsetStmt>()) {
if (stmt->src->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
src = stmt->src->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
src = stmt->src->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
src = stmt->src->as<GlobalPtrStmt>();
}
Expand All @@ -1334,18 +1334,18 @@ class MakeDual : public ADTransform {
snode = snode->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snode, src->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->src->as<PtrOffsetStmt>()->offset);
dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
stmt->src->as<MatrixPtrStmt>()->offset);
}
accumulate(stmt, insert<GlobalLoadStmt>(dual_ptr));
}

void visit(GlobalStoreStmt *stmt) override {
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
Expand All @@ -1358,18 +1358,18 @@ class MakeDual : public ADTransform {
snode = snode->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->dest->as<PtrOffsetStmt>()->offset);
dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
stmt->dest->as<MatrixPtrStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
}

void visit(AtomicOpStmt *stmt) override {
GlobalPtrStmt *dest = nullptr;
bool is_ptr_offset = false;
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->is<MatrixPtrStmt>()) {
is_ptr_offset = true;
dest = stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>();
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
} else {
dest = stmt->dest->as<GlobalPtrStmt>();
}
Expand All @@ -1382,8 +1382,8 @@ class MakeDual : public ADTransform {
snode = snode->get_dual();
auto dual_ptr = insert<GlobalPtrStmt>(snode, dest->indices);
if (is_ptr_offset) {
dual_ptr = insert<PtrOffsetStmt>(dual_ptr,
stmt->dest->as<PtrOffsetStmt>()->offset);
dual_ptr = insert<MatrixPtrStmt>(dual_ptr,
stmt->dest->as<MatrixPtrStmt>()->offset);
}
insert<AtomicOpStmt>(AtomicOpType::add, dual_ptr, load(dual(stmt->val)));
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/demote_atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class DemoteAtomics : public BasicStmtVisitor {
}
}
if (stmt->dest->is<AllocaStmt>() ||
(stmt->dest->is<PtrOffsetStmt>() &&
stmt->dest->cast<PtrOffsetStmt>()->origin->is<AllocaStmt>())) {
(stmt->dest->is<MatrixPtrStmt>() &&
stmt->dest->cast<MatrixPtrStmt>()->origin->is<AllocaStmt>())) {
demote = true;
is_local = true;
}
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ class DemoteOperations : public BasicStmtVisitor {
modifier.insert_before(stmt, std::move(rhs_store));
for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) {
auto idx = Stmt::make<ConstStmt>(TypedConstant(i));
auto lhs_i = Stmt::make<PtrOffsetStmt>(lhs_ptr, idx.get());
auto rhs_i = Stmt::make<PtrOffsetStmt>(rhs_ptr, idx.get());
auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get());
auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get());
auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get());
auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get());
auto cur_lhs = lhs_load.get();
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/flag_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class FlagAccess : public IRVisitor {
if (stmt->dest->is<GlobalPtrStmt>()) {
stmt->dest->as<GlobalPtrStmt>()->activate = true;
}
if (stmt->dest->is<PtrOffsetStmt>()) {
if (stmt->dest->as<PtrOffsetStmt>()->is_unlowered_global_ptr()) {
stmt->dest->as<PtrOffsetStmt>()->origin->as<GlobalPtrStmt>()->activate =
if (stmt->dest->is<MatrixPtrStmt>()) {
if (stmt->dest->as<MatrixPtrStmt>()->is_unlowered_global_ptr()) {
stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>()->activate =
true;
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ class IRPrinter : public IRVisitor {
print_raw(s);
}

void visit(PtrOffsetStmt *stmt) override {
void visit(MatrixPtrStmt *stmt) override {
std::string s =
fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(),
stmt->name(), stmt->origin->name(), stmt->offset->name());
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/lower_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class LowerAccess : public IRVisitor {
}

// TODO: this seems to be redundant
void visit(PtrOffsetStmt *stmt) override {
void visit(MatrixPtrStmt *stmt) override {
if (!stmt->is_unlowered_global_ptr())
return;
auto ptr = stmt->origin->as<GlobalPtrStmt>();
Expand Down
Loading

0 comments on commit 6b84c99

Please sign in to comment.