Skip to content

Commit

Permalink
Fixing bug when writing primitives in metal mesh shaders (#5069)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dynamitos authored Sep 25, 2024
1 parent f5bf5ba commit 84fef05
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 90 deletions.
2 changes: 2 additions & 0 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,8 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO
emitOperand(setIndices->getIndex(), getInfo(EmitOp::General));
m_writer->emit("*");
m_writer->emitUInt64(numIndices);
m_writer->emit("+");
m_writer->emitUInt64(i);
m_writer->emit(",(");
emitOperand(setIndices->getElementValue(), getInfo(EmitOp::General));
m_writer->emit(")[");
Expand Down
71 changes: 71 additions & 0 deletions tests/metal/simple-mesh.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
//TEST:SIMPLE(filecheck=METAL): -entry meshMain -stage mesh -target metal

//
// Mesh shader
//

const static float2 positions[3] = {
float2(0.0, -0.5),
float2(0.5, 0.5),
float2(-0.5, 0.5)
};

const static float3 colors[3] = {
float3(1.0, 1.0, 0.0),
float3(0.0, 1.0, 1.0),
float3(1.0, 0.0, 1.0)
};

struct MeshPayload
{
int exponent;
};


struct Vertex
{
float4 pos : SV_Position;
float3 color : Color;
int index : Index;
int value : Value;
};

struct Primitive
{
uint prim : SV_PrimitiveID;
};

const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;

[outputtopology("triangle")]
[numthreads(12, 1, 1)]
void meshMain(
in uint tig: SV_GroupIndex,
in payload MeshPayload meshPayload,
// METAL: const MeshPayload_0 object_data* meshPayload_0
OutputVertices<Vertex, MAX_VERTS> verts,
OutputIndices<uint3, MAX_PRIMS> triangles,
OutputPrimitives<Primitive, MAX_PRIMS> primitives
)
{
const uint numVertices = 12;
const uint numPrimitives = 4;
SetMeshOutputCounts(numVertices, numPrimitives);

if (tig < numVertices)
{
const int tri = tig / 3;
verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
}

if (tig < numPrimitives)
{
// METAL: _slang_mesh.set_index({{.*}}+0,{{.*}}[0]);
// METAL: _slang_mesh.set_index({{.*}}+1,{{.*}}[1]);
// METAL: _slang_mesh.set_index({{.*}}+2,{{.*}}[2]);
triangles[tig] = tig * 3 + uint3(0, 1, 2);
// METAL: _slang_mesh.set_primitive({{.*}}
primitives[tig] = { tig };
}
}
95 changes: 5 additions & 90 deletions tests/metal/simple-task.slang
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
//TEST:SIMPLE(filecheck=CHECK): -entry taskMain -stage amplification -target metal

//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer

uniform RWStructuredBuffer<float> outputBuffer;
//TEST:SIMPLE(filecheck=METAL): -entry taskMain -stage amplification -target metal

cbuffer Uniforms
{
Expand All @@ -18,95 +14,14 @@ struct MeshPayload
int exponent;
};

// CHECK: MeshPayload_0 object_data* _slang_mesh_payload
// CHECK: mesh_grid_properties _slang_mgp
// METAL: MeshPayload_0 object_data* _slang_mesh_payload
// METAL: mesh_grid_properties _slang_mgp
[numthreads(1,1,1)]
void taskMain()
{
// CHECK: _slang_mesh_payload
// CHECK: _slang_mgp.set_threadgroups_per_grid
// METAL: _slang_mesh_payload
// METAL: _slang_mgp.set_threadgroups_per_grid
MeshPayload p;
p.exponent = 3;
DispatchMesh(1, 1, 1, p);
}

//
// Mesh shader
//

const static float2 positions[3] = {
float2(0.0, -0.5),
float2(0.5, 0.5),
float2(-0.5, 0.5)
};

const static float3 colors[3] = {
float3(1.0, 1.0, 0.0),
float3(0.0, 1.0, 1.0),
float3(1.0, 0.0, 1.0)
};

struct Vertex
{
float4 pos : SV_Position;
float3 color : Color;
int index : Index;
int value : Value;
};

struct Primitive
{
uint prim : SV_PrimitiveID;
};

const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;

[outputtopology("triangle")]
[numthreads(12, 1, 1)]
void meshMain(
in uint tig: SV_GroupIndex,
in payload MeshPayload meshPayload,
// Check that we correctly generate the specific 'in payload' that HLSL
// requires:
// HLSL: , in payload MeshPayload
OutputVertices<Vertex, MAX_VERTS> verts,
OutputIndices<uint3, MAX_PRIMS> triangles,
OutputPrimitives<Primitive, MAX_PRIMS> primitives
)
{
const uint numVertices = 12;
const uint numPrimitives = 4;
SetMeshOutputCounts(numVertices, numPrimitives);

if (tig < numVertices)
{
const int tri = tig / 3;
verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
}

if (tig < numPrimitives)
{
triangles[tig] = tig * 3 + uint3(0, 1, 2);
primitives[tig] = { tig };
}
}

//
// Fragment Shader
//

struct Fragment
{
float4 color : SV_Target;
};

Fragment fragmentMain(Vertex input)
{
outputBuffer[input.index] = input.value;

Fragment output;
output.color = float4(input.color, 1.0);
return output;
}

0 comments on commit 84fef05

Please sign in to comment.