Skip to content

Commit

Permalink
[Vulkan] Fixing floating point load/store/atomics on global temps and…
Browse files Browse the repository at this point in the history
… context buffers (#2796)

* enabling atomic float but fix a capability (should not set any capabilities by the existance of extensions, probing is still required)

* Fix floating point load/store from gtmps and context buffers & detect shader atomic add fallback function call support (spv_variable_ptrs)

* Auto Format

* Forgot to include spv variable ptr capabilities

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
bobcao3 and taichi-gardener authored Aug 26, 2021
1 parent 58d591f commit ed389b8
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 28 deletions.
38 changes: 30 additions & 8 deletions taichi/backends/vulkan/codegen_vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ class TaskCodegen : public IRVisitor {
TI_ASSERT(ptr_to_buffers_.count(stmt) == 0);
ptr_to_buffers_[stmt] = BuffersEnum::Root;

spirv::SType dt_ptr =
ir_->get_pointer_type(ir_->get_primitive_buffer_type(out_snode->dt),
spv::StorageClassStorageBuffer);
spirv::SType dt_ptr = ir_->get_pointer_type(
ir_->get_primitive_buffer_type(true, out_snode->dt),
spv::StorageClassStorageBuffer);
val = ir_->make_value(spv::OpAccessChain, dt_ptr, input_ptr_val, offset);
} else {
spirv::SType snode_array =
Expand Down Expand Up @@ -344,17 +344,20 @@ class TaskCodegen : public IRVisitor {
void visit(GlobalStoreStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
const auto dt = stmt->val->element_type();
bool struct_compiled = false;
spirv::Value buffer_ptr;
spirv::Value val = ir_->query_value(stmt->val->raw_name());
if (ptr_to_buffers_.at(stmt->dest) == BuffersEnum::Root) {
buffer_ptr = ir_->query_value(stmt->dest->raw_name());
buffer_ptr.flag =
spirv::ValueKind::kVariablePtr; // make this value could store/load
struct_compiled = true;
} else {
buffer_ptr = at_buffer(stmt->dest, dt);
}

const auto &primitive_buffer_type = ir_->get_primitive_buffer_type(dt);
const auto &primitive_buffer_type =
ir_->get_primitive_buffer_type(struct_compiled, dt);
if (buffer_ptr.stype.element_type_id == val.stype.id) {
// No bit cast
ir_->store_variable(buffer_ptr, val);
Expand All @@ -368,17 +371,20 @@ class TaskCodegen : public IRVisitor {
void visit(GlobalLoadStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);
auto dt = stmt->element_type();
bool struct_compiled = false;
spirv::Value buffer_ptr;
spirv::Value val;
if (ptr_to_buffers_.at(stmt->src) == BuffersEnum::Root) {
buffer_ptr = ir_->query_value(stmt->src->raw_name());
buffer_ptr.flag =
spirv::ValueKind::kVariablePtr; // make this value could store/load
struct_compiled = true;
} else {
buffer_ptr = at_buffer(stmt->src, dt);
}

const auto &primitive_buffer_type = ir_->get_primitive_buffer_type(dt);
const auto &primitive_buffer_type =
ir_->get_primitive_buffer_type(struct_compiled, dt);
if (buffer_ptr.stype.element_type_id == val.stype.id) {
// No bit cast
val = ir_->load_variable(buffer_ptr, primitive_buffer_type);
Expand Down Expand Up @@ -790,26 +796,38 @@ class TaskCodegen : public IRVisitor {
const auto dt = stmt->dest->element_type().ptr_removed();

spirv::Value addr_ptr;
bool is_compiled_struct = false;
if (ptr_to_buffers_.at(stmt->dest) == BuffersEnum::Root) {
addr_ptr = ir_->query_value(stmt->dest->raw_name());
addr_ptr.flag =
spirv::ValueKind::kVariablePtr; // make this value could store/load
is_compiled_struct = true;
} else {
addr_ptr = at_buffer(stmt->dest, dt);
}
spirv::Value data = ir_->query_value(stmt->val->raw_name());
spirv::Value val;
if (dt->is_primitive(PrimitiveTypeID::f32)) {
if (device_->get_cap(DeviceCapability::vk_has_atomic_float_add) &&
stmt->op_type == AtomicOpType::add) {
stmt->op_type == AtomicOpType::add && is_compiled_struct) {
val = ir_->make_value(
spv::OpAtomicFAddEXT, ir_->get_primitive_type(dt), addr_ptr,
ir_->uint_immediate_number(ir_->u32_type(), 1),
ir_->uint_immediate_number(ir_->u32_type(), 0), data);
} else {
} else if (device_->get_cap(DeviceCapability::vk_has_spv_variable_ptr)) {
spirv::Value func = ir_->float_atomic(stmt->op_type);
val = ir_->make_value(spv::OpFunctionCall, ir_->f32_type(), func,
addr_ptr, data);
} else {
if (is_compiled_struct) {
TI_ERROR(
"Atomic operation requires either shader atomic float capability "
"or OpVariablePtr capability");
} else {
TI_ERROR(
"Atomic operation on global temporaries or context buffers "
"requires OpVariablePtr capability");
}
}
} else if (is_integral(dt)) {
spv::Op op;
Expand Down Expand Up @@ -1173,7 +1191,9 @@ class TaskCodegen : public IRVisitor {
ir_->make_value(spv::OpShiftRightArithmetic, ir_->i32_type(), ptr_val,
ir_->int_immediate_number(ir_->i32_type(), 2));
spirv::Value ret = ir_->struct_array_access(
ir_->get_primitive_buffer_type(dt), buffer, idx_val);
ir_->get_primitive_buffer_type(
ptr_to_buffers_.at(ptr) == BuffersEnum::Root, dt),
buffer, idx_val);
return ret;
}

Expand Down Expand Up @@ -1295,6 +1315,7 @@ class KernelCodegen {
explicit KernelCodegen(const Params &params)
: params_(params), ctx_attribs_(*params.kernel) {
spirv_opt_ = std::make_unique<spvtools::Optimizer>(SPV_ENV_VULKAN_1_2);
spirv_tools_ = std::make_unique<spvtools::SpirvTools>(SPV_ENV_VULKAN_1_2);

spirv_opt_->SetMessageConsumer(spriv_message_consumer);

Expand Down Expand Up @@ -1374,6 +1395,7 @@ class KernelCodegen {
KernelContextAttributes ctx_attribs_;

std::unique_ptr<spvtools::Optimizer> spirv_opt_;
std::unique_ptr<spvtools::SpirvTools> spirv_tools_;
spvtools::OptimizerOptions _spirv_opt_options;
};

Expand Down
32 changes: 22 additions & 10 deletions taichi/backends/vulkan/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ void IRBuilder::init_header() {
.commit(&header_);
}

if (device_->get_cap(cap::vk_has_spv_variable_ptr)) {
ib_.begin(spv::OpCapability)
.add(spv::CapabilityVariablePointers)
.commit(&header_);
ib_.begin(spv::OpCapability)
.add(spv::CapabilityVariablePointersStorageBuffer)
.commit(&header_);
}

if (device_->get_cap(cap::vk_has_int8)) {
ib_.begin(spv::OpCapability).add(spv::CapabilityInt8).commit(&header_);
}
Expand Down Expand Up @@ -240,16 +249,19 @@ SType IRBuilder::get_primitive_type(const DataType &dt) const {
}
}

SType IRBuilder::get_primitive_buffer_type(const DataType &dt) const {
if (dt->is_primitive(PrimitiveTypeID::f32) &&
device_->get_cap(cap::vk_has_atomic_float_add)) {
return t_fp32_;
} else if (dt->is_primitive(PrimitiveTypeID::f64) &&
device_->get_cap(cap::vk_has_atomic_float64_add)) {
return t_fp64_;
} else if (dt->is_primitive(PrimitiveTypeID::i64) &&
device_->get_cap(cap::vk_has_atomic_i64)) {
return t_int64_;
SType IRBuilder::get_primitive_buffer_type(const bool struct_compiled,
const DataType &dt) const {
if (struct_compiled) {
if (dt->is_primitive(PrimitiveTypeID::f32) &&
device_->get_cap(cap::vk_has_atomic_float_add)) {
return t_fp32_;
} else if (dt->is_primitive(PrimitiveTypeID::f64) &&
device_->get_cap(cap::vk_has_atomic_float64_add)) {
return t_fp64_;
} else if (dt->is_primitive(PrimitiveTypeID::i64) &&
device_->get_cap(cap::vk_has_atomic_i64)) {
return t_int64_;
}
}
return t_int32_;
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/backends/vulkan/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,8 @@ class IRBuilder {
// Get the spirv type for a given Taichi data type
SType get_primitive_type(const DataType &dt) const;
// Get the spirv type for the buffer for a given Taichi data type
SType get_primitive_buffer_type(const DataType &dt) const;
SType get_primitive_buffer_type(const bool struct_compiled,
const DataType &dt) const;
// Get the pointer type that points to value_type
SType get_pointer_type(const SType &value_type,
spv::StorageClass storage_class);
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/vulkan/spirv_snode_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SpirvSNodeCompiler {
SNodeSTypeTbl *snode_id_array_stype_tbl_) {
const auto &sn = sn_desc.snode;
if (sn->is_place()) {
return ir_->get_primitive_buffer_type(sn->dt);
return ir_->get_primitive_buffer_type(true, sn->dt);
} else {
SType sn_type = ir_->get_null_type();
sn_type.snode_desc = sn_desc;
Expand Down
13 changes: 5 additions & 8 deletions taichi/backends/vulkan/vulkan_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,11 @@ void EmbeddedVulkanDevice::create_logical_device() {
has_swapchain = true;
enabled_extensions.push_back(ext.extensionName);
} else if (name == VK_EXT_SHADER_ATOMIC_FLOAT_EXTENSION_NAME) {
// ti_device_->set_cap(DeviceCapability::vk_has_atomic_float_add, true);
// enabled_extensions.push_back(ext.extensionName);
enabled_extensions.push_back(ext.extensionName);
} else if (name == "VK_EXT_shader_atomic_float2") {
// FIXME: This feature requires vulkan headers with
// VK_EXT_shader_atomic_float2
/*
capability_.has_atomic_float_minmax = true;
enabled_extensions.push_back(ext.extensionName);
*/
} else if (name == VK_KHR_SHADER_ATOMIC_INT64_EXTENSION_NAME) {
Expand Down Expand Up @@ -537,14 +535,13 @@ void EmbeddedVulkanDevice::create_logical_device() {
vkGetPhysicalDeviceFeatures2KHR(physical_device_, &features2);

if (shader_atomic_float_feature.shaderBufferFloat32AtomicAdd) {
// ti_device_->set_cap(DeviceCapability::vk_has_atomic_float_add, true);
ti_device_->set_cap(DeviceCapability::vk_has_atomic_float_add, true);
} else if (shader_atomic_float_feature.shaderBufferFloat64AtomicAdd) {
// ti_device_->set_cap(DeviceCapability::vk_has_atomic_float64_add,
// true);
ti_device_->set_cap(DeviceCapability::vk_has_atomic_float64_add, true);
} else if (shader_atomic_float_feature.shaderBufferFloat32Atomics) {
// ti_device_->set_cap(DeviceCapability::vk_has_atomic_float, true);
ti_device_->set_cap(DeviceCapability::vk_has_atomic_float, true);
} else if (shader_atomic_float_feature.shaderBufferFloat64Atomics) {
// ti_device_->set_cap(DeviceCapability::vk_has_atomic_float64, true);
ti_device_->set_cap(DeviceCapability::vk_has_atomic_float64, true);
}
*pNextEnd = &shader_atomic_float_feature;
pNextEnd = &shader_atomic_float_feature.pNext;
Expand Down

0 comments on commit ed389b8

Please sign in to comment.