Skip to content

Commit

Permalink
[SPIR-V] Make access chains more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Jul 21, 2024
1 parent 9c64573 commit 59e02a7
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 18 deletions.
85 changes: 77 additions & 8 deletions Sources/backends/spirv.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,56 @@ typedef enum spirv_opcode {
SPIRV_OPCODE_LABEL = 248
} spirv_opcode;

static type_id find_access_type(int *indices, int indices_size, type_id base_type) {
if (indices_size == 1) {
if (base_type == float2_id || base_type == float3_id || base_type == float4_id) {
return float_id;
}
else {
type *t = get_type(base_type);
assert(indices[0] < t->members.size);
return t->members.m[indices[0]].type.type;
}
}
else {
type *t = get_type(base_type);
assert(indices[0] < t->members.size);
return find_access_type(&indices[1], indices_size - 1, t->members.m[indices[0]].type.type);
}
}

static void vector_member_indices(int *input_indices, int *output_indices, int indices_size, type_id base_type) {
if (base_type == float2_id || base_type == float3_id || base_type == float4_id) {
type *t = get_type(base_type);

if (strcmp(get_name(t->members.m[input_indices[0]].name), "x") == 0 || strcmp(get_name(t->members.m[input_indices[0]].name), "r") == 0) {
output_indices[0] = 0;
}
else if (strcmp(get_name(t->members.m[input_indices[0]].name), "y") == 0 || strcmp(get_name(t->members.m[input_indices[0]].name), "g") == 0) {
output_indices[0] = 1;
}
else if (strcmp(get_name(t->members.m[input_indices[0]].name), "z") == 0 || strcmp(get_name(t->members.m[input_indices[0]].name), "b") == 0) {
output_indices[0] = 2;
}
else if (strcmp(get_name(t->members.m[input_indices[0]].name), "w") == 0 || strcmp(get_name(t->members.m[input_indices[0]].name), "a") == 0) {
output_indices[0] = 3;
}
else {
// assert(false);
output_indices[0] = 0; // TODO
}
}
else {
output_indices[0] = input_indices[0];
}

if (indices_size > 1) {
type *t = get_type(base_type);
assert(input_indices[0] < t->members.size);
vector_member_indices(&input_indices[1], &output_indices[1], indices_size - 1, t->members.m[input_indices[0]].type.type);
}
}

typedef enum addressing_model { ADDRESSING_MODEL_LOGICAL = 0 } addressing_model;

typedef enum memory_model { MEMORY_MODEL_SIMPLE = 0, MEMORY_MODEL_GLSL450 = 1 } memory_model;
Expand Down Expand Up @@ -477,7 +527,8 @@ static void write_vertex_input_decorations(instructions_buffer *instructions, ui
}
}

