Skip to content

Commit

Permalink
[llvm] MatrixField refactor 7/n: Simplify codegen for TensorType allo…
Browse files Browse the repository at this point in the history
…cation and access (#6169)

Issue: #5959

### Brief Summary

This PR enforces that the return value of
`AllocaStmt(TensorType)/GlobalTemporaryStmt(TensorType)` is a pointer to
a `VectorType/ArrayType` (we no longer use a primitive type ptr for
`TensorType`) so that the codegen can be simplified a lot.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Sep 28, 2022
1 parent d1a33b3 commit 9aaff1e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 143 deletions.
129 changes: 12 additions & 117 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,46 +124,17 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto type = kernel->program->this_thread_config().real_matrix
? tlctx->get_data_type(tensor_type)
: tlctx->get_data_type(tensor_type->get_element_type());
// Return type is vector<tensor_type>* if use real matrix.
// otherwise the return type is [type * array_size]*
auto type = tlctx->get_data_type(tensor_type);
if (stmt->is_shared) {
auto array_type =
llvm::ArrayType::get(type, tensor_type->get_num_elements());
auto base = new llvm::GlobalVariable(
*module, array_type, false, llvm::GlobalValue::ExternalLinkage,
nullptr, fmt::format("shared_array_{}", stmt->id), nullptr,
*module, type, false, llvm::GlobalValue::ExternalLinkage, nullptr,
fmt::format("shared_array_{}", stmt->id), nullptr,
llvm::GlobalVariable::NotThreadLocal, 3 /*addrspace=shared*/);
base->setAlignment(llvm::MaybeAlign(8));
// FIXME: create GEP manually instead of using builder->CreateGEP for
// opaque ptr in llvm 15.
// If using builder->CreateGEP, it will just return base because all zero
// idx.
// When opaque ptr is enabled, the CreatePointerCast will only create
// address space case instead of bitcast and address space cast. The type
// which was kept in bitcast will be lost.
// The manually created GEP is usded to keep the type.
// Later when lower PtrOffsetStmt, the type should be element type instead
// of array_type.
// Once llvm type is converted from taichi ir directly when lower
// PtrOffsetStmt, we can switch back to builder->CreateGEP.
auto *gep = llvm::GetElementPtrInst::CreateInBounds(
#ifdef TI_LLVM_15
array_type,
#endif
base, {tlctx->get_constant(0), tlctx->get_constant(0)});
builder->Insert(gep);
auto ptr_type = llvm::PointerType::get(type, 0);
llvm_val[stmt] = builder->CreatePointerCast(gep, ptr_type);
llvm_val[stmt] = builder->CreatePointerCast(base, ptr_type);
} else {
if (kernel->program->this_thread_config().real_matrix)
llvm_val[stmt] =
create_entry_block_alloca(type, stmt->ret_type.is_pointer());
else
llvm_val[stmt] = create_entry_block_alloca(
type, 0, tlctx->get_constant(tensor_type->get_num_elements()));
llvm_val[stmt] = create_entry_block_alloca(type);
}
} else {
llvm_val[stmt] =
Expand Down Expand Up @@ -1907,77 +1878,10 @@ void TaskCodeGenLLVM::visit(GetChStmt *stmt) {

void TaskCodeGenLLVM::visit(PtrOffsetStmt *stmt) {
if (stmt->offset_used_as_index()) {
#ifdef TI_LLVM_15
// FIXME: get ptr_ty from taichi instead of llvm.
llvm::Type *ptr_ty = nullptr;
auto *val = llvm_val[stmt->origin];
auto *lhs = val;
// For SharedArray which is in address space 3.
if (auto *addr_cast = llvm::dyn_cast<llvm::AddrSpaceCastOperator>(val))
val = addr_cast->getOperand(0);
if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val))
if (stmt->origin->ret_type.ptr_removed()->is<TensorType>()) {
ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()
->cast<TensorType>()
->get_element_type());
lhs =
builder->CreatePointerCast(lhs, llvm::PointerType::get(ptr_ty, 0));
} else {
ptr_ty = alloc->getAllocatedType();
}
else if (auto *gv = llvm::dyn_cast<llvm::GlobalVariable>(val))
ptr_ty = gv->getValueType();
else if (stmt->origin->is<ExternalPtrStmt>()) {
ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed());
} else if (auto *gep = llvm::dyn_cast<llvm::GEPOperator>(val))
ptr_ty = gep->getResultElementType();
else if (stmt->origin->is<GlobalTemporaryStmt>()) {
if (stmt->origin->ret_type.ptr_removed()->is<TensorType>()) {
ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed()
->cast<TensorType>()
->get_element_type());
lhs =
builder->CreatePointerCast(lhs, llvm::PointerType::get(ptr_ty, 0));
} else {
ptr_ty = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed());
}
}
TI_ASSERT(ptr_ty);

