From af31bab15855f516ae40a86d576986b81e9e3210 Mon Sep 17 00:00:00 2001 From: Brad Werth Date: Tue, 6 Feb 2024 16:35:17 -0800 Subject: [PATCH] Add a not-yet-working vertex pulling flag to Metal pipelines. This is an early effort to add infrastructure to support vertex pulling transformation of Metal shaders. It is *not* a working transformation that generates valid, useful shaders. It includes: 1) It adds a experimental_vertex_pulling_transform flag to msl::PipelineOptions. This flag defaults to false but can be forcibly set to true by naga tests. 2) When the flag is set, generated msl vertex shaders are passed an additional vertex id parameter, plus an additional parameter for each bound vertex buffer, plus the _mslBufferSizes struct which is normally only used for dynamically sized buffers. 3) A new naga test is added which exercises this flag and demonstrates the effect of the transform. Future work will make the transformed shaders valid, and add tests that transformed shaders produce correct results. --- Cargo.lock | 1 + naga/CHANGELOG.md | 1 + naga/Cargo.toml | 3 +- naga/src/back/msl/mod.rs | 215 ++++++++++++++++++ naga/src/back/msl/writer.rs | 127 +++++++++-- naga/tests/in/interface.param.ron | 2 + .../in/vertex-pulling-transform.param.ron | 15 ++ naga/tests/in/vertex-pulling-transform.wgsl | 29 +++ .../out/msl/vertex-pulling-transform.msl | 58 +++++ naga/tests/snapshots.rs | 1 + wgpu-hal/src/metal/device.rs | 24 ++ 11 files changed, 453 insertions(+), 23 deletions(-) create mode 100644 naga/tests/in/vertex-pulling-transform.param.ron create mode 100644 naga/tests/in/vertex-pulling-transform.wgsl create mode 100644 naga/tests/out/msl/vertex-pulling-transform.msl diff --git a/Cargo.lock b/Cargo.lock index 230e6f41654..97b923bb823 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2138,6 +2138,7 @@ dependencies = [ "hlsl-snapshots", "indexmap", "log", + "metal", "num-traits", "petgraph", "pp-rs", diff --git a/naga/CHANGELOG.md b/naga/CHANGELOG.md index a92d0c4f97f..d2e0515ebd1 100644 --- a/naga/CHANGELOG.md +++ b/naga/CHANGELOG.md @@ -79,6 +79,7 @@ For changelogs after v0.14, see [the wgpu changelog](../CHANGELOG.md). - Add and fix minimum Metal version checks for optional functionality. ([#2486](https://github.com/gfx-rs/naga/pull/2486)) **@teoxoy** - Make varyings' struct members unique. ([#2521](https://github.com/gfx-rs/naga/pull/2521)) **@evahop** +- Add experimental vertex pulling transform flag. ([#5254](https://github.com/gfx-rs/wgpu/pull/5254)) **@bradwerth** #### GLSL-OUT diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 5cc078ad993..f6c69e1066f 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -24,7 +24,7 @@ default = [] dot-out = [] glsl-in = ["pp-rs"] glsl-out = [] -msl-out = [] +msl-out = ["metal"] serialize = ["serde", "bitflags/serde", "indexmap/serde"] deserialize = ["serde", "bitflags/serde", "indexmap/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] @@ -51,6 +51,7 @@ codespan-reporting = { version = "0.11.0" } rustc-hash = "1.1.0" indexmap = { version = "2", features = ["std"] } log = "0.4" +metal = { version = "0.27.0", git = "https://github.com/gfx-rs/metal-rs", rev = "ff8fd3d6dc7792852f8a015458d7e6d42d7fb352", optional = true } num-traits = "0.2" spirv = { version = "0.3", optional = true } thiserror = "1.0.57" diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 2c7cdea6af5..3e8a9b640ab 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -33,6 +33,8 @@ mod keywords; pub mod sampler; mod writer; +use metal::MTLVertexFormat::*; + pub use writer::Writer; pub type Slot = u8; @@ -222,6 +224,208 @@ impl Default for Options { } } +/// Vertex Format for a [`VertexAttribute`] (input). +/// +/// Corresponds to [WebGPU `GPUVertexFormat`]( +/// https://gpuweb.github.io/gpuweb/#enumdef-gpuvertexformat). +#[repr(u32)] +#[derive(Copy, Clone, Debug, Default, Hash, Eq, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum VertexFormat { + #[default] + /// Two unsigned bytes (u8). `vec2` in shaders. + Uint8x2 = 0, + /// Four unsigned bytes (u8). `vec4` in shaders. + Uint8x4 = 1, + /// Two signed bytes (i8). `vec2` in shaders. + Sint8x2 = 2, + /// Four signed bytes (i8). `vec4` in shaders. + Sint8x4 = 3, + /// Two unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec2` in shaders. + Unorm8x2 = 4, + /// Four unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec4` in shaders. + Unorm8x4 = 5, + /// Two signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec2` in shaders. + Snorm8x2 = 6, + /// Four signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec4` in shaders. + Snorm8x4 = 7, + /// Two unsigned shorts (u16). `vec2` in shaders. + Uint16x2 = 8, + /// Four unsigned shorts (u16). `vec4` in shaders. + Uint16x4 = 9, + /// Two signed shorts (i16). `vec2` in shaders. + Sint16x2 = 10, + /// Four signed shorts (i16). `vec4` in shaders. + Sint16x4 = 11, + /// Two unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec2` in shaders. + Unorm16x2 = 12, + /// Four unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec4` in shaders. + Unorm16x4 = 13, + /// Two signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec2` in shaders. + Snorm16x2 = 14, + /// Four signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec4` in shaders. + Snorm16x4 = 15, + /// Two half-precision floats (no Rust equiv). `vec2` in shaders. + Float16x2 = 16, + /// Four half-precision floats (no Rust equiv). `vec4` in shaders. + Float16x4 = 17, + /// One single-precision float (f32). `f32` in shaders. + Float32 = 18, + /// Two single-precision floats (f32). `vec2` in shaders. + Float32x2 = 19, + /// Three single-precision floats (f32). `vec3` in shaders. + Float32x3 = 20, + /// Four single-precision floats (f32). `vec4` in shaders. + Float32x4 = 21, + /// One unsigned int (u32). `u32` in shaders. + Uint32 = 22, + /// Two unsigned ints (u32). `vec2` in shaders. + Uint32x2 = 23, + /// Three unsigned ints (u32). `vec3` in shaders. + Uint32x3 = 24, + /// Four unsigned ints (u32). `vec4` in shaders. + Uint32x4 = 25, + /// One signed int (i32). `i32` in shaders. + Sint32 = 26, + /// Two signed ints (i32). `vec2` in shaders. + Sint32x2 = 27, + /// Three signed ints (i32). `vec3` in shaders. + Sint32x3 = 28, + /// Four signed ints (i32). `vec4` in shaders. + Sint32x4 = 29, + /// One double-precision float (f64). `f32` in shaders. Requires [`Features::VERTEX_ATTRIBUTE_64BIT`]. + Float64 = 30, + /// Two double-precision floats (f64). `vec2` in shaders. Requires [`Features::VERTEX_ATTRIBUTE_64BIT`]. + Float64x2 = 31, + /// Three double-precision floats (f64). `vec3` in shaders. Requires [`Features::VERTEX_ATTRIBUTE_64BIT`]. + Float64x3 = 32, + /// Four double-precision floats (f64). `vec4` in shaders. Requires [`Features::VERTEX_ATTRIBUTE_64BIT`]. + Float64x4 = 33, + /// Three unsigned 10-bit integers and one 2-bit integer, packed into a 32-bit integer (u32). [0, 1024] converted to float [0, 1] `vec4` in shaders. + #[cfg_attr(feature = "serde", serde(rename = "unorm10-10-10-2"))] + Unorm10_10_10_2 = 34, +} + +impl From for VertexFormat { + fn from(value: u32) -> Self { + use VertexFormat::*; + match value { + 0 => Uint8x2, + 1 => Uint8x4, + 2 => Sint8x2, + 3 => Sint8x4, + 4 => Unorm8x2, + 5 => Unorm8x4, + 6 => Snorm8x2, + 7 => Snorm8x4, + 8 => Uint16x2, + 9 => Uint16x4, + 10 => Sint16x2, + 11 => Sint16x4, + 12 => Unorm16x2, + 13 => Unorm16x4, + 14 => Snorm16x2, + 15 => Snorm16x4, + 16 => Float16x2, + 17 => Float16x4, + 18 => Float32, + 19 => Float32x2, + 20 => Float32x3, + 21 => Float32x4, + 22 => Uint32, + 23 => Uint32x2, + 24 => Uint32x3, + 25 => Uint32x4, + 26 => Sint32, + 27 => Sint32x2, + 28 => Sint32x3, + 29 => Sint32x4, + 30 => Float64, + 31 => Float64x2, + 32 => Float64x3, + 33 => Float64x4, + 34 => Unorm10_10_10_2, + _ => panic!("Can't convert value {}", value), + } + } +} + +impl Into for VertexFormat { + fn into(self) -> metal::MTLVertexFormat { + use VertexFormat::*; + match self { + Uint8x2 => UChar2, + Uint8x4 => UChar4, + Sint8x2 => Char2, + Sint8x4 => Char4, + Unorm8x2 => UChar2Normalized, + Unorm8x4 => UChar4Normalized, + Snorm8x2 => Char2Normalized, + Snorm8x4 => Char4Normalized, + Uint16x2 => UShort2, + Uint16x4 => UShort4, + Sint16x2 => Short2, + Sint16x4 => Short4, + Unorm16x2 => UShort2Normalized, + Unorm16x4 => UShort4Normalized, + Snorm16x2 => Short2Normalized, + Snorm16x4 => Short4Normalized, + Float16x2 => Half2, + Float16x4 => Half4, + Float32 => Float, + Float32x2 => Float2, + Float32x3 => Float3, + Float32x4 => Float4, + Uint32 => UInt, + Uint32x2 => UInt2, + Uint32x3 => UInt3, + Uint32x4 => UInt4, + Sint32 => Int, + Sint32x2 => Int2, + Sint32x3 => Int3, + Sint32x4 => Int4, + Float64 => unimplemented!(), + Float64x2 => unimplemented!(), + Float64x3 => unimplemented!(), + Float64x4 => unimplemented!(), + Unorm10_10_10_2 => UInt1010102Normalized, + } + } +} + +/// A mapping of vertex buffers and their attributes to shader +/// locations. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct AttributeMapping { + /// Shader location associated with this attribute + pub shader_location: u32, + /// Offset in bytes from start of vertex buffer structure + pub offset: u32, + /// Format code to help us unpack the attribute into the type + /// used by the shader. Codes correspond to a 0-based index of + /// . + /// The conversion process is described by + /// . + pub format: VertexFormat, +} + +/// A description of a vertex buffer with all the information we +/// need to address the attributes within it. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct VertexBufferMapping { + /// Shader location associated with this buffer + pub id: u32, + /// Size of the structure in bytes + pub stride: u32, + /// Vec of the attributes within the structure + pub attributes: Vec, +} + /// A subset of options that are meant to be changed per pipeline. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -234,6 +438,17 @@ pub struct PipelineOptions { /// /// Enable this for vertex shaders with point primitive topologies. pub allow_and_force_point_size: bool, + + /// Experimental + /// If set, when generating the Metal vertex shader, transform it + /// to receive the vertex buffers, lengths, and vertex id as args, + /// and bounds-check the vertex id and use the index into the + /// vertex buffers to access attributes, rather than using Metal's + /// [[stage-in]] assembled attribute data. + pub experimental_vertex_pulling_transform: bool, + + /// Only used if experimental_vertex_pulling_transform is set. + pub vertex_buffer_mappings: Vec, } impl Options { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 7031d043617..42e150ce41a 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -2958,7 +2958,7 @@ impl Writer { // follow-up with any global resources used let mut separate = !arguments.is_empty(); let fun_info = &context.expression.mod_info[function]; - let mut supports_array_length = false; + let mut needs_buffer_sizes = false; for (handle, var) in context.expression.module.global_variables.iter() { if fun_info[handle].is_empty() { continue; @@ -2972,10 +2972,10 @@ impl Writer { } write!(self.out, "{name}")?; } - supports_array_length |= + needs_buffer_sizes |= needs_array_length(var.ty, &context.expression.module.types); } - if supports_array_length { + if needs_buffer_sizes { if separate { write!(self.out, ", ")?; } @@ -3294,13 +3294,18 @@ impl Writer { } } - if !indices.is_empty() { + if !indices.is_empty() || !pipeline_options.vertex_buffer_mappings.is_empty() { writeln!(self.out, "struct _mslBufferSizes {{")?; for idx in indices { writeln!(self.out, "{}uint size{};", back::INDENT, idx)?; } + for vbm in &pipeline_options.vertex_buffer_mappings { + let idx = vbm.id; + writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?; + } + writeln!(self.out, "}};")?; writeln!(self.out)?; } @@ -3659,13 +3664,13 @@ impl Writer { let fun_info = &mod_info[fun_handle]; pass_through_globals.clear(); - let mut supports_array_length = false; + let mut needs_buffer_sizes = false; for (handle, var) in module.global_variables.iter() { if !fun_info[handle].is_empty() { if var.space.needs_pass_through() { pass_through_globals.push(handle); } - supports_array_length |= needs_array_length(var.ty, &module.types); + needs_buffer_sizes |= needs_array_length(var.ty, &module.types); } } @@ -3702,7 +3707,7 @@ impl Writer { let separator = separate( !pass_through_globals.is_empty() || index + 1 != fun.arguments.len() - || supports_array_length, + || needs_buffer_sizes, ); writeln!( self.out, @@ -3723,13 +3728,13 @@ impl Writer { reference: true, }; let separator = - separate(index + 1 != pass_through_globals.len() || supports_array_length); + separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes); write!(self.out, "{}", back::INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, "{separator}")?; } - if supports_array_length { + if needs_buffer_sizes { writeln!( self.out, "{}constant _mslBufferSizes& _buffer_sizes", @@ -3801,11 +3806,12 @@ impl Writer { ); // Is any global variable used by this entry point dynamically sized? - let supports_array_length = module - .global_variables - .iter() - .filter(|&(handle, _)| !fun_info[handle].is_empty()) - .any(|(_, var)| needs_array_length(var.ty, &module.types)); + let needs_buffer_sizes = pipeline_options.experimental_vertex_pulling_transform + || module + .global_variables + .iter() + .filter(|&(handle, _)| !fun_info[handle].is_empty()) + .any(|(_, var)| needs_array_length(var.ty, &module.types)); // skip this entry point if any global bindings are missing, // or their types are incompatible. @@ -3863,7 +3869,7 @@ impl Writer { | crate::AddressSpace::WorkGroup => {} } } - if supports_array_length { + if needs_buffer_sizes { if let Err(err) = options.resolve_sizes_buffer(ep) { ep_error = Some(err); } @@ -3974,6 +3980,7 @@ impl Writer { // `Output`. let stage_out_name = format!("{fun_name}Output"); let result_member_name = self.namer.call("member"); + let result_return_statement: &str; let result_type_name = match fun.result { Some(ref result) => { let mut result_members = Vec::new(); @@ -4045,11 +4052,47 @@ impl Writer { )?; } writeln!(self.out, "}};")?; + result_return_statement = "return {}"; &stage_out_name } - None => "void", + None => { + result_return_statement = "return"; + "void" + } }; + // If we're doing a vertex pulling transform, generate the names + // we need: a vertex index arg, a name and type name for each buffer, + // based on the buffer id, then generate the structs associated with + // each buffer. + let v_id = self.namer.call("v_id"); + let mut buffer_names = Vec::::new(); + let mut buffer_ty_names = Vec::::new(); + if pipeline_options.experimental_vertex_pulling_transform { + for mapping in pipeline_options.vertex_buffer_mappings.iter() { + let buffer_id = mapping.id; + let buffer_stride = mapping.stride; + let buffer_name = self.namer.call(format!("v_{buffer_id}_in").as_str()); + let buffer_ty = self.namer.call(format!("v_{buffer_id}_in_type").as_str()); + + assert!( + buffer_stride > 0, + "Vertex pulling requires a non-zero buffer stride." + ); + + // Define a structure of bytes of the appropriate size. + // When we access the attributes, we'll be unpacking these + // bytes at some offset. + writeln!( + self.out, + "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};" + )?; + + buffer_names.push(buffer_name); + buffer_ty_names.push(buffer_ty); + } + } + // Write the entry point function's name, and begin its argument list. writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; let mut is_first_argument = true; @@ -4282,16 +4325,37 @@ impl Writer { writeln!(self.out)?; } - // If this entry uses any variable-length arrays, their sizes are - // passed as a final struct-typed argument. - if supports_array_length { - // this is checked earlier - let resolved = options.resolve_sizes_buffer(ep).unwrap(); - let separator = if module.global_variables.is_empty() { + if pipeline_options.experimental_vertex_pulling_transform { + let separator = if is_first_argument { + is_first_argument = false; ' ' } else { ',' }; + + // Write the [[vertex_id]] argument. + writeln!(self.out, "{separator} uint {v_id} [[vertex_id]]")?; + + // Read the pipeline options we specified earlier, output one + // argument for every vertex buffer, using the names and type + // names we generated earlier. + for (i, vbm) in pipeline_options.vertex_buffer_mappings.iter().enumerate() { + let buffer_name = &buffer_names[i]; + let buffer_ty_name = &buffer_ty_names[i]; + let buffer_id = vbm.id; + writeln!( + self.out, + "{separator} constant {buffer_ty_name} *{buffer_name} [[buffer({buffer_id})]]" + )?; + } + } + + // If this entry uses any variable-length arrays, their sizes are + // passed as a final struct-typed argument. + if needs_buffer_sizes { + // this is checked earlier + let resolved = options.resolve_sizes_buffer(ep).unwrap(); + let separator = if is_first_argument { ' ' } else { ',' }; write!( self.out, "{separator} constant _mslBufferSizes& _buffer_sizes", @@ -4303,6 +4367,25 @@ impl Writer { // end of the entry point argument list writeln!(self.out, ") {{")?; + // Starting the function body. + if pipeline_options.experimental_vertex_pulling_transform { + // Output the bounds check against all the limits in the _buffer_sizes struct. + write!(self.out, "{}if (", back::Level(1))?; + + let mut is_first_comparison = true; + for vbm in &pipeline_options.vertex_buffer_mappings { + let idx = vbm.id; + let separator = if is_first_comparison { "" } else { " || " }; + write!( + self.out, + "{separator}{v_id} >= _buffer_sizes.buffer_size{idx}" + )?; + is_first_comparison = false; + } + + writeln!(self.out, ") {{ {result_return_statement}; }}")?; + } + if need_workgroup_variables_initialization { self.write_workgroup_variables_initialization( module, diff --git a/naga/tests/in/interface.param.ron b/naga/tests/in/interface.param.ron index 4d85661767b..ca369812bfe 100644 --- a/naga/tests/in/interface.param.ron +++ b/naga/tests/in/interface.param.ron @@ -27,5 +27,7 @@ ), msl_pipeline: ( allow_and_force_point_size: true, + experimental_vertex_pulling_transform: false, + vertex_buffer_mappings: [], ), ) diff --git a/naga/tests/in/vertex-pulling-transform.param.ron b/naga/tests/in/vertex-pulling-transform.param.ron new file mode 100644 index 00000000000..3ab3e4cf831 --- /dev/null +++ b/naga/tests/in/vertex-pulling-transform.param.ron @@ -0,0 +1,15 @@ +( + msl_pipeline: ( + allow_and_force_point_size: false, + experimental_vertex_pulling_transform: true, + vertex_buffer_mappings: [( + id: 1, + stride: 4, + attributes: [( + shader_location: 1, + offset: 0, + format: Float32, + )], + )], + ), +) diff --git a/naga/tests/in/vertex-pulling-transform.wgsl b/naga/tests/in/vertex-pulling-transform.wgsl new file mode 100644 index 00000000000..4f02bb84df3 --- /dev/null +++ b/naga/tests/in/vertex-pulling-transform.wgsl @@ -0,0 +1,29 @@ +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, + @location(1) texcoord: vec2, +} + +struct VertexInput { + @location(0) position: vec4, + @location(1) normal: vec3, + @location(2) texcoord: vec2, +} + +@group(0) @binding(0) var mvp_matrix: mat4x4; + +@vertex +fn render_vertex(v_in: VertexInput) -> VertexOutput +{ + var v_out: VertexOutput; + v_out.position = v_in.position * mvp_matrix; + v_out.color = do_lighting(v_in.position, + v_in.normal); + v_out.texcoord = v_in.texcoord; + return v_out; +} + +fn do_lighting(position: vec4, normal: vec3) -> vec4 { + // blah blah blah + return vec4(0); +} diff --git a/naga/tests/out/msl/vertex-pulling-transform.msl b/naga/tests/out/msl/vertex-pulling-transform.msl new file mode 100644 index 00000000000..dba8cec5290 --- /dev/null +++ b/naga/tests/out/msl/vertex-pulling-transform.msl @@ -0,0 +1,58 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _mslBufferSizes { + uint buffer_size1; +}; + +struct VertexOutput { + metal::float4 position; + metal::float4 color; + metal::float2 texcoord; +}; +struct VertexInput { + metal::float4 position; + metal::float3 normal; + metal::float2 texcoord; +}; + +metal::float4 do_lighting( + metal::float4 position, + metal::float3 normal +) { + return metal::float4(0.0); +} + +struct render_vertexInput { + metal::float4 position [[attribute(0)]]; + metal::float3 normal [[attribute(1)]]; + metal::float2 texcoord [[attribute(2)]]; +}; +struct render_vertexOutput { + metal::float4 position [[position]]; + metal::float4 color [[user(loc0), center_perspective]]; + metal::float2 texcoord [[user(loc1), center_perspective]]; +}; +struct v_1_in_type { metal::uchar data[4]; }; +vertex render_vertexOutput render_vertex( + render_vertexInput varyings [[stage_in]] +, constant metal::float4x4& mvp_matrix [[user(fake0)]] +, uint v_id [[vertex_id]] +, constant v_1_in_type *v_1_in [[buffer(1)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] +) { + if (v_id >= _buffer_sizes.buffer_size1) { return {}; } + const VertexInput v_in = { varyings.position, varyings.normal, varyings.texcoord }; + VertexOutput v_out = {}; + metal::float4x4 _e5 = mvp_matrix; + v_out.position = v_in.position * _e5; + metal::float4 _e10 = do_lighting(v_in.position, v_in.normal); + v_out.color = _e10; + v_out.texcoord = v_in.texcoord; + VertexOutput _e13 = v_out; + const auto _tmp = _e13; + return render_vertexOutput { _tmp.position, _tmp.color, _tmp.texcoord }; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 3e45faeb166..658638fe809 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -867,6 +867,7 @@ fn convert_wgsl() { "overrides-ray-query", Targets::IR | Targets::SPIRV | Targets::METAL, ), + ("vertex-pulling-transform", Targets::METAL), ]; for &(name, targets) in inputs.iter() { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 0906d215106..870a3500824 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -63,6 +63,7 @@ impl super::Device { fn load_shader( &self, stage: &crate::ProgrammableStage, + vertex_buffer_mappings: Vec, layout: &super::PipelineLayout, primitive_class: metal::MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, @@ -120,6 +121,8 @@ impl super::Device { metal::MTLPrimitiveTopologyClass::Point => true, _ => false, }, + experimental_vertex_pulling_transform: false, + vertex_buffer_mappings, }; let (source, info) = @@ -832,8 +835,27 @@ impl crate::Device for super::Device { // Vertex shader let (vs_lib, vs_info) = { + let mut vertex_buffer_mappings = Vec::::new(); + for (i, vbl) in desc.vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: (attribute.format as u32).into(), + }); + } + + vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: vbl.array_stride as u32, + attributes, + }); + } + let vs = self.load_shader( &desc.vertex_stage, + vertex_buffer_mappings, desc.layout, primitive_class, naga::ShaderStage::Vertex, @@ -861,6 +883,7 @@ impl crate::Device for super::Device { Some(ref stage) => { let fs = self.load_shader( stage, + vec![], desc.layout, primitive_class, naga::ShaderStage::Fragment, @@ -1053,6 +1076,7 @@ impl crate::Device for super::Device { let cs = self.load_shader( &desc.stage, + vec![], desc.layout, metal::MTLPrimitiveTopologyClass::Unspecified, naga::ShaderStage::Compute,