From bf49f41c2b3803a4944c5794f083cfd0048ff148 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Wed, 7 Apr 2021 23:54:21 +0800 Subject: [PATCH 1/8] [IR] [refactor] Unify field names in load/store/atomic statements --- taichi/analysis/data_source_analysis.cpp | 12 ++-- taichi/analysis/gather_snode_read_writes.cpp | 6 +- taichi/analysis/has_store_or_atomic.cpp | 2 +- taichi/analysis/last_store_or_atomic.cpp | 6 +- taichi/analysis/same_statements.cpp | 2 +- taichi/analysis/verify.cpp | 4 +- taichi/backends/metal/codegen_metal.cpp | 36 ++++++------ taichi/backends/opengl/codegen_opengl.cpp | 24 ++++---- taichi/codegen/codegen_llvm.cpp | 34 +++++------ taichi/ir/control_flow_graph.cpp | 18 +++--- taichi/ir/state_machine.cpp | 12 ++-- taichi/ir/statements.cpp | 10 ++-- taichi/ir/statements.h | 28 ++++----- taichi/program/async_utils.cpp | 8 +-- taichi/transforms/auto_diff.cpp | 38 ++++++------ taichi/transforms/bit_loop_vectorize.cpp | 12 ++-- taichi/transforms/flag_access.cpp | 4 +- taichi/transforms/insert_scratch_pad.cpp | 4 +- taichi/transforms/ir_printer.cpp | 8 +-- taichi/transforms/loop_vectorize.cpp | 10 ++-- taichi/transforms/lower_access.cpp | 18 +++--- taichi/transforms/make_thread_local.cpp | 4 +- taichi/transforms/offload.cpp | 8 +-- .../transforms/optimize_bit_struct_stores.cpp | 4 +- taichi/transforms/simplify.cpp | 14 ++--- taichi/transforms/type_check.cpp | 36 ++++++------ taichi/transforms/variable_optimization.cpp | 58 +++++++++---------- taichi/transforms/vector_split.cpp | 20 +++---- 28 files changed, 220 insertions(+), 220 deletions(-) diff --git a/taichi/analysis/data_source_analysis.cpp b/taichi/analysis/data_source_analysis.cpp index 687ffdc95b664..d77e84b13af5a 100644 --- a/taichi/analysis/data_source_analysis.cpp +++ b/taichi/analysis/data_source_analysis.cpp @@ -10,13 +10,13 @@ std::vector get_load_pointers(Stmt *load_stmt) { // If load_stmt loads some variables or a stack, return the pointers of them. if (auto local_load = load_stmt->cast()) { std::vector result; - for (auto &address : local_load->ptr.data) { + for (auto &address : local_load->src.data) { if (std::find(result.begin(), result.end(), address.var) == result.end()) result.push_back(address.var); } return result; } else if (auto global_load = load_stmt->cast()) { - return std::vector(1, global_load->ptr); + return std::vector(1, global_load->src); } else if (auto atomic = load_stmt->cast()) { return std::vector(1, atomic->dest); } else if (auto stack_load_top = load_stmt->cast()) { @@ -46,9 +46,9 @@ Stmt *get_store_data(Stmt *store_stmt) { // stores. return store_stmt; } else if (auto local_store = store_stmt->cast()) { - return local_store->data; + return local_store->val; } else if (auto global_store = store_stmt->cast()) { - return global_store->data; + return global_store->val; } else { return nullptr; } @@ -60,9 +60,9 @@ std::vector get_store_destination(Stmt *store_stmt) { // The statement itself provides a data source (const [0]). return std::vector(1, store_stmt); } else if (auto local_store = store_stmt->cast()) { - return std::vector(1, local_store->ptr); + return std::vector(1, local_store->dest); } else if (auto global_store = store_stmt->cast()) { - return std::vector(1, global_store->ptr); + return std::vector(1, global_store->dest); } else if (auto atomic = store_stmt->cast()) { return std::vector(1, atomic->dest); } else if (auto external_func = store_stmt->cast()) { diff --git a/taichi/analysis/gather_snode_read_writes.cpp b/taichi/analysis/gather_snode_read_writes.cpp index af525e4a4ac53..45abb217ad84c 100644 --- a/taichi/analysis/gather_snode_read_writes.cpp +++ b/taichi/analysis/gather_snode_read_writes.cpp @@ -17,17 +17,17 @@ gather_snode_read_writes(IRNode *root) { bool read = false, write = false; if (auto global_load = stmt->cast()) { read = true; - ptr = global_load->ptr; + ptr = global_load->src; } else if (auto global_store = stmt->cast()) { write = true; - ptr = global_store->ptr; + ptr = global_store->dest; } else if (auto global_atomic = stmt->cast()) { read = true; write = true; ptr = global_atomic->dest; } if (ptr) { - if (GlobalPtrStmt *global_ptr = ptr->cast()) { + if (auto *global_ptr = ptr->cast()) { for (auto &snode : global_ptr->snodes.data) { if (read) accessed.first.emplace(snode); diff --git a/taichi/analysis/has_store_or_atomic.cpp b/taichi/analysis/has_store_or_atomic.cpp index b498413e38aaf..04d546673364c 100644 --- a/taichi/analysis/has_store_or_atomic.cpp +++ b/taichi/analysis/has_store_or_atomic.cpp @@ -25,7 +25,7 @@ class LocalStoreSearcher : public BasicStmtVisitor { void visit(LocalStoreStmt *stmt) override { for (auto var : vars) { - if (stmt->ptr == var) { + if (stmt->dest == var) { result = true; break; } diff --git a/taichi/analysis/last_store_or_atomic.cpp b/taichi/analysis/last_store_or_atomic.cpp index 54868b76a8135..744e4ef1c62d9 100644 --- a/taichi/analysis/last_store_or_atomic.cpp +++ b/taichi/analysis/last_store_or_atomic.cpp @@ -24,7 +24,7 @@ class LocalStoreForwarder : public BasicStmtVisitor { } void visit(LocalStoreStmt *stmt) override { - if (stmt->ptr == var) { + if (stmt->dest == var) { is_valid = true; result = stmt; } @@ -70,8 +70,8 @@ class LocalStoreForwarder : public BasicStmtVisitor { } else { TI_ASSERT(true_stmt->is()); TI_ASSERT(false_stmt->is()); - if (true_stmt->as()->data != - false_stmt->as()->data) { + if (true_stmt->as()->val != + false_stmt->as()->val) { // two branches finally store the variable differently is_valid = false; } else { diff --git a/taichi/analysis/same_statements.cpp b/taichi/analysis/same_statements.cpp index 2691f54cb82cd..e53b4baa35a5a 100644 --- a/taichi/analysis/same_statements.cpp +++ b/taichi/analysis/same_statements.cpp @@ -152,7 +152,7 @@ class IRNodeComparator : public IRVisitor { } else { bool same_value = false; if (auto global_load = stmt->cast()) { - if (auto global_ptr = global_load->ptr->cast()) { + if (auto global_ptr = global_load->src->cast()) { TI_ASSERT(global_ptr->width() == 1); if (possibly_modified_states_.count(ir_bank_->get_async_state( global_ptr->snodes[0], AsyncState::Type::value)) == 0) { diff --git a/taichi/analysis/verify.cpp b/taichi/analysis/verify.cpp index 3c8f8c30cc950..21e4d54a4d664 100644 --- a/taichi/analysis/verify.cpp +++ b/taichi/analysis/verify.cpp @@ -104,13 +104,13 @@ class IRVerifier : public BasicStmtVisitor { void visit(LocalLoadStmt *stmt) override { basic_verify(stmt); for (int i = 0; i < stmt->width(); i++) { - TI_ASSERT(stmt->ptr[i].var->is()); + TI_ASSERT(stmt->src[i].var->is()); } } void visit(LocalStoreStmt *stmt) override { basic_verify(stmt); - TI_ASSERT(stmt->ptr->is()); + TI_ASSERT(stmt->dest->is()); } void visit(LoopIndexStmt *stmt) override { diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 733b61d3ca337..0adb7963f1ad4 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -172,14 +172,14 @@ class KernelCodegen : public IRVisitor { void visit(LocalLoadStmt *stmt) override { // TODO: optimize for partially vectorized load... bool linear_index = true; - for (int i = 0; i < (int)stmt->ptr.size(); i++) { - if (stmt->ptr[i].offset != i) { + for (int i = 0; i < (int)stmt->src.size(); i++) { + if (stmt->src[i].offset != i) { linear_index = false; } } if (stmt->same_source() && linear_index && - stmt->width() == stmt->ptr[0].var->width()) { - auto ptr = stmt->ptr[0].var; + stmt->width() == stmt->src[0].var->width()) { + auto ptr = stmt->src[0].var; emit("const {} {}({});", metal_data_type_name(stmt->element_type()), stmt->raw_name(), ptr->raw_name()); } else { @@ -188,7 +188,7 @@ class KernelCodegen : public IRVisitor { } void visit(LocalStoreStmt *stmt) override { - emit(R"({} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name()); + emit(R"({} = {};)", stmt->dest->raw_name(), stmt->val->raw_name()); } void visit(GetRootStmt *stmt) override { @@ -335,8 +335,8 @@ class KernelCodegen : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - if (!is_ret_type_bit_pointer(stmt->ptr)) { - emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name()); + if (!is_ret_type_bit_pointer(stmt->dest)) { + emit(R"(*{} = {};)", stmt->dest->raw_name(), stmt->val->raw_name()); return; } handle_bit_pointer_global_store(stmt); @@ -345,8 +345,8 @@ class KernelCodegen : public IRVisitor { void visit(GlobalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); std::string rhs_expr; - if (!is_ret_type_bit_pointer(stmt->ptr)) { - rhs_expr = fmt::format("*{}", stmt->ptr->raw_name()); + if (!is_ret_type_bit_pointer(stmt->src)) { + rhs_expr = fmt::format("*{}", stmt->src->raw_name()); } else { rhs_expr = construct_bit_pointer_global_load(stmt); } @@ -832,30 +832,30 @@ class KernelCodegen : public IRVisitor { } void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) { - auto *ptr_type = stmt->ptr->ret_type->as(); + auto *ptr_type = stmt->dest->ret_type->as(); TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); CustomIntType *cit = nullptr; std::string store_value_expr; if (auto *cit_cast = pointee_type->cast()) { cit = cit_cast; - store_value_expr = stmt->data->raw_name(); + store_value_expr = stmt->val->raw_name(); } else if (auto *cft = pointee_type->cast()) { validate_cft_for_metal(cft); auto *digits_cit = cft->get_digits_type()->as(); cit = digits_cit; store_value_expr = construct_float_to_custom_int_expr( - stmt->data, cft->get_scale(), digits_cit); + stmt->val, cft->get_scale(), digits_cit); } else { TI_NOT_IMPLEMENTED; } - // Type of |stmt->ptr| is SNodeBitPointer + // Type of |stmt->dest| is SNodeBitPointer const auto num_bits = cit->get_num_bits(); if (is_full_bits(num_bits)) { - emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(), + emit("mtl_set_full_bits({}, {});", stmt->dest->raw_name(), store_value_expr); } else { - emit("mtl_set_partial_bits({},", stmt->ptr->raw_name()); + emit("mtl_set_partial_bits({},", stmt->dest->raw_name()); emit(" {},", store_value_expr); emit(" /*bits=*/{});", num_bits); } @@ -863,15 +863,15 @@ class KernelCodegen : public IRVisitor { // Returns the expression of the load result std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const { - auto *ptr_type = stmt->ptr->ret_type->as(); + auto *ptr_type = stmt->src->ret_type->as(); TI_ASSERT(ptr_type->is_bit_pointer()); auto *pointee_type = ptr_type->get_pointee_type(); if (auto *cit = pointee_type->cast()) { - return construct_load_as_custom_int(stmt->ptr, cit); + return construct_load_as_custom_int(stmt->src, cit); } else if (auto *cft = pointee_type->cast()) { validate_cft_for_metal(cft); const auto loaded = construct_load_as_custom_int( - stmt->ptr, cft->get_digits_type()->as()); + stmt->src, cft->get_digits_type()->as()); // Computes `float(digits_expr) * scale` // See LLVM backend's reconstruct_custom_float() return fmt::format("(static_cast({}) * {})", loaded, diff --git a/taichi/backends/opengl/codegen_opengl.cpp b/taichi/backends/opengl/codegen_opengl.cpp index 8bc28bd68d07b..83227ba68f57d 100644 --- a/taichi/backends/opengl/codegen_opengl.cpp +++ b/taichi/backends/opengl/codegen_opengl.cpp @@ -405,11 +405,11 @@ class KernelGen : public IRVisitor { void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - auto dt = stmt->data->element_type(); + auto dt = stmt->val->element_type(); emit("_{}_{}_[{} >> {}] = {};", - ptr_signats.at(stmt->ptr->id), // throw out_of_range if not a pointer - opengl_data_type_short_name(dt), stmt->ptr->short_name(), - opengl_data_address_shifter(dt), stmt->data->short_name()); + ptr_signats.at(stmt->dest->id), // throw out_of_range if not a pointer + opengl_data_type_short_name(dt), stmt->dest->short_name(), + opengl_data_address_shifter(dt), stmt->val->short_name()); } void visit(GlobalLoadStmt *stmt) override { @@ -417,8 +417,8 @@ class KernelGen : public IRVisitor { auto dt = stmt->element_type(); emit("{} {} = _{}_{}_[{} >> {}];", opengl_data_type_name(stmt->element_type()), stmt->short_name(), - ptr_signats.at(stmt->ptr->id), opengl_data_type_short_name(dt), - stmt->ptr->short_name(), opengl_data_address_shifter(dt)); + ptr_signats.at(stmt->src->id), opengl_data_type_short_name(dt), + stmt->src->short_name(), opengl_data_address_shifter(dt)); } void visit(ExternalPtrStmt *stmt) override { @@ -648,23 +648,23 @@ class KernelGen : public IRVisitor { void visit(LocalLoadStmt *stmt) override { bool linear_index = true; - for (int i = 0; i < (int)stmt->ptr.size(); i++) { - if (stmt->ptr[i].offset != i) { + for (int i = 0; i < (int)stmt->src.size(); i++) { + if (stmt->src[i].offset != i) { linear_index = false; } } if (stmt->same_source() && linear_index && - stmt->width() == stmt->ptr[0].var->width()) { - auto ptr = stmt->ptr[0].var; + stmt->width() == stmt->src[0].var->width()) { + auto src = stmt->src[0].var; emit("{} {} = {};", opengl_data_type_name(stmt->element_type()), - stmt->short_name(), ptr->short_name()); + stmt->short_name(), src->short_name()); } else { TI_NOT_IMPLEMENTED; } } void visit(LocalStoreStmt *stmt) override { - emit("{} = {};", stmt->ptr->short_name(), stmt->data->short_name()); + emit("{} = {};", stmt->dest->short_name(), stmt->val->short_name()); } void visit(AllocaStmt *alloca) override { diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index d91eb8707051c..91d5918448a18 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -957,7 +957,7 @@ void CodeGenLLVM::visit(KernelReturnStmt *stmt) { void CodeGenLLVM::visit(LocalLoadStmt *stmt) { TI_ASSERT(stmt->width() == 1); - llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->ptr[0].var]); + llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->src[0].var]); } void CodeGenLLVM::visit(LocalStoreStmt *stmt) { @@ -965,7 +965,7 @@ void CodeGenLLVM::visit(LocalStoreStmt *stmt) { if (mask && stmt->width() != 1) { TI_NOT_IMPLEMENTED } else { - builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); + builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } } @@ -1133,16 +1133,16 @@ void CodeGenLLVM::visit(GlobalPtrStmt *stmt) { void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_ASSERT(!stmt->parent->mask() || stmt->width() == 1); - TI_ASSERT(llvm_val[stmt->data]); - TI_ASSERT(llvm_val[stmt->ptr]); - auto ptr_type = stmt->ptr->ret_type->as(); + TI_ASSERT(llvm_val[stmt->val]); + TI_ASSERT(llvm_val[stmt->dest]); + auto ptr_type = stmt->dest->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto pointee_type = ptr_type->get_pointee_type(); llvm::Value *store_value = nullptr; CustomIntType *cit = nullptr; if (auto cit_ = pointee_type->cast()) { cit = cit_; - store_value = llvm_val[stmt->data]; + store_value = llvm_val[stmt->val]; } else if (auto cft = pointee_type->cast()) { llvm::Value *digit_bits = nullptr; auto digits_cit = cft->get_digits_type()->as(); @@ -1154,7 +1154,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { // f32 = 1 sign bit + 8 exponent bits + 23 fraction bits auto f32_bits = builder->CreateBitCast( - llvm_val[stmt->data], llvm::Type::getInt32Ty(*llvm_context)); + llvm_val[stmt->val], llvm::Type::getInt32Ty(*llvm_context)); // Rounding to nearest here. Note that if the digits overflows then the // carry-on will contribute to the exponent, which is desired. if (cft->get_digit_bits() < 23) { @@ -1184,7 +1184,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { auto exponent_cit = exp->as(); - auto digits_snode = stmt->ptr->as()->output_snode; + auto digits_snode = stmt->dest->as()->output_snode; auto exponent_snode = digits_snode->exp_snode; auto exponent_offset = get_exponent_offset(exponent_bits, cft); @@ -1195,7 +1195,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { // Compute the bit pointer of the exponent bits. TI_ASSERT(digits_snode->parent == exponent_snode->parent); auto exponent_bit_ptr = - offset_bit_ptr(llvm_val[stmt->ptr], exponent_snode->bit_offset - + offset_bit_ptr(llvm_val[stmt->dest], exponent_snode->bit_offset - digits_snode->bit_offset); store_custom_int(exponent_bit_ptr, exponent_cit, exponent_bits); store_value = digit_bits; @@ -1210,36 +1210,36 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { store_value = builder->CreateSelect(exp_non_zero, store_value, tlctx->get_constant(0)); } else { - digit_bits = llvm_val[stmt->data]; + digit_bits = llvm_val[stmt->val]; store_value = float_to_custom_int(cft, digits_cit, digit_bits); } cit = digits_cit; } else { TI_NOT_IMPLEMENTED } - store_custom_int(llvm_val[stmt->ptr], cit, store_value); + store_custom_int(llvm_val[stmt->dest], cit, store_value); } else { - builder->CreateStore(llvm_val[stmt->data], llvm_val[stmt->ptr]); + builder->CreateStore(llvm_val[stmt->val], llvm_val[stmt->dest]); } } void CodeGenLLVM::visit(GlobalLoadStmt *stmt) { int width = stmt->width(); TI_ASSERT(width == 1); - auto ptr_type = stmt->ptr->ret_type->as(); + auto ptr_type = stmt->src->ret_type->as(); if (ptr_type->is_bit_pointer()) { auto val_type = ptr_type->get_pointee_type(); if (val_type->is()) { - llvm_val[stmt] = load_as_custom_int(llvm_val[stmt->ptr], val_type); + llvm_val[stmt] = load_as_custom_int(llvm_val[stmt->src], val_type); } else if (auto cft = val_type->cast()) { - TI_ASSERT(stmt->ptr->is()); - llvm_val[stmt] = load_custom_float(stmt->ptr); + TI_ASSERT(stmt->src->is()); + llvm_val[stmt] = load_custom_float(stmt->src); } else { TI_NOT_IMPLEMENTED } } else { llvm_val[stmt] = builder->CreateLoad(tlctx->get_data_type(stmt->ret_type), - llvm_val[stmt->ptr]); + llvm_val[stmt->src]); } } diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 0657da1fed4b8..e026f1d213bbb 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -251,10 +251,10 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { Stmt *result = nullptr; if (auto local_load = stmt->cast()) { bool regular = true; - auto alloca = local_load->ptr[0].var; + auto alloca = local_load->src[0].var; for (int l = 0; l < stmt->width(); l++) { - if (local_load->ptr[l].offset != l || - local_load->ptr[l].var != alloca) { + if (local_load->src[l].offset != l || + local_load->src[l].var != alloca) { regular = false; } } @@ -263,7 +263,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { } } else if (auto global_load = stmt->cast()) { if (!after_lower_access) { - result = get_store_forwarding_data(global_load->ptr, i); + result = get_store_forwarding_data(global_load->src, i); } } if (result) { @@ -282,11 +282,11 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { } // Identical store elimination if (auto local_store = stmt->cast()) { - result = get_store_forwarding_data(local_store->ptr, i); + result = get_store_forwarding_data(local_store->dest, i); if (result) { if (result->is()) { // special case of alloca (initialized to 0) - if (auto stored_data = local_store->data->cast()) { + if (auto stored_data = local_store->val->cast()) { bool all_zero = true; for (auto &val : stored_data->val.data) { if (!val.equal_value(0)) { @@ -302,7 +302,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { } } else { // not alloca - if (irpass::analysis::same_value(result, local_store->data)) { + if (irpass::analysis::same_value(result, local_store->val)) { erase(i); // This causes end_location-- i--; // to cancel i++ in the for loop modified = true; @@ -311,8 +311,8 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access) { } } else if (auto global_store = stmt->cast()) { if (!after_lower_access) { - result = get_store_forwarding_data(global_store->ptr, i); - if (irpass::analysis::same_value(result, global_store->data)) { + result = get_store_forwarding_data(global_store->dest, i); + if (irpass::analysis::same_value(result, global_store->val)) { erase(i); // This causes end_location-- i--; // to cancel i++ in the for loop modified = true; diff --git a/taichi/ir/state_machine.cpp b/taichi/ir/state_machine.cpp index 8d5d8d0075541..4c747c4cc08f2 100644 --- a/taichi/ir/state_machine.cpp +++ b/taichi/ir/state_machine.cpp @@ -29,14 +29,14 @@ bool StateMachine::same_data(Stmt *store_stmt1, Stmt *store_stmt2) { if (!store_stmt2->is()) return false; return irpass::analysis::same_statements( - store_stmt1->as()->data, - store_stmt2->as()->data); + store_stmt1->as()->val, + store_stmt2->as()->val); } else { if (!store_stmt2->is()) return false; return irpass::analysis::same_statements( - store_stmt1->as()->data, - store_stmt2->as()->data); + store_stmt1->as()->val, + store_stmt2->as()->val); } } @@ -146,9 +146,9 @@ void StateMachine::load(Stmt *load_stmt) { if (last_store_forwardable) { // store-forwarding if (last_store->is()) - load_stmt->replace_with(last_store->as()->data); + load_stmt->replace_with(last_store->as()->val); else - load_stmt->replace_with(last_store->as()->data); + load_stmt->replace_with(last_store->as()->val); load_stmt->parent->erase(load_stmt); throw IRModified(); } diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 7b8083c007d88..57c6700c275d2 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -155,14 +155,14 @@ Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() { if (parent->statements[i]->is()) { auto store = parent->statements[i]->as(); // TI_ASSERT(store->width() == 1); - if (store->ptr == this->ptr[0].var) { + if (store->dest == this->src[0].var) { // found return store; } } else if (parent->statements[i]->is()) { auto alloca = parent->statements[i]->as(); // TI_ASSERT(alloca->width() == 1); - if (alloca == this->ptr[0].var) { + if (alloca == this->src[0].var) { return alloca; } } @@ -171,8 +171,8 @@ Stmt *LocalLoadStmt::previous_store_or_alloca_in_block() { } bool LocalLoadStmt::same_source() const { - for (int i = 1; i < (int)ptr.size(); i++) { - if (ptr[i].var != ptr[0].var) + for (int i = 1; i < (int)src.size(); i++) { + if (src[i].var != src[0].var) return false; } return true; @@ -180,7 +180,7 @@ bool LocalLoadStmt::same_source() const { bool LocalLoadStmt::has_source(Stmt *alloca) const { for (int i = 0; i < width(); i++) { - if (ptr[i].var == alloca) + if (src[i].var == alloca) return true; } return false; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index cd5c24717354e..46e552bb08ee3 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -364,9 +364,9 @@ class LoopUniqueStmt : public Stmt { class GlobalLoadStmt : public Stmt { public: - Stmt *ptr; + Stmt *src; - GlobalLoadStmt(Stmt *ptr) : ptr(ptr) { + GlobalLoadStmt(Stmt *src) : src(src) { TI_STMT_REG_FIELDS; } @@ -378,15 +378,15 @@ class GlobalLoadStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, ptr); + TI_STMT_DEF_FIELDS(ret_type, src); TI_DEFINE_ACCEPT_AND_CLONE; }; class GlobalStoreStmt : public Stmt { public: - Stmt *ptr, *data; + Stmt *dest, *val; - GlobalStoreStmt(Stmt *ptr, Stmt *data) : ptr(ptr), data(data) { + GlobalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { TI_STMT_REG_FIELDS; } @@ -394,15 +394,15 @@ class GlobalStoreStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, ptr, data); + TI_STMT_DEF_FIELDS(ret_type, dest, val); TI_DEFINE_ACCEPT_AND_CLONE; }; class LocalLoadStmt : public Stmt { public: - LaneAttribute ptr; + LaneAttribute src; - LocalLoadStmt(const LaneAttribute &ptr) : ptr(ptr) { + LocalLoadStmt(const LaneAttribute &src) : src(src) { TI_STMT_REG_FIELDS; } @@ -419,17 +419,17 @@ class LocalLoadStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, ptr); + TI_STMT_DEF_FIELDS(ret_type, src); TI_DEFINE_ACCEPT_AND_CLONE; }; class LocalStoreStmt : public Stmt { public: - Stmt *ptr; - Stmt *data; + Stmt *dest; + Stmt *val; - LocalStoreStmt(Stmt *ptr, Stmt *data) : ptr(ptr), data(data) { - TI_ASSERT(ptr->is()); + LocalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { + TI_ASSERT(dest->is()); TI_STMT_REG_FIELDS; } @@ -445,7 +445,7 @@ class LocalStoreStmt : public Stmt { return false; } - TI_STMT_DEF_FIELDS(ret_type, ptr, data); + TI_STMT_DEF_FIELDS(ret_type, dest, val); TI_DEFINE_ACCEPT_AND_CLONE; }; diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index d3b164c08d93c..ddac3c7b2bceb 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -159,16 +159,16 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { // For a global load, GlobalPtrStmt has already been handled in // get_meta_input_value_states(). if (auto global_store = stmt->cast()) { - if (auto ptr = global_store->ptr->cast()) { - for (auto &snode : ptr->snodes.data) { + if (auto dest = global_store->dest->cast()) { + for (auto &snode : dest->snodes.data) { meta.output_states.insert( ir_bank->get_async_state(snode, AsyncState::Type::value)); } } } if (auto global_atomic = stmt->cast()) { - if (auto ptr = global_atomic->dest->cast()) { - for (auto &snode : ptr->snodes.data) { + if (auto dest = global_atomic->dest->cast()) { + for (auto &snode : dest->snodes.data) { // input_state is already handled in // get_meta_input_value_states(). meta.output_states.insert( diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 20f463d65ef03..6456e672cd6c7 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -48,11 +48,11 @@ class IdentifyIndependentBlocks : public BasicStmtVisitor { // TODO: remove this abuse since it *gathers nothing* irpass::analysis::gather_statements(block, [&](Stmt *stmt) -> bool { if (auto local_load = stmt->cast(); local_load) { - for (auto &lane : local_load->ptr.data) { + for (auto &lane : local_load->src.data) { touched_allocas.insert(lane.var->as()); } } else if (auto local_store = stmt->cast(); local_store) { - touched_allocas.insert(local_store->ptr->as()); + touched_allocas.insert(local_store->dest->as()); } return false; }); @@ -192,7 +192,7 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { alloc->parent, [&](Stmt *s) { if (auto store = s->cast()) - return store->ptr == alloc; + return store->dest == alloc; else if (auto atomic = s->cast()) { return atomic->dest == alloc; } else { @@ -218,13 +218,13 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - if (stmt->ptr[0].var->is()) - stmt->replace_with(Stmt::make(stmt->ptr[0].var)); + if (stmt->src[0].var->is()) + stmt->replace_with(Stmt::make(stmt->src[0].var)); } void visit(LocalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - stmt->replace_with(Stmt::make(stmt->ptr, stmt->data)); + stmt->replace_with(Stmt::make(stmt->dest, stmt->val)); } }; @@ -647,9 +647,9 @@ class MakeAdjoint : public IRVisitor { void visit(GlobalLoadStmt *stmt) override { // issue global store to adjoint - GlobalPtrStmt *ptr = stmt->ptr->as(); - TI_ASSERT(ptr->width() == 1); - auto snodes = ptr->snodes; + GlobalPtrStmt *src = stmt->src->as(); + TI_ASSERT(src->width() == 1); + auto snodes = src->snodes; if (!snodes[0]->has_grad()) { // No adjoint SNode. Do nothing return; @@ -660,36 +660,36 @@ class MakeAdjoint : public IRVisitor { } TI_ASSERT(snodes[0]->get_grad() != nullptr); snodes[0] = snodes[0]->get_grad(); - auto adj_ptr = insert(snodes, ptr->indices); + auto adj_ptr = insert(snodes, src->indices); insert(AtomicOpType::add, adj_ptr, load(adjoint(stmt))); } void visit(GlobalStoreStmt *stmt) override { // erase and replace with global load adjoint - GlobalPtrStmt *ptr = stmt->ptr->as(); - TI_ASSERT(ptr->width() == 1); - auto snodes = ptr->snodes; + GlobalPtrStmt *dest = stmt->dest->as(); + TI_ASSERT(dest->width() == 1); + auto snodes = dest->snodes; if (!snodes[0]->has_grad()) { // no gradient (likely integer types) return; } TI_ASSERT(snodes[0]->get_grad() != nullptr); snodes[0] = snodes[0]->get_grad(); - auto adjoint_ptr = insert(snodes, ptr->indices); + auto adjoint_ptr = insert(snodes, dest->indices); auto load = insert(adjoint_ptr); - accumulate(stmt->data, load); + accumulate(stmt->val, load); stmt->parent->erase(stmt); } void visit(AtomicOpStmt *stmt) override { // erase and replace with global load adjoint - GlobalPtrStmt *ptr = stmt->dest->as(); - TI_ASSERT(ptr->width() == 1); - auto snodes = ptr->snodes; + GlobalPtrStmt *dest = stmt->dest->as(); + TI_ASSERT(dest->width() == 1); + auto snodes = dest->snodes; if (snodes[0]->has_grad()) { TI_ASSERT(snodes[0]->get_grad() != nullptr); snodes[0] = snodes[0]->get_grad(); - auto adjoint_ptr = insert(snodes, ptr->indices); + auto adjoint_ptr = insert(snodes, dest->indices); accumulate(stmt->val, insert(adjoint_ptr)); } else { // no gradient (likely integer types) diff --git a/taichi/transforms/bit_loop_vectorize.cpp b/taichi/transforms/bit_loop_vectorize.cpp index 266970e6f4fe0..8a127e9a5a091 100644 --- a/taichi/transforms/bit_loop_vectorize.cpp +++ b/taichi/transforms/bit_loop_vectorize.cpp @@ -38,12 +38,12 @@ class BitLoopVectorize : public IRVisitor { } void visit(GlobalLoadStmt *stmt) override { - auto ptr_type = stmt->ptr->ret_type->as(); + auto ptr_type = stmt->src->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { if (auto cit = ptr_type->get_pointee_type()->cast()) { // rewrite the previous GlobalPtrStmt's return type from *cit to // *phy_type - auto ptr = stmt->ptr->cast(); + auto ptr = stmt->src->cast(); auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( bit_array_physical_type, false); DataType new_ret_type(ptr_physical_type); @@ -125,12 +125,12 @@ class BitLoopVectorize : public IRVisitor { } void visit(GlobalStoreStmt *stmt) override { - auto ptr_type = stmt->ptr->ret_type->as(); + auto ptr_type = stmt->dest->ret_type->as(); if (in_struct_for_loop && bit_vectorize != 1) { if (auto cit = ptr_type->get_pointee_type()->cast()) { // rewrite the previous GlobalPtrStmt's return type from *cit to // *phy_type - auto ptr = stmt->ptr->cast(); + auto ptr = stmt->dest->cast(); auto ptr_physical_type = TypeFactory::get_instance().get_pointer_type( bit_array_physical_type, false); DataType new_ret_type(ptr_physical_type); @@ -174,7 +174,7 @@ class BitLoopVectorize : public IRVisitor { } else if (stmt->op_type == BinaryOpType::cmp_eq) { if (auto lhs = stmt->lhs->cast()) { // case 0: lhs is a vectorized global load from the bit array - if (auto ptr = lhs->ptr->cast(); + if (auto ptr = lhs->src->cast(); ptr && ptr->is_bit_vectorized) { int32 rhs_val = get_constant_value(stmt->rhs); // TODO: we limit 1 for now, 0 should be easy to implement by a @@ -196,7 +196,7 @@ class BitLoopVectorize : public IRVisitor { } } else if (auto lhs = stmt->lhs->cast()) { // case 1: lhs is a local load from a local adder structure - auto it = transformed_atomics.find(lhs->ptr[0].var); + auto it = transformed_atomics.find(lhs->src[0].var); if (it != transformed_atomics.end()) { int32 rhs_val = get_constant_value(stmt->rhs); // TODO: we limit 2 and 3 for now, the other case should be diff --git a/taichi/transforms/flag_access.cpp b/taichi/transforms/flag_access.cpp index b49b140066abe..02831b376b4dd 100644 --- a/taichi/transforms/flag_access.cpp +++ b/taichi/transforms/flag_access.cpp @@ -51,8 +51,8 @@ class FlagAccess : public IRVisitor { } void visit(GlobalStoreStmt *stmt) { - if (stmt->ptr->is()) { - stmt->ptr->as()->activate = true; + if (stmt->dest->is()) { + stmt->dest->as()->activate = true; } } diff --git a/taichi/transforms/insert_scratch_pad.cpp b/taichi/transforms/insert_scratch_pad.cpp index c7587f1b3765f..482ccb0195753 100644 --- a/taichi/transforms/insert_scratch_pad.cpp +++ b/taichi/transforms/insert_scratch_pad.cpp @@ -118,12 +118,12 @@ class BLSAnalysis : public BasicStmtVisitor { // Do not eliminate global data access void visit(GlobalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); // TODO: support vectorization - record_access(stmt->ptr, AccessFlag::read); + record_access(stmt->src, AccessFlag::read); } void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); // TODO: support vectorization - record_access(stmt->ptr, AccessFlag::write); + record_access(stmt->dest, AccessFlag::write); } void visit(AtomicOpStmt *stmt) override { diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 6a8399ab9a6c7..16eb936e95cf3 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -390,22 +390,22 @@ class IRPrinter : public IRVisitor { void visit(LocalLoadStmt *stmt) override { print("{}{} = local load [{}]", stmt->type_hint(), stmt->name(), - to_string(stmt->ptr)); + to_string(stmt->src)); } void visit(LocalStoreStmt *stmt) override { print("{}{} : local store [{} <- {}]", stmt->type_hint(), stmt->name(), - stmt->ptr->name(), stmt->data->name()); + stmt->dest->name(), stmt->val->name()); } void visit(GlobalLoadStmt *stmt) override { print("{}{} = global load {}", stmt->type_hint(), stmt->name(), - stmt->ptr->name()); + stmt->src->name()); } void visit(GlobalStoreStmt *stmt) override { print("{}{} : global store [{} <- {}]", stmt->type_hint(), stmt->name(), - stmt->ptr->name(), stmt->data->name()); + stmt->dest->name(), stmt->val->name()); } void visit(ElementShuffleStmt *stmt) override { diff --git a/taichi/transforms/loop_vectorize.cpp b/taichi/transforms/loop_vectorize.cpp index 79bb55efd861c..8156ee84f1c3c 100644 --- a/taichi/transforms/loop_vectorize.cpp +++ b/taichi/transforms/loop_vectorize.cpp @@ -89,17 +89,17 @@ class LoopVectorize : public IRVisitor { return; int original_width = stmt->width(); widen_type(stmt->ret_type, vectorize); - stmt->ptr.repeat(vectorize); + stmt->src.repeat(vectorize); // TODO: this can be buggy - int stride = stmt->ptr[original_width - 1].offset + 1; - if (stmt->ptr[0].var->width() != 1) { + int stride = stmt->src[original_width - 1].offset + 1; + if (stmt->src[0].var->width() != 1) { for (int i = 0; i < vectorize; i++) { for (int j = 0; j < original_width; j++) { - stmt->ptr[i * original_width + j].offset += i * stride; + stmt->src[i * original_width + j].offset += i * stride; } } } - if (loop_var && stmt->same_source() && stmt->ptr[0].var == loop_var) { + if (loop_var && stmt->same_source() && stmt->src[0].var == loop_var) { // insert_before_me LaneAttribute const_offsets; const_offsets.resize(vectorize * original_width); diff --git a/taichi/transforms/lower_access.cpp b/taichi/transforms/lower_access.cpp index eaa8171cdb6e1..16df1a33a33d2 100644 --- a/taichi/transforms/lower_access.cpp +++ b/taichi/transforms/lower_access.cpp @@ -197,21 +197,21 @@ class LowerAccess : public IRVisitor { } void visit(GlobalLoadStmt *stmt) override { - if (stmt->ptr->is()) { + if (stmt->src->is()) { // No need to activate for all read accesses - auto lowered = lower_vector_ptr(stmt->ptr->as(), false); - stmt->ptr = lowered.back().get(); + auto lowered = lower_vector_ptr(stmt->src->as(), false); + stmt->src = lowered.back().get(); modifier.insert_before(stmt, std::move(lowered)); } } void visit(GlobalStoreStmt *stmt) override { - if (stmt->ptr->is()) { - auto ptr = stmt->ptr->as(); + if (stmt->dest->is()) { + auto ptr = stmt->dest->as(); // If ptr already has activate = false, no need to activate all the // generated micro-access ops. Otherwise, activate the nodes. auto lowered = lower_vector_ptr(ptr, ptr->activate); - stmt->ptr = lowered.back().get(); + stmt->dest = lowered.back().get(); modifier.insert_before(stmt, std::move(lowered)); } } @@ -246,9 +246,9 @@ class LowerAccess : public IRVisitor { } void visit(LocalStoreStmt *stmt) override { - if (stmt->data->is()) { - auto lowered = lower_vector_ptr(stmt->data->as(), true); - stmt->data = lowered.back().get(); + if (stmt->val->is()) { + auto lowered = lower_vector_ptr(stmt->val->as(), true); + stmt->val = lowered.back().get(); modifier.insert_before(stmt, std::move(lowered)); } } diff --git a/taichi/transforms/make_thread_local.cpp b/taichi/transforms/make_thread_local.cpp index c0051acd4f31e..e95e4e7befef6 100644 --- a/taichi/transforms/make_thread_local.cpp +++ b/taichi/transforms/make_thread_local.cpp @@ -54,11 +54,11 @@ std::vector find_global_reduction_destinations( auto related_global_mem_ops = irpass::analysis::gather_statements(offload, [&](Stmt *stmt) { if (auto load = stmt->cast()) { - if (irpass::analysis::maybe_same_address(load->ptr, dest)) { + if (irpass::analysis::maybe_same_address(load->src, dest)) { return true; } } else if (auto store = stmt->cast()) { - if (irpass::analysis::maybe_same_address(store->ptr, dest)) { + if (irpass::analysis::maybe_same_address(store->dest, dest)) { return true; } } else if (auto atomic = stmt->cast()) { diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index e38867fc9ed96..ea6c6e303bf47 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -434,7 +434,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { // Replace local LD/ST with global LD/ST void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - auto alloca = stmt->ptr[0].var; + auto alloca = stmt->src[0].var; if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) return; @@ -450,10 +450,10 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } void visit(LocalStoreStmt *stmt) override { - if (visit_operand(stmt, stmt->locate_operand(&stmt->data))) + if (visit_operand(stmt, stmt->locate_operand(&stmt->val))) throw IRModified(); TI_ASSERT(stmt->width() == 1); - auto alloca = stmt->ptr; + auto alloca = stmt->dest; if (local_to_global_offset.find(alloca) == local_to_global_offset.end()) return; @@ -462,7 +462,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto ptr = replacement.push_back( local_to_global_offset[alloca], ret_type); - replacement.push_back(ptr, stmt->data); + replacement.push_back(ptr, stmt->val); stmt->parent->replace_with(stmt, std::move(replacement)); throw IRModified(); diff --git a/taichi/transforms/optimize_bit_struct_stores.cpp b/taichi/transforms/optimize_bit_struct_stores.cpp index dae773da0d27d..0947eceb97972 100644 --- a/taichi/transforms/optimize_bit_struct_stores.cpp +++ b/taichi/transforms/optimize_bit_struct_stores.cpp @@ -25,7 +25,7 @@ class CreateBitStructStores : public BasicStmtVisitor { } void visit(GlobalStoreStmt *stmt) override { - auto get_ch = stmt->ptr->cast(); + auto get_ch = stmt->dest->cast(); if (!get_ch || get_ch->input_snode->type != SNodeType::bit_struct) return; @@ -41,7 +41,7 @@ class CreateBitStructStores : public BasicStmtVisitor { get_ch->output_snode->owns_shared_exponent) { auto s = Stmt::make(get_ch->input_ptr, std::vector{get_ch->chid}, - std::vector{stmt->data}); + std::vector{stmt->val}); stmt->replace_with(VecStatement(std::move(s))); } } diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 83b4a1e984823..7ebab9ec23fc6 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -84,7 +84,7 @@ class BasicBlockSimplify : public IRVisitor { auto &bstmt_data = *bstmt; if (typeid(bstmt_data) == typeid(*stmt)) { auto bstmt_ = bstmt->as(); - bool same = stmt->ptr == bstmt_->ptr; + bool same = stmt->src == bstmt_->src; if (same) { // no store to the var? bool has_store = false; @@ -107,10 +107,10 @@ class BasicBlockSimplify : public IRVisitor { [&](Stmt *s) { if (auto store = s->cast()) return irpass::analysis::maybe_same_address( - store->ptr, stmt->ptr); + store->dest, stmt->src); else if (auto atomic = s->cast()) return irpass::analysis::maybe_same_address( - atomic->dest, stmt->ptr); + atomic->dest, stmt->src); else return false; }) @@ -433,17 +433,17 @@ class BasicBlockSimplify : public IRVisitor { auto store = clause[i]->as(); auto lanes = LaneAttribute(); for (int l = 0; l < store->width(); l++) { - lanes.push_back(LocalAddress(store->ptr, l)); + lanes.push_back(LocalAddress(store->dest, l)); } auto load = if_stmt->insert_before_me(Stmt::make(lanes)); irpass::type_check(load); auto select = if_stmt->insert_before_me( Stmt::make(TernaryOpType::select, if_stmt->cond, - true_branch ? store->data : load, - true_branch ? load : store->data)); + true_branch ? store->val : load, + true_branch ? load : store->val)); irpass::type_check(select); - store->data = select; + store->val = select; if_stmt->insert_before_me(std::move(clause[i])); } else { if_stmt->insert_before_me(std::move(clause[i])); diff --git a/taichi/transforms/type_check.cpp b/taichi/transforms/type_check.cpp index 69379c66d3214..1194121f7bb18 100644 --- a/taichi/transforms/type_check.cpp +++ b/taichi/transforms/type_check.cpp @@ -81,35 +81,35 @@ class TypeCheck : public IRVisitor { void visit(LocalLoadStmt *stmt) { TI_ASSERT(stmt->width() == 1); - auto lookup = stmt->ptr[0].var->ret_type; + auto lookup = stmt->src[0].var->ret_type; stmt->ret_type = lookup; } void visit(LocalStoreStmt *stmt) { - if (stmt->ptr->ret_type->is_primitive(PrimitiveTypeID::unknown)) { + if (stmt->dest->ret_type->is_primitive(PrimitiveTypeID::unknown)) { // Infer data type for alloca - stmt->ptr->ret_type = stmt->data->ret_type; + stmt->dest->ret_type = stmt->val->ret_type; } auto common_container_type = - promoted_type(stmt->ptr->ret_type, stmt->data->ret_type); + promoted_type(stmt->dest->ret_type, stmt->val->ret_type); - auto old_data = stmt->data; - if (stmt->ptr->ret_type != stmt->data->ret_type) { - stmt->data = - insert_type_cast_before(stmt, stmt->data, stmt->ptr->ret_type); + auto old_data = stmt->val; + if (stmt->dest->ret_type != stmt->val->ret_type) { + stmt->val = + insert_type_cast_before(stmt, stmt->val, stmt->dest->ret_type); } - if (stmt->ptr->ret_type != common_container_type) { + if (stmt->dest->ret_type != common_container_type) { TI_WARN( "[{}] Local store may lose precision (target = {}, value = {}) at", - stmt->name(), stmt->ptr->ret_data_type_name(), + stmt->name(), stmt->dest->ret_data_type_name(), old_data->ret_data_type_name(), stmt->id); TI_WARN("\n{}", stmt->tb); } - stmt->ret_type = stmt->ptr->ret_type; + stmt->ret_type = stmt->dest->ret_type; } void visit(GlobalLoadStmt *stmt) { - auto pointee_type = stmt->ptr->ret_type.ptr_removed(); + auto pointee_type = stmt->src->ret_type.ptr_removed(); if (auto bit_struct = pointee_type->cast()) { stmt->ret_type = bit_struct->get_physical_type(); } else { @@ -159,21 +159,21 @@ class TypeCheck : public IRVisitor { } void visit(GlobalStoreStmt *stmt) { - auto dst_value_type = stmt->ptr->ret_type.ptr_removed(); + auto dst_value_type = stmt->dest->ret_type.ptr_removed(); if (dst_value_type->is() || dst_value_type->is()) { // We force the value type to be the compute_type of the bit pointer. // Casting from compute_type to physical_type is handled in codegen. dst_value_type = dst_value_type->get_compute_type(); } - auto promoted = promoted_type(dst_value_type, stmt->data->ret_type); - auto input_type = stmt->data->ret_data_type_name(); - if (dst_value_type != stmt->data->ret_type) { - stmt->data = insert_type_cast_before(stmt, stmt->data, dst_value_type); + auto promoted = promoted_type(dst_value_type, stmt->val->ret_type); + auto input_type = stmt->val->ret_data_type_name(); + if (dst_value_type != stmt->val->ret_type) { + stmt->val = insert_type_cast_before(stmt, stmt->val, dst_value_type); } // TODO: do not use "promoted" here since u8 + u8 = i32 in C++ and storing // u8 to u8 leads to extra warnings. - if (dst_value_type != promoted && dst_value_type != stmt->data->ret_type) { + if (dst_value_type != promoted && dst_value_type != stmt->val->ret_type) { TI_WARN("[{}] Global store may lose precision: {} <- {}, at", stmt->name(), dst_value_type->to_string(), input_type); TI_WARN("\n{}", stmt->tb); diff --git a/taichi/transforms/variable_optimization.cpp b/taichi/transforms/variable_optimization.cpp index 109289fb36d3a..8ad5d4f41d39a 100644 --- a/taichi/transforms/variable_optimization.cpp +++ b/taichi/transforms/variable_optimization.cpp @@ -134,18 +134,18 @@ class AllocaOptimize : public VariableOptimize { void visit(LocalStoreStmt *stmt) override { if (maybe_run) - get_state_machine(stmt->ptr).maybe_store(stmt); + get_state_machine(stmt->dest).maybe_store(stmt); else - get_state_machine(stmt->ptr).store(stmt); + get_state_machine(stmt->dest).store(stmt); } void visit(LocalLoadStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - TI_ASSERT(stmt->ptr[0].offset == 0); + TI_ASSERT(stmt->src[0].offset == 0); if (maybe_run) - get_state_machine(stmt->ptr[0].var).maybe_load(); + get_state_machine(stmt->src[0].var).maybe_load(); else - get_state_machine(stmt->ptr[0].var).load(stmt); + get_state_machine(stmt->src[0].var).load(stmt); } void visit(IfStmt *if_stmt) override { @@ -248,21 +248,21 @@ class GlobalTempOptimize : public VariableOptimize { } void visit(GlobalStoreStmt *stmt) override { - if (!stmt->ptr->is()) + if (!stmt->dest->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_store(stmt); + get_state_machine(stmt->dest).maybe_store(stmt); else - get_state_machine(stmt->ptr).store(stmt); + get_state_machine(stmt->dest).store(stmt); } void visit(GlobalLoadStmt *stmt) override { - if (!stmt->ptr->is()) + if (!stmt->src->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_load(); + get_state_machine(stmt->src).maybe_load(); else - get_state_machine(stmt->ptr).load(stmt); + get_state_machine(stmt->src).load(stmt); } void visit(IfStmt *if_stmt) override { @@ -393,13 +393,13 @@ class GlobalPtrOptimize : public VariableOptimize { } void visit(GlobalStoreStmt *stmt) override { - if (!stmt->ptr->is()) + if (!stmt->dest->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_store(stmt); + get_state_machine(stmt->dest).maybe_store(stmt); else - get_state_machine(stmt->ptr).store(stmt); - auto dest = stmt->ptr->as(); + get_state_machine(stmt->dest).store(stmt); + auto dest = stmt->dest->as(); for (auto &var : state_machines[dest->snodes[0]->id]) { if (var.first != dest && irpass::analysis::maybe_same_address(dest, var.first)) { @@ -409,13 +409,13 @@ class GlobalPtrOptimize : public VariableOptimize { } void visit(GlobalLoadStmt *stmt) override { - if (!stmt->ptr->is()) + if (!stmt->src->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_load(); + get_state_machine(stmt->src).maybe_load(); else - get_state_machine(stmt->ptr).load(stmt); - auto dest = stmt->ptr->as(); + get_state_machine(stmt->src).load(stmt); + auto dest = stmt->src->as(); for (auto &var : state_machines[dest->snodes[0]->id]) { if (var.first != dest && irpass::analysis::maybe_same_address(dest, var.first)) { @@ -540,30 +540,30 @@ class OtherVariableOptimize : public VariableOptimize { } void visit(GlobalStoreStmt *stmt) override { - if (stmt->ptr->is()) + if (stmt->dest->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_store(stmt); + get_state_machine(stmt->dest).maybe_store(stmt); else - get_state_machine(stmt->ptr).store(stmt); + get_state_machine(stmt->dest).store(stmt); for (auto &var : state_machines) { - if (var.first != stmt->ptr && - irpass::analysis::maybe_same_address(stmt->ptr, var.first)) { + if (var.first != stmt->dest && + irpass::analysis::maybe_same_address(stmt->dest, var.first)) { var.second.maybe_store(stmt); } } } void visit(GlobalLoadStmt *stmt) override { - if (stmt->ptr->is()) + if (stmt->src->is()) return; if (maybe_run) - get_state_machine(stmt->ptr).maybe_load(); + get_state_machine(stmt->src).maybe_load(); else - get_state_machine(stmt->ptr).load(stmt); + get_state_machine(stmt->src).load(stmt); for (auto &var : state_machines) { - if (var.first != stmt->ptr && - irpass::analysis::maybe_same_address(stmt->ptr, var.first)) { + if (var.first != stmt->src && + irpass::analysis::maybe_same_address(stmt->src, var.first)) { var.second.maybe_load(); } } diff --git a/taichi/transforms/vector_split.cpp b/taichi/transforms/vector_split.cpp index aae6900a0fe73..ea7bcdbc0418a 100644 --- a/taichi/transforms/vector_split.cpp +++ b/taichi/transforms/vector_split.cpp @@ -118,12 +118,12 @@ class BasicBlockVectorSplit : public IRVisitor { if (stmt_->is()) { auto stmt = stmt_->as(); for (int l = 0; l < stmt->width(); l++) { - auto *old_var = stmt->ptr[l].var; + auto *old_var = stmt->src[l].var; if (origin2split.find(old_var) != origin2split.end()) { auto new_var = - origin2split[old_var][stmt->ptr[l].offset / max_width]; - stmt->ptr[l].var = new_var; - stmt->ptr[l].offset %= max_width; + origin2split[old_var][stmt->src[l].offset / max_width]; + stmt->src[l].var = new_var; + stmt->src[l].offset %= max_width; // TI_WARN("replaced..."); } } @@ -183,7 +183,7 @@ class BasicBlockVectorSplit : public IRVisitor { int new_width = need_split ? max_width : stmt->width(); ptr.reserve(new_width); for (int j = 0; j < new_width; j++) { - LocalAddress addr(stmt->ptr[lane_start(i) + j]); + LocalAddress addr(stmt->src[lane_start(i) + j]); if (origin2split.find(addr.var) == origin2split.end()) { ptr.push_back(addr); } else { @@ -197,21 +197,21 @@ class BasicBlockVectorSplit : public IRVisitor { void visit(LocalStoreStmt *stmt) override { for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->ptr, i), - lookup(stmt->data, i)); + current_split[i] = Stmt::make(lookup(stmt->dest, i), + lookup(stmt->val, i)); } } void visit(GlobalLoadStmt *stmt) override { for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->ptr, i)); + current_split[i] = Stmt::make(lookup(stmt->src, i)); } } void visit(GlobalStoreStmt *stmt) override { for (int i = 0; i < current_split_factor; i++) { - current_split[i] = Stmt::make(lookup(stmt->ptr, i), - lookup(stmt->data, i)); + current_split[i] = Stmt::make(lookup(stmt->dest, i), + lookup(stmt->val, i)); } } From cc820f477674e844b92df71bc7060fab9fc8ce3a Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Wed, 7 Apr 2021 11:58:48 -0400 Subject: [PATCH 2/8] [skip ci] enforce code format --- taichi/codegen/codegen_llvm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 91d5918448a18..a570375221240 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1196,7 +1196,7 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) { TI_ASSERT(digits_snode->parent == exponent_snode->parent); auto exponent_bit_ptr = offset_bit_ptr(llvm_val[stmt->dest], exponent_snode->bit_offset - - digits_snode->bit_offset); + digits_snode->bit_offset); store_custom_int(exponent_bit_ptr, exponent_cit, exponent_bits); store_value = digit_bits; From 39f8c307452e8f83c49bc68864e2214046df1453 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Apr 2021 00:11:01 +0800 Subject: [PATCH 3/8] Fix CE --- taichi/backends/cuda/codegen_cuda.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/taichi/backends/cuda/codegen_cuda.cpp b/taichi/backends/cuda/codegen_cuda.cpp index 2f5646f7bdce8..68cfdd955c2c4 100644 --- a/taichi/backends/cuda/codegen_cuda.cpp +++ b/taichi/backends/cuda/codegen_cuda.cpp @@ -417,7 +417,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { } void visit(GlobalLoadStmt *stmt) override { - if (auto get_ch = stmt->ptr->cast(); get_ch) { + if (auto get_ch = stmt->src->cast(); get_ch) { bool should_cache_as_read_only = false; if (current_offload->mem_access_opt.has_flag( get_ch->output_snode, SNodeAccessFlag::read_only)) { @@ -425,7 +425,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { } if (should_cache_as_read_only) { auto dtype = stmt->ret_type; - if (auto ptr_type = stmt->ptr->ret_type->as(); + if (auto ptr_type = stmt->src->ret_type->as(); ptr_type->is_bit_pointer()) { // Bit pointer case. auto val_type = ptr_type->get_pointee_type(); @@ -436,13 +436,13 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { if (auto cit = val_type->cast()) { int_in_mem = val_type; dtype = cit->get_physical_type(); - auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->ptr]); + auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]); data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype)); auto data = create_intrinsic_load(dtype, data_ptr); llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem); } else if (auto cft = val_type->cast()) { // TODO: support __ldg - llvm_val[stmt] = load_custom_float(stmt->ptr); + llvm_val[stmt] = load_custom_float(stmt->src); } else { TI_NOT_IMPLEMENTED; } @@ -450,7 +450,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM { // Byte pointer case. // Issue an CUDA "__ldg" instruction so that data are cached in // the CUDA read-only data cache. - llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->ptr]); + llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->src]); } } else { CodeGenLLVM::visit(stmt); From 18ca4494d5660d2b80e6279458289b027ec50f0b Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Apr 2021 14:02:20 +0800 Subject: [PATCH 4/8] Update taichi/ir/statements.h Co-authored-by: Ye Kuang --- taichi/ir/statements.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 46e552bb08ee3..5296e0ef106e1 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -384,7 +384,8 @@ class GlobalLoadStmt : public Stmt { class GlobalStoreStmt : public Stmt { public: - Stmt *dest, *val; + Stmt *dest; + Stmt *val; GlobalStoreStmt(Stmt *dest, Stmt *val) : dest(dest), val(val) { TI_STMT_REG_FIELDS; From d4f41d0183214583e8484093c01a34dbd5794e71 Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Apr 2021 14:05:17 +0800 Subject: [PATCH 5/8] code format --- taichi/transforms/auto_diff.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 6456e672cd6c7..ed139a492579e 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -188,18 +188,16 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { TI_ASSERT(alloc->width() == 1); - bool load_only = irpass::analysis::gather_statements( - alloc->parent, - [&](Stmt *s) { - if (auto store = s->cast()) - return store->dest == alloc; - else if (auto atomic = s->cast()) { - return atomic->dest == alloc; - } else { - return false; - } - }) - .empty(); + bool load_only = + irpass::analysis::gather_statements(alloc->parent, [&](Stmt *s) { + if (auto store = s->cast()) + return store->dest == alloc; + else if (auto atomic = s->cast()) { + return atomic->dest == alloc; + } else { + return false; + } + }).empty(); if (!load_only) { auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make( From 64b737f708950daa2259d4402bb2f2abe3357d3a Mon Sep 17 00:00:00 2001 From: Taichi Gardener Date: Thu, 8 Apr 2021 02:05:47 -0400 Subject: [PATCH 6/8] [skip ci] enforce code format --- taichi/transforms/auto_diff.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index ed139a492579e..6456e672cd6c7 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -188,16 +188,18 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { TI_ASSERT(alloc->width() == 1); - bool load_only = - irpass::analysis::gather_statements(alloc->parent, [&](Stmt *s) { - if (auto store = s->cast()) - return store->dest == alloc; - else if (auto atomic = s->cast()) { - return atomic->dest == alloc; - } else { - return false; - } - }).empty(); + bool load_only = irpass::analysis::gather_statements( + alloc->parent, + [&](Stmt *s) { + if (auto store = s->cast()) + return store->dest == alloc; + else if (auto atomic = s->cast()) { + return atomic->dest == alloc; + } else { + return false; + } + }) + .empty(); if (!load_only) { auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make( From aabe6d7d7cf48da675c1cb6cf2c9829887133c2d Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Apr 2021 14:09:27 +0800 Subject: [PATCH 7/8] Revert "[skip ci] enforce code format" This reverts commit 64b737f7 --- taichi/transforms/auto_diff.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/taichi/transforms/auto_diff.cpp b/taichi/transforms/auto_diff.cpp index 6456e672cd6c7..ed139a492579e 100644 --- a/taichi/transforms/auto_diff.cpp +++ b/taichi/transforms/auto_diff.cpp @@ -188,18 +188,16 @@ class ReplaceLocalVarWithStacks : public BasicStmtVisitor { void visit(AllocaStmt *alloc) override { TI_ASSERT(alloc->width() == 1); - bool load_only = irpass::analysis::gather_statements( - alloc->parent, - [&](Stmt *s) { - if (auto store = s->cast()) - return store->dest == alloc; - else if (auto atomic = s->cast()) { - return atomic->dest == alloc; - } else { - return false; - } - }) - .empty(); + bool load_only = + irpass::analysis::gather_statements(alloc->parent, [&](Stmt *s) { + if (auto store = s->cast()) + return store->dest == alloc; + else if (auto atomic = s->cast()) { + return atomic->dest == alloc; + } else { + return false; + } + }).empty(); if (!load_only) { auto dtype = alloc->ret_type; auto stack_alloca = Stmt::make( From 4f4acc51aff1c2d26334780152126977dcc30b9b Mon Sep 17 00:00:00 2001 From: xumingkuan Date: Thu, 8 Apr 2021 14:21:17 +0800 Subject: [PATCH 8/8] fix build error --- taichi/backends/cc/codegen_cc.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index d2ee5e968417b..e41eeaf374a12 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -126,12 +126,12 @@ class CCTransformer : public IRVisitor { TI_ASSERT(stmt->width() == 1); emit("{} = *{};", define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()), - stmt->ptr->raw_name()); + stmt->src->raw_name()); } void visit(GlobalStoreStmt *stmt) override { TI_ASSERT(stmt->width() == 1); - emit("*{} = {};", stmt->ptr->raw_name(), stmt->data->raw_name()); + emit("*{} = {};", stmt->dest->raw_name(), stmt->val->raw_name()); } void visit(GlobalTemporaryStmt *stmt) override { @@ -202,21 +202,21 @@ class CCTransformer : public IRVisitor { void visit(LocalLoadStmt *stmt) override { bool linear_index = true; - for (int i = 0; i < (int)stmt->ptr.size(); i++) { - if (stmt->ptr[i].offset != i) { + for (int i = 0; i < (int)stmt->src.size(); i++) { + if (stmt->src[i].offset != i) { linear_index = false; } } TI_ASSERT(stmt->same_source() && linear_index && - stmt->width() == stmt->ptr[0].var->width()); + stmt->width() == stmt->src[0].var->width()); auto var = define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()); - emit("{} = {};", var, stmt->ptr[0].var->raw_name()); + emit("{} = {};", var, stmt->src[0].var->raw_name()); } void visit(LocalStoreStmt *stmt) override { - emit("{} = {};", stmt->ptr->raw_name(), stmt->data->raw_name()); + emit("{} = {};", stmt->dest->raw_name(), stmt->val->raw_name()); } void visit(ExternalFuncCallStmt *stmt) override {