static uint32_t write_op_function_preallocated(instructions_buffer *instructions, uint32_t result_type, function_control control, uint32_t function_type, uint32_t result) {
static uint32_t write_op_function_preallocated(instructions_buffer *instructions, uint32_t result_type, function_control control, uint32_t function_type,
uint32_t result) {
uint32_t operands[] = {result_type, result, (uint32_t)control, function_type};
write_instruction(instructions, WORD_COUNT(operands), SPIRV_OPCODE_FUNCTION, operands);
return result;
Expand Down Expand Up @@ -611,8 +662,8 @@ static uint32_t convert_kong_index_to_spirv_index(uint64_t index) {
static uint32_t output_var;
static uint32_t input_var;

static void write_function(instructions_buffer *instructions, function *f, uint32_t function_id, shader_stage stage, bool main, type_id input, uint32_t input_var, type_id output,
uint32_t output_var) {
static void write_function(instructions_buffer *instructions, function *f, uint32_t function_id, shader_stage stage, bool main, type_id input,
uint32_t input_var, type_id output, uint32_t output_var) {
write_op_function_preallocated(instructions, void_type, FUNCTION_CONTROL_NONE, void_function_type, function_id);
write_label(instructions);

Expand Down Expand Up @@ -644,7 +695,8 @@ static void write_function(instructions_buffer *instructions, function *f, uint3
opcode *o = (opcode *)&data[index];
switch (o->type) {
case OPCODE_VAR: {
uint32_t result = write_op_variable(instructions, convert_type_to_spirv_index(o->op_var.var.type.type), STORAGE_CLASS_FUNCTION);
uint32_t result =
write_op_variable(instructions, convert_pointer_type_to_spirv_index(o->op_var.var.type.type, STORAGE_CLASS_FUNCTION), STORAGE_CLASS_FUNCTION);
hmput(index_map, o->op_var.var.index, result);
break;
}
Expand Down Expand Up @@ -724,8 +776,25 @@ static void write_function(instructions_buffer *instructions, function *f, uint3
for (size_t i = 0; i < indices_size; ++i) {
indices[i] = (int)o->op_store_member.member_indices[i];
}
uint32_t pointer = write_op_access_chain(instructions, convert_type_to_spirv_index(o->op_store_member.to.type.type),
convert_kong_index_to_spirv_index(o->op_store_member.to.index), indices, indices_size);

type_id access_kong_type = find_access_type(indices, indices_size, o->op_store_member.to.type.type);

uint32_t access_type = 0;

switch (o->op_store_member.to.kind) {
case VARIABLE_LOCAL:
access_type = convert_pointer_type_to_spirv_index(access_kong_type, STORAGE_CLASS_FUNCTION);
break;
case VARIABLE_GLOBAL:
access_type = convert_pointer_type_to_spirv_index(access_kong_type, STORAGE_CLASS_OUTPUT);
break;
}

int spirv_indices[256];
vector_member_indices(indices, spirv_indices, indices_size, o->op_store_member.to.type.type);

uint32_t pointer =
write_op_access_chain(instructions, access_type, convert_kong_index_to_spirv_index(o->op_store_member.to.index), spirv_indices, indices_size);
write_op_store(instructions, pointer, convert_kong_index_to_spirv_index(o->op_store_member.from.index));
break;
}
Expand Down Expand Up @@ -807,8 +876,8 @@ static void write_function(instructions_buffer *instructions, function *f, uint3
write_function_end(instructions);
}

static void write_functions(instructions_buffer *instructions, function *main, uint32_t entry_point, shader_stage stage, type_id input, uint32_t input_var, type_id output,
uint32_t output_var) {
static void write_functions(instructions_buffer *instructions, function *main, uint32_t entry_point, shader_stage stage, type_id input, uint32_t input_var,
type_id output, uint32_t output_var) {
write_function(instructions, main, entry_point, stage, true, input, input_var, output, output_var);
}

Expand Down
21 changes: 11 additions & 10 deletions Sources/compiler.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ variable find_local_var(block *b, name_id name) {
variable var;
var.index = b->vars.v[i].variable_id;
var.type = b->vars.v[i].type;
var.kind = VARIABLE_LOCAL;
return var;
}
}
Expand All @@ -57,10 +58,11 @@ static uint64_t next_variable_id = 1;

variable all_variables[1024 * 1024];

variable allocate_variable(type_ref type) {
variable allocate_variable(type_ref type, variable_kind kind) {
variable v;
v.index = next_variable_id;
v.type = type;
v.kind = kind;
all_variables[v.index] = v;
++next_variable_id;
return v;
Expand Down Expand Up @@ -100,7 +102,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_MULTIPLY: {
variable right_var = emit_expression(code, parent, right);
variable left_var = emit_expression(code, parent, left);
variable result_var = allocate_variable(right_var.type);
variable result_var = allocate_variable(right_var.type, VARIABLE_LOCAL);

opcode o;
switch (e->binary.op) {
Expand Down Expand Up @@ -149,8 +151,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
case OPERATOR_MINUS_ASSIGN:
case OPERATOR_PLUS_ASSIGN:
case OPERATOR_DIVIDE_ASSIGN:
case OPERATOR_MULTIPLY_ASSIGN:
{
case OPERATOR_MULTIPLY_ASSIGN: {
variable v = emit_expression(code, parent, right);

switch (left->kind) {
Expand Down Expand Up @@ -301,7 +302,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
o.type = OPCODE_NOT;
o.size = OP_SIZE(o, op_not);
o.op_not.from = v;
o.op_not.to = allocate_variable(v.type);
o.op_not.to = allocate_variable(v.type, VARIABLE_LOCAL);
emit_op(code, &o);
return o.op_not.to;
}
Expand All @@ -323,7 +324,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
type_ref t;
init_type_ref(&t, NO_NAME);
t.type = float_id;
variable v = allocate_variable(t);
variable v = allocate_variable(t, VARIABLE_LOCAL);

opcode o;
o.type = OPCODE_LOAD_CONSTANT;
Expand Down Expand Up @@ -362,7 +363,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
type_ref t;
init_type_ref(&t, NO_NAME);
t.type = float4_id;
variable v = allocate_variable(t);
variable v = allocate_variable(t, VARIABLE_LOCAL);

opcode o;
o.type = OPCODE_CALL;
Expand All @@ -382,7 +383,7 @@ variable emit_expression(opcodes *code, block *parent, expression *e) {
return v;
}
case EXPRESSION_MEMBER: {
variable v = allocate_variable(e->type);
variable v = allocate_variable(e->type, VARIABLE_LOCAL);

opcode o;
o.type = OPCODE_LOAD_MEMBER;
Expand Down Expand Up @@ -533,7 +534,7 @@ void convert_globals(void) {
type_ref t;
init_type_ref(&t, NO_NAME);
t.type = g.type;
variable v = allocate_variable(t);
variable v = allocate_variable(t, VARIABLE_GLOBAL);
allocated_globals[allocated_globals_size].g = g;
allocated_globals[allocated_globals_size].variable_id = v.index;
allocated_globals_size += 1;
Expand All @@ -553,7 +554,7 @@ void convert_function_block(opcodes *code, struct statement *block) {
error(context, "Expected a block");
}
for (size_t i = 0; i < block->block.vars.size; ++i) {
variable var = allocate_variable(block->block.vars.v[i].type);
variable var = allocate_variable(block->block.vars.v[i].type, VARIABLE_LOCAL);
block->block.vars.v[i].variable_id = var.index;
}
for (size_t i = 0; i < block->block.statements.size; ++i) {
Expand Down
3 changes: 3 additions & 0 deletions Sources/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
#include <stddef.h>
#include <stdint.h>

typedef enum variable_kind { VARIABLE_GLOBAL, VARIABLE_LOCAL } variable_kind;

typedef struct variable {
variable_kind kind;
uint64_t index;
type_ref type;
} variable;
Expand Down
3 changes: 3 additions & 0 deletions tests/in/test.kong
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ fun pos(input: VertexIn): FragmentIn {
fun pixel(input: FragmentIn): float4 {
var color: float4;
color.r = 1.0;
color.g = 0.0;
color.b = 0.0;
color.a = 1.0;
return color;
}

Expand Down

0 comments on commit 59e02a7

Please sign in to comment.