diff --git a/taichi/codegen/spirv/kernel_utils.cpp b/taichi/codegen/spirv/kernel_utils.cpp index d485e3269c2cb..23db156e22231 100644 --- a/taichi/codegen/spirv/kernel_utils.cpp +++ b/taichi/codegen/spirv/kernel_utils.cpp @@ -68,13 +68,13 @@ KernelContextAttributes::KernelContextAttributes( auto tensor_dtype = tensor_type->get_element_type(); TI_ASSERT(tensor_dtype->is()); ra.dtype = tensor_dtype->cast()->type; - dt_bytes = data_type_size(tensor_dtype); + dt_bytes = data_type_size_gfx(tensor_dtype); ra.is_array = true; ra.stride = tensor_type->get_num_elements() * dt_bytes; } else { TI_ASSERT(kr.dt->is()); ra.dtype = kr.dt->cast()->type; - dt_bytes = data_type_size(kr.dt); + dt_bytes = data_type_size_gfx(kr.dt); ra.is_array = false; ra.stride = dt_bytes; } @@ -90,7 +90,7 @@ KernelContextAttributes::KernelContextAttributes( const size_t dt_bytes = (attribs.is_array && !is_ret) ? (has_buffer_ptr ? sizeof(uint64_t) : sizeof(uint32_t)) - : data_type_size(PrimitiveType::get(attribs.dtype)); + : data_type_size_gfx(PrimitiveType::get(attribs.dtype)); // Align bytes to the nearest multiple of dt_bytes bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes; attribs.offset_in_mem = bytes; diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 19f03bfb1846f..b6dccb69ede85 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -583,7 +583,9 @@ class TaskCodegen : public IRVisitor { } else { auto buffer_value = get_buffer_value(BufferType::Args, PrimitiveType::i32); - const auto val_type = args_struct_types_.at(arg_id); + bool is_bool = arg_type->is_primitive(PrimitiveTypeID::u1); + const auto val_type = + is_bool ? ir_->i32_type() : args_struct_types_.at(arg_id); spirv::Value buffer_val = ir_->make_value( spv::OpAccessChain, ir_->get_pointer_type(val_type, spv::StorageClassUniform), @@ -594,6 +596,10 @@ class TaskCodegen : public IRVisitor { return; } spirv::Value val = ir_->load_variable(buffer_val, val_type); + if (is_bool) { + val = ir_->make_value(spv::OpINotEqual, ir_->bool_type(), val, + ir_->int_immediate_number(ir_->i32_type(), 0)); + } ir_->register_value(stmt->raw_name(), val); } } @@ -612,14 +618,20 @@ class TaskCodegen : public IRVisitor { // Now we only support one ret auto dt = stmt->element_types()[0]; for (int i = 0; i < stmt->values.size(); i++) { + auto val_type = ir_->get_primitive_type(dt); + if (dt->is_primitive(PrimitiveTypeID::u1)) { + val_type = ir_->i32_type(); + } spirv::Value buffer_val = ir_->make_value( - spv::OpAccessChain, - ir_->get_storage_pointer_type(ir_->get_primitive_type(dt)), + spv::OpAccessChain, ir_->get_storage_pointer_type(val_type), get_buffer_value(BufferType::Rets, dt), ir_->int_immediate_number(ir_->i32_type(), 0), ir_->int_immediate_number(ir_->i32_type(), i)); buffer_val.flag = ValueKind::kVariablePtr; spirv::Value val = ir_->query_value(stmt->values[i]->raw_name()); + if (dt->is_primitive(PrimitiveTypeID::u1)) { + val = ir_->select(val, ir_->const_i32_one_, ir_->const_i32_zero_); + } ir_->store_variable(buffer_val, val); } } @@ -2141,16 +2153,19 @@ class TaskCodegen : public IRVisitor { if (ptr_val.stype.dt == PrimitiveType::u64) { ti_buffer_type = dt; + } else if (dt->is_primitive(PrimitiveTypeID::u1)) { + ti_buffer_type = PrimitiveType::i32; } auto buf_ptr = at_buffer(ptr, ti_buffer_type); auto val_bits = ir_->load_variable(buf_ptr, ir_->get_primitive_type(ti_buffer_type)); - auto ret = ti_buffer_type == dt - ? val_bits - : ir_->make_value(spv::OpBitcast, - ir_->get_primitive_type(dt), val_bits); - return ret; + if (dt->is_primitive(PrimitiveTypeID::u1)) + return ir_->cast(ir_->bool_type(), val_bits); + return ti_buffer_type == dt + ? val_bits + : ir_->make_value(spv::OpBitcast, ir_->get_primitive_type(dt), + val_bits); } void store_buffer(const Stmt *ptr, spirv::Value val) { @@ -2160,6 +2175,10 @@ class TaskCodegen : public IRVisitor { if (ptr_val.stype.dt == PrimitiveType::u64) { ti_buffer_type = val.stype.dt; + } else if (val.stype.dt->is_primitive(PrimitiveTypeID::u1)) { + ti_buffer_type = PrimitiveType::i32; + val = ir_->make_value(spv::OpSelect, ir_->i32_type(), val, + ir_->const_i32_one_, ir_->const_i32_zero_); } auto buf_ptr = at_buffer(ptr, ti_buffer_type); diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 9829cf0e8bb15..719b87115a66d 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -559,7 +559,11 @@ SType IRBuilder::get_storage_pointer_type(const SType &value_type) { return get_pointer_type(value_type, storage_class); } -SType IRBuilder::get_array_type(const SType &value_type, uint32_t num_elems) { +SType IRBuilder::get_array_type(const SType &_value_type, uint32_t num_elems) { + auto value_type = _value_type; + if (value_type.dt->is_primitive(PrimitiveTypeID::u1)) { + value_type = i32_type(); + } SType arr_type; arr_type.id = id_counter_++; arr_type.flag = TypeKind::kPtr; diff --git a/taichi/inc/constants.h b/taichi/inc/constants.h index 7071abff74809..eb28602a5022e 100644 --- a/taichi/inc/constants.h +++ b/taichi/inc/constants.h @@ -39,6 +39,13 @@ constexpr std::size_t cuda_dynamic_shared_array_threshold_bytes = 49152; // TODO: get this at runtime constexpr std::size_t default_shared_mem_size = 65536; +// Specialization for bool type. This solves the issue that return type ti.u1 +// always returns 0 in vulkan. This issue is caused by data endianness. +template +bool taichi_union_cast_with_different_sizes(G g) { + return g != 0; +} + template T taichi_union_cast_with_different_sizes(G g) { union { diff --git a/taichi/ir/type_utils.cpp b/taichi/ir/type_utils.cpp index 2ad7ceacd2151..9d7466bf841f6 100644 --- a/taichi/ir/type_utils.cpp +++ b/taichi/ir/type_utils.cpp @@ -69,6 +69,17 @@ int data_type_size(DataType t) { } } +// In GLSL, boolean types are special as they are 4 bytes in length. If we +// ignore this characteristic, issues will arise when we pass values to kernels +// or read return values from kernels. +int data_type_size_gfx(DataType t) { + if (t->is_primitive(PrimitiveTypeID::u1)) { + return 4; + } else { + return data_type_size(t); + } +} + std::string tensor_type_format_helper(const std::vector &shape, std::string format_str, int dim) { diff --git a/taichi/ir/type_utils.h b/taichi/ir/type_utils.h index bb8c687170d6c..c5177df63e455 100644 --- a/taichi/ir/type_utils.h +++ b/taichi/ir/type_utils.h @@ -12,6 +12,8 @@ TI_DLL_EXPORT std::string data_type_name(DataType t); TI_DLL_EXPORT int data_type_size(DataType t); +TI_DLL_EXPORT int data_type_size_gfx(DataType t); + TI_DLL_EXPORT std::string data_type_format(DataType dt, Arch arch = Arch::x64); inline int data_type_bits(DataType t) { diff --git a/taichi/runtime/gfx/runtime.cpp b/taichi/runtime/gfx/runtime.cpp index be9b89eeb9e15..3c8aac2e385c0 100644 --- a/taichi/runtime/gfx/runtime.cpp +++ b/taichi/runtime/gfx/runtime.cpp @@ -169,7 +169,7 @@ class HostDeviceContextBlitter { const auto &ret = ctx_attribs_->rets()[i]; void *device_ptr = (uint8_t *)device_base + ret.offset_in_mem; const auto dt = PrimitiveType::get(ret.dtype); - const auto num = ret.stride / data_type_size(dt); + const auto num = ret.stride / data_type_size_gfx(dt); for (int j = 0; j < num; ++j) { // (penguinliong) Again, it's the module loader's responsibility to // check the data type availability. @@ -820,7 +820,7 @@ GfxRuntime::get_struct_type_with_data_layout_impl( member_align = member_align_; member_size = size; } else if (auto tensor_type = member.type->cast()) { - size_t element_size = data_type_size(tensor_type->get_element_type()); + size_t element_size = data_type_size_gfx(tensor_type->get_element_type()); size_t num_elements = tensor_type->get_num_elements(); if (num_elements == 2) { member_align = element_size * 2; @@ -843,7 +843,7 @@ GfxRuntime::get_struct_type_with_data_layout_impl( } } else { TI_ASSERT(member.type->is()); - member_size = data_type_size(member.type); + member_size = data_type_size_gfx(member.type); member_align = member_size; } bytes = align_up(bytes, member_align);