Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [cuda] Support bit-level pointer on cuda backend #2065

Merged
merged 14 commits into from
Nov 30, 2020
48 changes: 38 additions & 10 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
fmt::format("cuda_rand_{}", data_type_short_name(stmt->ret_type)),
{get_context()});
}

void visit(RangeForStmt *for_stmt) override {
create_naive_range_for(for_stmt);
}
Expand Down Expand Up @@ -385,6 +386,21 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
return true; // on CUDA, pass the argument by value
}

llvm::Value *create_intrinsic_load(const DataType &dtype,
llvm::Value *data_ptr) {
auto llvm_dtype = llvm_type(dtype);
auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0);
llvm::Intrinsic::ID intrin;
if (is_real(dtype)) {
intrin = llvm::Intrinsic::nvvm_ldg_global_f;
} else {
intrin = llvm::Intrinsic::nvvm_ldg_global_i;
}
return builder->CreateIntrinsic(
intrin, {llvm_dtype, llvm_dtype_ptr},
{data_ptr, tlctx->get_constant(data_type_size(dtype))});
}
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved

void visit(GlobalLoadStmt *stmt) override {
if (auto get_ch = stmt->ptr->cast<GetChStmt>(); get_ch) {
bool should_cache_as_read_only = false;
Expand All @@ -398,18 +414,30 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
// Issue an CUDA "__ldg" instruction so that data are cached in
// the CUDA read-only data cache.
auto dtype = stmt->ret_type;
auto llvm_dtype = llvm_type(dtype);
auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0);
llvm::Intrinsic::ID intrin;
if (is_real(dtype)) {
intrin = llvm::Intrinsic::nvvm_ldg_global_f;
if (auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
llvm::Value *data_ptr = nullptr;
llvm::Value *bit_offset = nullptr;
if (auto cit = val_type->cast<CustomIntType>()) {
dtype = cit->get_physical_type();
} else if (auto cft = val_type->cast<CustomFloatType>()) {
dtype = cft->get_compute_type()
->as<CustomIntType>()
->get_physical_type();
} else {
TI_NOT_IMPLEMENTED;
}
read_bit_pointer(llvm_val[stmt->ptr], data_ptr, bit_offset);
data_ptr = builder->CreateBitCast(data_ptr, llvm_ptr_type(dtype));
auto data = create_intrinsic_load(dtype, data_ptr);
llvm_val[stmt] = extract_custom_int(data, bit_offset, val_type);
if (val_type->is<CustomFloatType>()) {
llvm_val[stmt] = reconstruct_custom_float(llvm_val[stmt], val_type);
}
} else {
intrin = llvm::Intrinsic::nvvm_ldg_global_i;
llvm_val[stmt] = create_intrinsic_load(dtype, llvm_val[stmt->ptr]);
}

llvm_val[stmt] = builder->CreateIntrinsic(
intrin, {llvm_dtype, llvm_dtype_ptr},
{llvm_val[stmt->ptr], tlctx->get_constant(data_type_size(dtype))});
} else {
CodeGenLLVM::visit(stmt);
}
Expand Down
55 changes: 35 additions & 20 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,8 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {
}
llvm::Value *byte_ptr = nullptr, *bit_offset = nullptr;
read_bit_pointer(llvm_val[stmt->ptr], byte_ptr, bit_offset);
// TODO(type): CUDA only supports atomicCAS on 32- and 64-bit integers
// try to support CustomInt/Float Type with 16-bits or 8-bits physical type
auto func_name = fmt::format("set_partial_bits_b{}",
data_type_bits(cit->get_physical_type()));
builder->CreateCall(
Expand All @@ -1157,25 +1159,33 @@ void CodeGenLLVM::visit(GlobalStoreStmt *stmt) {

llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) {
auto *cit = load_type->as<CustomIntType>();
// 1. load bit pointer
// load bit pointer
llvm::Value *byte_ptr, *bit_offset;
read_bit_pointer(llvm_val[ptr], byte_ptr, bit_offset);

auto bit_level_container = builder->CreateLoad(builder->CreateBitCast(
byte_ptr, llvm_ptr_type(cit->get_physical_type())));
// 2. bit shifting

return extract_custom_int(bit_level_container, bit_offset, load_type);
}

llvm::Value *CodeGenLLVM::extract_custom_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
Type *load_type) {
// bit shifting
// first left shift `physical_type - (offset + num_bits)`
// then right shift `physical_type - num_bits`
auto cit = load_type->as<CustomIntType>();
auto bit_end =
builder->CreateAdd(bit_offset, tlctx->get_constant(cit->get_num_bits()));
auto left = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())), bit_end);
auto right = builder->CreateSub(
tlctx->get_constant(data_type_bits(cit->get_physical_type())),
tlctx->get_constant(cit->get_num_bits()));
left = builder->CreateIntCast(left, bit_level_container->getType(), false);
right = builder->CreateIntCast(right, bit_level_container->getType(), false);
auto step1 = builder->CreateShl(bit_level_container, left);
left = builder->CreateIntCast(left, physical_value->getType(), false);
right = builder->CreateIntCast(right, physical_value->getType(), false);
auto step1 = builder->CreateShl(physical_value, left);
llvm::Value *step2 = nullptr;
Comment on lines +1172 to 1189
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is separated from load_as_custom_int with slight changes.


