Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Update the codegen for the refactor of Alloca #8124

Merged
merged 1 commit into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {

void visit(AllocaStmt *stmt) override {
// Override shared memory codegen logic for large shared memory
if (stmt->ret_type->is<TensorType>() && stmt->is_shared) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto tensor_type = stmt->ret_type.ptr_removed()->cast<TensorType>();
if (tensor_type && stmt->is_shared) {
size_t shared_array_bytes =
tensor_type->get_num_elements() *
data_type_size(tensor_type->get_element_type());
Expand Down
13 changes: 6 additions & 7 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
}

void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto alloca_type = stmt->ret_type.ptr_removed();
if (alloca_type->is<TensorType>()) {
auto tensor_type = alloca_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type);
if (stmt->is_shared) {
auto base = new llvm::GlobalVariable(
Expand All @@ -141,12 +142,10 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
llvm_val[stmt] = create_entry_block_alloca(type);
}
} else {
llvm_val[stmt] =
create_entry_block_alloca(stmt->ret_type, stmt->ret_type.is_pointer());
llvm_val[stmt] = create_entry_block_alloca(alloca_type);
// initialize as zero if element is not a pointer
if (!stmt->ret_type.is_pointer())
builder->CreateStore(tlctx->get_constant(stmt->ret_type, 0),
llvm_val[stmt]);
if (!alloca_type->is<PointerType>())
builder->CreateStore(tlctx->get_constant(alloca_type, 0), llvm_val[stmt]);
}
}

Expand Down
4 changes: 1 addition & 3 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ class LLVMModuleBuilder {
return alloca;
}

llvm::Value *create_entry_block_alloca(DataType dt, bool is_pointer = false) {
llvm::Value *create_entry_block_alloca(DataType dt) {
auto type = tlctx->get_data_type(dt);
if (is_pointer)
type = llvm::PointerType::get(type, 0);
return create_entry_block_alloca(type);
}

Expand Down
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ class TaskCodegen : public IRVisitor {

void visit(AllocaStmt *alloca) override {
spirv::Value ptr_val;
if (alloca->ret_type->is<TensorType>()) {
auto tensor_type = alloca->ret_type->cast<TensorType>();
auto alloca_type = alloca->ret_type.ptr_removed();
if (auto tensor_type = alloca_type->cast<TensorType>()) {
auto elem_num = tensor_type->get_num_elements();
spirv::SType elem_type =
ir_->get_primitive_type(tensor_type->get_element_type());
Expand All @@ -288,7 +288,7 @@ class TaskCodegen : public IRVisitor {
}
} else {
// Alloca for a single variable
spirv::SType src_type = ir_->get_primitive_type(alloca->element_type());
spirv::SType src_type = ir_->get_primitive_type(alloca_type);
ptr_val = ir_->alloca_variable(src_type);
ir_->store_variable(ptr_val, ir_->get_zero(src_type));
}
Expand Down