Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for specialization constants. They are now automaticall… #197

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions common/output_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2201,6 +2201,19 @@ void SpvReflectToYaml::Write(std::ostream& os) {
<< SafeString(sm_.push_constant_blocks[i].name) << std::endl;
}

// uint32_t specialization_constant_count;
os << t1 << "specialization_constant_count: " << sm_.specialization_constant_count << ",\n";
// SpvReflectSpecializationConstant* specialization_constants;
os << t1 << "specialization_constants:" << std::endl;
for (uint32_t i = 0; i < sm_.specialization_constant_count; ++i) {
os << t2 << "- *sc" << i << " # " << SafeString(sm_.specialization_constants[i].name) << std::endl;
os << t3 << "spirv_id: " << sm_.specialization_constants[i].spirv_id << std::endl;
os << t3 << "constant_id: " << sm_.specialization_constants[i].constant_id << std::endl;
os << t3 << "size: " << sm_.specialization_constants[i].size << std::endl;
os << t3 << "default_value (as float): " << sm_.specialization_constants[i].default_value.float_value << std::endl;
os << t3 << "default_value (as int): " << sm_.specialization_constants[i].default_value.int_bool_value << std::endl;
}

if (verbosity_ >= 2) {
// struct Internal {
os << t1 << "_internal:" << std::endl;
Expand Down
133 changes: 125 additions & 8 deletions spirv_reflect.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ typedef struct SpvReflectPrvDecorations {
SpvReflectPrvNumberDecoration component;
SpvReflectPrvNumberDecoration offset;
SpvReflectPrvNumberDecoration uav_counter_buffer;
SpvReflectPrvNumberDecoration specialization_constant;
SpvReflectPrvStringDecoration semantic;
uint32_t array_stride;
uint32_t matrix_stride;
Expand Down Expand Up @@ -563,7 +564,7 @@ static uint32_t FindBaseId(SpvReflectPrvParser* p_parser,
SpvReflectPrvAccessChain* ac) {
uint32_t base_id = ac->base_id;
SpvReflectPrvNode* base_node = FindNode(p_parser, base_id);
while (base_node->op != SpvOpVariable) {
while (base_node && base_node->op != SpvOpVariable) {
assert(base_node->op == SpvOpLoad);
UNCHECKED_READU32(p_parser, base_node->word_offset + 3, base_id);
SpvReflectPrvAccessChain* base_ac = FindAccessChain(p_parser, base_id);
Expand Down Expand Up @@ -706,6 +707,7 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser)
p_parser->nodes[i].decorations.offset.value = (uint32_t)INVALID_VALUE;
p_parser->nodes[i].decorations.uav_counter_buffer.value = (uint32_t)INVALID_VALUE;
p_parser->nodes[i].decorations.built_in = (SpvBuiltIn)INVALID_VALUE;
p_parser->nodes[i].decorations.specialization_constant.value = (SpvBuiltIn)INVALID_VALUE;
}
// Mark source file id node
p_parser->source_file_id = (uint32_t)INVALID_VALUE;
Expand Down Expand Up @@ -917,7 +919,12 @@ static SpvReflectResult ParseNodes(SpvReflectPrvParser* p_parser)

case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant:
case SpvOpSpecConstant: {
CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
p_node->is_type = true;
}
break;
case SpvOpSpecConstantComposite:
case SpvOpSpecConstantOp: {
CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
Expand Down Expand Up @@ -1433,6 +1440,7 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
}
break;
case SpvDecorationRelaxedPrecision:
case SpvDecorationSpecId:
case SpvDecorationBlock:
case SpvDecorationBufferBlock:
case SpvDecorationColMajor:
Expand Down Expand Up @@ -1590,6 +1598,13 @@ static SpvReflectResult ParseDecorations(SpvReflectPrvParser* p_parser)
}
break;

case SpvDecorationSpecId: {
uint32_t word_offset = p_node->word_offset + member_offset+ 3;
CHECKED_READU32(p_parser, word_offset, p_target_decorations->specialization_constant.value);
p_target_decorations->specialization_constant.word_offset = word_offset;
}
break;

case SpvReflectDecorationHlslCounterBufferGOOGLE: {
uint32_t word_offset = p_node->word_offset + member_offset+ 3;
CHECKED_READU32(p_parser, word_offset, p_target_decorations->uav_counter_buffer.value);
Expand Down Expand Up @@ -1898,6 +1913,12 @@ static SpvReflectResult ParseType(
p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_ACCELERATION_STRUCTURE;
}
break;

case SpvOpSpecConstantTrue:
case SpvOpSpecConstantFalse:
case SpvOpSpecConstant: {
}
break;
}

if (result == SPV_REFLECT_RESULT_SUCCESS) {
Expand All @@ -1912,6 +1933,71 @@ static SpvReflectResult ParseType(
return result;
}

static SpvReflectResult ParseSpecializationConstants(SpvReflectPrvParser* p_parser, SpvReflectShaderModule* p_module)
{
p_module->specialization_constant_count = 0;
p_module->specialization_constants = NULL;
for (size_t i = 0; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
if (p_node->op == SpvOpSpecConstantTrue || p_node->op == SpvOpSpecConstantFalse || p_node->op == SpvOpSpecConstant) {
p_module->specialization_constant_count++;
}
}

if (p_module->specialization_constant_count == 0) {
return SPV_REFLECT_RESULT_SUCCESS;
}

p_module->specialization_constants = (SpvReflectSpecializationConstant*)calloc(p_module->specialization_constant_count, sizeof(SpvReflectSpecializationConstant));

uint32_t index = 0;

for (size_t i = 0; i < p_parser->node_count; ++i) {
SpvReflectPrvNode* p_node = &(p_parser->nodes[i]);
switch(p_node->op) {
default: continue;
case SpvOpSpecConstantTrue: {
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
p_module->specialization_constants[index].default_value.int_bool_value = 1;
p_module->specialization_constants[index].size = sizeof(uint32_t);
} break;
case SpvOpSpecConstantFalse: {
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL;
p_module->specialization_constants[index].default_value.int_bool_value = 0;
p_module->specialization_constants[index].size = sizeof(uint32_t);
} break;
case SpvOpSpecConstant: {
SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
uint32_t element_type_id = (uint32_t)INVALID_VALUE;
uint32_t default_value = 0;
IF_READU32(result, p_parser, p_node->word_offset + 1, element_type_id);
IF_READU32(result, p_parser, p_node->word_offset + 3, default_value);

SpvReflectPrvNode* p_next_node = FindNode(p_parser, element_type_id);

if (p_next_node->op == SpvOpTypeInt) {
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_INT;
p_module->specialization_constants[index].size = sizeof(int32_t);
} else if (p_next_node->op == SpvOpTypeFloat) {
p_module->specialization_constants[index].constant_type = SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT;
p_module->specialization_constants[index].size = sizeof(float);
} else {
return SPV_REFLECT_RESULT_ERROR_PARSE_FAILED;
}

p_module->specialization_constants[index].default_value.int_bool_value = default_value; //bits are the same for int and float
} break;
}

p_module->specialization_constants[index].name = p_node->name;
p_module->specialization_constants[index].constant_id = p_node->decorations.specialization_constant.value;
p_module->specialization_constants[index].spirv_id = p_node->result_id;
index++;
}

return SPV_REFLECT_RESULT_SUCCESS;
}

static SpvReflectResult ParseTypes(
SpvReflectPrvParser* p_parser,
SpvReflectShaderModule* p_module)
Expand Down Expand Up @@ -2703,7 +2789,7 @@ static SpvReflectResult ParseDescriptorBlockVariableUsage(
else if (IsPointerToPointer(p_parser, p_access_chain->result_type_id)) {
// Remember block var for this access chain for downstream dereference
p_access_chain->block_var = p_member_var;
}
}
else {
// Clear UNUSED flag for remaining variables
MarkSelfAndAllMemberVarsAsUsed(p_member_var);
Expand Down Expand Up @@ -2765,11 +2851,6 @@ static SpvReflectResult ParseDescriptorBlocks(
if (result != SPV_REFLECT_RESULT_SUCCESS) {
return result;
}

if (is_parent_rta) {
p_descriptor->block.size = 0;
p_descriptor->block.padded_size = 0;
}
}

return SPV_REFLECT_RESULT_SUCCESS;
Expand Down Expand Up @@ -3595,6 +3676,7 @@ static SpvReflectResult ParsePushConstantBlocks(
SpvReflectBlockVariable* p_push_constant = &p_module->push_constant_blocks[push_constant_index];
p_push_constant->spirv_id = p_node->result_id;
SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, p_push_constant);
p_push_constant->name = p_node->name;
if (result != SPV_REFLECT_RESULT_SUCCESS) {
return result;
}
Expand Down Expand Up @@ -3959,6 +4041,10 @@ static SpvReflectResult CreateShaderModule(
result = ParsePushConstantBlocks(&parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}
if (result == SPV_REFLECT_RESULT_SUCCESS) {
result = ParseSpecializationConstants(&parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
}
if (result == SPV_REFLECT_RESULT_SUCCESS) {
result = ParseEntryPoints(&parser, p_module);
SPV_REFLECT_ASSERT(result == SPV_REFLECT_RESULT_SUCCESS);
Expand Down Expand Up @@ -4122,6 +4208,7 @@ void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module)
}
SafeFree(p_module->capabilities);
SafeFree(p_module->entry_points);
SafeFree(p_module->specialization_constants);

// Push constants
for (size_t i = 0; i < p_module->push_constant_block_count; ++i) {
Expand Down Expand Up @@ -4611,6 +4698,36 @@ SpvReflectResult spvReflectEnumerateEntryPointPushConstantBlocks(
return SPV_REFLECT_RESULT_SUCCESS;
}

SpvReflectResult spvReflectEnumerateSpecializationConstants(
const SpvReflectShaderModule* p_module,
uint32_t* p_count,
SpvReflectSpecializationConstant** pp_constants
)
{
if (IsNull(p_module)) {
return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
}
if (IsNull(p_count)) {
return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
}

if (IsNotNull(pp_constants)) {
if (*p_count != p_module->specialization_constant_count) {
return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
}

for (uint32_t index = 0; index < *p_count; ++index) {
SpvReflectSpecializationConstant *p_constant = (SpvReflectSpecializationConstant*)&p_module->specialization_constants[index];
pp_constants[index] = p_constant;
}
}
else {
*p_count = p_module->specialization_constant_count;
}

return SPV_REFLECT_RESULT_SUCCESS;
}

const SpvReflectDescriptorBinding* spvReflectGetDescriptorBinding(
const SpvReflectShaderModule* p_module,
uint32_t binding_number,
Expand Down
72 changes: 72 additions & 0 deletions spirv_reflect.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,26 @@ typedef struct SpvReflectTypeDescription {
struct SpvReflectTypeDescription* members;
} SpvReflectTypeDescription;

/*! @struct SpvReflectSpecializationConstant
*/

typedef enum SpvReflectSpecializationConstantType {
SPV_REFLECT_SPECIALIZATION_CONSTANT_BOOL = 0,
SPV_REFLECT_SPECIALIZATION_CONSTANT_INT = 1,
SPV_REFLECT_SPECIALIZATION_CONSTANT_FLOAT = 2,
} SpvReflectSpecializationConstantType;

typedef struct SpvReflectSpecializationConstant {
const char* name;
uint32_t spirv_id;
uint32_t constant_id;
uint32_t size;
SpvReflectSpecializationConstantType constant_type;
union {
float float_value;
uint32_t int_bool_value;
} default_value;
} SpvReflectSpecializationConstant;

/*! @struct SpvReflectInterfaceVariable

Expand Down Expand Up @@ -503,6 +523,9 @@ typedef struct SpvReflectShaderModule {
SpvReflectInterfaceVariable* interface_variables; // Uses value(s) from first entry point
uint32_t push_constant_block_count; // Uses value(s) from first entry point
SpvReflectBlockVariable* push_constant_blocks; // Uses value(s) from first entry point
uint32_t specialization_constant_count; // Uses value(s) from first entry point
SpvReflectSpecializationConstant* specialization_constants; // Uses value(s) from first entry point


struct Internal {
SpvReflectModuleFlags module_flags;
Expand Down Expand Up @@ -914,6 +937,33 @@ SpvReflectResult spvReflectEnumerateEntryPointPushConstantBlocks(
);


/*! @fn spvReflectEnumerateSpecializationConstants
@brief Note: If the module contains multiple entry points, this will only get
the specialization constant blocks for the first one.
@param p_module Pointer to an instance of SpvReflectShaderModule.
@param p_count If pp_blocks is NULL, the module's specialization constant
count will be stored here.
If pp_blocks is not NULL, *p_count must
contain the module's specialization constant count.
@param pp_constants If NULL, the module's specialization constant count
will be written to *p_count.
If non-NULL, pp_blocks must point to an
array with *p_count entries, where pointers to
the module's specialization constant blocks will be written.
The caller must not free the variables written
to this array.
@return If successful, returns SPV_REFLECT_RESULT_SUCCESS.
Otherwise, the error code indicates the cause of the
failure.

*/
SpvReflectResult spvReflectEnumerateSpecializationConstants(
const SpvReflectShaderModule* p_module,
uint32_t* p_count,
SpvReflectSpecializationConstant** pp_constants
);


/*! @fn spvReflectGetDescriptorBinding

@param p_module Pointer to an instance of SpvReflectShaderModule.
Expand Down Expand Up @@ -1504,6 +1554,7 @@ class ShaderModule {
SpvReflectResult EnumeratePushConstants(uint32_t* p_count, SpvReflectBlockVariable** pp_blocks) const {
return EnumeratePushConstantBlocks(p_count, pp_blocks);
}
SpvReflectResult EnumerateSpecializationConstants(uint32_t* p_count, SpvReflectSpecializationConstant** pp_constants) const;

const SpvReflectDescriptorBinding* GetDescriptorBinding(uint32_t binding_number, uint32_t set_number, SpvReflectResult* p_result = nullptr) const;
const SpvReflectDescriptorBinding* GetEntryPointDescriptorBinding(const char* entry_point, uint32_t binding_number, uint32_t set_number, SpvReflectResult* p_result = nullptr) const;
Expand Down Expand Up @@ -1951,6 +2002,27 @@ inline SpvReflectResult ShaderModule::EnumeratePushConstantBlocks(
return m_result;
}

/*! @fn EnumerateSpecializationConstants

@param p_count
@param pp_constants
@return

*/
inline SpvReflectResult ShaderModule::EnumerateSpecializationConstants(
uint32_t* p_count,
SpvReflectSpecializationConstant** pp_constants
) const
{
m_result = spvReflectEnumerateSpecializationConstants(
&m_module,
p_count,
pp_constants
);
return m_result;
}


/*! @fn EnumerateEntryPointPushConstantBlocks

@param entry_point
Expand Down