if (cit->get_is_signed())
Expand All @@ -1187,29 +1197,34 @@ llvm::Value *CodeGenLLVM::load_as_custom_int(Stmt *ptr, Type *load_type) {
cit->get_is_signed());
}

llvm::Value *CodeGenLLVM::reconstruct_custom_float(llvm::Value *digits,
Type *load_type) {
// Compute float(digits) * scale
auto cft = load_type->as<CustomFloatType>();
llvm::Value *cast = nullptr;
auto compute_type = cft->get_compute_type()->as<PrimitiveType>();
if (cft->get_digits_type()->cast<CustomIntType>()->get_is_signed()) {
cast = builder->CreateSIToFP(digits, llvm_type(compute_type));
} else {
cast = builder->CreateUIToFP(digits, llvm_type(compute_type));
}
llvm::Value *s =
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale()));
s = builder->CreateFPCast(s, llvm_type(compute_type));
return builder->CreateFMul(cast, s);
}

void CodeGenLLVM::visit(GlobalLoadStmt *stmt) {
int width = stmt->width();
TI_ASSERT(width == 1);
if (auto ptr_type = stmt->ptr->ret_type->cast<PointerType>();
ptr_type->is_bit_pointer()) {
auto ptr_type = stmt->ptr->ret_type->as<PointerType>();
if (ptr_type->is_bit_pointer()) {
auto val_type = ptr_type->get_pointee_type();
if (val_type->is<CustomIntType>()) {
llvm_val[stmt] = load_as_custom_int(stmt->ptr, val_type);
} else if (auto cft = val_type->cast<CustomFloatType>()) {
auto digits = load_as_custom_int(stmt->ptr, cft->get_digits_type());
// Compute float(digits) * scale
llvm::Value *cast = nullptr;
auto compute_type = cft->get_compute_type()->as<PrimitiveType>();
if (cft->get_digits_type()->cast<CustomIntType>()->get_is_signed()) {
cast = builder->CreateSIToFP(digits, llvm_type(compute_type));
} else {
cast = builder->CreateUIToFP(digits, llvm_type(compute_type));
}
llvm::Value *s =
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(cft->get_scale()));
s = builder->CreateFPCast(s, llvm_type(compute_type));
auto scaled = builder->CreateFMul(cast, s);
llvm_val[stmt] = scaled;
llvm_val[stmt] = reconstruct_custom_float(digits, val_type);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *load_as_custom_int(Stmt *ptr, Type *load_type);

llvm::Value *extract_custom_int(llvm::Value *physical_value,
llvm::Value *bit_offset,
Type *load_type);

llvm::Value *reconstruct_custom_float(llvm::Value *digits, Type *load_type);

void visit(GlobalLoadStmt *stmt) override;

void visit(ElementShuffleStmt *stmt) override;
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_bit_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False)
@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_1D_bit_array():
ci1 = ti.type_factory_.get_custom_int_type(1, False)

Expand All @@ -28,7 +28,7 @@ def verify_val():
verify_val()


@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False)
@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_2D_bit_array():
ci1 = ti.type_factory_.get_custom_int_type(1, False)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_bit_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np


@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False)
@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_simple_array():
ci13 = ti.type_factory_.get_custom_int_type(13, True)
cu19 = ti.type_factory_.get_custom_int_type(19, False)
Expand Down Expand Up @@ -36,7 +36,7 @@ def verify_val():
verify_val.__wrapped__()


@ti.test(arch=ti.cpu, debug=True, cfg_optimization=False)
@ti.test(require=ti.extension.quant, debug=True, cfg_optimization=False)
def test_custom_int_load_and_store():
ci13 = ti.type_factory_.get_custom_int_type(13, True)
cu14 = ti.type_factory_.get_custom_int_type(14, False)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_custom_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pytest import approx


@ti.test(arch=ti.cpu, cfg_optimization=False)
@ti.test(require=ti.extension.quant, cfg_optimization=False)
def test_custom_float():
ci13 = ti.type_factory_.get_custom_int_type(13, True)
cft = ti.type_factory_.get_custom_float_type(ci13, ti.f32.get_ptr(), 0.1)
Expand Down