Skip to content

Commit

Permalink
Write compute descriptor sets
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Sep 12, 2024
1 parent 0f2e011 commit bcbe5f5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 15 deletions.
41 changes: 38 additions & 3 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ static size_t all_descriptor_sets_count = 0;
static void write_root_signature(char *hlsl, size_t *offset) {
uint32_t cbv_index = 0;
uint32_t srv_index = 0;
uint32_t uav_index = 0;
uint32_t sampler_index = 0;

*offset += sprintf(&hlsl[*offset], "[RootSignature(\"RootFlags(ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT)");
Expand Down Expand Up @@ -355,8 +356,17 @@ static void write_root_signature(char *hlsl, size_t *offset) {
else {
*offset += sprintf(&hlsl[*offset], ", ");
}
*offset += sprintf(&hlsl[*offset], "SRV(t%i)", srv_index);
srv_index += 1;

attribute *write_attribute = find_attribute(&get_global(def->global)->attributes, add_name("write"));

if (write_attribute != NULL) {
*offset += sprintf(&hlsl[*offset], "UAV(u%i)", uav_index);
uav_index += 1;
}
else {
*offset += sprintf(&hlsl[*offset], "SRV(t%i)", srv_index);
srv_index += 1;
}
break;
}
}
Expand Down Expand Up @@ -493,7 +503,8 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
error(context, "Compute function requires a threads attribute with three parameters");
}

