Skip to content

Commit

Permalink
[bug] Fix struct field error on bool on cuda (#8134)
Browse files Browse the repository at this point in the history
Issue: #8086 

### Brief Summary

This pull request adds support for bool types in
`TaskCodeGenCUDA::create_intrinsic_load` which solves bit width issue on
CUDA runtime. The problem was that `nvvm_ldg_global_i` does not support
1 bit integer. So we cast pointers and values to `i8` to solve this
issue.

### Walkthrough

+ Added value and pointer casting in
`TaskCodeGenCUDA::create_intrinsic_load`
+ Added test case

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
listerily and pre-commit-ci[bot] authored Jun 5, 2023
1 parent f18af28 commit a54ac8a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
11 changes: 11 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,17 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
// Issue an "__ldg" instruction to cache data in the read-only data cache.
auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f
: llvm::Intrinsic::nvvm_ldg_global_i;
// Special treatment for bool types. As nvvm_ldg_global_i does not support
// 1-bit integer, so we convert them to i8.
if (ty->getScalarSizeInBits() == 1) {
auto *new_ty = tlctx->get_data_type<uint8>();
auto *new_ptr =
builder->CreatePointerCast(ptr, llvm::PointerType::get(new_ty, 0));
auto *v = builder->CreateIntrinsic(
intrin, {new_ty, llvm::PointerType::get(new_ty, 0)},
{new_ptr, tlctx->get_constant(new_ty->getScalarSizeInBits())});
return builder->CreateIsNotNull(v);
}
return builder->CreateIntrinsic(
intrin, {ty, llvm::PointerType::get(ty, 0)},
{ptr, tlctx->get_constant(ty->getScalarSizeInBits())});
Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,35 @@ def k() -> int:
return x.testme()

assert k() == 42


@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu])
def test_struct_field_with_bool():
@ti.dataclass
class S:
a: ti.i16
b: bool
c: ti.i16

sf = S.field(shape=(10, 1))
sf[0, 0].b = False
sf[0, 0].a = 0xFFFF
sf[0, 0].c = 0xFFFF

def foo() -> S:
return sf[0, 0]

assert foo().a == -1
assert foo().c == -1
assert foo().b == False

sf[1, 0].a = 0x0000
sf[1, 0].c = 0x0000
sf[1, 0].b = True

def bar() -> S:
return sf[1, 0]

assert bar().a == 0
assert bar().c == 0
assert bar().b == True

0 comments on commit a54ac8a

Please sign in to comment.