Skip to content

Commit

Permalink
Compile a mesh shader
Browse files Browse the repository at this point in the history
  • Loading branch information
RobDangerous committed Aug 27, 2024
1 parent 5bb471c commit 332513b
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 47 deletions.
10 changes: 6 additions & 4 deletions Sources/backends/d3d12.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ static const wchar_t *shader_string(shader_stage stage) {
return L"cs_6_0";
case SHADER_STAGE_RAY_GENERATION:
return L"lib_6_3";
case SHADER_STAGE_MESH:
return L"ms_6_5";
default: {
debug_context context = {0};
error(context, "Unsupported shader stage/version combination");
Expand All @@ -44,10 +46,10 @@ int compile_hlsl_to_d3d12(const char *source, uint8_t **output, size_t *outputle
L"-T", shader_string(stage), // target
L"-Zi", // enable debug info
// L"-D", L"MYDEFINE=1", // a single define
// L"-Fo", L"myshader.bin", // optional. stored in the pdb.
// L"-Fd", L"myshader.pdb", // the file name of the pdb. This must either be supplied or the auto generated file name must be used
// L"-D", L"__XBOX_STRIP_DXIL", // strip DXIL
// L"-Qstrip_reflect", // strip reflection into a seperate blob
// L"-Fo", L"myshader.bin", // optional. stored in the pdb.
// L"-Fd", L"myshader.pdb", // the file name of the pdb. This must either be supplied or the auto generated file name must be used
// L"-D", L"__XBOX_STRIP_DXIL", // strip DXIL
// L"-Qstrip_reflect", // strip reflection into a seperate blob
};

DxcBuffer source_buffer;
Expand Down
78 changes: 75 additions & 3 deletions Sources/backends/hlsl.c
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,33 @@ static void write_functions(char *hlsl, size_t *offset, shader_stage stage, func
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else if (stage == SHADER_STAGE_MESH) {
attribute *topology_attribute = find_attribute(&f->attributes, add_name("topology"));
if (topology_attribute == NULL || topology_attribute->paramters_count != 1 || topology_attribute->parameters[0] != 0) {
debug_context context = {0};
error(context, "Mesh function requires a threads attribute with one parameter which has to be \"triangle\"");
}

attribute *threads_attribute = find_attribute(&f->attributes, add_name("threads"));
if (threads_attribute == NULL || threads_attribute->paramters_count != 3) {
debug_context context = {0};
error(context, "Mesh function requires a threads attribute with three parameters");
}

*offset += sprintf(&hlsl[*offset], "[outputtopology(\"triangle\")][numthreads(%i, %i, %i)] %s main(", (int)threads_attribute->parameters[0],
(int)threads_attribute->parameters[1], (int)threads_attribute->parameters[2], type_string(f->return_type.type));
for (uint8_t parameter_index = 0; parameter_index < f->parameters_size; ++parameter_index) {
if (parameter_index == 0) {
*offset +=
sprintf(&hlsl[*offset], "%s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
else {
*offset +=
sprintf(&hlsl[*offset], ", %s _%" PRIu64, type_string(f->parameter_types[parameter_index].type), parameter_ids[parameter_index]);
}
}
*offset += sprintf(&hlsl[*offset], ") {\n");
}
else {
debug_context context = {0};
error(context, "Unsupported shader stage");
Expand Down Expand Up @@ -645,6 +672,34 @@ static void hlsl_export_vertex(char *directory, api_kind d3d, function *main) {
write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_mesh(char *directory, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;

write_types(hlsl, &offset, SHADER_STAGE_MESH, NO_TYPE, NO_TYPE, main, NULL, 0);

write_globals(hlsl, &offset, main, NULL, 0);

write_functions(hlsl, &offset, SHADER_STAGE_MESH, main, NULL, 0);

char *output = NULL;
size_t output_size = 0;
int result = compile_hlsl_to_d3d12(hlsl, &output, &output_size, SHADER_STAGE_MESH, false);

debug_context context = {0};
check(result == 0, context, "HLSL compilation failed");

char *name = get_name(main->name);

char filename[512];
sprintf(filename, "kong_%s", name);

char var_name[256];
sprintf(var_name, "%s_code", name);

write_bytecode(hlsl, directory, filename, var_name, output, output_size);
}

static void hlsl_export_fragment(char *directory, api_kind d3d, function *main) {
char *hlsl = (char *)calloc(1024 * 1024, 1);
size_t offset = 0;
Expand Down Expand Up @@ -813,35 +868,46 @@ void hlsl_export(char *directory, api_kind d3d) {
function *vertex_shaders[256];
size_t vertex_shaders_size = 0;

function *mesh_shaders[256];
size_t mesh_shaders_size = 0;

function *fragment_shaders[256];
size_t fragment_shaders_size = 0;

for (type_id i = 0; get_type(i) != NULL; ++i) {
type *t = get_type(i);
if (!t->built_in && has_attribute(&t->attributes, add_name("pipe"))) {
name_id vertex_shader_name = NO_NAME;
name_id mesh_shader_name = NO_NAME;
name_id fragment_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].name == add_name("vertex")) {
vertex_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("mesh")) {
mesh_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("fragment")) {
fragment_shader_name = t->members.m[j].value.identifier;
}
}

debug_context context = {0};
check(vertex_shader_name != NO_NAME, context, "vertex shader missing");
check(vertex_shader_name != NO_NAME || mesh_shader_name != NO_NAME, context, "vertex or mesh shader missing");
check(fragment_shader_name != NO_NAME, context, "fragment shader missing");

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == vertex_shader_name) {
if (vertex_shader_name != NO_NAME && f->name == vertex_shader_name) {
vertex_shaders[vertex_shaders_size] = f;
vertex_shaders_size += 1;
}
else if (f->name == fragment_shader_name) {
if (mesh_shader_name != NO_NAME && f->name == mesh_shader_name) {
mesh_shaders[mesh_shaders_size] = f;
mesh_shaders_size += 1;
}
if (f->name == fragment_shader_name) {
fragment_shaders[fragment_shaders_size] = f;
fragment_shaders_size += 1;
}
Expand Down Expand Up @@ -906,6 +972,12 @@ void hlsl_export(char *directory, api_kind d3d) {
hlsl_export_vertex(directory, d3d, vertex_shaders[i]);
}

if (d3d == API_DIRECT3D12) {
for (size_t i = 0; i < mesh_shaders_size; ++i) {
hlsl_export_mesh(directory, mesh_shaders[i]);
}
}

for (size_t i = 0; i < fragment_shaders_size; ++i) {
hlsl_export_fragment(directory, d3d, fragment_shaders[i]);
}
Expand Down
95 changes: 56 additions & 39 deletions Sources/integrations/kinc.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,33 +187,42 @@ void kinc_export(char *directory, api_kind api) {
type *t = get_type(i);
if (!t->built_in && has_attribute(&t->attributes, add_name("pipe"))) {
name_id vertex_shader_name = NO_NAME;
name_id mesh_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
if (t->members.m[j].name == add_name("vertex")) {
debug_context context = {0};
check(t->members.m[j].value.kind == TOKEN_IDENTIFIER, context, "vertex expects an identifier");
vertex_shader_name = t->members.m[j].value.identifier;
}
if (t->members.m[j].name == add_name("mesh")) {
debug_context context = {0};
check(t->members.m[j].value.kind == TOKEN_IDENTIFIER, context, "mesh expects an identifier");
mesh_shader_name = t->members.m[j].value.identifier;
}
}

debug_context context = {0};
check(vertex_shader_name != NO_NAME, context, "No vertex shader name found");
check(vertex_shader_name != NO_NAME || mesh_shader_name != NO_NAME, context, "No vertex or mesh shader name found");

type_id vertex_input = NO_TYPE;
if (vertex_shader_name != NO_NAME) {

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == vertex_shader_name) {
check(f->parameters_size > 0, context, "Vertex function requires at least one parameter");
vertex_input = f->parameter_types[0].type;
break;
type_id vertex_input = NO_TYPE;

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == vertex_shader_name) {
check(f->parameters_size > 0, context, "Vertex function requires at least one parameter");
vertex_input = f->parameter_types[0].type;
break;
}
}
}

check(vertex_input != NO_TYPE, context, "No vertex input found");
check(vertex_input != NO_TYPE, context, "No vertex input found");

vertex_inputs[vertex_inputs_size] = vertex_input;
vertex_inputs_size += 1;
vertex_inputs[vertex_inputs_size] = vertex_input;
vertex_inputs_size += 1;
}
}
}

Expand Down Expand Up @@ -454,6 +463,7 @@ void kinc_export(char *directory, api_kind api) {
fprintf(output, "\tkinc_g4_pipeline_init(&%s);\n\n", get_name(t->name));

name_id vertex_shader_name = NO_NAME;
name_id mesh_shader_name = NO_NAME;
name_id fragment_shader_name = NO_NAME;

for (size_t j = 0; j < t->members.size; ++j) {
Expand All @@ -470,6 +480,9 @@ void kinc_export(char *directory, api_kind api) {
fprintf(output, "\t%s.vertex_shader = &%s;\n\n", get_name(t->name), get_name(t->members.m[j].value.identifier));
vertex_shader_name = t->members.m[j].value.identifier;
}
if (t->members.m[j].name == add_name("mesh")) {
mesh_shader_name = t->members.m[j].value.identifier;
}
else if (t->members.m[j].name == add_name("fragment")) {
if (api == API_METAL || api == API_WEBGPU) {
fprintf(output, "\tkinc_g4_shader_init(&%s, \"%s\", 0, KINC_G4_SHADER_TYPE_FRAGMENT);\n",
Expand Down Expand Up @@ -538,7 +551,7 @@ void kinc_export(char *directory, api_kind api) {

{
debug_context context = {0};
check(vertex_shader_name != NO_NAME, context, "No vertex shader name found");
check(vertex_shader_name != NO_NAME || mesh_shader_name != NO_NAME, context, "No vertex or mesh shader name found");
check(fragment_shader_name != NO_NAME, context, "No fragment shader name found");
}

Expand All @@ -547,14 +560,16 @@ void kinc_export(char *directory, api_kind api) {

type_id vertex_input = NO_TYPE;

for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == vertex_shader_name) {
vertex_function = f;
debug_context context = {0};
check(f->parameters_size > 0, context, "Vertex function requires at least one parameter");
vertex_input = f->parameter_types[0].type;
break;
if (vertex_shader_name != NO_NAME) {
for (function_id i = 0; get_function(i) != NULL; ++i) {
function *f = get_function(i);
if (f->name == vertex_shader_name) {
vertex_function = f;
debug_context context = {0};
check(f->parameters_size > 0, context, "Vertex function requires at least one parameter");
vertex_input = f->parameter_types[0].type;
break;
}
}
}

Expand All @@ -568,31 +583,33 @@ void kinc_export(char *directory, api_kind api) {

{
debug_context context = {0};
check(vertex_function != NULL, context, "Vertex function not found");
check(vertex_shader_name == NO_NAME || vertex_function != NULL, context, "Vertex function not found");
check(fragment_function != NULL, context, "Fragment function not found");
check(vertex_input != NO_TYPE, context, "No vertex input found");
check(vertex_function == NULL || vertex_input != NO_TYPE, context, "No vertex input found");
}

for (type_id i = 0; get_type(i) != NULL; ++i) {
if (i == vertex_input) {
type *t = get_type(i);
fprintf(output, "\tkinc_g4_vertex_structure_init(&%s_structure);\n", get_name(t->name));
for (size_t j = 0; j < t->members.size; ++j) {
if (api == API_OPENGL) {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s_%s\", %s);\n", get_name(t->name), get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
}
else {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s\", %s);\n", get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
if (vertex_function != NULL) {
for (type_id i = 0; get_type(i) != NULL; ++i) {
if (i == vertex_input) {
type *t = get_type(i);
fprintf(output, "\tkinc_g4_vertex_structure_init(&%s_structure);\n", get_name(t->name));
for (size_t j = 0; j < t->members.size; ++j) {
if (api == API_OPENGL) {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s_%s\", %s);\n", get_name(t->name), get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
}
else {
fprintf(output, "\tkinc_g4_vertex_structure_add(&%s_structure, \"%s\", %s);\n", get_name(t->name),
get_name(t->members.m[j].name), structure_type(t->members.m[j].type.type));
}
}
fprintf(output, "\n");
}
fprintf(output, "\n");
}
}

fprintf(output, "\t%s.input_layout[0] = &%s_structure;\n", get_name(t->name), get_name(get_type(vertex_input)->name));
fprintf(output, "\t%s.input_layout[1] = NULL;\n\n", get_name(t->name));
fprintf(output, "\t%s.input_layout[0] = &%s_structure;\n", get_name(t->name), get_name(get_type(vertex_input)->name));
fprintf(output, "\t%s.input_layout[1] = NULL;\n\n", get_name(t->name));
}

if (fragment_function->return_type.array_size > 0) {
fprintf(output, "\t%s.color_attachment_count = %i;\n", get_name(t->name), fragment_function->return_type.array_size);
Expand Down
3 changes: 3 additions & 0 deletions Sources/parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ static definition parse_function(state_t *state);
static definition parse_const(state_t *state, attribute_list attributes);

static double attribute_parameter_to_number(name_id attribute_name, name_id parameter_name) {
if (attribute_name == add_name("topology") && parameter_name == add_name("triangle")) {
return 0;
}
debug_context context = {0};
error(context, "Unknown attribute parameter %s", get_name(parameter_name));
return 0;
Expand Down
1 change: 1 addition & 0 deletions Sources/shader_stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

typedef enum shader_stage {
SHADER_STAGE_VERTEX,
SHADER_STAGE_MESH,
SHADER_STAGE_FRAGMENT,
SHADER_STAGE_COMPUTE,
SHADER_STAGE_RAY_GENERATION,
Expand Down
37 changes: 36 additions & 1 deletion tests/in/test.kong
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fun comp(): void {

// based on https://landelare.github.io/2023/02/18/dxr-tutorial.html

struct Payload {
/*struct Payload {
color: float3;
allow_reflection: bool;
missed: bool;
Expand Down Expand Up @@ -99,4 +99,39 @@ struct RayPipe {
gen = sendrays;
miss = raymissed;
closest = closesthit;
}*/

struct VertexIn {
position: float3;
}

struct FragmentIn {
position: float4;
}

//fun amplify(): void {}

#[topology(triangle), threads(32, 1, 1)]
fun meshy(): void {

}

fun pixel(input: FragmentIn): float4 {
var color: float4;
color.r = 0.0;
color.g = 1.0;
color.b = 0.0;
color.a = 1.0;

var a: int = 3;
a += 2;

return color;
}

#[pipe]
struct Pipe {
//prim = amplify;
mesh = meshy;
fragment = pixel;
}

0 comments on commit 332513b

Please sign in to comment.