Skip to content

Commit

Permalink
gpu: Fix Chaining of Accesses for GeneralBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
spencer-lunarg committed Dec 17, 2024
1 parent b54d52d commit 91dfde0
Show file tree
Hide file tree
Showing 7 changed files with 974 additions and 57 deletions.
38 changes: 24 additions & 14 deletions layers/gpu/spirv/descriptor_class_general_buffer_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/

#include "descriptor_class_general_buffer_pass.h"
#include "instruction.h"
#include "module.h"
#include <spirv/unified1/spirv.hpp>
#include <iostream>
Expand Down Expand Up @@ -42,18 +43,20 @@ uint32_t DescriptorClassGeneralBufferPass::GetLinkFunctionId() {

uint32_t DescriptorClassGeneralBufferPass::CreateFunctionCall(BasicBlock& block, InstructionIt* inst_it,
const InjectionData& injection_data) {
assert(access_chain_inst_ && var_inst_);
assert(!access_chain_insts_.empty() && var_inst_);
const Constant& set_constant = module_.type_manager_.GetConstantUInt32(descriptor_set_);
const Constant& binding_constant = module_.type_manager_.GetConstantUInt32(descriptor_binding_);
const uint32_t descriptor_index_id = CastToUint32(descriptor_index_id_, block, inst_it); // might be int32

// For now, only do bounds check for non-aggregate types
// TODO - Do bounds check for aggregate loads and stores
const Type* pointer_type = module_.type_manager_.FindTypeById(access_chain_inst_->TypeId());
//
// Grab front() as it will be the "final" type we access
const Type* pointer_type = module_.type_manager_.FindTypeById(access_chain_insts_.front()->TypeId());
const Type* pointee_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3));
if (pointee_type && pointee_type->spv_type_ != SpvType::kArray && pointee_type->spv_type_ != SpvType::kRuntimeArray &&
pointee_type->spv_type_ != SpvType::kStruct) {
descriptor_offset_id_ = GetLastByte(*var_inst_, *access_chain_inst_, block, inst_it); // Get Last Byte Index
descriptor_offset_id_ = GetLastByte(*var_inst_, access_chain_insts_, block, inst_it); // Get Last Byte Index
} else {
descriptor_offset_id_ = module_.type_manager_.GetConstantZeroUint32().Id();
}
Expand All @@ -75,7 +78,6 @@ uint32_t DescriptorClassGeneralBufferPass::CreateFunctionCall(BasicBlock& block,
}

