Skip to content

Commit

Permalink
Support DXR in wgpu-hal & naga. (#6777)
Browse files Browse the repository at this point in the history
Co-authored-by: Connor Fitzgerald <[email protected]>
  • Loading branch information
Vecvec and cwfitzgerald authored Jan 15, 2025
1 parent 2cfea40 commit 21de7f7
Show file tree
Hide file tree
Showing 20 changed files with 1,008 additions and 60 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ By @wumpf in [#6849](https://github.com/gfx-rs/wgpu/pull/6849).
- Add build support for Apple Vision Pro. By @guusw in [#6611](https://github.com/gfx-rs/wgpu/pull/6611).
- Add `raw_handle` method to access raw Metal textures in [#6894](https://github.com/gfx-rs/wgpu/pull/6894).

#### D3D12

- Support DXR (DirectX Ray-tracing) in wgpu-hal. By @Vecvec in [#6777](https://github.com/gfx-rs/wgpu/pull/6777)

#### Changes

##### Naga
Expand Down
1 change: 0 additions & 1 deletion examples/src/ray_cube_compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ struct Example {
impl crate::framework::Example for Example {
fn required_features() -> wgpu::Features {
wgpu::Features::TEXTURE_BINDING_ARRAY
| wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY
| wgpu::Features::VERTEX_WRITABLE_STORAGE
| wgpu::Features::EXPERIMENTAL_RAY_QUERY
| wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE
Expand Down
33 changes: 25 additions & 8 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,9 @@ impl<W: Write> super::Writer<'_, W> {
&crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {}
}
}
if module.special_types.ray_desc.is_some() {
self.write_ray_desc_from_ray_desc_constructor_function(module)?;
}

Ok(())
}
Expand All @@ -852,16 +855,30 @@ impl<W: Write> super::Writer<'_, W> {
expressions: &crate::Arena<crate::Expression>,
) -> BackendResult {
for (handle, _) in expressions.iter() {
if let crate::Expression::Compose { ty, .. } = expressions[handle] {
match module.types[ty].inner {
crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
let constructor = WrappedConstructor { ty };
if self.wrapped.constructors.insert(constructor) {
self.write_wrapped_constructor_function(module, constructor)?;
match expressions[handle] {
crate::Expression::Compose { ty, .. } => {
match module.types[ty].inner {
crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => {
let constructor = WrappedConstructor { ty };
if self.wrapped.constructors.insert(constructor) {
self.write_wrapped_constructor_function(module, constructor)?;
}
}
_ => {}
};
}
crate::Expression::RayQueryGetIntersection { committed, .. } => {
if committed {
if !self.written_committed_intersection {
self.write_committed_intersection_function(module)?;
self.written_committed_intersection = true;
}
} else if !self.written_candidate_intersection {
self.write_candidate_intersection_function(module)?;
self.written_candidate_intersection = true;
}
_ => {}
};
}
_ => {}
}
}
Ok(())
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ pub const RESERVED: &[&str] = &[
"TextureBuffer",
"ConstantBuffer",
"RayQuery",
"RayDesc",
// Naga utilities
super::writer::MODF_FUNCTION,
super::writer::FREXP_FUNCTION,
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ accessing individual columns by dynamic index.
mod conv;
mod help;
mod keywords;
mod ray;
mod storage;
mod writer;

Expand Down Expand Up @@ -331,6 +332,8 @@ pub struct Writer<'a, W> {
/// Set of expressions that have associated temporary variables
named_expressions: crate::NamedExpressions,
wrapped: Wrapped,
written_committed_intersection: bool,
written_candidate_intersection: bool,
continue_ctx: back::continue_forward::ContinueCtx,

/// A reference to some part of a global variable, lowered to a series of
Expand Down
163 changes: 163 additions & 0 deletions naga/src/back/hlsl/ray.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
use crate::back::hlsl::BackendResult;
use crate::{RayQueryIntersection, TypeInner};
use std::fmt::Write;

impl<W: Write> super::Writer<'_, W> {
// constructs hlsl RayDesc from wgsl RayDesc
pub(super) fn write_ray_desc_from_ray_desc_constructor_function(
&mut self,
module: &crate::Module,
) -> BackendResult {
write!(self.out, "RayDesc RayDescFromRayDesc_(")?;
self.write_type(module, module.special_types.ray_desc.unwrap())?;
writeln!(self.out, " arg0) {{")?;
writeln!(self.out, " RayDesc ret = (RayDesc)0;")?;
writeln!(self.out, " ret.Origin = arg0.origin;")?;
writeln!(self.out, " ret.TMin = arg0.tmin;")?;
writeln!(self.out, " ret.Direction = arg0.dir;")?;
writeln!(self.out, " ret.TMax = arg0.tmax;")?;
writeln!(self.out, " return ret;")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}
pub(super) fn write_committed_intersection_function(
&mut self,
module: &crate::Module,
) -> BackendResult {
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
write!(self.out, " GetCommittedIntersection(")?;
self.write_value_type(module, &TypeInner::RayQuery)?;
writeln!(self.out, " rq) {{")?;
write!(self.out, " ")?;
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
write!(self.out, " ret = (")?;
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
writeln!(self.out, ")0;")?;
writeln!(self.out, " ret.kind = rq.CommittedStatus();")?;
writeln!(
self.out,
" if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{"
)?;
writeln!(self.out, " ret.t = rq.CommittedRayT();")?;
writeln!(
self.out,
" ret.instance_custom_index = rq.CommittedInstanceID();"
)?;
writeln!(
self.out,
" ret.instance_id = rq.CommittedInstanceIndex();"
)?;
writeln!(
self.out,
" ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();"
)?;
writeln!(
self.out,
" ret.geometry_index = rq.CommittedGeometryIndex();"
)?;
writeln!(
self.out,
" ret.primitive_index = rq.CommittedPrimitiveIndex();"
)?;
writeln!(
self.out,
" if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{"
)?;
writeln!(
self.out,
" ret.barycentrics = rq.CommittedTriangleBarycentrics();"
)?;
writeln!(
self.out,
" ret.front_face = rq.CommittedTriangleFrontFace();"
)?;
writeln!(self.out, " }}")?;
writeln!(
self.out,
" ret.object_to_world = rq.CommittedObjectToWorld4x3();"
)?;
writeln!(
self.out,
" ret.world_to_object = rq.CommittedWorldToObject4x3();"
)?;
writeln!(self.out, " }}")?;
writeln!(self.out, " return ret;")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}
pub(super) fn write_candidate_intersection_function(
&mut self,
module: &crate::Module,
) -> BackendResult {
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
write!(self.out, " GetCandidateIntersection(")?;
self.write_value_type(module, &TypeInner::RayQuery)?;
writeln!(self.out, " rq) {{")?;
write!(self.out, " ")?;
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
write!(self.out, " ret = (")?;
self.write_type(module, module.special_types.ray_intersection.unwrap())?;
writeln!(self.out, ")0;")?;
writeln!(self.out, " CANDIDATE_TYPE kind = rq.CandidateType();")?;
writeln!(
self.out,
" if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{"
)?;
writeln!(
self.out,
" ret.kind = {};",
RayQueryIntersection::Triangle as u32
)?;
writeln!(self.out, " ret.t = rq.CandidateTriangleRayT();")?;
writeln!(
self.out,
" ret.barycentrics = rq.CandidateTriangleBarycentrics();"
)?;
writeln!(
self.out,
" ret.front_face = rq.CandidateTriangleFrontFace();"
)?;
writeln!(self.out, " }} else {{")?;
writeln!(
self.out,
" ret.kind = {};",
RayQueryIntersection::Aabb as u32
)?;
writeln!(self.out, " }}")?;

writeln!(
self.out,
" ret.instance_custom_index = rq.CandidateInstanceID();"
)?;
writeln!(
self.out,
" ret.instance_id = rq.CandidateInstanceIndex();"
)?;
writeln!(
self.out,
" ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();"
)?;
writeln!(
self.out,
" ret.geometry_index = rq.CandidateGeometryIndex();"
)?;
writeln!(
self.out,
" ret.primitive_index = rq.CandidatePrimitiveIndex();"
)?;
writeln!(
self.out,
" ret.object_to_world = rq.CandidateObjectToWorld4x3();"
)?;
writeln!(
self.out,
" ret.world_to_object = rq.CandidateWorldToObject4x3();"
)?;
writeln!(self.out, " return ret;")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
Ok(())
}
}
79 changes: 67 additions & 12 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::{
use crate::{
back::{self, Baked},
proc::{self, index, ExpressionKindTracker, NameKey},
valid, Handle, Module, Scalar, ScalarKind, ShaderStage, TypeInner,
valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner,
};
use std::{fmt, mem};

Expand Down Expand Up @@ -104,6 +104,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
entry_point_io: Vec::new(),
named_expressions: crate::NamedExpressions::default(),
wrapped: super::Wrapped::default(),
written_committed_intersection: false,
written_candidate_intersection: false,
continue_ctx: back::continue_forward::ContinueCtx::default(),
temp_access_chain: Vec::new(),
need_bake_expressions: Default::default(),
Expand All @@ -123,6 +125,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.entry_point_io.clear();
self.named_expressions.clear();
self.wrapped.clear();
self.written_committed_intersection = false;
self.written_candidate_intersection = false;
self.continue_ctx.clear();
self.need_bake_expressions.clear();
}
Expand Down Expand Up @@ -1218,6 +1222,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => {
self.write_array_size(module, base, size)?;
}
TypeInner::AccelerationStructure => {
write!(self.out, "RaytracingAccelerationStructure")?;
}
TypeInner::RayQuery => {
// these are constant flags, there are dynamic flags also but constant flags are not supported by naga
write!(self.out, "RayQuery<RAY_FLAG_NONE>")?;
}
_ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))),
}

Expand Down Expand Up @@ -1375,15 +1386,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_array_size(module, base, size)?;
}

write!(self.out, " = ")?;
// Write the local initializer if needed
if let Some(init) = local.init {
self.write_expr(module, init, func_ctx)?;
} else {
// Zero initialize local variables
self.write_default_init(module, local.ty)?;
match module.types[local.ty].inner {
// from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed
TypeInner::RayQuery => {}
_ => {
write!(self.out, " = ")?;
// Write the local initializer if needed
if let Some(init) = local.init {
self.write_expr(module, init, func_ctx)?;
} else {
// Zero initialize local variables
self.write_default_init(module, local.ty)?;
}
}
}

// Finish the local with `;` and add a newline (only for readability)
writeln!(self.out, ";")?
}
Expand Down Expand Up @@ -2250,7 +2266,37 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
} => {
self.write_switch(module, func_ctx, level, selector, cases)?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::RayQuery { query, ref fun } => match *fun {
RayQueryFunction::Initialize {
acceleration_structure,
descriptor,
} => {
write!(self.out, "{level}")?;
self.write_expr(module, query, func_ctx)?;
write!(self.out, ".TraceRayInline(")?;
self.write_expr(module, acceleration_structure, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, descriptor, func_ctx)?;
write!(self.out, ".flags, ")?;
self.write_expr(module, descriptor, func_ctx)?;
write!(self.out, ".cull_mask, ")?;
write!(self.out, "RayDescFromRayDesc_(")?;
self.write_expr(module, descriptor, func_ctx)?;
writeln!(self.out, "));")?;
}
RayQueryFunction::Proceed { result } => {
write!(self.out, "{level}")?;
let name = Baked(result).to_string();
write!(self.out, "const bool {name} = ")?;
self.named_expressions.insert(result, name);
self.write_expr(module, query, func_ctx)?;
writeln!(self.out, ".Proceed();")?;
}
RayQueryFunction::Terminate => {
self.write_expr(module, query, func_ctx)?;
writeln!(self.out, ".Abort();")?;
}
},
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let name = Baked(result).to_string();
Expand Down Expand Up @@ -3608,8 +3654,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, reject, func_ctx)?;
write!(self.out, ")")?
}
// Not supported yet
Expression::RayQueryGetIntersection { .. } => unreachable!(),
Expression::RayQueryGetIntersection { query, committed } => {
if committed {
write!(self.out, "GetCommittedIntersection(")?;
self.write_expr(module, query, func_ctx)?;
write!(self.out, ")")?;
} else {
write!(self.out, "GetCandidateIntersection(")?;
self.write_expr(module, query, func_ctx)?;
write!(self.out, ")")?;
}
}
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
Expand Down
7 changes: 7 additions & 0 deletions naga/tests/in/ray-query.param.ron
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@
per_entry_point_map: {},
inline_samplers: [],
),
hlsl: (
shader_model: V6_5,
binding_map: {},
fake_missing_bindings: true,
special_constants_binding: None,
zero_initialize_workgroup_memory: true,
)
)
Loading

0 comments on commit 21de7f7

Please sign in to comment.