From a54ac8a48f2c9636b385b9a325541a880dd924cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A7=8B=E4=BA=91=E6=9C=AA=E4=BA=91?= Date: Mon, 5 Jun 2023 16:00:10 +0800 Subject: [PATCH] [bug] Fix struct field error on bool on cuda (#8134) Issue: #8086 ### Brief Summary This pull request adds support for bool types in `TaskCodeGenCUDA::create_intrinsic_load` which solves bit width issue on CUDA runtime. The problem was that `nvvm_ldg_global_i` does not support 1 bit integer. So we cast pointers and values to `i8` to solve this issue. ### Walkthrough + Added value and pointer casting in `TaskCodeGenCUDA::create_intrinsic_load` + Added test case --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/codegen/cuda/codegen_cuda.cpp | 11 ++++++++++ tests/python/test_struct.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 5989713e67ad5..734b941d5fe8a 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -599,6 +599,17 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { // Issue an "__ldg" instruction to cache data in the read-only data cache. auto intrin = ty->isFloatingPointTy() ? llvm::Intrinsic::nvvm_ldg_global_f : llvm::Intrinsic::nvvm_ldg_global_i; + // Special treatment for bool types. As nvvm_ldg_global_i does not support + // 1-bit integer, so we convert them to i8. + if (ty->getScalarSizeInBits() == 1) { + auto *new_ty = tlctx->get_data_type(); + auto *new_ptr = + builder->CreatePointerCast(ptr, llvm::PointerType::get(new_ty, 0)); + auto *v = builder->CreateIntrinsic( + intrin, {new_ty, llvm::PointerType::get(new_ty, 0)}, + {new_ptr, tlctx->get_constant(new_ty->getScalarSizeInBits())}); + return builder->CreateIsNotNull(v); + } return builder->CreateIntrinsic( intrin, {ty, llvm::PointerType::get(ty, 0)}, {ptr, tlctx->get_constant(ty->getScalarSizeInBits())}); diff --git a/tests/python/test_struct.py b/tests/python/test_struct.py index a5584fc0c49e5..979e10d25a4e4 100644 --- a/tests/python/test_struct.py +++ b/tests/python/test_struct.py @@ -154,3 +154,35 @@ def k() -> int: return x.testme() assert k() == 42 + + +@test_utils.test(arch=[ti.cpu, ti.cuda, ti.amdgpu]) +def test_struct_field_with_bool(): + @ti.dataclass + class S: + a: ti.i16 + b: bool + c: ti.i16 + + sf = S.field(shape=(10, 1)) + sf[0, 0].b = False + sf[0, 0].a = 0xFFFF + sf[0, 0].c = 0xFFFF + + def foo() -> S: + return sf[0, 0] + + assert foo().a == -1 + assert foo().c == -1 + assert foo().b == False + + sf[1, 0].a = 0x0000 + sf[1, 0].c = 0x0000 + sf[1, 0].b = True + + def bar() -> S: + return sf[1, 0] + + assert bar().a == 0 + assert bar().c == 0 + assert bar().b == True