Skip to content

Commit

Permalink
ray query: validation, better test
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Mar 18, 2023
1 parent 8e13d58 commit 8c5cabe
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 111 deletions.
23 changes: 22 additions & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub enum ExpressionError {
InvalidPointerType(Handle<crate::Expression>),
#[error("Array length of {0:?} can't be done")]
InvalidArrayType(Handle<crate::Expression>),
#[error("Get intersection of {0:?} can't be done")]
InvalidRayQueryType(Handle<crate::Expression>),
#[error("Splatting {0:?} can't be done")]
InvalidSplatType(Handle<crate::Expression>),
#[error("Swizzling {0:?} can't be done")]
Expand Down Expand Up @@ -1426,7 +1428,26 @@ impl super::Validator {
return Err(ExpressionError::InvalidArrayType(expr));
}
},
E::RayQueryProceedResult | E::RayQueryGetIntersection { .. } => ShaderStages::all(),
E::RayQueryProceedResult => ShaderStages::all(),
E::RayQueryGetIntersection {
query,
committed: _,
} => match resolver[query] {
Ti::Pointer {
base,
space: crate::AddressSpace::Function,
} => match resolver.types[base].inner {
Ti::RayQuery => ShaderStages::all(),
ref other => {
log::error!("Intersection result of a pointer to {:?}", other);
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
ref other => {
log::error!("Intersection result of {:?}", other);
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
};
Ok(stages)
}
Expand Down
95 changes: 78 additions & 17 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub enum AtomicError {
InvalidPointer(Handle<crate::Expression>),
#[error("Operand {0:?} has invalid type.")]
InvalidOperand(Handle<crate::Expression>),
#[error("Result expression {0:?} has already been introduced earlier")]
ResultAlreadyInScope(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>),
}
Expand Down Expand Up @@ -129,6 +127,14 @@ pub enum FunctionError {
},
#[error("Atomic operation is invalid")]
InvalidAtomic(#[from] AtomicError),
#[error("Ray Query {0:?} is not a local variable")]
InvalidRayQueryExpression(Handle<crate::Expression>),
#[error("Acceleration structure {0:?} is not a matching expression")]
InvalidAccelerationStructure(Handle<crate::Expression>),
#[error("Ray descriptor {0:?} is not a matching expression")]
InvalidRayDescriptor(Handle<crate::Expression>),
#[error("Ray Query {0:?} does not have a matching type")]
InvalidRayQueryType(Handle<crate::Type>),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
Expand Down Expand Up @@ -167,8 +173,10 @@ struct BlockContext<'a> {
info: &'a FunctionInfo,
expressions: &'a Arena<crate::Expression>,
types: &'a UniqueArena<crate::Type>,
local_vars: &'a Arena<crate::LocalVariable>,
global_vars: &'a Arena<crate::GlobalVariable>,
functions: &'a Arena<crate::Function>,
special_types: &'a crate::SpecialTypes,
prev_infos: &'a [FunctionInfo],
return_type: Option<Handle<crate::Type>>,
}
Expand All @@ -186,8 +194,10 @@ impl<'a> BlockContext<'a> {
info,
expressions: &fun.expressions,
types: &module.types,
local_vars: &fun.local_variables,
global_vars: &module.global_variables,
functions: &module.functions,
special_types: &module.special_types,
prev_infos,
return_type: fun.result.as_ref().map(|fr| fr.ty),
}
Expand Down Expand Up @@ -297,6 +307,21 @@ impl super::Validator {
Ok(callee_info.available_stages)
}

#[cfg(feature = "validate")]
fn emit_expression(
&mut self,
handle: Handle<crate::Expression>,
context: &BlockContext,
) -> Result<(), WithSpan<FunctionError>> {
if self.valid_expression_set.insert(handle.index()) {
self.valid_expression_list.push(handle);
Ok(())
} else {
Err(FunctionError::ExpressionAlreadyInScope(handle)
.with_span_handle(handle, context.expressions))
}
}

#[cfg(feature = "validate")]
fn validate_atomic(
&mut self,
Expand Down Expand Up @@ -345,13 +370,7 @@ impl super::Validator {
}
}

if self.valid_expression_set.insert(result.index()) {
self.valid_expression_list.push(result);
} else {
return Err(AtomicError::ResultAlreadyInScope(result)
.with_span_handle(result, context.expressions)
.into_other());
}
self.emit_expression(result, context)?;
match context.expressions[result] {
crate::Expression::AtomicResult { ty, comparison }
if {
Expand Down Expand Up @@ -399,12 +418,7 @@ impl super::Validator {
match *statement {
S::Emit(ref range) => {
for handle in range.clone() {
if self.valid_expression_set.insert(handle.index()) {
self.valid_expression_list.push(handle);
} else {
return Err(FunctionError::ExpressionAlreadyInScope(handle)
.with_span_handle(handle, context.expressions));
}
self.emit_expression(handle, context)?;
}
}
S::Block(ref block) => {
Expand Down Expand Up @@ -801,8 +815,55 @@ impl super::Validator {
} => {
self.validate_atomic(pointer, fun, value, result, context)?;
}
S::RayQuery { query: _, fun: _ } => {
//TODO
S::RayQuery { query, ref fun } => {
let query_var = match *context.get_expression(query) {
crate::Expression::LocalVariable(var) => &context.local_vars[var],
ref other => {
log::error!("Unexpected ray query expression {other:?}");
return Err(FunctionError::InvalidRayQueryExpression(query)
.with_span_static(span, "invalid query expression"));
}
};
match context.types[query_var.ty].inner {
Ti::RayQuery => {}
ref other => {
log::error!("Unexpected ray query type {other:?}");
return Err(FunctionError::InvalidRayQueryType(query_var.ty)
.with_span_static(span, "invalid query type"));
}
}
match *fun {
crate::RayQueryFunction::Initialize {
acceleration_structure,
descriptor,
} => {
match *context
.resolve_type(acceleration_structure, &self.valid_expression_set)?
{
Ti::AccelerationStructure => {}
_ => {
return Err(FunctionError::InvalidAccelerationStructure(
acceleration_structure,
)
.with_span_static(span, "invalid acceleration structure"))
}
}
let desc_ty_given =
context.resolve_type(descriptor, &self.valid_expression_set)?;
let desc_ty_expected = context
.special_types
.ray_desc
.map(|handle| &context.types[handle].inner);
if Some(desc_ty_given) != desc_ty_expected {
return Err(FunctionError::InvalidRayDescriptor(descriptor)
.with_span_static(span, "invalid ray descriptor"));
}
}
crate::RayQueryFunction::Proceed { result } => {
self.emit_expression(result, context)?;
}
crate::RayQueryFunction::Terminate => {}
}
}
}
}
Expand Down
14 changes: 12 additions & 2 deletions tests/in/ray-query.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,29 @@ struct RayIntersection {

struct Output {
visible: u32,
normal: vec3<f32>,
}

@group(0) @binding(1)
var<storage, read_write> output: Output;

fn get_torus_normal(world_point: vec3<f32>, intersection: RayIntersection) -> vec3<f32> {
let local_point = intersection.world_to_object * vec4<f32>(world_point, 1.0);
let point_on_guiding_line = normalize(local_point.xy) * 2.4;
let world_point_on_guiding_line = intersection.object_to_world * vec4<f32>(point_on_guiding_line, 0.0, 1.0);
return normalize(world_point - world_point_on_guiding_line);
}

@compute @workgroup_size(1)
fn main() {
var rq: ray_query;

rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), vec3<f32>(0.0, 1.0, 0.0)));
let dir = vec3<f32>(0.0, 1.0, 0.0);
rayQueryInitialize(&rq, acc_struct, RayDesc(RAY_FLAG_TERMINATE_ON_FIRST_HIT, 0xFFu, 0.1, 100.0, vec3<f32>(0.0), dir));

rayQueryProceed(&rq);
while (rayQueryProceed(&rq)) {}

let intersection = rayQueryGetCommittedIntersection(&rq);
output.visible = u32(intersection.kind == RAY_QUERY_INTERSECTION_NONE);
output.normal = get_torus_normal(dir * intersection.t, intersection);
}
Loading

0 comments on commit 8c5cabe

Please sign in to comment.