Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] [refactor] Unify field names in load/store/atomic statements #2250

Merged
merged 8 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
// If load_stmt loads some variables or a stack, return the pointers of them.
if (auto local_load = load_stmt->cast<LocalLoadStmt>()) {
std::vector<Stmt *> result;
for (auto &address : local_load->ptr.data) {
for (auto &address : local_load->src.data) {
if (std::find(result.begin(), result.end(), address.var) == result.end())
result.push_back(address.var);
}
return result;
} else if (auto global_load = load_stmt->cast<GlobalLoadStmt>()) {
return std::vector<Stmt *>(1, global_load->ptr);
return std::vector<Stmt *>(1, global_load->src);
} else if (auto atomic = load_stmt->cast<AtomicOpStmt>()) {
return std::vector<Stmt *>(1, atomic->dest);
} else if (auto stack_load_top = load_stmt->cast<StackLoadTopStmt>()) {
Expand Down Expand Up @@ -46,9 +46,9 @@ Stmt *get_store_data(Stmt *store_stmt) {
// stores.
return store_stmt;
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
return local_store->data;
return local_store->val;
} else if (auto global_store = store_stmt->cast<GlobalStoreStmt>()) {
return global_store->data;
return global_store->val;
} else {
return nullptr;
}
Expand All @@ -60,9 +60,9 @@ std::vector<Stmt *> get_store_destination(Stmt *store_stmt) {
// The statement itself provides a data source (const [0]).
return std::vector<Stmt *>(1, store_stmt);
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
return std::vector<Stmt *>(1, local_store->ptr);
return std::vector<Stmt *>(1, local_store->dest);
} else if (auto global_store = store_stmt->cast<GlobalStoreStmt>()) {
return std::vector<Stmt *>(1, global_store->ptr);
return std::vector<Stmt *>(1, global_store->dest);
} else if (auto atomic = store_stmt->cast<AtomicOpStmt>()) {
return std::vector<Stmt *>(1, atomic->dest);
} else if (auto external_func = store_stmt->cast<ExternalFuncCallStmt>()) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/analysis/gather_snode_read_writes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ gather_snode_read_writes(IRNode *root) {
bool read = false, write = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
read = true;
ptr = global_load->ptr;
ptr = global_load->src;
} else if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
write = true;
ptr = global_store->ptr;
ptr = global_store->dest;
} else if (auto global_atomic = stmt->cast<AtomicOpStmt>()) {
read = true;
write = true;
ptr = global_atomic->dest;
}
if (ptr) {
if (GlobalPtrStmt *global_ptr = ptr->cast<GlobalPtrStmt>()) {
if (auto *global_ptr = ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : global_ptr->snodes.data) {
if (read)
accessed.first.emplace(snode);
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/has_store_or_atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LocalStoreSearcher : public BasicStmtVisitor {

void visit(LocalStoreStmt *stmt) override {
for (auto var : vars) {
if (stmt->ptr == var) {
if (stmt->dest == var) {
result = true;
break;
}
Expand Down
6 changes: 3 additions & 3 deletions taichi/analysis/last_store_or_atomic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LocalStoreForwarder : public BasicStmtVisitor {
}

void visit(LocalStoreStmt *stmt) override {
if (stmt->ptr == var) {
if (stmt->dest == var) {
is_valid = true;
result = stmt;
}
Expand Down Expand Up @@ -70,8 +70,8 @@ class LocalStoreForwarder : public BasicStmtVisitor {
} else {
TI_ASSERT(true_stmt->is<LocalStoreStmt>());
TI_ASSERT(false_stmt->is<LocalStoreStmt>());
if (true_stmt->as<LocalStoreStmt>()->data !=
false_stmt->as<LocalStoreStmt>()->data) {
if (true_stmt->as<LocalStoreStmt>()->val !=
false_stmt->as<LocalStoreStmt>()->val) {
// two branches finally store the variable differently
is_valid = false;
} else {
Expand Down
2 changes: 1 addition & 1 deletion taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class IRNodeComparator : public IRVisitor {
} else {
bool same_value = false;
if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (auto global_ptr = global_load->ptr->cast<GlobalPtrStmt>()) {
if (auto global_ptr = global_load->src->cast<GlobalPtrStmt>()) {
TI_ASSERT(global_ptr->width() == 1);
if (possibly_modified_states_.count(ir_bank_->get_async_state(
global_ptr->snodes[0], AsyncState::Type::value)) == 0) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class IRVerifier : public BasicStmtVisitor {
void visit(LocalLoadStmt *stmt) override {
basic_verify(stmt);
for (int i = 0; i < stmt->width(); i++) {
TI_ASSERT(stmt->ptr[i].var->is<AllocaStmt>());
TI_ASSERT(stmt->src[i].var->is<AllocaStmt>());
}
}

void visit(LocalStoreStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->ptr->is<AllocaStmt>());
TI_ASSERT(stmt->dest->is<AllocaStmt>());
}

void visit(LoopIndexStmt *stmt) override {
Expand Down
10 changes: 5 additions & 5 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,15 +417,15 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
}

void visit(GlobalLoadStmt *stmt) override {
if (auto get_ch = stmt->ptr->cast<GetChStmt>(); get_ch) {
if (auto get_ch = stmt->src->cast<GetChStmt>(); get_ch) {
bool should_cache_as_read_only = false;
if (current_offload->mem_access_opt.has_flag(
get_ch->output_snode, SNodeAccessFlag::read_only)) {
should_cache_as_read_only = true;
}
if (should_cache_as_read_only) {
auto dtype = stmt->ret_type;
if (auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
if (auto ptr_type = stmt->src->ret_type->as<PointerType>();
ptr_type->is_bit_pointer()) {
// Bit pointer case.
auto val_type = ptr_type->get_pointee_type();
Expand All @@ -436,21 +436,21 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
if (auto cit = val_type->cast<CustomIntType>()) {
int_in_mem = val_type;
dtype = cit->get_physical_type();
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->ptr]);
auto [data_ptr, bit_offset] = load_bit_pointer(llvm_val[stmt->src]);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_custom_int(data, bit_offset, int_in_mem);
} else if (auto cft = val_type->cast<CustomFloatType>()) {
// TODO: support __ldg
llvm_val[stmt] = load_custom_float(stmt->ptr);
llvm_val[stmt] = load_custom_float(stmt->src);
} else {
TI_NOT_IMPLEMENTED;
}
} else {
// Byte pointer case.
// Issue an CUDA "__ldg" instruction so that data are cached in
// the CUDA read-only data cache.
llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->ptr]);
llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->src]);
}
} else {
CodeGenLLVM::visit(stmt);
Expand Down
36 changes: 18 additions & 18 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,14 @@ class KernelCodegen : public IRVisitor {
void visit(LocalLoadStmt *stmt) override {
// TODO: optimize for partially vectorized load...
bool linear_index = true;
for (int i = 0; i < (int)stmt->ptr.size(); i++) {
if (stmt->ptr[i].offset != i) {
for (int i = 0; i < (int)stmt->src.size(); i++) {
if (stmt->src[i].offset != i) {
linear_index = false;
}
}
if (stmt->same_source() && linear_index &&
stmt->width() == stmt->ptr[0].var->width()) {
auto ptr = stmt->ptr[0].var;
stmt->width() == stmt->src[0].var->width()) {
auto ptr = stmt->src[0].var;
emit("const {} {}({});", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), ptr->raw_name());
} else {
Expand All @@ -188,7 +188,7 @@ class KernelCodegen : public IRVisitor {
}

void visit(LocalStoreStmt *stmt) override {
emit(R"({} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
emit(R"({} = {};)", stmt->dest->raw_name(), stmt->val->raw_name());
}

void visit(GetRootStmt *stmt) override {
Expand Down Expand Up @@ -335,8 +335,8 @@ class KernelCodegen : public IRVisitor {
void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);

if (!is_ret_type_bit_pointer(stmt->ptr)) {
emit(R"(*{} = {};)", stmt->ptr->raw_name(), stmt->data->raw_name());
if (!is_ret_type_bit_pointer(stmt->dest)) {
emit(R"(*{} = {};)", stmt->dest->raw_name(), stmt->val->raw_name());
return;
}
handle_bit_pointer_global_store(stmt);
Expand All @@ -345,8 +345,8 @@ class KernelCodegen : public IRVisitor {
void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
std::string rhs_expr;
if (!is_ret_type_bit_pointer(stmt->ptr)) {
rhs_expr = fmt::format("*{}", stmt->ptr->raw_name());
if (!is_ret_type_bit_pointer(stmt->src)) {
rhs_expr = fmt::format("*{}", stmt->src->raw_name());
} else {
rhs_expr = construct_bit_pointer_global_load(stmt);
}
Expand Down Expand Up @@ -832,46 +832,46 @@ class KernelCodegen : public IRVisitor {
}

void handle_bit_pointer_global_store(GlobalStoreStmt *stmt) {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
auto *ptr_type = stmt->dest->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
CustomIntType *cit = nullptr;
std::string store_value_expr;
if (auto *cit_cast = pointee_type->cast<CustomIntType>()) {
cit = cit_cast;
store_value_expr = stmt->data->raw_name();
store_value_expr = stmt->val->raw_name();
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
auto *digits_cit = cft->get_digits_type()->as<CustomIntType>();
cit = digits_cit;
store_value_expr = construct_float_to_custom_int_expr(
stmt->data, cft->get_scale(), digits_cit);
stmt->val, cft->get_scale(), digits_cit);
} else {
TI_NOT_IMPLEMENTED;
}
// Type of |stmt->ptr| is SNodeBitPointer
// Type of |stmt->dest| is SNodeBitPointer
const auto num_bits = cit->get_num_bits();
if (is_full_bits(num_bits)) {
emit("mtl_set_full_bits({}, {});", stmt->ptr->raw_name(),
emit("mtl_set_full_bits({}, {});", stmt->dest->raw_name(),
store_value_expr);
} else {
emit("mtl_set_partial_bits({},", stmt->ptr->raw_name());
emit("mtl_set_partial_bits({},", stmt->dest->raw_name());
emit(" {},", store_value_expr);
emit(" /*bits=*/{});", num_bits);
}
}

// Returns the expression of the load result
std::string construct_bit_pointer_global_load(GlobalLoadStmt *stmt) const {
auto *ptr_type = stmt->ptr->ret_type->as<PointerType>();
auto *ptr_type = stmt->src->ret_type->as<PointerType>();
TI_ASSERT(ptr_type->is_bit_pointer());
auto *pointee_type = ptr_type->get_pointee_type();
if (auto *cit = pointee_type->cast<CustomIntType>()) {
return construct_load_as_custom_int(stmt->ptr, cit);
return construct_load_as_custom_int(stmt->src, cit);
} else if (auto *cft = pointee_type->cast<CustomFloatType>()) {
validate_cft_for_metal(cft);
const auto loaded = construct_load_as_custom_int(
stmt->ptr, cft->get_digits_type()->as<CustomIntType>());
stmt->src, cft->get_digits_type()->as<CustomIntType>());
// Computes `float(digits_expr) * scale`
// See LLVM backend's reconstruct_custom_float()
return fmt::format("(static_cast<float>({}) * {})", loaded,
Expand Down
24 changes: 12 additions & 12 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,20 +405,20 @@ class KernelGen : public IRVisitor {

void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->data->element_type();
auto dt = stmt->val->element_type();
emit("_{}_{}_[{} >> {}] = {};",
ptr_signats.at(stmt->ptr->id), // throw out_of_range if not a pointer
opengl_data_type_short_name(dt), stmt->ptr->short_name(),
opengl_data_address_shifter(dt), stmt->data->short_name());
ptr_signats.at(stmt->dest->id), // throw out_of_range if not a pointer
opengl_data_type_short_name(dt), stmt->dest->short_name(),
opengl_data_address_shifter(dt), stmt->val->short_name());
}

void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->element_type();
emit("{} {} = _{}_{}_[{} >> {}];",
opengl_data_type_name(stmt->element_type()), stmt->short_name(),
ptr_signats.at(stmt->ptr->id), opengl_data_type_short_name(dt),
stmt->ptr->short_name(), opengl_data_address_shifter(dt));
ptr_signats.at(stmt->src->id), opengl_data_type_short_name(dt),
stmt->src->short_name(), opengl_data_address_shifter(dt));
}

void visit(ExternalPtrStmt *stmt) override {
Expand Down Expand Up @@ -648,23 +648,23 @@ class KernelGen : public IRVisitor {

void visit(LocalLoadStmt *stmt) override {
bool linear_index = true;
for (int i = 0; i < (int)stmt->ptr.size(); i++) {
if (stmt->ptr[i].offset != i) {
for (int i = 0; i < (int)stmt->src.size(); i++) {
if (stmt->src[i].offset != i) {
linear_index = false;
}
}
if (stmt->same_source() && linear_index &&
stmt->width() == stmt->ptr[0].var->width()) {
auto ptr = stmt->ptr[0].var;
stmt->width() == stmt->src[0].var->width()) {
auto src = stmt->src[0].var;
emit("{} {} = {};", opengl_data_type_name(stmt->element_type()),
stmt->short_name(), ptr->short_name());
stmt->short_name(), src->short_name());
} else {
TI_NOT_IMPLEMENTED;
}
}

void visit(LocalStoreStmt *stmt) override {
emit("{} = {};", stmt->ptr->short_name(), stmt->data->short_name());
emit("{} = {};", stmt->dest->short_name(), stmt->val->short_name());
}

void visit(AllocaStmt *alloca) override {
Expand Down
Loading