void DescriptorClassGeneralBufferPass::Reset() {
access_chain_inst_ = nullptr;
var_inst_ = nullptr;
target_instruction_ = nullptr;
descriptor_set_ = 0;
Expand All @@ -91,15 +93,23 @@ bool DescriptorClassGeneralBufferPass::RequiresInstrumentation(const Function& f
return false;
}

// TODO - Should have loop to walk Load/Store to the Pointer,
// this case will not cover things such as OpCopyObject or double OpAccessChains
access_chain_inst_ = function.FindInstruction(inst.Operand(0));
if (!access_chain_inst_ || access_chain_inst_->Opcode() != spv::OpAccessChain) {
const Instruction* next_access_chain = function.FindInstruction(inst.Operand(0));
if (!next_access_chain || next_access_chain->Opcode() != spv::OpAccessChain) {
return false;
}

const uint32_t variable_id = access_chain_inst_->Operand(0);
const Variable* variable = module_.type_manager_.FindVariableById(variable_id);
access_chain_insts_.clear(); // only clear right before we know we will need again

const Variable* variable = nullptr;
// We need to walk down possibly multiple chained OpAccessChains or OpCopyObject to get the variable
while (next_access_chain && next_access_chain->Opcode() == spv::OpAccessChain) {
access_chain_insts_.push_back(next_access_chain);
const uint32_t access_chain_base_id = next_access_chain->Operand(0);
variable = module_.type_manager_.FindVariableById(access_chain_base_id);
if (variable) {
break; // found
}
next_access_chain = function.FindInstruction(access_chain_base_id);
}
if (!variable) {
return false;
}
Expand Down Expand Up @@ -134,15 +144,15 @@ bool DescriptorClassGeneralBufferPass::RequiresInstrumentation(const Function& f
// A load through a descriptor array will have at least 3 operands. We
// do not want to instrument loads of descriptors here which are part of
// an image-based reference.
if (is_descriptor_array && access_chain_inst_->Length() >= 6) {
descriptor_index_id_ = access_chain_inst_->Operand(1);
if (is_descriptor_array && access_chain_insts_.back()->Length() >= 6) {
descriptor_index_id_ = access_chain_insts_.back()->Operand(1);
} else {
// There is no array of this descriptor, so we essentially have an array of 1
descriptor_index_id_ = module_.type_manager_.GetConstantZeroUint32().Id();
}

for (const auto& annotation : module_.annotations_) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(1) == variable_id) {
if (annotation->Opcode() == spv::OpDecorate && annotation->Word(1) == variable->Id()) {
if (annotation->Word(2) == spv::DecorationDescriptorSet) {
descriptor_set_ = annotation->Word(3);
} else if (annotation->Word(2) == spv::DecorationBinding) {
Expand Down
7 changes: 6 additions & 1 deletion layers/gpu/spirv/descriptor_class_general_buffer_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ class DescriptorClassGeneralBufferPass : public Pass {
uint32_t link_function_id = 0;
uint32_t GetLinkFunctionId();

const Instruction* access_chain_inst_ = nullptr;
// List of OpAccessChains fom the Store/Load down to the OpVariable
// The front() will be closet to the exact spot accesssed
// The back() will be closest to the OpVariable
// (note GLSL will try to always create a single large OpAccessChain)
std::vector<const Instruction*> access_chain_insts_;
// The OpVariable that is being accessed
const Instruction* var_inst_ = nullptr;

uint32_t descriptor_set_ = 0;
Expand Down
17 changes: 12 additions & 5 deletions layers/gpu/spirv/pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
*/

#include "pass.h"
#include "instruction.h"
#include "module.h"
#include "gpu/shaders/gpuav_error_codes.h"

Expand Down Expand Up @@ -208,17 +209,17 @@ const Instruction* Pass::GetMemeberDecoration(uint32_t id, uint32_t member_index
// Find outermost buffer type and its access chain index.
// Because access chains indexes can be runtime values, we need to build arithmetic logic in the SPIR-V to get the runtime value of
// the indexing
uint32_t Pass::GetLastByte(const Instruction& var_inst, const Instruction& access_chain_inst, BasicBlock& block,
uint32_t Pass::GetLastByte(const Instruction& var_inst, std::vector<const Instruction*>& access_chain_insts, BasicBlock& block,
InstructionIt* inst_it) {
const Type* pointer_type = module_.type_manager_.FindTypeById(var_inst.TypeId());
const Type* descriptor_type = module_.type_manager_.FindTypeById(pointer_type->inst_.Word(3));

uint32_t current_type_id = 0;
uint32_t ac_word_index = 4;
uint32_t ac_word_index = 4; // points to first "Index" operand of an OpAccessChain

if (descriptor_type->spv_type_ == SpvType::kArray || descriptor_type->spv_type_ == SpvType::kRuntimeArray) {
current_type_id = descriptor_type->inst_.Operand(0);
ac_word_index++;
ac_word_index++; // this jumps over the array of descriptors so we first start on the descriptor itself
} else if (descriptor_type->spv_type_ == SpvType::kStruct) {
current_type_id = descriptor_type->Id();
} else {
Expand All @@ -236,8 +237,9 @@ uint32_t Pass::GetLastByte(const Instruction& var_inst, const Instruction& acces
uint32_t matrix_stride_id = 0;
bool in_matrix = false;

while (ac_word_index < access_chain_inst.Length()) {
const uint32_t ac_index_id = access_chain_inst.Word(ac_word_index);
auto access_chain_iter = access_chain_insts.rbegin();
while (access_chain_iter != access_chain_insts.rend()) {
const uint32_t ac_index_id = (*access_chain_iter)->Word(ac_word_index);
uint32_t current_offset_id = 0;

const Type* current_type = module_.type_manager_.FindTypeById(current_type_id);
Expand Down Expand Up @@ -335,7 +337,12 @@ uint32_t Pass::GetLastByte(const Instruction& var_inst, const Instruction& acces
block.CreateInstruction(spv::OpIAdd, {uint32_type.Id(), new_sum_id, sum_id, current_offset_id}, inst_it);
sum_id = new_sum_id;
}

ac_word_index++;
if (ac_word_index >= (*access_chain_iter)->Length()) {
++access_chain_iter;
ac_word_index = 4; // reset
}
}

// Add in offset of last byte of referenced object
Expand Down
2 changes: 1 addition & 1 deletion layers/gpu/spirv/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class Pass {
const Instruction* GetDecoration(uint32_t id, spv::Decoration decoration);
const Instruction* GetMemeberDecoration(uint32_t id, uint32_t member_index, spv::Decoration decoration);

uint32_t GetLastByte(const Instruction& var_inst, const Instruction& access_chain_inst, BasicBlock& block,
uint32_t GetLastByte(const Instruction& var_inst, std::vector<const Instruction*>& access_chain_insts, BasicBlock& block,
InstructionIt* inst_it);
// Generate SPIR-V needed to help convert things to be uniformly uint32_t
// If no inst_it is passed in, any new instructions will be added to end of the Block
Expand Down
5 changes: 5 additions & 0 deletions tests/framework/layer_validation_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ class GpuAVDescriptorIndexingTest : public GpuAVTest {
void InitGpuVUDescriptorIndexing();
};

class GpuAVDescriptorClassGeneralBuffer : public GpuAVTest {
public:
void ComputeStorageBufferTest(const char *shader, bool is_glsl, VkDeviceSize buffer_size, const char *expected_error = nullptr);
};

class GpuAVRayQueryTest : public GpuAVTest {
public:
void InitGpuAVRayQuery();
Expand Down
Loading

0 comments on commit 91dfde0

Please sign in to comment.