Skip to content

Commit

Permalink
[msl-out] Properly rename entry point arguments for struct members. (g…
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy authored Mar 8, 2022
1 parent 7984537 commit c84aa77
Show file tree
Hide file tree
Showing 10 changed files with 383 additions and 176 deletions.
92 changes: 73 additions & 19 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3288,9 +3288,6 @@ impl<W: Write> Writer<W> {

writeln!(self.out)?;

let stage_out_name = format!("{}Output", fun_name);
let stage_in_name = format!("{}Input", fun_name);

let (em_str, in_mode, out_mode) = match ep.stage {
crate::ShaderStage::Vertex => (
"vertex",
Expand All @@ -3307,35 +3304,44 @@ impl<W: Write> Writer<W> {
}
};

let mut argument_members = Vec::new();
// List all the Naga `EntryPoint`'s `Function`'s arguments,
// flattening structs into their members. In Metal, we will pass
// each of these values to the entry point as a separate argument—
// except for the varyings, handled next.
let mut flattened_arguments = Vec::new();
for (arg_index, arg) in fun.arguments.iter().enumerate() {
match module.types[arg.ty].inner {
crate::TypeInner::Struct { ref members, .. } => {
for (member_index, member) in members.iter().enumerate() {
argument_members.push((
NameKey::StructMember(arg.ty, member_index as u32),
let member_index = member_index as u32;
flattened_arguments.push((
NameKey::StructMember(arg.ty, member_index),
member.ty,
member.binding.as_ref(),
))
));
}
}
_ => argument_members.push((
_ => flattened_arguments.push((
NameKey::EntryPointArgument(ep_index as _, arg_index as u32),
arg.ty,
arg.binding.as_ref(),
)),
}
}

// Identify the varyings among the argument values, and emit a
// struct type named `<fun>Input` to hold them.
let stage_in_name = format!("{}Input", fun_name);
let varyings_member_name = self.namer.call("varyings");
let mut varying_count = 0;
if !argument_members.is_empty() {
let mut has_varyings = false;
if !flattened_arguments.is_empty() {
writeln!(self.out, "struct {} {{", stage_in_name)?;
for &(ref name_key, ty, binding) in argument_members.iter() {
for &(ref name_key, ty, binding) in flattened_arguments.iter() {
let binding = match binding {
Some(ref binding @ &crate::Binding::Location { .. }) => binding,
_ => continue,
};
varying_count += 1;
has_varyings = true;
let name = &self.names[name_key];
let ty_name = TypeContext {
handle: ty,
Expand All @@ -3352,6 +3358,9 @@ impl<W: Write> Writer<W> {
writeln!(self.out, "}};")?;
}

// Define a struct type named for the return value, if any, named
// `<fun>Output`.
let stage_out_name = format!("{}Output", fun_name);
let result_member_name = self.namer.call("member");
let result_type_name = match fun.result {
Some(ref result) => {
Expand Down Expand Up @@ -3444,23 +3453,46 @@ impl<W: Write> Writer<W> {
}
None => "void",
};
writeln!(self.out, "{} {} {}(", em_str, result_type_name, fun_name)?;

// Write the entry point function's name, and begin its argument list.
writeln!(self.out, "{} {} {}(", em_str, result_type_name, fun_name)?;
let mut is_first_argument = true;
if varying_count != 0 {

// If we have produced a struct holding the `EntryPoint`'s
// `Function`'s arguments' varyings, pass that struct first.
if has_varyings {
writeln!(
self.out,
" {} {} [[stage_in]]",
stage_in_name, varyings_member_name
)?;
is_first_argument = false;
}
for &(ref name_key, ty, binding) in argument_members.iter() {

// Then pass the remaining arguments not included in the varyings
// struct.
//
// Since `Namer.reset` wasn't expecting struct members to be
// suddenly injected into the normal namespace like this,
// `self.names` doesn't keep them distinct from other variables.
// Generate fresh names for these arguments, and remember the
// mapping.
let mut flattened_member_names = FastHashMap::default();
for &(ref name_key, ty, binding) in flattened_arguments.iter() {
let binding = match binding {
Some(ref binding @ &crate::Binding::BuiltIn(..)) => binding,
_ => continue,
};
let name = &self.names[name_key];
let name = if let NameKey::StructMember(ty, index) = *name_key {
// We should always insert a fresh entry here, but use
// `or_insert` to get a reference to the `String` we just
// inserted.
flattened_member_names
.entry(NameKey::StructMember(ty, index))
.or_insert_with(|| self.namer.call(&self.names[name_key]))
} else {
&self.names[name_key]
};
let ty_name = TypeContext {
handle: ty,
arena: &module.types,
Expand All @@ -3479,6 +3511,11 @@ impl<W: Write> Writer<W> {
resolved.try_fmt_decorated(&mut self.out)?;
writeln!(self.out)?;
}

// Those global variables used by this entry point and its callees
// get passed as arguments. `Private` globals are an exception, they
// don't outlive this invocation, so we declare them below as locals
// within the entry point.
for (handle, var) in module.global_variables.iter() {
let usage = fun_info[handle];
if usage.is_empty() || var.space == crate::AddressSpace::Private {
Expand Down Expand Up @@ -3534,6 +3571,8 @@ impl<W: Write> Writer<W> {
writeln!(self.out)?;
}

// If this entry uses any variable-length arrays, their sizes are
// passed as a final struct-typed argument.
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
Expand Down Expand Up @@ -3603,7 +3642,16 @@ impl<W: Write> Writer<W> {
}
}

// Now refactor the inputs in a way that the rest of the code expects
// Now take the arguments that we gathered into structs, and the
// structs that we flattened into arguments, and emit local
// variables with initializers that put everything back the way the
// body code expects.
//
// If we had to generate fresh names for struct members passed as
// arguments, be sure to use those names when rebuilding the struct.
//
// "Each day, I change some zeros to ones, and some ones to zeros.
// The rest, I leave alone."
for (arg_index, arg) in fun.arguments.iter().enumerate() {
let arg_name =
&self.names[&NameKey::EntryPointArgument(ep_index as _, arg_index as u32)];
Expand All @@ -3618,8 +3666,14 @@ impl<W: Write> Writer<W> {
arg_name
)?;
for (member_index, member) in members.iter().enumerate() {
let name =
&self.names[&NameKey::StructMember(arg.ty, member_index as u32)];
let key = NameKey::StructMember(arg.ty, member_index as u32);
// If it's not in the varying struct, then we should
// have passed it as its own argument and assigned
// it a new name.
let name = match member.binding {
Some(crate::Binding::BuiltIn(_)) => &flattened_member_names[&key],
_ => &self.names[&key],
};
if member_index != 0 {
write!(self.out, ", ")?;
}
Expand Down
14 changes: 14 additions & 0 deletions tests/in/interface.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ fn compute(
) {
output[0] = global_id.x + local_id.x + local_index + wg_id.x + num_wgs.x;
}

struct Input1 {
@builtin(vertex_index) index: u32;
};

struct Input2 {
@builtin(instance_index) index: u32;
};

@stage(vertex)
fn vertex_two_structs(in1: Input1, in2: Input2) -> @builtin(position) vec4<f32> {
var index = 2u;
return vec4<f32>(f32(in1.index), f32(in2.index), f32(index), 0.0);
}
16 changes: 16 additions & 0 deletions tests/out/hlsl/interface.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ struct FragmentOutput {
linear float color : SV_Target0;
};

struct Input1_ {
uint index : SV_VertexID;
};

struct Input2_ {
uint index : SV_InstanceID;
};

groupshared uint output[1];

struct VertexOutput_vertex {
Expand Down Expand Up @@ -72,3 +80,11 @@ void compute(uint3 global_id : SV_DispatchThreadID, uint3 local_id : SV_GroupThr
output[0] = ((((global_id.x + local_id.x) + local_index) + wg_id.x) + uint3(_NagaConstants.base_vertex, _NagaConstants.base_instance, _NagaConstants.other).x);
return;
}

float4 vertex_two_structs(Input1_ in1_, Input2_ in2_) : SV_Position
{
uint index = 2u;

uint _expr9 = index;
return float4(float((_NagaConstants.base_vertex + in1_.index)), float((_NagaConstants.base_instance + in2_.index)), float(_expr9), 0.0);
}
2 changes: 1 addition & 1 deletion tests/out/hlsl/interface.hlsl.config
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
vertex=(vertex:vs_5_1 )
vertex=(vertex:vs_5_1 vertex_two_structs:vs_5_1 )
fragment=(fragment:ps_5_1 )
compute=(compute:cs_5_1 )
24 changes: 24 additions & 0 deletions tests/out/msl/interface.msl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ struct FragmentOutput {
struct type_4 {
uint inner[1];
};
struct Input1_ {
uint index;
};
struct Input2_ {
uint index;
};

struct vertex_Input {
uint color [[attribute(10)]];
Expand Down Expand Up @@ -73,3 +79,21 @@ kernel void compute_(
output.inner[0] = (((global_id.x + local_id.x) + local_index) + wg_id.x) + num_wgs.x;
return;
}


struct vertex_two_structsInput {
};
struct vertex_two_structsOutput {
metal::float4 member_3 [[position]];
float _point_size [[point_size]];
};
vertex vertex_two_structsOutput vertex_two_structs(
uint index_1 [[vertex_id]]
, uint index_2 [[instance_id]]
) {
const Input1_ in1_ = { index_1 };
const Input2_ in2_ = { index_2 };
uint index = 2u;
uint _e9 = index;
return vertex_two_structsOutput { metal::float4(static_cast<float>(in1_.index), static_cast<float>(in2_.index), static_cast<float>(_e9), 0.0), 1.0 };
}
103 changes: 54 additions & 49 deletions tests/out/spv/interface.compute.spvasm
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 46
; Bound: 49
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %32 "compute" %20 %23 %25 %28 %30
OpExecutionMode %32 LocalSize 1 1 1
OpMemberDecorate %12 0 Offset 0
OpMemberDecorate %12 1 Offset 16
OpEntryPoint GLCompute %35 "compute" %23 %26 %28 %31 %33
OpExecutionMode %35 LocalSize 1 1 1
OpMemberDecorate %13 0 Offset 0
OpMemberDecorate %13 1 Offset 4
OpMemberDecorate %13 2 Offset 8
OpDecorate %15 ArrayStride 4
OpDecorate %20 BuiltIn GlobalInvocationId
OpDecorate %23 BuiltIn LocalInvocationId
OpDecorate %25 BuiltIn LocalInvocationIndex
OpDecorate %28 BuiltIn WorkgroupId
OpDecorate %30 BuiltIn NumWorkgroups
OpMemberDecorate %13 1 Offset 16
OpMemberDecorate %14 0 Offset 0
OpMemberDecorate %14 1 Offset 4
OpMemberDecorate %14 2 Offset 8
OpDecorate %16 ArrayStride 4
OpMemberDecorate %18 0 Offset 0
OpMemberDecorate %19 0 Offset 0
OpDecorate %23 BuiltIn GlobalInvocationId
OpDecorate %26 BuiltIn LocalInvocationId
OpDecorate %28 BuiltIn LocalInvocationIndex
OpDecorate %31 BuiltIn WorkgroupId
OpDecorate %33 BuiltIn NumWorkgroups
%2 = OpTypeVoid
%4 = OpTypeFloat 32
%3 = OpConstant %4 1.0
Expand All @@ -27,42 +29,45 @@ OpDecorate %30 BuiltIn NumWorkgroups
%9 = OpTypeInt 32 1
%8 = OpConstant %9 1
%10 = OpConstant %9 0
%11 = OpTypeVector %4 4
%12 = OpTypeStruct %11 %4
%13 = OpTypeStruct %4 %6 %4
%14 = OpTypeBool
%15 = OpTypeArray %6 %8
%16 = OpTypeVector %6 3
%18 = OpTypePointer Workgroup %15
%17 = OpVariable %18 Workgroup
%21 = OpTypePointer Input %16
%20 = OpVariable %21 Input
%23 = OpVariable %21 Input
%26 = OpTypePointer Input %6
%25 = OpVariable %26 Input
%28 = OpVariable %21 Input
%30 = OpVariable %21 Input
%33 = OpTypeFunction %2
%35 = OpTypePointer Workgroup %6
%44 = OpConstant %6 0
%32 = OpFunction %2 None %33
%19 = OpLabel
%22 = OpLoad %16 %20
%24 = OpLoad %16 %23
%27 = OpLoad %6 %25
%29 = OpLoad %16 %28
%31 = OpLoad %16 %30
OpBranch %34
%34 = OpLabel
%36 = OpCompositeExtract %6 %22 0
%37 = OpCompositeExtract %6 %24 0
%38 = OpIAdd %6 %36 %37
%39 = OpIAdd %6 %38 %27
%40 = OpCompositeExtract %6 %29 0
%11 = OpConstant %6 2
%12 = OpTypeVector %4 4
%13 = OpTypeStruct %12 %4
%14 = OpTypeStruct %4 %6 %4
%15 = OpTypeBool
%16 = OpTypeArray %6 %8
%17 = OpTypeVector %6 3
%18 = OpTypeStruct %6
%19 = OpTypeStruct %6
%21 = OpTypePointer Workgroup %16
%20 = OpVariable %21 Workgroup
%24 = OpTypePointer Input %17
%23 = OpVariable %24 Input
%26 = OpVariable %24 Input
%29 = OpTypePointer Input %6
%28 = OpVariable %29 Input
%31 = OpVariable %24 Input
%33 = OpVariable %24 Input
%36 = OpTypeFunction %2
%38 = OpTypePointer Workgroup %6
%47 = OpConstant %6 0
%35 = OpFunction %2 None %36
%22 = OpLabel
%25 = OpLoad %17 %23
%27 = OpLoad %17 %26
%30 = OpLoad %6 %28
%32 = OpLoad %17 %31
%34 = OpLoad %17 %33
OpBranch %37
%37 = OpLabel
%39 = OpCompositeExtract %6 %25 0
%40 = OpCompositeExtract %6 %27 0
%41 = OpIAdd %6 %39 %40
%42 = OpCompositeExtract %6 %31 0
%43 = OpIAdd %6 %41 %42
%45 = OpAccessChain %35 %17 %44
OpStore %45 %43
%42 = OpIAdd %6 %41 %30
%43 = OpCompositeExtract %6 %32 0
%44 = OpIAdd %6 %42 %43
%45 = OpCompositeExtract %6 %34 0
%46 = OpIAdd %6 %44 %45
%48 = OpAccessChain %38 %20 %47
OpStore %48 %46
OpReturn
OpFunctionEnd
Loading

0 comments on commit c84aa77

Please sign in to comment.