Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spirv] [ir] Support type u1 as arg, in buffer and as return value #8018

Merged
merged 3 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions taichi/codegen/spirv/kernel_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ KernelContextAttributes::KernelContextAttributes(
auto tensor_dtype = tensor_type->get_element_type();
TI_ASSERT(tensor_dtype->is<PrimitiveType>());
ra.dtype = tensor_dtype->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(tensor_dtype);
dt_bytes = data_type_size_gfx(tensor_dtype);
ra.is_array = true;
ra.stride = tensor_type->get_num_elements() * dt_bytes;
} else {
TI_ASSERT(kr.dt->is<PrimitiveType>());
ra.dtype = kr.dt->cast<PrimitiveType>()->type;
dt_bytes = data_type_size(kr.dt);
dt_bytes = data_type_size_gfx(kr.dt);
ra.is_array = false;
ra.stride = dt_bytes;
}
Expand All @@ -90,7 +90,7 @@ KernelContextAttributes::KernelContextAttributes(
const size_t dt_bytes =
(attribs.is_array && !is_ret)
? (has_buffer_ptr ? sizeof(uint64_t) : sizeof(uint32_t))
: data_type_size(PrimitiveType::get(attribs.dtype));
: data_type_size_gfx(PrimitiveType::get(attribs.dtype));
// Align bytes to the nearest multiple of dt_bytes
bytes = (bytes + dt_bytes - 1) / dt_bytes * dt_bytes;
attribs.offset_in_mem = bytes;
Expand Down
35 changes: 27 additions & 8 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ class TaskCodegen : public IRVisitor {
} else {
auto buffer_value =
get_buffer_value(BufferType::Args, PrimitiveType::i32);
const auto val_type = args_struct_types_.at(arg_id);
bool is_bool = arg_type->is_primitive(PrimitiveTypeID::u1);
const auto val_type =
is_bool ? ir_->i32_type() : args_struct_types_.at(arg_id);
spirv::Value buffer_val = ir_->make_value(
spv::OpAccessChain,
ir_->get_pointer_type(val_type, spv::StorageClassUniform),
Expand All @@ -594,6 +596,10 @@ class TaskCodegen : public IRVisitor {
return;
}
spirv::Value val = ir_->load_variable(buffer_val, val_type);
if (is_bool) {
val = ir_->make_value(spv::OpINotEqual, ir_->bool_type(), val,
ir_->int_immediate_number(ir_->i32_type(), 0));
}
ir_->register_value(stmt->raw_name(), val);
}
}
Expand All @@ -612,14 +618,20 @@ class TaskCodegen : public IRVisitor {
// Now we only support one ret
auto dt = stmt->element_types()[0];
for (int i = 0; i < stmt->values.size(); i++) {
auto val_type = ir_->get_primitive_type(dt);
if (dt->is_primitive(PrimitiveTypeID::u1)) {
val_type = ir_->i32_type();
}
spirv::Value buffer_val = ir_->make_value(
spv::OpAccessChain,
ir_->get_storage_pointer_type(ir_->get_primitive_type(dt)),
spv::OpAccessChain, ir_->get_storage_pointer_type(val_type),
get_buffer_value(BufferType::Rets, dt),
ir_->int_immediate_number(ir_->i32_type(), 0),
ir_->int_immediate_number(ir_->i32_type(), i));
buffer_val.flag = ValueKind::kVariablePtr;
spirv::Value val = ir_->query_value(stmt->values[i]->raw_name());
if (dt->is_primitive(PrimitiveTypeID::u1)) {
val = ir_->select(val, ir_->const_i32_one_, ir_->const_i32_zero_);
}
ir_->store_variable(buffer_val, val);
}
}
Expand Down Expand Up @@ -2141,16 +2153,19 @@ class TaskCodegen : public IRVisitor {

if (ptr_val.stype.dt == PrimitiveType::u64) {
ti_buffer_type = dt;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
ti_buffer_type = PrimitiveType::i32;
}

auto buf_ptr = at_buffer(ptr, ti_buffer_type);
auto val_bits =
ir_->load_variable(buf_ptr, ir_->get_primitive_type(ti_buffer_type));
auto ret = ti_buffer_type == dt
? val_bits
: ir_->make_value(spv::OpBitcast,
ir_->get_primitive_type(dt), val_bits);
return ret;
if (dt->is_primitive(PrimitiveTypeID::u1))
return ir_->cast(ir_->bool_type(), val_bits);
return ti_buffer_type == dt
? val_bits
: ir_->make_value(spv::OpBitcast, ir_->get_primitive_type(dt),
val_bits);
}

void store_buffer(const Stmt *ptr, spirv::Value val) {
Expand All @@ -2160,6 +2175,10 @@ class TaskCodegen : public IRVisitor {

if (ptr_val.stype.dt == PrimitiveType::u64) {
ti_buffer_type = val.stype.dt;
} else if (val.stype.dt->is_primitive(PrimitiveTypeID::u1)) {
ti_buffer_type = PrimitiveType::i32;
val = ir_->make_value(spv::OpSelect, ir_->i32_type(), val,
ir_->const_i32_one_, ir_->const_i32_zero_);
}

auto buf_ptr = at_buffer(ptr, ti_buffer_type);
Expand Down
6 changes: 5 additions & 1 deletion taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,11 @@ SType IRBuilder::get_storage_pointer_type(const SType &value_type) {
return get_pointer_type(value_type, storage_class);
}

SType IRBuilder::get_array_type(const SType &value_type, uint32_t num_elems) {
SType IRBuilder::get_array_type(const SType &_value_type, uint32_t num_elems) {
auto value_type = _value_type;
if (value_type.dt->is_primitive(PrimitiveTypeID::u1)) {
value_type = i32_type();
}
SType arr_type;
arr_type.id = id_counter_++;
arr_type.flag = TypeKind::kPtr;
Expand Down
7 changes: 7 additions & 0 deletions taichi/inc/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ constexpr std::size_t cuda_dynamic_shared_array_threshold_bytes = 49152;
// TODO: get this at runtime
constexpr std::size_t default_shared_mem_size = 65536;

// Specialization for bool type. This solves the issue that return type ti.u1
// always returns 0 in vulkan. This issue is caused by data endianness.
template <bool, typename G>
bool taichi_union_cast_with_different_sizes(G g) {
return g != 0;
}

template <typename T, typename G>
T taichi_union_cast_with_different_sizes(G g) {
union {
Expand Down
11 changes: 11 additions & 0 deletions taichi/ir/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ int data_type_size(DataType t) {
}
}

// In GLSL, boolean types are special as they are 4 bytes in length. If we
// ignore this characteristic, issues will arise when we pass values to kernels
// or read return values from kernels.
int data_type_size_gfx(DataType t) {
if (t->is_primitive(PrimitiveTypeID::u1)) {
return 4;
} else {
return data_type_size(t);
}
}

std::string tensor_type_format_helper(const std::vector<int> &shape,
std::string format_str,
int dim) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ TI_DLL_EXPORT std::string data_type_name(DataType t);

TI_DLL_EXPORT int data_type_size(DataType t);

TI_DLL_EXPORT int data_type_size_gfx(DataType t);

TI_DLL_EXPORT std::string data_type_format(DataType dt, Arch arch = Arch::x64);

inline int data_type_bits(DataType t) {
Expand Down
6 changes: 3 additions & 3 deletions taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ class HostDeviceContextBlitter {
const auto &ret = ctx_attribs_->rets()[i];
void *device_ptr = (uint8_t *)device_base + ret.offset_in_mem;
const auto dt = PrimitiveType::get(ret.dtype);
const auto num = ret.stride / data_type_size(dt);
const auto num = ret.stride / data_type_size_gfx(dt);
for (int j = 0; j < num; ++j) {
// (penguinliong) Again, it's the module loader's responsibility to
// check the data type availability.
Expand Down Expand Up @@ -820,7 +820,7 @@ GfxRuntime::get_struct_type_with_data_layout_impl(
member_align = member_align_;
member_size = size;
} else if (auto tensor_type = member.type->cast<lang::TensorType>()) {
size_t element_size = data_type_size(tensor_type->get_element_type());
size_t element_size = data_type_size_gfx(tensor_type->get_element_type());
size_t num_elements = tensor_type->get_num_elements();
if (num_elements == 2) {
member_align = element_size * 2;
Expand All @@ -843,7 +843,7 @@ GfxRuntime::get_struct_type_with_data_layout_impl(
}
} else {
TI_ASSERT(member.type->is<PrimitiveType>());
member_size = data_type_size(member.type);
member_size = data_type_size_gfx(member.type);
member_align = member_size;
}
bytes = align_up(bytes, member_align);
Expand Down