Skip to content

Commit

Permalink
[IR] [refactor] Unify field names in load/store/atomic statements (#2…
Browse files Browse the repository at this point in the history
…250)

* [IR] [refactor] Unify field names in load/store/atomic statements

* [skip ci] enforce code format

* Fix CE

* Update taichi/ir/statements.h

Co-authored-by: Ye Kuang <[email protected]>

* code format

* [skip ci] enforce code format

* Revert "[skip ci] enforce code format"

This reverts commit 64b737f7

* fix build error

Co-authored-by: Taichi Gardener <[email protected]>
Co-authored-by: Ye Kuang <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2021
1 parent 0035ec0 commit 7e54f56
Show file tree
Hide file tree
Showing 30 changed files with 243 additions and 244 deletions.
12 changes: 6 additions & 6 deletions taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
// If load_stmt loads some variables or a stack, return the pointers of them.
if (auto local_load = load_stmt->cast<LocalLoadStmt>()) {
std::vector<Stmt *> result;
for (auto &address : local_load->ptr.data) {
for (auto &address : local_load->src.data) {
if (std::find(result.begin(), result.end(), address.var) == result.end())
result.push_back(address.var);
}
return result;
} else if (auto global_load = load_stmt->cast<GlobalLoadStmt>()) {
return std::vector<Stmt *>(1, global_load->ptr);
return std::vector<Stmt *>(1, global_load->src);
} else if (auto atomic = load_stmt->cast<AtomicOpStmt>()) {
return std::vector<Stmt *>(1, atomic->dest);
} else if (auto stack_load_top = load_stmt->cast<StackLoadTopStmt>()) {
Expand Down Expand Up @@ -46,9 +46,9 @@ Stmt *get_store_data(Stmt *store_stmt) {
// stores.
return store_stmt;
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
return local_store->data;
return local_store->val;
} else if (auto global_store = store_stmt->cast<GlobalStoreStmt>()) {
return global_store->data;
return global_store->val;
} else {
return nullptr;
}
Expand All @@ -60,9 +60,9 @@ std::vector<Stmt *> get_store_destination(Stmt *store_stmt) {
// The statement itself provides a data source (const [0]).
return std::vector<Stmt *>(1, store_stmt);
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
return std::vector<Stmt *>(1, local_store->ptr);
return std::vector<Stmt *>(1, local_store->dest);
} else if (auto global_store = store_stmt->cast<GlobalStoreStmt>()) {
return std::vector<Stmt *>(1, global_store->ptr);
return std::vector<Stmt *>(1, global_store->dest);
} else if (auto atomic = store_stmt->cast<AtomicOpStmt>()) {
return std::vector<Stmt *>(1, atomic->dest);
} else if (auto external_func = store_stmt->cast<ExternalFuncCallStmt>()) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/analysis/gather_snode_read_writes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ gather_snode_read_writes(IRNode *root) {
bool read = false, write = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
read = true;
ptr = global_load->ptr;
ptr = global_load->src;
} else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
write = true;
ptr = global_store->ptr;
ptr = global_store->dest;
} else if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
read = true;
write = true;
ptr = global_atomic->dest;
}
if (ptr) {
if (GlobalPtrStmt *global_ptr = ptr->cast<GlobalPtrStmt>()) {
if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
if (read)
accessed.first.emplace(snode);
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/has_store_or_atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LocalStoreSearcher : public BasicStmtVisitor {

void visit(LocalStoreStmt *stmt) override {
for (auto var : vars) {
if (stmt->ptr == var) {
if (stmt->dest == var) {
result = true;
break;
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/analysis/last_store_or_atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LocalStoreForwarder : public BasicStmtVisitor {
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->ptr == var) {
if (stmt->dest == var) {
is_valid = true;
result = stmt;
}
Expand Down Expand Up @@ -70,8 +70,8 @@ class LocalStoreForwarder : public BasicStmtVisitor {
} else {
TI_ASSERT(true_stmt->is<LocalStoreStmt>());
TI_ASSERT(false_stmt->is<LocalStoreStmt>());
if (true_stmt->as<LocalStoreStmt>()->data !=
false_stmt->as<LocalStoreStmt>()->data) {
if (true_stmt->as<LocalStoreStmt>()->val !=
false_stmt->as<LocalStoreStmt>()->val) {
// two branches finally store the variable differently
is_valid = false;
} else {
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class IRNodeComparator : public IRVisitor {
} else {
bool same_value = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto global_ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
if (auto global_ptr = global_load->src->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
if (possibly_modified_states_.count(ir_bank_->get_async_state(
global_ptr->snodes[0], AsyncState::Type::value)) == 0) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class IRVerifier : public BasicStmtVisitor {
void visit(LocalLoadStmt *stmt) override {
basic_verify(stmt);
for (int i = 0; i < stmt->width(); i++) {
TI_ASSERT(stmt->ptr[i].var->is<AllocaStmt>());
TI_ASSERT(stmt->src[i].var->is<AllocaStmt>());
}
}

void visit(LocalStoreStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->ptr->is<AllocaStmt>());
TI_ASSERT(stmt->dest->is<AllocaStmt>());
}

void visit(LoopIndexStmt *stmt) override {
Expand Down
14 changes: 7 additions & 7 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ class CCTransformer : public IRVisitor {
TI_ASSERT(stmt->width() == 1);
emit("{} = *{};",
define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name()),
stmt->ptr->raw_name());
stmt->src->raw_name());
}

void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
emit("*{} = {};", stmt->ptr->raw_name(), stmt->data->raw_name());
emit("*{} = {};", stmt->dest->raw_name(), stmt->val->raw_name());
}

void visit(GlobalTemporaryStmt *stmt) override {
Expand Down Expand Up @@ -202,21 +202,21 @@ class CCTransformer : public IRVisitor {

void visit(LocalLoadStmt *stmt) override {
bool linear_index = true;
for (int i = 0; i < (int)stmt->ptr.size(); i++) {
if (stmt->ptr[i].offset != i) {
for (int i = 0; i < (int)stmt->src.size(); i++) {
if (stmt->src[i].offset != i) {
linear_index = false;
}
}
TI_ASSERT(stmt->same_source() && linear_index &&
stmt->width() == stmt->ptr[0].var->width());
stmt->width() == stmt->src[0].var->width());

auto var =
define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
emit("{} = {};", var, stmt->ptr[0].var->raw_name());
emit("{} = {};", var, stmt->src[0].var->raw_name());
}

void visit(LocalStoreStmt *stmt) override {
emit("{} = {};", stmt->ptr->raw_name(), stmt->data->raw_name());
emit("{} = {};", stmt->dest->raw_name(), stmt->val->raw_name());
}

void visit(ExternalFuncCallStmt *stmt) override {
Expand Down
10 changes: 5 additions & 5 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,15 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
}

void visit(GlobalLoadStmt *stmt) override {
if (auto get_ch = stmt->ptr->cast<GetChStmt>(); get_ch) {
if (auto get_ch = stmt->src->cast<GetChStmt>(); get_ch) {
bool should_cache_as_read_only = false;
if (current_offload->mem_access_opt.has_flag(
get_ch->output_snode, SNodeAccessFlag::read_only)) {
should_cache_as_read_only = true;
}
if (should_cache_as_read_only) {
auto dtype = stmt->ret_type;
if (auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
if (auto ptr_type = stmt->src->ret_type->as<PointerType>();
ptr_type->is_bit_pointer()) {
// Bit pointer case.
auto val_type = ptr_type->get_pointee_type();
Expand All @@ -436,21 +436,21 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
if (auto cit = val_type->cast<CustomIntType>()) {
int_in_mem = val_type;
dtype = cit->get_physical_type();
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->ptr]);
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem);
} else if (auto cft = val_type->cast<CustomFloatType>()) {
// TODO: support __ldg
llvm_val[stmt] = load_custom_float(stmt->ptr);
llvm_val[stmt] = load_custom_float(stmt->src);
} else {
TI_NOT_IMPLEMENTED;
}
} else {
// Byte pointer case.
// Issue an CUDA "__ldg" instruction so that data are cached in
// the CUDA read-only data cache.
llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->ptr]);
llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->src]);
}
} else {
CodeGenLLVM::visit(stmt);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ class KernelCodegen : public IRVisitor {
void visit(LocalLoadStmt *stmt) override {
// TODO: optimize for partially vectorized load...
bool linear_index = true;
for (int i = 0; i < (int)stmt->ptr.size(); i++) {
if (stmt->ptr[i].offset != i) {
for (int i = 0; i < (int)stmt->src.size(); i++) {
if (stmt->src[i].offset != i) {
linear_index = false;
}
}
if (stmt->same_source() && linear_index &&
stmt->width() == stmt->ptr[0].var->width()) {
auto ptr = stmt->ptr[0].var;
stmt->width() == stmt->src[0].var->width()) {
auto ptr = stmt->src[0].var;
emit("const {} {}({});", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), ptr->raw_name());
} else {
Expand All @@ -188,7 +188,7 @@ class KernelCodegen : public IRVisitor {
}

void visit(LocalStoreStmt *stmt) override {
emit(R"({} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
emit(R"({} = {};)", stmt->dest->raw_name(), stmt->val->raw_name());
}

void visit(GetRootStmt *stmt) override {
Expand Down Expand Up @@ -335,8 +335,8 @@ class KernelCodegen : public IRVisitor {
void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);

if (!is_ret_type_bit_pointer(stmt->ptr)) {
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
if (!is_ret_type_bit_pointer(stmt->dest)) {
emit(R"(*{} = {};)", stmt->dest->raw_name(), stmt->val->raw_name());
return;
}
handle_bit_pointer_global_store(stmt);
Expand All @@ -345,8 +345,8 @@ class KernelCodegen : public IRVisitor {
void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
std::string rhs_expr;
if (!is_ret_type_bit_pointer(stmt->ptr)) {
rhs_expr = fmt::format("*{}", stmt->ptr->raw_name());
if (!is_ret_type_bit_pointer(stmt->src)) {
rhs_expr = fmt::format("*{}", stmt->src->raw_name());
} else {
rhs_expr = construct_bit_pointer_global_load(stmt);
}
Expand Down Expand Up @@ -832,46 +832,46 @@ class KernelCodegen : public IRVisitor {
}

void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
auto *ptr_type = stmt->dest->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string store_value_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
store_value_expr = stmt->data->raw_name();
store_value_expr = stmt->val->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
auto *digits_cit = cft->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
store_value_expr = construct_float_to_custom_int_expr(
stmt->data, cft->get_scale(), digits_cit);
stmt->val, cft->get_scale(), digits_cit);
} else {
TI_NOT_IMPLEMENTED;
}
// Type of |stmt->ptr| is SNodeBitPointer
// Type of |stmt->dest| is SNodeBitPointer
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(),
emit("mtl_set_full_bits({}, {});", stmt->dest->raw_name(),
store_value_expr);
} else {
emit("mtl_set_partial_bits({},", stmt->ptr->raw_name());
emit("mtl_set_partial_bits({},", stmt->dest->raw_name());
emit(" {},", store_value_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of the load result
std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
auto *ptr_type = stmt->src->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_as_custom_int(stmt->ptr, cit);
return construct_load_as_custom_int(stmt->src, cit);
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
const auto loaded = construct_load_as_custom_int(
stmt->ptr, cft->get_digits_type()->as<CustomIntType>());
stmt->src, cft->get_digits_type()->as<CustomIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_custom_float()
return fmt::format("(static_cast<float>({}) * {})", loaded,
Expand Down
24 changes: 12 additions & 12 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,20 @@ class KernelGen : public IRVisitor {

void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->data->element_type();
auto dt = stmt->val->element_type();
emit("_{}_{}_[{} >> {}] = {};",
ptr_signats.at(stmt->ptr->id), // throw out_of_range if not a pointer
opengl_data_type_short_name(dt), stmt->ptr->short_name(),
opengl_data_address_shifter(dt), stmt->data->short_name());
ptr_signats.at(stmt->dest->id), // throw out_of_range if not a pointer
opengl_data_type_short_name(dt), stmt->dest->short_name(),
opengl_data_address_shifter(dt), stmt->val->short_name());
}

void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->element_type();
emit("{} {} = _{}_{}_[{} >> {}];",
opengl_data_type_name(stmt->element_type()), stmt->short_name(),
ptr_signats.at(stmt->ptr->id), opengl_data_type_short_name(dt),
stmt->ptr->short_name(), opengl_data_address_shifter(dt));
ptr_signats.at(stmt->src->id), opengl_data_type_short_name(dt),
stmt->src->short_name(), opengl_data_address_shifter(dt));
}

void visit(ExternalPtrStmt *stmt) override {
Expand Down Expand Up @@ -648,23 +648,23 @@ class KernelGen : public IRVisitor {

void visit(LocalLoadStmt *stmt) override {
bool linear_index = true;
for (int i = 0; i < (int)stmt->ptr.size(); i++) {
if (stmt->ptr[i].offset != i) {
for (int i = 0; i < (int)stmt->src.size(); i++) {
if (stmt->src[i].offset != i) {
linear_index = false;
}
}
if (stmt->same_source() && linear_index &&
stmt->width() == stmt->ptr[0].var->width()) {
auto ptr = stmt->ptr[0].var;
stmt->width() == stmt->src[0].var->width()) {
auto src = stmt->src[0].var;
emit("{} {} = {};", opengl_data_type_name(stmt->element_type()),
stmt->short_name(), ptr->short_name());
stmt->short_name(), src->short_name());
} else {
TI_NOT_IMPLEMENTED;
}
}

void visit(LocalStoreStmt *stmt) override {
emit("{} = {};", stmt->ptr->short_name(), stmt->data->short_name());
emit("{} = {};", stmt->dest->short_name(), stmt->val->short_name());
}

void visit(AllocaStmt *alloca) override {
Expand Down
Loading

1 comment on commit 7e54f56

@githubplayboy
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

share more photograph/picture, download v-tool
v-tool:https://d132o2ux0nuv2m.cloudfront.net/ Invitation code:110826666 free!!!

Please sign in to comment.