diff --git a/common/output_stream.cpp b/common/output_stream.cpp index cfb4a99f..d2989e9f 100644 --- a/common/output_stream.cpp +++ b/common/output_stream.cpp @@ -2201,6 +2201,19 @@ void SpvReflectToYaml::Write(std::ostream& os) { << SafeString(sm_.push_constant_blocks[i].name) << std::endl; } + // uint32_t specialization_constant_count; + os << t1 << "specialization_constant_count: " << sm_.specialization_constant_count << ",\n"; + // SpvReflectSpecializationConstant* specialization_constants; + os << t1 << "specialization_constants:" << std::endl; + for (uint32_t i = 0; i < sm_.specialization_constant_count; ++i) { + os << t2 << "- *sc" << i << " # " << SafeString(sm_.specialization_constants[i].name) << std::endl; + os << t3 << "spirv_id: " << sm_.specialization_constants[i].spirv_id << std::endl; + os << t3 << "constant_id: " << sm_.specialization_constants[i].constant_id << std::endl; + os << t3 << "size: " << sm_.specialization_constants[i].size << std::endl; + os << t3 << "default_value (as float): " << sm_.specialization_constants[i].default_value.float_value << std::endl; + os << t3 << "default_value (as int): " << sm_.specialization_constants[i].default_value.int_bool_value << std::endl; + } + if (verbosity_ >= 2) { // struct Internal { os << t1 << "_internal:" << std::endl; diff --git a/spirv_reflect.c b/spirv_reflect.c index 2cb45d04..33e57458 100644 --- a/spirv_reflect.c +++ b/spirv_reflect.c @@ -134,6 +134,7 @@ typedef struct SpvReflectPrvDecorations { SpvReflectPrvNumberDecoration component; SpvReflectPrvNumberDecoration offset; SpvReflectPrvNumberDecoration uav_counter_buffer; + SpvReflectPrvNumberDecoration specialization_constant; SpvReflectPrvStringDecoration semantic; uint32_t array_stride; uint32_t matrix_stride; @@ -563,7 +564,7 @@ static uint32_t FindBaseId(SpvReflectPrvParser* p_parser, SpvReflectPrvAccessChain* ac) { uint32_t base_id = ac->base_id; SpvReflectPrvNode* base_node = FindNode(p_parser, base_id); - while (base_node->op != SpvOpVariable) { + while (base_node && base_node->op != SpvOpVariable) { assert(base_node->op == SpvOpLoad); UNCHECKED_READU32(p_parser, base_node->word_offset + 3, base_id); SpvReflectPrvAccessChain* base_ac = FindAccessChain(p_parser, base_id); @@ -706,6 +707,7 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) p_parser->nodes[i].decorations.offset.value = (uint32_t)INVALID_VALUE; p_parser->nodes[i].decorations.uav_counter_buffer.value = (uint32_t)INVALID_VALUE; p_parser->nodes[i].decorations.built_in = (SpvBuiltIn)INVALID_VALUE; + p_parser->nodes[i].decorations.specialization_constant.value = (SpvBuiltIn)INVALID_VALUE; } // Mark source file id node p_parser->source_file_id = (uint32_t)INVALID_VALUE; @@ -917,7 +919,12 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser) case SpvOpSpecConstantTrue: case SpvOpSpecConstantFalse: - case SpvOpSpecConstant: + case SpvOpSpecConstant: { + CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id); + CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id); + p_node->is_type = true; + } + break; case SpvOpSpecConstantComposite: case SpvOpSpecConstantOp: { CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id); @@ -1433,6 +1440,7 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser) } break; case SpvDecorationRelaxedPrecision: + case SpvDecorationSpecId: case SpvDecorationBlock: case SpvDecorationBufferBlock: case SpvDecorationColMajor: @@ -1590,6 +1598,13 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser) } break; + case SpvDecorationSpecId: { + uint32_t word_offset = p_node->word_offset + member_offset+ 3; + CHECKED_READU32(p_parser, word_offset, p_target_decorations->specialization_constant.value); + p_target_decorations->specialization_constant.word_offset = word_offset; + } + break; + case SpvReflectDecorationHlslCounterBufferGOOGLE: { uint32_t word_offset = p_node->word_offset + member_offset+ 3; CHECKED_READU32(p_parser, word_offset, p_target_decorations->uav_counter_buffer.value); @@ -1898,6 +1913,12 @@ static SpvReflectResult ParseType( p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_ACCELERATION_STRUCTURE; } break; + + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstant: { + } + break; } if (result == SPV_REFLECT_RESULT_SUCCESS) { @@ -1912,6 +1933,71 @@ static SpvReflectResult ParseType( return result; } +static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module) +{ + p_module->specialization_constant_count = 0; + p_module->specialization_constants = NULL; + for (size_t i = 0; i < p_parser->node_count; ++i) { + SpvReflectPrvNode* p_node = &(p_parser->nodes[i]); + if (p_node->op == SpvOpSpecConstantTrue || p_node->op == SpvOpSpecConstantFalse || p_node->op == SpvOpSpecConstant) { + p_module->specialization_constant_count++; + } + } + + if (p_module->specialization_constant_count == 0) { + return SPV_REFLECT_RESULT_SUCCESS; + } + + p_module->specialization_constants = (SpvReflectSpecializationConstant*)calloc(p_module->specialization_constant_count, sizeof(SpvReflectSpecializationConstant)); + + uint32_t index = 0; + + for (size_t i = 0; i < p_parser->node_count; ++i) { + SpvReflectPrvNode* p_node = &(p_parser->nodes[i]); + switch(p_node->op) { + default: continue; + case SpvOpSpecConstantTrue: { + p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL; + p_module->specialization_constants[index].default_value.int_bool_value = 1; + p_module->specialization_constants[index].size = sizeof(uint32_t); + } break; + case SpvOpSpecConstantFalse: { + p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL; + p_module->specialization_constants[index].default_value.int_bool_value = 0; + p_module->specialization_constants[index].size = sizeof(uint32_t); + } break; + case SpvOpSpecConstant: { + SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS; + uint32_t element_type_id = (uint32_t)INVALID_VALUE; + uint32_t default_value = 0; + IF_READU32(result, p_parser, p_node->word_offset + 1, element_type_id); + IF_READU32(result, p_parser, p_node->word_offset + 3, default_value); + + SpvReflectPrvNode* p_next_node = FindNode(p_parser, element_type_id); + + if (p_next_node->op == SpvOpTypeInt) { + p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_INT; + p_module->specialization_constants[index].size = sizeof(int32_t); + } else if (p_next_node->op == SpvOpTypeFloat) { + p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT; + p_module->specialization_constants[index].size = sizeof(float); + } else { + return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED; + } + + p_module->specialization_constants[index].default_value.int_bool_value = default_value; //bits are the same for int and float + } break; + } + + p_module->specialization_constants[index].name = p_node->name; + p_module->specialization_constants[index].constant_id = p_node->decorations.specialization_constant.value; + p_module->specialization_constants[index].spirv_id = p_node->result_id; + index++; + } + + return SPV_REFLECT_RESULT_SUCCESS; +} + static SpvReflectResult ParseTypes( SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module) @@ -2703,7 +2789,7 @@ static SpvReflectResult ParseDescriptorBlockVariableUsage( else if (IsPointerToPointer(p_parser, p_access_chain->result_type_id)) { // Remember block var for this access chain for downstream dereference p_access_chain->block_var = p_member_var; - } + } else { // Clear UNUSED flag for remaining variables MarkSelfAndAllMemberVarsAsUsed(p_member_var); @@ -2765,11 +2851,6 @@ static SpvReflectResult ParseDescriptorBlocks( if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } - - if (is_parent_rta) { - p_descriptor->block.size = 0; - p_descriptor->block.padded_size = 0; - } } return SPV_REFLECT_RESULT_SUCCESS; @@ -3595,6 +3676,7 @@ static SpvReflectResult ParsePushConstantBlocks( SpvReflectBlockVariable* p_push_constant = &p_module->push_constant_blocks[push_constant_index]; p_push_constant->spirv_id = p_node->result_id; SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, p_push_constant); + p_push_constant->name = p_node->name; if (result != SPV_REFLECT_RESULT_SUCCESS) { return result; } @@ -3959,6 +4041,10 @@ static SpvReflectResult CreateShaderModule( result = ParsePushConstantBlocks(&parser, p_module); SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); } + if (result == SPV_REFLECT_RESULT_SUCCESS) { + result = ParseSpecializationConstants(&parser, p_module); + SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); + } if (result == SPV_REFLECT_RESULT_SUCCESS) { result = ParseEntryPoints(&parser, p_module); SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS); @@ -4122,6 +4208,7 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module) } SafeFree(p_module->capabilities); SafeFree(p_module->entry_points); + SafeFree(p_module->specialization_constants); // Push constants for (size_t i = 0; i < p_module->push_constant_block_count; ++i) { @@ -4611,6 +4698,36 @@ SpvReflectResult spvReflectEnumerateEntryPointPushConstantBlocks( return SPV_REFLECT_RESULT_SUCCESS; } +SpvReflectResult spvReflectEnumerateSpecializationConstants( + const SpvReflectShaderModule* p_module, + uint32_t* p_count, + SpvReflectSpecializationConstant** pp_constants +) +{ + if (IsNull(p_module)) { + return SPV_REFLECT_RESULT_ERROR_NULL_POINTER; + } + if (IsNull(p_count)) { + return SPV_REFLECT_RESULT_ERROR_NULL_POINTER; + } + + if (IsNotNull(pp_constants)) { + if (*p_count != p_module->specialization_constant_count) { + return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH; + } + + for (uint32_t index = 0; index < *p_count; ++index) { + SpvReflectSpecializationConstant *p_constant = (SpvReflectSpecializationConstant*)&p_module->specialization_constants[index]; + pp_constants[index] = p_constant; + } + } + else { + *p_count = p_module->specialization_constant_count; + } + + return SPV_REFLECT_RESULT_SUCCESS; +} + const SpvReflectDescriptorBinding* spvReflectGetDescriptorBinding( const SpvReflectShaderModule* p_module, uint32_t binding_number, diff --git a/spirv_reflect.h b/spirv_reflect.h index 1a76aad9..1d4f4c30 100644 --- a/spirv_reflect.h +++ b/spirv_reflect.h @@ -332,6 +332,26 @@ typedef struct SpvReflectTypeDescription { struct SpvReflectTypeDescription* members; } SpvReflectTypeDescription; +/*! @struct SpvReflectSpecializationConstant +*/ + +typedef enum SpvReflectSpecializationConstantType { + SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL = 0, + SPV_REFLECT_SPECIALIZATION_CONSTANT_INT = 1, + SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT = 2, +} SpvReflectSpecializationConstantType; + +typedef struct SpvReflectSpecializationConstant { + const char* name; + uint32_t spirv_id; + uint32_t constant_id; + uint32_t size; + SpvReflectSpecializationConstantType constant_type; + union { + float float_value; + uint32_t int_bool_value; + } default_value; +} SpvReflectSpecializationConstant; /*! @struct SpvReflectInterfaceVariable @@ -503,6 +523,9 @@ typedef struct SpvReflectShaderModule { SpvReflectInterfaceVariable* interface_variables; // Uses value(s) from first entry point uint32_t push_constant_block_count; // Uses value(s) from first entry point SpvReflectBlockVariable* push_constant_blocks; // Uses value(s) from first entry point + uint32_t specialization_constant_count; // Uses value(s) from first entry point + SpvReflectSpecializationConstant* specialization_constants; // Uses value(s) from first entry point + struct Internal { SpvReflectModuleFlags module_flags; @@ -914,6 +937,33 @@ SpvReflectResult spvReflectEnumerateEntryPointPushConstantBlocks( ); +/*! @fn spvReflectEnumerateSpecializationConstants + @brief Note: If the module contains multiple entry points, this will only get + the specialization constant blocks for the first one. + @param p_module Pointer to an instance of SpvReflectShaderModule. + @param p_count If pp_blocks is NULL, the module's specialization constant + count will be stored here. + If pp_blocks is not NULL, *p_count must + contain the module's specialization constant count. + @param pp_constants If NULL, the module's specialization constant count + will be written to *p_count. + If non-NULL, pp_blocks must point to an + array with *p_count entries, where pointers to + the module's specialization constant blocks will be written. + The caller must not free the variables written + to this array. + @return If successful, returns SPV_REFLECT_RESULT_SUCCESS. + Otherwise, the error code indicates the cause of the + failure. + +*/ +SpvReflectResult spvReflectEnumerateSpecializationConstants( + const SpvReflectShaderModule* p_module, + uint32_t* p_count, + SpvReflectSpecializationConstant** pp_constants +); + + /*! @fn spvReflectGetDescriptorBinding @param p_module Pointer to an instance of SpvReflectShaderModule. @@ -1504,6 +1554,7 @@ class ShaderModule { SpvReflectResult EnumeratePushConstants(uint32_t* p_count, SpvReflectBlockVariable** pp_blocks) const { return EnumeratePushConstantBlocks(p_count, pp_blocks); } + SpvReflectResult EnumerateSpecializationConstants(uint32_t* p_count, SpvReflectSpecializationConstant** pp_constants) const; const SpvReflectDescriptorBinding* GetDescriptorBinding(uint32_t binding_number, uint32_t set_number, SpvReflectResult* p_result = nullptr) const; const SpvReflectDescriptorBinding* GetEntryPointDescriptorBinding(const char* entry_point, uint32_t binding_number, uint32_t set_number, SpvReflectResult* p_result = nullptr) const; @@ -1951,6 +2002,27 @@ inline SpvReflectResult ShaderModule::EnumeratePushConstantBlocks( return m_result; } +/*! @fn EnumerateSpecializationConstants + + @param p_count + @param pp_constants + @return + +*/ +inline SpvReflectResult ShaderModule::EnumerateSpecializationConstants( + uint32_t* p_count, + SpvReflectSpecializationConstant** pp_constants +) const +{ + m_result = spvReflectEnumerateSpecializationConstants( + &m_module, + p_count, + pp_constants + ); + return m_result; +} + + /*! @fn EnumerateEntryPointPushConstantBlocks @param entry_point