Skip to content

Commit

Permalink
[ir] Update the codegen for the refactor of Alloca
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
lin-hitonami committed Jun 2, 2023
1 parent 95b24bb commit 707c10c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 15 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {

void visit(AllocaStmt *stmt) override {
// Override shared memory codegen logic for large shared memory
if (stmt->ret_type->is<TensorType>() && stmt->is_shared) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>();
if (tensor_type && stmt->is_shared) {
size_t shared_array_bytes =
tensor_type->get_num_elements() *
data_type_size(tensor_type->get_element_type());
Expand Down
13 changes: 6 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
}

void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto alloca_type = stmt->ret_type.ptr_removed();
if (alloca_type->is<TensorType>()) {
auto tensor_type = alloca_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type);
if (stmt->is_shared) {
auto base = new llvm::GlobalVariable(
Expand All @@ -141,12 +142,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
llvm_val[stmt] = create_entry_block_alloca(type);
}
} else {
llvm_val[stmt] =
create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer());
llvm_val[stmt] = create_entry_block_alloca(alloca_type);
// initialize as zero if element is not a pointer
if (!stmt->ret_type.is_pointer())
builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0),
llvm_val[stmt]);
if (!alloca_type->is<PointerType>())
builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]);
}
}

Expand Down
4 changes: 1 addition & 3 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ class LLVMModuleBuilder {
return alloca;
}

llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) {
llvm::Value *create_entry_block_alloca(DataType dt) {
auto type = tlctx->get_data_type(dt);
if (is_pointer)
type = llvm::PointerType::get(type, 0);
return create_entry_block_alloca(type);
}

Expand Down
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor {

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto alloca_type = alloca->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
Expand All @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor {
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::SType src_type = ir_->get_primitive_type(alloca_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
}
Expand Down

0 comments on commit 707c10c

Please sign in to comment.