From 241b1c4a618718d005793fff933344a0c4a72489 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sun, 26 Feb 2023 00:27:03 -0800 Subject: [PATCH] msl: ray query support --- src/back/mod.rs | 23 +++ src/back/msl/mod.rs | 12 +- src/back/msl/writer.rs | 199 +++++++++++++++++++-- tests/in/ray-query.param.ron | 8 + tests/in/ray-query.wgsl | 13 +- tests/out/msl/binding-arrays.msl | 2 +- tests/out/msl/bounds-check-image-rzsw.msl | 2 +- tests/out/msl/bounds-check-zero-atomic.msl | 2 +- tests/out/msl/bounds-check-zero.msl | 2 +- tests/out/msl/policy-mix.msl | 2 +- tests/out/msl/ray-query.msl | 58 ++++++ tests/snapshots.rs | 2 +- 12 files changed, 295 insertions(+), 30 deletions(-) create mode 100644 tests/out/msl/ray-query.msl diff --git a/src/back/mod.rs b/src/back/mod.rs index 56223ac2bb..ce46e8cdbd 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -218,3 +218,26 @@ impl crate::Statement { } } } + +bitflags::bitflags! { + /// Ray flags. + #[derive(Default)] + pub struct RayFlag: u32 { + const OPAQUE = 0x01; + const NO_OPAQUE = 0x02; + const TERMINATE_ON_FIRST_HIT = 0x04; + const SKIP_CLOSEST_HIT_SHADER = 0x08; + const CULL_FRONT_FACING = 0x10; + const CULL_BACK_FACING = 0x20; + const CULL_OPAQUE = 0x40; + const CULL_NO_OPAQUE = 0x80; + const SKIP_TRIANGLES = 0x100; + const SKIP_AABBS = 0x200; + } +} + +#[repr(u32)] +enum RayIntersectionType { + Triangle = 1, + BoundingBox = 4, +} diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 271557fbf2..ca09e2af97 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -312,7 +312,7 @@ impl Options { } } - const fn resolve_push_constants( + fn resolve_push_constants( &self, stage: crate::ShaderStage, ) -> Result { @@ -324,10 +324,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", @@ -346,10 +343,7 @@ impl Options { match slot { Some(slot) => Ok(ResolvedBinding::Resource(BindTarget { buffer: Some(slot), - texture: None, - sampler: None, - binding_array_size: None, - mutable: false, + ..Default::default() })), None if self.fake_missing_bindings => Ok(ResolvedBinding::User { prefix: "fake", diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 9f1d0e6877..5a9e39b876 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -25,6 +25,13 @@ const WRAPPED_ARRAY_FIELD: &str = "inner"; // Some more general handling of pointers is needed to be implemented here. const ATOMIC_REFERENCE: &str = "&"; +const RT_NAMESPACE: &str = "metal::raytracing"; +const RAY_QUERY_TYPE: &str = "_RayQuery"; +const RAY_QUERY_FIELD_INTERSECTOR: &str = "intersector"; +const RAY_QUERY_FIELD_INTERSECTION: &str = "intersection"; +const RAY_QUERY_FIELD_READY: &str = "ready"; +const RAY_QUERY_FUN_MAP_INTERSECTION: &str = "_map_intersection_type"; + /// Write the Metal name for a Naga numeric type: scalar, vector, or matrix. /// /// The `sizes` slice determines whether this function writes a @@ -194,8 +201,11 @@ impl<'a> Display for TypeContext<'a> { crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } - crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => { - unreachable!("Ray queries are not supported yet"); + crate::TypeInner::AccelerationStructure => { + write!(out, "{RT_NAMESPACE}::instance_acceleration_structure") + } + crate::TypeInner::RayQuery => { + write!(out, "{RAY_QUERY_TYPE}") } crate::TypeInner::BindingArray { base, size } => { let base_tyname = Self { @@ -1863,8 +1873,39 @@ impl Writer { write!(self.out, ")")?; } } - // hot supported yet - crate::Expression::RayQueryGetIntersection { .. } => unreachable!(), + crate::Expression::RayQueryGetIntersection { query, committed } => { + if !committed { + unimplemented!() + } + let ty = context.module.special_types.ray_intersection.unwrap(); + let type_name = &self.names[&NameKey::Type(ty)]; + write!(self.out, "{type_name} {{{RAY_QUERY_FUN_MAP_INTERSECTION}(")?; + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.type)")?; + let fields = [ + "distance", + "user_instance_id", + "instance_id", + "", // SBT offset + "geometry_id", + "primitive_id", + "triangle_barycentric_coord", + "triangle_front_facing", + "", // padding + "object_to_world_transform", + "world_to_object_transform", + ]; + for field in fields { + write!(self.out, ", ")?; + if field.is_empty() { + write!(self.out, "{{}}")?; + } else { + self.put_expression(query, context, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.{field}")?; + } + } + write!(self.out, "}}")?; + } } Ok(()) } @@ -2323,6 +2364,7 @@ impl Writer { ) { use crate::Expression; self.need_bake_expressions.clear(); + for (expr_handle, expr) in func.expressions.iter() { // Expressions whose reference count is above the // threshold should always be stored in temporaries. @@ -2330,6 +2372,16 @@ impl Writer { let min_ref_count = func.expressions[expr_handle].bake_ref_count(); if min_ref_count <= expr_info.ref_count { self.need_bake_expressions.insert(expr_handle); + } else { + match expr_info.ty { + // force ray desc to be baked: it's used multiple times internally + TypeResolution::Handle(h) + if Some(h) == context.module.special_types.ray_desc => + { + self.need_bake_expressions.insert(expr_handle); + } + _ => {} + } } if let Expression::Math { fun, arg, arg1, .. } = *expr { @@ -2341,11 +2393,11 @@ impl Writer { // times, once for each component (see `put_dot_product`), so to // avoid duplicated evaluation, we must bake integer operands. - use crate::TypeInner; // check what kind of product this is depending // on the resolve type of the Dot function itself - let inner = context.resolve_type(expr_handle); - if let TypeInner::Scalar { kind, .. } = *inner { + if let crate::TypeInner::Scalar { kind, .. } = + *context.resolve_type(expr_handle) + { match kind { crate::ScalarKind::Sint | crate::ScalarKind::Uint => { self.need_bake_expressions.insert(arg); @@ -2770,7 +2822,100 @@ impl Writer { // done writeln!(self.out, ";")?; } - crate::Statement::RayQuery { .. } => unreachable!(), + crate::Statement::RayQuery { query, ref fun } => { + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + //TODO: how to deal with winding? + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.assume_geometry_type({RT_NAMESPACE}::geometry_type::triangle);")?; + { + let f_opaque = back::RayFlag::CULL_OPAQUE.bits(); + let f_no_opaque = back::RayFlag::CULL_NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.set_opacity_cull_mode((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::opacity_cull_mode::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::opacity_cull_mode::none);")?; + } + { + let f_opaque = back::RayFlag::OPAQUE.bits(); + let f_no_opaque = back::RayFlag::NO_OPAQUE.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTOR}.force_opacity((")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::opaque : (")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".flags & {f_no_opaque}) != 0 ? {RT_NAMESPACE}::forced_opacity::non_opaque : ")?; + writeln!(self.out, "{RT_NAMESPACE}::forced_opacity::none);")?; + } + { + let flag = back::RayFlag::TERMINATE_ON_FIRST_HIT.bits(); + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.accept_any_intersection((" + )?; + self.put_expression(descriptor, &context.expression, true)?; + writeln!(self.out, ".flags & {flag}) != 0);")?; + } + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + write!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION} = ")?; + self.put_expression(query, &context.expression, true)?; + write!( + self.out, + ".{RAY_QUERY_FIELD_INTERSECTOR}.intersect({RT_NAMESPACE}::ray(" + )?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".origin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".dir, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmin, ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".tmax), ")?; + self.put_expression(acceleration_structure, &context.expression, true)?; + write!(self.out, ", ")?; + self.put_expression(descriptor, &context.expression, true)?; + write!(self.out, ".cull_mask);")?; + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = true;")?; + } + crate::RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY};")?; + //TODO: actually proceed? + + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_READY} = false;")?; + } + crate::RayQueryFunction::Terminate => { + write!(self.out, "{level}")?; + self.put_expression(query, &context.expression, true)?; + writeln!(self.out, ".{RAY_QUERY_FIELD_INTERSECTION}.abort();")?; + } + } + } } } @@ -2882,14 +3027,41 @@ impl Writer { writeln!(self.out)?; // Work around Metal bug where `uint` is not available by default writeln!(self.out, "using {NAMESPACE}::uint;")?; - writeln!(self.out)?; + if module.types.iter().any(|(_, t)| match t.inner { + crate::TypeInner::RayQuery => true, + _ => false, + }) { + let tab = back::INDENT; + writeln!(self.out, "struct {RAY_QUERY_TYPE} {{")?; + let full_type = format!("{RT_NAMESPACE}::intersector<{RT_NAMESPACE}::instancing, {RT_NAMESPACE}::triangle_data, {RT_NAMESPACE}::world_space_data>"); + writeln!(self.out, "{tab}{full_type} {RAY_QUERY_FIELD_INTERSECTOR};")?; + writeln!( + self.out, + "{tab}{full_type}::result_type {RAY_QUERY_FIELD_INTERSECTION};" + )?; + writeln!(self.out, "{tab}bool {RAY_QUERY_FIELD_READY} = false;")?; + writeln!(self.out, "}};")?; + writeln!(self.out, "constexpr {NAMESPACE}::uint {RAY_QUERY_FUN_MAP_INTERSECTION}(const {RT_NAMESPACE}::intersection_type ty) {{")?; + let v_triangle = back::RayIntersectionType::Triangle as u32; + let v_bbox = back::RayIntersectionType::BoundingBox as u32; + writeln!( + self.out, + "{tab}return ty=={RT_NAMESPACE}::intersection_type::triangle ? {v_triangle} : " + )?; + writeln!( + self.out, + "{tab}{tab}ty=={RT_NAMESPACE}::intersection_type::bounding_box ? {v_bbox} : 0;" + )?; + writeln!(self.out, "}}")?; + } if options .bounds_check_policies .contains(index::BoundsCheckPolicy::ReadZeroSkipWrite) { self.put_default_constructible()?; } + writeln!(self.out)?; { let mut indices = vec![]; @@ -2931,11 +3103,12 @@ impl Writer { /// /// [`ReadZeroSkipWrite`]: index::BoundsCheckPolicy::ReadZeroSkipWrite fn put_default_constructible(&mut self) -> BackendResult { + let tab = back::INDENT; writeln!(self.out, "struct DefaultConstructible {{")?; - writeln!(self.out, " template")?; - writeln!(self.out, " operator T() && {{")?; - writeln!(self.out, " return T {{}};")?; - writeln!(self.out, " }}")?; + writeln!(self.out, "{tab}template")?; + writeln!(self.out, "{tab}operator T() && {{")?; + writeln!(self.out, "{tab}{tab}return T {{}};")?; + writeln!(self.out, "{tab}}}")?; writeln!(self.out, "}};")?; Ok(()) } diff --git a/tests/in/ray-query.param.ron b/tests/in/ray-query.param.ron index 9d8666954d..64d2f1b868 100644 --- a/tests/in/ray-query.param.ron +++ b/tests/in/ray-query.param.ron @@ -3,4 +3,12 @@ spv: ( version: (1, 4), ), + msl: ( + lang_version: (2, 4), + spirv_cross_compatibility: false, + fake_missing_bindings: true, + zero_initialize_workgroup_memory: false, + per_stage_map: (), + inline_samplers: [], + ), ) diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 5eabf3a2d3..b755d8f60a 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -2,8 +2,17 @@ var acc_struct: acceleration_structure; /* -let RAY_FLAG_NONE = 0u; -let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 4u; +let RAY_FLAG_NONE = 0x00u; +let RAY_FLAG_OPAQUE = 0x01u; +let RAY_FLAG_NO_OPAQUE = 0x02u; +let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; +let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; +let RAY_FLAG_CULL_FRONT_FACING = 0x10u; +let RAY_FLAG_CULL_BACK_FACING = 0x20u; +let RAY_FLAG_CULL_OPAQUE = 0x40u; +let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; +let RAY_FLAG_SKIP_TRIANGLES = 0x100u; +let RAY_FLAG_SKIP_AABBS = 0x200u; let RAY_QUERY_INTERSECTION_NONE = 0u; let RAY_QUERY_INTERSECTION_TRIANGLE = 1u; diff --git a/tests/out/msl/binding-arrays.msl b/tests/out/msl/binding-arrays.msl index da1078b5d8..694f79452d 100644 --- a/tests/out/msl/binding-arrays.msl +++ b/tests/out/msl/binding-arrays.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct UniformIndex { uint index; }; diff --git a/tests/out/msl/bounds-check-image-rzsw.msl b/tests/out/msl/bounds-check-image-rzsw.msl index 9032af14ca..eeb03c9849 100644 --- a/tests/out/msl/bounds-check-image-rzsw.msl +++ b/tests/out/msl/bounds-check-image-rzsw.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + constant metal::int2 const_type_4_ = {0, 0}; constant metal::int3 const_type_7_ = {0, 0, 0}; constant metal::float4 const_type_2_ = {0.0, 0.0, 0.0, 0.0}; diff --git a/tests/out/msl/bounds-check-zero-atomic.msl b/tests/out/msl/bounds-check-zero-atomic.msl index 95028ee796..daaa079233 100644 --- a/tests/out/msl/bounds-check-zero-atomic.msl +++ b/tests/out/msl/bounds-check-zero-atomic.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/bounds-check-zero.msl b/tests/out/msl/bounds-check-zero.msl index fece92de35..816983d98b 100644 --- a/tests/out/msl/bounds-check-zero.msl +++ b/tests/out/msl/bounds-check-zero.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct _mslBufferSizes { uint size0; }; diff --git a/tests/out/msl/policy-mix.msl b/tests/out/msl/policy-mix.msl index 842c57e58c..7eb0c61ede 100644 --- a/tests/out/msl/policy-mix.msl +++ b/tests/out/msl/policy-mix.msl @@ -3,13 +3,13 @@ #include using metal::uint; - struct DefaultConstructible { template operator T() && { return T {}; } }; + struct type_1 { metal::float4 inner[10]; }; diff --git a/tests/out/msl/ray-query.msl b/tests/out/msl/ray-query.msl new file mode 100644 index 0000000000..2a09737873 --- /dev/null +++ b/tests/out/msl/ray-query.msl @@ -0,0 +1,58 @@ +// language: metal2.4 +#include +#include + +using metal::uint; +struct _RayQuery { + metal::raytracing::intersector intersector; + metal::raytracing::intersector::result_type intersection; + bool ready = false; +}; +constexpr metal::uint _map_intersection_type(const metal::raytracing::intersection_type ty) { + return ty==metal::raytracing::intersection_type::triangle ? 1 : + ty==metal::raytracing::intersection_type::bounding_box ? 4 : 0; +} + +struct Output { + uint visible_; +}; +struct RayDesc { + uint flags; + uint cull_mask; + float tmin; + float tmax; + metal::float3 origin; + metal::float3 dir; +}; +struct RayIntersection { + uint kind; + float t; + uint instance_custom_index; + uint instance_id; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + metal::float2 barycentrics; + bool front_face; + char _pad9[11]; + metal::float4x3 object_to_world; + metal::float4x3 world_to_object; +}; + +kernel void main_( + metal::raytracing::instance_acceleration_structure acc_struct [[user(fake0)]] +, device Output& output [[user(fake0)]] +) { + _RayQuery rq = {}; + RayDesc _e12 = RayDesc {4u, 255u, 0.10000000149011612, 100.0, metal::float3(0.0), metal::float3(0.0, 1.0, 0.0)}; + rq.intersector.assume_geometry_type(metal::raytracing::geometry_type::triangle); + rq.intersector.set_opacity_cull_mode((_e12.flags & 64) != 0 ? metal::raytracing::opacity_cull_mode::opaque : (_e12.flags & 128) != 0 ? metal::raytracing::opacity_cull_mode::non_opaque : metal::raytracing::opacity_cull_mode::none); + rq.intersector.force_opacity((_e12.flags & 1) != 0 ? metal::raytracing::forced_opacity::opaque : (_e12.flags & 2) != 0 ? metal::raytracing::forced_opacity::non_opaque : metal::raytracing::forced_opacity::none); + rq.intersector.accept_any_intersection((_e12.flags & 4) != 0); + rq.intersection = rq.intersector.intersect(metal::raytracing::ray(_e12.origin, _e12.dir, _e12.tmin, _e12.tmax), acc_struct, _e12.cull_mask); rq.ready = true; + bool _e13 = rq.ready; + rq.ready = false; + RayIntersection intersection = RayIntersection {_map_intersection_type(rq.intersection.type), rq.intersection.distance, rq.intersection.user_instance_id, rq.intersection.instance_id, {}, rq.intersection.geometry_id, rq.intersection.primitive_id, rq.intersection.triangle_barycentric_coord, rq.intersection.triangle_front_facing, {}, rq.intersection.object_to_world_transform, rq.intersection.world_to_object_transform}; + output.visible_ = static_cast(intersection.kind == 0u); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 0b3c1f0ce5..9fe7d46c37 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -559,7 +559,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), - ("ray-query", Targets::SPIRV), + ("ray-query", Targets::SPIRV | Targets::METAL), ]; for &(name, targets) in inputs.iter() {