From 8c5cabe442c62e56504cf18356b3d702720db9c5 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 17 Mar 2023 22:53:13 -0700 Subject: [PATCH] ray query: validation, better test --- src/valid/expression.rs | 23 +++- src/valid/function.rs | 95 +++++++++++--- tests/in/ray-query.wgsl | 14 +- tests/out/spv/ray-query.spvasm | 229 ++++++++++++++++++++------------- 4 files changed, 250 insertions(+), 111 deletions(-) diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 43fe1cc6c6..b923b88532 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -35,6 +35,8 @@ pub enum ExpressionError { InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] InvalidArrayType(Handle), + #[error("Get intersection of {0:?} can't be done")] + InvalidRayQueryType(Handle), #[error("Splatting {0:?} can't be done")] InvalidSplatType(Handle), #[error("Swizzling {0:?} can't be done")] @@ -1426,7 +1428,26 @@ impl super::Validator { return Err(ExpressionError::InvalidArrayType(expr)); } }, - E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(), + E::RayQueryProceedResult => ShaderStages::all(), + E::RayQueryGetIntersection { + query, + committed: _, + } => match resolver[query] { + Ti::Pointer { + base, + space: crate::AddressSpace::Function, + } => match resolver.types[base].inner { + Ti::RayQuery => ShaderStages::all(), + ref other => { + log::error!("Intersection result of a pointer to {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, + ref other => { + log::error!("Intersection result of {:?}", other); + return Err(ExpressionError::InvalidRayQueryType(query)); + } + }, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 037464bd0c..d96f2c3b0b 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -47,8 +47,6 @@ pub enum AtomicError { InvalidPointer(Handle), #[error("Operand {0:?} has invalid type.")] InvalidOperand(Handle), - #[error("Result expression {0:?} has already been introduced earlier")] - ResultAlreadyInScope(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), } @@ -129,6 +127,14 @@ pub enum FunctionError { }, #[error("Atomic operation is invalid")] InvalidAtomic(#[from] AtomicError), + #[error("Ray Query {0:?} is not a local variable")] + InvalidRayQueryExpression(Handle), + #[error("Acceleration structure {0:?} is not a matching expression")] + InvalidAccelerationStructure(Handle), + #[error("Ray descriptor {0:?} is not a matching expression")] + InvalidRayDescriptor(Handle), + #[error("Ray Query {0:?} does not have a matching type")] + InvalidRayQueryType(Handle), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -167,8 +173,10 @@ struct BlockContext<'a> { info: &'a FunctionInfo, expressions: &'a Arena, types: &'a UniqueArena, + local_vars: &'a Arena, global_vars: &'a Arena, functions: &'a Arena, + special_types: &'a crate::SpecialTypes, prev_infos: &'a [FunctionInfo], return_type: Option>, } @@ -186,8 +194,10 @@ impl<'a> BlockContext<'a> { info, expressions: &fun.expressions, types: &module.types, + local_vars: &fun.local_variables, global_vars: &module.global_variables, functions: &module.functions, + special_types: &module.special_types, prev_infos, return_type: fun.result.as_ref().map(|fr| fr.ty), } @@ -297,6 +307,21 @@ impl super::Validator { Ok(callee_info.available_stages) } + #[cfg(feature = "validate")] + fn emit_expression( + &mut self, + handle: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + if self.valid_expression_set.insert(handle.index()) { + self.valid_expression_list.push(handle); + Ok(()) + } else { + Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)) + } + } + #[cfg(feature = "validate")] fn validate_atomic( &mut self, @@ -345,13 +370,7 @@ impl super::Validator { } } - if self.valid_expression_set.insert(result.index()) { - self.valid_expression_list.push(result); - } else { - return Err(AtomicError::ResultAlreadyInScope(result) - .with_span_handle(result, context.expressions) - .into_other()); - } + self.emit_expression(result, context)?; match context.expressions[result] { crate::Expression::AtomicResult { ty, comparison } if { @@ -399,12 +418,7 @@ impl super::Validator { match *statement { S::Emit(ref range) => { for handle in range.clone() { - if self.valid_expression_set.insert(handle.index()) { - self.valid_expression_list.push(handle); - } else { - return Err(FunctionError::ExpressionAlreadyInScope(handle) - .with_span_handle(handle, context.expressions)); - } + self.emit_expression(handle, context)?; } } S::Block(ref block) => { @@ -801,8 +815,55 @@ impl super::Validator { } => { self.validate_atomic(pointer, fun, value, result, context)?; } - S::RayQuery { query: _, fun: _ } => { - //TODO + S::RayQuery { query, ref fun } => { + let query_var = match *context.get_expression(query) { + crate::Expression::LocalVariable(var) => &context.local_vars[var], + ref other => { + log::error!("Unexpected ray query expression {other:?}"); + return Err(FunctionError::InvalidRayQueryExpression(query) + .with_span_static(span, "invalid query expression")); + } + }; + match context.types[query_var.ty].inner { + Ti::RayQuery => {} + ref other => { + log::error!("Unexpected ray query type {other:?}"); + return Err(FunctionError::InvalidRayQueryType(query_var.ty) + .with_span_static(span, "invalid query type")); + } + } + match *fun { + crate::RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + match *context + .resolve_type(acceleration_structure, &self.valid_expression_set)? + { + Ti::AccelerationStructure => {} + _ => { + return Err(FunctionError::InvalidAccelerationStructure( + acceleration_structure, + ) + .with_span_static(span, "invalid acceleration structure")) + } + } + let desc_ty_given = + context.resolve_type(descriptor, &self.valid_expression_set)?; + let desc_ty_expected = context + .special_types + .ray_desc + .map(|handle| &context.types[handle].inner); + if Some(desc_ty_given) != desc_ty_expected { + return Err(FunctionError::InvalidRayDescriptor(descriptor) + .with_span_static(span, "invalid ray descriptor")); + } + } + crate::RayQueryFunction::Proceed { result } => { + self.emit_expression(result, context)?; + } + crate::RayQueryFunction::Terminate => {} + } } } } diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl index 5eabf3a2d3..dafd170f81 100644 --- a/tests/in/ray-query.wgsl +++ b/tests/in/ray-query.wgsl @@ -36,19 +36,29 @@ struct RayIntersection { struct Output { visible: u32, + normal: vec3, } @group(0) @binding(1) var output: Output; +fn get_torus_normal(world_point: vec3, intersection: RayIntersection) -> vec3 { + let local_point = intersection.world_to_object * vec4(world_point, 1.0); + let point_on_guiding_line = normalize(local_point.xy) * 2.4; + let world_point_on_guiding_line = intersection.object_to_world * vec4(point_on_guiding_line, 0.0, 1.0); + return normalize(world_point - world_point_on_guiding_line); +} + @compute @workgroup_size(1) fn main() { var rq: ray_query; - rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), vec3(0.0, 1.0, 0.0))); + let dir = vec3(0.0, 1.0, 0.0); + rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3(0.0), dir)); - rayQueryProceed(&rq); + while (rayQueryProceed(&rq)) {} let intersection = rayQueryGetCommittedIntersection(&rq); output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE); + output.normal = get_torus_normal(dir * intersection.t, intersection); } diff --git a/tests/out/spv/ray-query.spvasm b/tests/out/spv/ray-query.spvasm index 1a1a18bba1..306cda758c 100644 --- a/tests/out/spv/ray-query.spvasm +++ b/tests/out/spv/ray-query.spvasm @@ -1,105 +1,152 @@ ; SPIR-V ; Version: 1.4 ; Generator: rspirv -; Bound: 63 +; Bound: 95 OpCapability RayQueryKHR OpCapability Shader OpExtension "SPV_KHR_ray_query" %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %29 "main" %21 %23 -OpExecutionMode %29 LocalSize 1 1 1 -OpMemberDecorate %13 0 Offset 0 +OpEntryPoint GLCompute %48 "main" %23 %25 +OpExecutionMode %48 LocalSize 1 1 1 OpMemberDecorate %15 0 Offset 0 -OpMemberDecorate %15 1 Offset 4 -OpMemberDecorate %15 2 Offset 8 -OpMemberDecorate %15 3 Offset 12 -OpMemberDecorate %15 4 Offset 16 -OpMemberDecorate %15 5 Offset 32 -OpMemberDecorate %20 0 Offset 0 -OpMemberDecorate %20 1 Offset 4 -OpMemberDecorate %20 2 Offset 8 -OpMemberDecorate %20 3 Offset 12 -OpMemberDecorate %20 4 Offset 16 -OpMemberDecorate %20 5 Offset 20 -OpMemberDecorate %20 6 Offset 24 -OpMemberDecorate %20 7 Offset 28 -OpMemberDecorate %20 8 Offset 36 -OpMemberDecorate %20 9 Offset 48 -OpMemberDecorate %20 9 ColMajor -OpMemberDecorate %20 9 MatrixStride 16 -OpMemberDecorate %20 10 Offset 112 -OpMemberDecorate %20 10 ColMajor -OpMemberDecorate %20 10 MatrixStride 16 -OpDecorate %21 DescriptorSet 0 -OpDecorate %21 Binding 0 +OpMemberDecorate %15 1 Offset 16 +OpMemberDecorate %19 0 Offset 0 +OpMemberDecorate %19 1 Offset 4 +OpMemberDecorate %19 2 Offset 8 +OpMemberDecorate %19 3 Offset 12 +OpMemberDecorate %19 4 Offset 16 +OpMemberDecorate %19 5 Offset 20 +OpMemberDecorate %19 6 Offset 24 +OpMemberDecorate %19 7 Offset 28 +OpMemberDecorate %19 8 Offset 36 +OpMemberDecorate %19 9 Offset 48 +OpMemberDecorate %19 9 ColMajor +OpMemberDecorate %19 9 MatrixStride 16 +OpMemberDecorate %19 10 Offset 112 +OpMemberDecorate %19 10 ColMajor +OpMemberDecorate %19 10 MatrixStride 16 +OpMemberDecorate %21 0 Offset 0 +OpMemberDecorate %21 1 Offset 4 +OpMemberDecorate %21 2 Offset 8 +OpMemberDecorate %21 3 Offset 12 +OpMemberDecorate %21 4 Offset 16 +OpMemberDecorate %21 5 Offset 32 OpDecorate %23 DescriptorSet 0 -OpDecorate %23 Binding 1 -OpDecorate %24 Block -OpMemberDecorate %24 0 Offset 0 +OpDecorate %23 Binding 0 +OpDecorate %25 DescriptorSet 0 +OpDecorate %25 Binding 1 +OpDecorate %26 Block +OpMemberDecorate %26 0 Offset 0 %2 = OpTypeVoid -%4 = OpTypeInt 32 0 -%3 = OpConstant %4 4 -%5 = OpConstant %4 255 -%7 = OpTypeFloat 32 -%6 = OpConstant %7 0.1 -%8 = OpConstant %7 100.0 -%9 = OpConstant %7 0.0 -%10 = OpConstant %7 1.0 -%11 = OpConstant %4 0 -%12 = OpTypeAccelerationStructureNV -%13 = OpTypeStruct %4 -%14 = OpTypeVector %7 3 -%15 = OpTypeStruct %4 %4 %7 %7 %14 %14 -%16 = OpTypeRayQueryKHR -%17 = OpTypeVector %7 2 -%18 = OpTypeBool -%19 = OpTypeMatrix %14 4 -%20 = OpTypeStruct %4 %7 %4 %4 %4 %4 %4 %17 %18 %19 %19 -%22 = OpTypePointer UniformConstant %12 -%21 = OpVariable %22 UniformConstant -%24 = OpTypeStruct %13 -%25 = OpTypePointer StorageBuffer %24 -%23 = OpVariable %25 StorageBuffer -%27 = OpTypePointer Function %16 -%30 = OpTypeFunction %2 -%32 = OpTypePointer StorageBuffer %13 -%45 = OpConstant %4 1 -%58 = OpTypePointer StorageBuffer %4 -%29 = OpFunction %2 None %30 +%4 = OpTypeFloat 32 +%3 = OpConstant %4 1.0 +%5 = OpConstant %4 2.4 +%6 = OpConstant %4 0.0 +%8 = OpTypeInt 32 0 +%7 = OpConstant %8 4 +%9 = OpConstant %8 255 +%10 = OpConstant %4 0.1 +%11 = OpConstant %4 100.0 +%12 = OpConstant %8 0 +%13 = OpTypeAccelerationStructureNV +%14 = OpTypeVector %4 3 +%15 = OpTypeStruct %8 %14 +%16 = OpTypeVector %4 2 +%17 = OpTypeBool +%18 = OpTypeMatrix %14 4 +%19 = OpTypeStruct %8 %4 %8 %8 %8 %8 %8 %16 %17 %18 %18 +%20 = OpTypeVector %4 4 +%21 = OpTypeStruct %8 %8 %4 %4 %14 %14 +%22 = OpTypeRayQueryKHR +%24 = OpTypePointer UniformConstant %13 +%23 = OpVariable %24 UniformConstant +%26 = OpTypeStruct %15 +%27 = OpTypePointer StorageBuffer %26 +%25 = OpVariable %27 StorageBuffer +%32 = OpTypeFunction %14 %14 %19 +%46 = OpTypePointer Function %22 +%49 = OpTypeFunction %2 +%51 = OpTypePointer StorageBuffer %15 +%72 = OpConstant %8 1 +%85 = OpTypePointer StorageBuffer %8 +%90 = OpTypePointer StorageBuffer %14 +%31 = OpFunction %14 None %32 +%29 = OpFunctionParameter %14 +%30 = OpFunctionParameter %19 %28 = OpLabel -%26 = OpVariable %27 Function -%31 = OpLoad %12 %21 -%33 = OpAccessChain %32 %23 %11 -OpBranch %34 -%34 = OpLabel -%35 = OpCompositeConstruct %14 %9 %9 %9 -%36 = OpCompositeConstruct %14 %9 %10 %9 -%37 = OpCompositeConstruct %15 %3 %5 %6 %8 %35 %36 -%38 = OpCompositeExtract %4 %37 0 -%39 = OpCompositeExtract %4 %37 1 -%40 = OpCompositeExtract %7 %37 2 -%41 = OpCompositeExtract %7 %37 3 -%42 = OpCompositeExtract %14 %37 4 -%43 = OpCompositeExtract %14 %37 5 -OpRayQueryInitializeKHR %26 %31 %38 %39 %42 %40 %43 %41 -%44 = OpRayQueryProceedKHR %18 %26 -%46 = OpRayQueryGetIntersectionTypeKHR %4 %26 %45 -%47 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %4 %26 %45 -%48 = OpRayQueryGetIntersectionInstanceIdKHR %4 %26 %45 -%49 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %4 %26 %45 -%50 = OpRayQueryGetIntersectionGeometryIndexKHR %4 %26 %45 -%51 = OpRayQueryGetIntersectionPrimitiveIndexKHR %4 %26 %45 -%52 = OpRayQueryGetIntersectionTKHR %7 %26 %45 -%53 = OpRayQueryGetIntersectionBarycentricsKHR %17 %26 %45 -%54 = OpRayQueryGetIntersectionFrontFaceKHR %18 %26 %45 -%55 = OpRayQueryGetIntersectionObjectToWorldKHR %19 %26 %45 -%56 = OpRayQueryGetIntersectionWorldToObjectKHR %19 %26 %45 -%57 = OpCompositeConstruct %20 %46 %52 %47 %48 %49 %50 %51 %53 %54 %55 %56 -%59 = OpCompositeExtract %4 %57 0 -%60 = OpIEqual %18 %59 %11 -%61 = OpSelect %4 %60 %45 %11 -%62 = OpAccessChain %58 %33 %11 -OpStore %62 %61 +OpBranch %33 +%33 = OpLabel +%34 = OpCompositeExtract %18 %30 10 +%35 = OpCompositeConstruct %20 %29 %3 +%36 = OpMatrixTimesVector %14 %34 %35 +%37 = OpVectorShuffle %16 %36 %36 0 1 +%38 = OpExtInst %16 %1 Normalize %37 +%39 = OpVectorTimesScalar %16 %38 %5 +%40 = OpCompositeExtract %18 %30 9 +%41 = OpCompositeConstruct %20 %39 %6 %3 +%42 = OpMatrixTimesVector %14 %40 %41 +%43 = OpFSub %14 %29 %42 +%44 = OpExtInst %14 %1 Normalize %43 +OpReturnValue %44 +OpFunctionEnd +%48 = OpFunction %2 None %49 +%47 = OpLabel +%45 = OpVariable %46 Function +%50 = OpLoad %13 %23 +%52 = OpAccessChain %51 %25 %12 +OpBranch %53 +%53 = OpLabel +%54 = OpCompositeConstruct %14 %6 %3 %6 +%55 = OpCompositeConstruct %14 %6 %6 %6 +%56 = OpCompositeConstruct %21 %7 %9 %10 %11 %55 %54 +%57 = OpCompositeExtract %8 %56 0 +%58 = OpCompositeExtract %8 %56 1 +%59 = OpCompositeExtract %4 %56 2 +%60 = OpCompositeExtract %4 %56 3 +%61 = OpCompositeExtract %14 %56 4 +%62 = OpCompositeExtract %14 %56 5 +OpRayQueryInitializeKHR %45 %50 %57 %58 %61 %59 %62 %60 +OpBranch %63 +%63 = OpLabel +OpLoopMerge %64 %66 None +OpBranch %65 +%65 = OpLabel +%67 = OpRayQueryProceedKHR %17 %45 +OpSelectionMerge %68 None +OpBranchConditional %67 %68 %69 +%69 = OpLabel +OpBranch %64 +%68 = OpLabel +OpBranch %70 +%70 = OpLabel +OpBranch %71 +%71 = OpLabel +OpBranch %66 +%66 = OpLabel +OpBranch %63 +%64 = OpLabel +%73 = OpRayQueryGetIntersectionTypeKHR %8 %45 %72 +%74 = OpRayQueryGetIntersectionInstanceCustomIndexKHR %8 %45 %72 +%75 = OpRayQueryGetIntersectionInstanceIdKHR %8 %45 %72 +%76 = OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR %8 %45 %72 +%77 = OpRayQueryGetIntersectionGeometryIndexKHR %8 %45 %72 +%78 = OpRayQueryGetIntersectionPrimitiveIndexKHR %8 %45 %72 +%79 = OpRayQueryGetIntersectionTKHR %4 %45 %72 +%80 = OpRayQueryGetIntersectionBarycentricsKHR %16 %45 %72 +%81 = OpRayQueryGetIntersectionFrontFaceKHR %17 %45 %72 +%82 = OpRayQueryGetIntersectionObjectToWorldKHR %18 %45 %72 +%83 = OpRayQueryGetIntersectionWorldToObjectKHR %18 %45 %72 +%84 = OpCompositeConstruct %19 %73 %79 %74 %75 %76 %77 %78 %80 %81 %82 %83 +%86 = OpCompositeExtract %8 %84 0 +%87 = OpIEqual %17 %86 %12 +%88 = OpSelect %8 %87 %72 %12 +%89 = OpAccessChain %85 %52 %12 +OpStore %89 %88 +%91 = OpCompositeExtract %4 %84 1 +%92 = OpVectorTimesScalar %14 %54 %91 +%93 = OpFunctionCall %14 %31 %92 %84 +%94 = OpAccessChain %90 %52 %72 +OpStore %94 %93 OpReturn OpFunctionEnd \ No newline at end of file