From 40f2aeb03924071f0141e008172353c792e8be9b Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 21 Mar 2023 22:47:53 -0700 Subject: [PATCH] Address Jim's review notes, use typegen module for atomic struct --- src/back/dot/mod.rs | 28 +++++++----- src/back/mod.rs | 4 +- src/back/spv/writer.rs | 2 +- src/front/type_gen.rs | 49 +++++++++++++++++++++ src/front/wgsl/error.rs | 6 +++ src/front/wgsl/lower/mod.rs | 53 ++++------------------- src/lib.rs | 6 +-- tests/in/ray-query.wgsl | 4 +- tests/out/wgsl/atomicCompareExchange.wgsl | 4 +- 9 files changed, 90 insertions(+), 66 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index d293c3adc1..1eebbee067 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -254,19 +254,25 @@ impl StatementGraph { } S::RayQuery { query, ref fun } => { self.dependencies.push((id, query, "query")); - if let crate::RayQueryFunction::Initialize { - acceleration_structure, - descriptor, - } = *fun - { - self.dependencies.push(( - id, + match *fun { + crate::RayQueryFunction::Initialize { acceleration_structure, - "acceleration_structure", - )); - self.dependencies.push((id, descriptor, "descriptor")); + descriptor, + } => { + self.dependencies.push(( + id, + acceleration_structure, + "acceleration_structure", + )); + self.dependencies.push((id, descriptor, "descriptor")); + "RayQueryInitialize" + } + crate::RayQueryFunction::Proceed { result } => { + self.emits.push((id, result)); + "RayQueryProceed" + } + crate::RayQueryFunction::Terminate => "RayQueryTerminate", } - "RayQuery" } }; // Set the last node to the merge node diff --git a/src/back/mod.rs b/src/back/mod.rs index f51262524c..8467ee787b 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -234,8 +234,8 @@ bitflags::bitflags! { const NO_OPAQUE = 0x02; const TERMINATE_ON_FIRST_HIT = 0x04; const SKIP_CLOSEST_HIT_SHADER = 0x08; - const CULL_FRONT_FACING = 0x10; - const CULL_BACK_FACING = 0x20; + const CULL_BACK_FACING = 0x10; + const CULL_FRONT_FACING = 0x20; const CULL_OPAQUE = 0x40; const CULL_NO_OPAQUE = 0x80; const SKIP_TRIANGLES = 0x100; diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 800a40ed68..ba235e6d03 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -973,7 +973,7 @@ impl Writer { self.write_type_declaration_local(id, local); - // If it's an type that needs SPIR-V capabilities, request them now, + // If it's a type that needs SPIR-V capabilities, request them now, // so write_type_declaration_local can stay infallible. self.request_type_capabilities(&ty.inner)?; diff --git a/src/front/type_gen.rs b/src/front/type_gen.rs index b695b52792..3cdaa33268 100644 --- a/src/front/type_gen.rs +++ b/src/front/type_gen.rs @@ -5,6 +5,55 @@ Type generators. use crate::{arena::Handle, span::Span}; impl crate::Module { + pub(super) fn generate_atomic_compare_exchange_result( + &mut self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Handle { + let bool_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + }, + }, + Span::UNDEFINED, + ); + let scalar_ty = self.types.insert( + crate::Type { + name: None, + inner: crate::TypeInner::Scalar { kind, width }, + }, + Span::UNDEFINED, + ); + + self.types.insert( + crate::Type { + name: Some(format!( + "__atomic_compare_exchange_result<{kind:?},{width}>" + )), + inner: crate::TypeInner::Struct { + members: vec![ + crate::StructMember { + name: Some("old_value".to_string()), + ty: scalar_ty, + binding: None, + offset: 0, + }, + crate::StructMember { + name: Some("exchanged".to_string()), + ty: bool_ty, + binding: None, + offset: 4, + }, + ], + span: 8, + }, + }, + Span::UNDEFINED, + ) + } /// Populate this module's [`SpecialTypes::ray_desc`] type. /// /// [`SpecialTypes::ray_desc`] is the type of the [`descriptor`] operand of diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index a4e6540237..2e71a76624 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -188,6 +188,7 @@ pub enum Error<'a> { MissingAttribute(&'static str, Span), InvalidAtomicPointer(Span), InvalidAtomicOperandType(Span), + InvalidRayQueryPointer(Span), Pointer(&'static str, Span), NotPointer(Span), NotReference(&'static str, Span), @@ -526,6 +527,11 @@ impl<'a> Error<'a> { labels: vec![(span, "atomic operand type is invalid".into())], notes: vec![], }, + Error::InvalidRayQueryPointer(span) => ParseError { + message: "ray query operation is done on a pointer to a non-ray-query".to_string(), + labels: vec![(span, "ray query pointer is invalid".into())], + notes: vec![], + }, Error::NotPointer(span) => ParseError { message: "the operand of the `*` operator must be a pointer".to_string(), labels: vec![(span, "expression is not a pointer".into())], diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 8df3460b2e..314eea52ec 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -641,6 +641,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let span = tu.decls.get_span(decl_handle); let decl = &tu.decls[decl_handle]; + //NOTE: This is done separately from `resolve_ast_type` because `RayDesc` may be + // first encountered in a local constructor invocation. //TODO: find a nicer way? if let Some(dep) = decl.dependencies.iter().find(|dep| dep.ident == "RayDesc") { let ty_handle = ctx.module.generate_ray_desc_type(); @@ -1733,50 +1735,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let expression = match *ctx.resolved_inner(value) { crate::TypeInner::Scalar { kind, width } => { - let bool_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { - kind: crate::ScalarKind::Bool, - width: crate::BOOL_WIDTH, - }, - }, - Span::UNDEFINED, - ); - let scalar_ty = ctx.module.types.insert( - crate::Type { - name: None, - inner: crate::TypeInner::Scalar { kind, width }, - }, - Span::UNDEFINED, - ); - let struct_ty = ctx.module.types.insert( - crate::Type { - name: Some( - "__atomic_compare_exchange_result".to_string(), - ), - inner: crate::TypeInner::Struct { - members: vec![ - crate::StructMember { - name: Some("old_value".to_string()), - ty: scalar_ty, - binding: None, - offset: 0, - }, - crate::StructMember { - name: Some("exchanged".to_string()), - ty: bool_ty, - binding: None, - offset: 4, - }, - ], - span: 8, - }, - }, - Span::UNDEFINED, - ); crate::Expression::AtomicResult { - ty: struct_ty, + //TODO: cache this to avoid generating duplicate types + ty: ctx + .module + .generate_atomic_compare_exchange_result(kind, width), comparison: true, } } @@ -2449,12 +2412,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { crate::TypeInner::RayQuery => Ok(pointer), ref other => { log::error!("Pointer type to {:?} passed to ray query op", other); - Err(Error::InvalidAtomicPointer(span)) + Err(Error::InvalidRayQueryPointer(span)) } }, ref other => { log::error!("Type {:?} passed to ray query op", other); - Err(Error::InvalidAtomicPointer(span)) + Err(Error::InvalidRayQueryPointer(span)) } } } diff --git a/src/lib.rs b/src/lib.rs index f91dc7ce56..a70015d16d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1460,7 +1460,7 @@ pub enum Expression { /// Return an intersection found by `query`. /// - /// If `committed` is true, return the committed result available when + /// If `committed` is true, return the committed result available when RayQueryGetIntersection { query: Handle, committed: bool, @@ -1848,13 +1848,13 @@ pub struct SpecialTypes { /// /// Call [`Module::generate_ray_desc_type`] to populate this if /// needed and return the handle. - ray_desc: Option>, + pub ray_desc: Option>, /// Type for `RayIntersection`. /// /// Call [`Module::generate_ray_intersection_type`] to populate /// this if needed and return the handle. - ray_intersection: Option>, + pub ray_intersection: Option>, } /// Shader module. diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 1a9c967490..4826547ded 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -7,8 +7,8 @@ let RAY_FLAG_OPAQUE = 0x01u; let RAY_FLAG_NO_OPAQUE = 0x02u; let RAY_FLAG_TERMINATE_ON_FIRST_HIT = 0x04u; let RAY_FLAG_SKIP_CLOSEST_HIT_SHADER = 0x08u; -let RAY_FLAG_CULL_FRONT_FACING = 0x10u; -let RAY_FLAG_CULL_BACK_FACING = 0x20u; +let RAY_FLAG_CULL_BACK_FACING = 0x10u; +let RAY_FLAG_CULL_FRONT_FACING = 0x20u; let RAY_FLAG_CULL_OPAQUE = 0x40u; let RAY_FLAG_CULL_NO_OPAQUE = 0x80u; let RAY_FLAG_SKIP_TRIANGLES = 0x100u; diff --git a/tests/out/wgsl/atomicCompareExchange.wgsl b/tests/out/wgsl/atomicCompareExchange.wgsl index 2c213c8fec..bfad298fab 100644 --- a/tests/out/wgsl/atomicCompareExchange.wgsl +++ b/tests/out/wgsl/atomicCompareExchange.wgsl @@ -1,9 +1,9 @@ -struct gen___atomic_compare_exchange_result { +struct gen___atomic_compare_exchange_resultSint4_ { old_value: i32, exchanged: bool, } -struct gen___atomic_compare_exchange_result_1 { +struct gen___atomic_compare_exchange_resultUint4_ { old_value: u32, exchanged: bool, }