if (stmt->tensor_type_represented_as_primitive_type_ptr()) {
llvm_val[stmt] = builder->CreateGEP(ptr_ty, lhs, llvm_val[stmt->offset]);
} else {
llvm_val[stmt] =
builder->CreateGEP(ptr_ty, llvm_val[stmt->origin],
{tlctx->get_constant(0), llvm_val[stmt->offset]});
}
#else
/*
You might have wondered why there's no leading "ConstIntType(0)" in the
indices, as you always see "indices = { ConstIntType(0),
llvm_val[offsets]...} for most of the GEPs.
This is because we used AllocaInst with "ArraySize" argument, which will
return a pointer to the "PrimitiveType" instead of a pointer to an array
of PrimitiveTypes.
https://llvm.org/doxygen/classllvm_1_1AllocaInst.html#ac68a7586b8be7de3c39531d9eca902e6
*/
if (stmt->tensor_type_represented_as_primitive_type_ptr()) {
auto element_type = stmt->origin->ret_type.ptr_removed()
->as<TensorType>()
->get_element_type();
auto element_ptr =
llvm::PointerType::get(tlctx->get_data_type(element_type), 0);
auto val =
builder->CreatePointerCast(llvm_val[stmt->origin], element_ptr);
llvm_val[stmt] = builder->CreateGEP(val, llvm_val[stmt->offset]);
} else {
llvm_val[stmt] =
builder->CreateGEP(llvm_val[stmt->origin],
{tlctx->get_constant(0), llvm_val[stmt->offset]});
}
#endif
auto type = tlctx->get_data_type(stmt->origin->ret_type.ptr_removed());
llvm_val[stmt] =
builder->CreateGEP(type, llvm_val[stmt->origin],
{tlctx->get_constant(0), llvm_val[stmt->offset]});
} else {
// Access PtrOffset via: base_ptr + offset
auto origin_address = builder->CreatePtrToInt(
Expand Down Expand Up @@ -2548,18 +2452,9 @@ void TaskCodeGenLLVM::visit(GlobalTemporaryStmt *stmt) {
auto buffer = call("get_temporary_pointer", runtime,
tlctx->get_constant((int64)stmt->offset));

if (stmt->ret_type->is<TensorType>() &&
!prog->this_thread_config().real_matrix) {
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(
stmt->ret_type->cast<TensorType>()->get_element_type()),
0);
llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
} else {
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
}
auto ptr_type = llvm::PointerType::get(
tlctx->get_data_type(stmt->ret_type.ptr_removed()), 0);
llvm_val[stmt] = builder->CreatePointerCast(buffer, ptr_type);
}

void TaskCodeGenLLVM::visit(ThreadLocalPtrStmt *stmt) {
Expand Down
22 changes: 0 additions & 22 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,28 +395,6 @@ class PtrOffsetStmt : public Stmt {

PtrOffsetStmt(Stmt *, Stmt *);

/* TODO(zhanlue/yi) Stop using llvm::AllocaInst with "ArraySize" argument so
that Alloca can return ArrayType.
Currently, AllocaStmt and GlobalTemporaryStmt uses llvm::AllocaInst with
"ArraySize" argument, which returns a pointer to the first element of the
array, instead of the array itself.
We would like to refactor this behaviour because:
1. It drops the array type information.
2. Causes crash on AMDGPU backend in certain circumstances.
https://llvm.org/doxygen/classllvm_1_1AllocaInst.html#ac68a7586b8be7de3c39531d9eca902e6
*/
bool tensor_type_represented_as_primitive_type_ptr() const {
if (origin->ret_type.ptr_removed()->is<TensorType>()) {
if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>()) {
return true;
}
}
return false;
}

/* TODO(zhanlue/yi): Unify semantics of offset in PrtOffsetStmt
There is a hack in PtrOffsetStmt in terms of the semantics of "offset",
Expand Down
12 changes: 8 additions & 4 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,16 @@ llvm::Type *TaichiLLVMContext::get_data_type(DataType dt) {
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*ctx);
} else if (dt->is<TensorType>()) {
TI_ASSERT_INFO(config_->real_matrix || config_->dynamic_index,
"Real matrix not enabled but got TensorType");
auto tensor_type = dt->cast<TensorType>();
auto element_type = get_data_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
auto num_elements = tensor_type->get_num_elements();
// Return type is <element_type * num_elements> if real matrix is used,
// otherwise [element_type * num_elements].
if (config_->real_matrix) {
return llvm::VectorType::get(element_type, num_elements,
/*scalable=*/false);
}
return llvm::ArrayType::get(element_type, num_elements);
} else {
TI_INFO(data_type_name(dt));
TI_NOT_IMPLEMENTED;
Expand Down

0 comments on commit 9aaff1e

Please sign in to comment.