*offset += sprintf(&hlsl[*offset], "[numthreads(%i, %i, %i)] %s main(", (int)threads_attribute->parameters[0],
write_root_signature(hlsl, offset);
*offset += sprintf(&hlsl[*offset], "[numthreads(%i, %i, %i)]\n%s main(", (int)threads_attribute->parameters[0],
(int)threads_attribute->parameters[1], (int)threads_attribute->parameters[2], type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
Expand Down Expand Up @@ -1263,6 +1274,30 @@ void hlsl_export(char *directory, api_kind d3d) {
for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (has_attribute(&f->attributes, add_name("compute"))) {
global_id all_globals[256];
size_t all_globals_size = 0;

find_referenced_globals(f, all_globals, &all_globals_size);

for (size_t global_index = 0; global_index < all_globals_size; ++global_index) {
global *g = get_global(all_globals[global_index]);
if (g->set != NULL) {
bool found = false;

for (size_t set_index = 0; set_index < all_descriptor_sets_count; ++set_index) {
if (all_descriptor_sets[set_index] == g->set) {
found = true;
break;
}
}

if (!found) {
all_descriptor_sets[all_descriptor_sets_count] = g->set;
all_descriptor_sets_count += 1;
}
}
}

compute_shaders[compute_shaders_size] = f;
compute_shaders_size += 1;
}
Expand Down
51 changes: 39 additions & 12 deletions Sources/integrations/kope.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ static char *type_string(type_id type) {
if (type == float4_id) {
return "kinc_vector4_t";
}
if (type == float3x3_id) {
return "kinc_matrix3x3_t";
}
if (type == float4x4_id) {
return "kinc_matrix4x4_t";
}
Expand Down Expand Up @@ -427,7 +430,7 @@ void kope_export(char *directory, api_kind api) {
for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (has_attribute(&f->attributes, add_name("compute"))) {
fprintf(output, "extern kinc_g4_compute_shader %s;\n\n", get_name(f->name));
fprintf(output, "extern kope_d3d12_pipeline %s;\n\n", get_name(f->name));
}
}

Expand Down Expand Up @@ -561,7 +564,7 @@ void kope_export(char *directory, api_kind api) {
if (api != API_OPENGL) {
bool has_matrices = false;
for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].type.type == float4x4_id) {
if (t->members.m[j].type.type == float4x4_id || t->members.m[j].type.type == float3x3_id) {
has_matrices = true;
break;
}
Expand All @@ -574,6 +577,9 @@ void kope_export(char *directory, api_kind api) {
if (t->members.m[j].type.type == float4x4_id) {
fprintf(output, "\tkinc_matrix4x4_transpose(&data->%s);\n", get_name(t->members.m[j].name));
}
else if (t->members.m[j].type.type == float3x3_id) {
fprintf(output, "\tkinc_matrix3x3_transpose(&data->%s);\n", get_name(t->members.m[j].name));
}
}
fprintf(output, "\tkope_g5_buffer_unlock(buffer);\n");
}
Expand All @@ -583,6 +589,7 @@ void kope_export(char *directory, api_kind api) {
}
}

uint32_t descriptor_table_index = 0;
for (size_t set_index = 0; set_index < sets_count; ++set_index) {
descriptor_set *set = sets[set_index];

Expand Down Expand Up @@ -623,12 +630,20 @@ void kope_export(char *directory, api_kind api) {
fprintf(output, "\tset->%s = parameters->%s;\n", get_name(get_global(d.global)->name), get_name(get_global(d.global)->name));
other_index += 1;
break;
case DEFINITION_TEX2D:
fprintf(output, "\tkope_d3d12_descriptor_set_set_texture_view_srv(device, &set->set, parameters->%s, %" PRIu64 ");\n",
get_name(get_global(d.global)->name), other_index);
case DEFINITION_TEX2D: {
attribute *write_attribute = find_attribute(&get_global(d.global)->attributes, add_name("write"));
if (write_attribute != NULL) {
fprintf(output, "\tkope_d3d12_descriptor_set_set_texture_view_uav(device, &set->set, parameters->%s, %" PRIu64 ");\n",
get_name(get_global(d.global)->name), other_index);
}
else {
fprintf(output, "\tkope_d3d12_descriptor_set_set_texture_view_srv(device, &set->set, parameters->%s, %" PRIu64 ");\n",
get_name(get_global(d.global)->name), other_index);
}
fprintf(output, "\tset->%s = parameters->%s;\n", get_name(get_global(d.global)->name), get_name(get_global(d.global)->name));
other_index += 1;
break;
}
case DEFINITION_SAMPLER:
fprintf(output, "\tkope_d3d12_descriptor_set_set_sampler(device, &set->set, parameters->%s, %" PRIu64 ");\n",
get_name(get_global(d.global)->name), sampler_index);
Expand All @@ -646,13 +661,22 @@ void kope_export(char *directory, api_kind api) {
case DEFINITION_CONST_CUSTOM:
fprintf(output, "\tkope_d3d12_descriptor_set_prepare_cbv_buffer(list, set->%s);\n", get_name(get_global(d.global)->name));
break;
case DEFINITION_TEX2D:
fprintf(output, "\tkope_d3d12_descriptor_set_prepare_srv_texture(list, set->%s);\n", get_name(get_global(d.global)->name));
case DEFINITION_TEX2D: {
attribute *write_attribute = find_attribute(&get_global(d.global)->attributes, add_name("write"));
if (write_attribute != NULL) {
fprintf(output, "\tkope_d3d12_descriptor_set_prepare_uav_texture(list, set->%s);\n", get_name(get_global(d.global)->name));
}
else {
fprintf(output, "\tkope_d3d12_descriptor_set_prepare_srv_texture(list, set->%s);\n", get_name(get_global(d.global)->name));
}
break;
}
}
}
fprintf(output, "\n\tkope_d3d12_command_list_set_descriptor_table(list, 0, &set->set);\n");
fprintf(output, "\n\tkope_d3d12_command_list_set_descriptor_table(list, %i, &set->set);\n", descriptor_table_index);
fprintf(output, "}\n\n");

descriptor_table_index += (sampler_count > 0) ? 2 : 1;
}

for (type_id i = 0; get_type(i) != NULL; ++i) {
Expand All @@ -671,7 +695,7 @@ void kope_export(char *directory, api_kind api) {
for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (has_attribute(&f->attributes, add_name("compute"))) {
fprintf(output, "kinc_g4_compute_shader %s;\n", get_name(f->name));
fprintf(output, "kope_d3d12_pipeline %s;\n", get_name(f->name));
}
}

Expand All @@ -692,7 +716,7 @@ void kope_export(char *directory, api_kind api) {
for (type_id i = 0; get_type(i) != NULL; ++i) {
type *t = get_type(i);
if (!t->built_in && has_attribute(&t->attributes, add_name("pipe"))) {
fprintf(output, "\tkope_d3d12_pipeline_parameters %s_parameters = {0};\n\n", get_name(t->name));
fprintf(output, "\tkope_d3d12_render_pipeline_parameters %s_parameters = {0};\n\n", get_name(t->name));

name_id vertex_shader_name = NO_NAME;
name_id amplification_shader_name = NO_NAME;
Expand Down Expand Up @@ -913,7 +937,7 @@ void kope_export(char *directory, api_kind api) {
fprintf(output, "\t%s_parameters.fragment.targets[0].write_mask = 0xf;\n\n", get_name(t->name));
}

fprintf(output, "\tkope_d3d12_pipeline_init(&device->d3d12, &%s, &%s_parameters);\n\n", get_name(t->name), get_name(t->name));
fprintf(output, "\tkope_d3d12_render_pipeline_init(&device->d3d12, &%s, &%s_parameters);\n\n", get_name(t->name), get_name(t->name));

if (api == API_OPENGL) {
global_id globals[256];
Expand Down Expand Up @@ -941,7 +965,10 @@ void kope_export(char *directory, api_kind api) {
for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (has_attribute(&f->attributes, add_name("compute"))) {
fprintf(output, "\tkinc_g4_compute_shader_init(&%s, %s_code, %s_code_size);\n", get_name(f->name), get_name(f->name), get_name(f->name));
fprintf(output, "\tkope_d3d12_compute_pipeline_parameters %s_parameters;\n", get_name(f->name));
fprintf(output, "\t%s_parameters.shader.data = %s_code;\n", get_name(f->name), get_name(f->name));
fprintf(output, "\t%s_parameters.shader.size = %s_code_size;\n", get_name(f->name), get_name(f->name));
fprintf(output, "\tkope_d3d12_compute_pipeline_init(&device->d3d12, &%s, &%s_parameters);\n", get_name(f->name), get_name(f->name));
}
}

Expand Down

0 comments on commit bcbe5f5

Please sign in to comment.