Skip to content

Commit

Permalink
Add support for vecN<i32> and vecN<u32> to dot() function (gfx-…
Browse files Browse the repository at this point in the history
…rs#1689)

* Allow vecN<i32> and vecN<u32> in `dot()`, first changes

* Added a test case

* Fix the test

* Changes to baking of expressions, incl args of integer dot product

* Implemented requested changes for glsl backend

* Added support for integer dot product on MSL backend

* Removed outdated code for hlsl and wgls writers

* Implement in spv backend

* Commit modified outputs from running the tests

* cargo fmt

* Applied requested changes for both MSL and GLSL back

* Changes to spv back

* Committed all test output changes

* Cargo fmt

* Added a comment w.r.t. VK_KHR_shader_integer_dot_product

* Implemented requested svp change

* Minor change to test case

This is because I wanted to highlight the fact that the correct
id is used in the last sum of the integer dot product expression

* Changed function signature

since it could not fail, changed it to simply return `void`
  • Loading branch information
francesco-cattoglio authored Feb 3, 2022
1 parent 42bf354 commit b235973
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 42 deletions.
92 changes: 85 additions & 7 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ pub struct Writer<'a, W> {
block_id: IdGenerator,
/// Set of expressions that have associated temporary variables.
named_expressions: crate::NamedExpressions,
/// Set of expressions that need to be baked to avoid unnecessary repetition in output
need_bake_expressions: crate::NeedBakeExpressions,
}

impl<'a, W: Write> Writer<'a, W> {
Expand Down Expand Up @@ -468,6 +470,7 @@ impl<'a, W: Write> Writer<'a, W> {

block_id: IdGenerator::default(),
named_expressions: crate::NamedExpressions::default(),
need_bake_expressions: crate::NeedBakeExpressions::default(),
};

// Find all features required to print this module
Expand Down Expand Up @@ -1000,6 +1003,45 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// Clears `need_bake_expressions` set before adding to it
fn update_expressions_to_bake(&mut self, func: &crate::Function, info: &valid::FunctionInfo) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
}
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
let inner = info[fun_handle].ty.inner_with(&self.module.types);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
}
}

/// Helper method used to get a name for a global
///
/// Globals have different naming schemes depending on their binding:
Expand Down Expand Up @@ -1151,6 +1193,7 @@ impl<'a, W: Write> Writer<'a, W> {
};

self.named_expressions.clear();
self.update_expressions_to_bake(func, info);

// Write the function header
//
Expand Down Expand Up @@ -1401,6 +1444,33 @@ impl<'a, W: Write> Writer<'a, W> {
Ok(())
}

/// Helper method used to output a dot product as an arithmetic expression
///
fn write_dot_product(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
) -> BackendResult {
write!(self.out, "(")?;

let arg0_name = &self.named_expressions[&arg];
let arg1_name = &self.named_expressions[&arg1];

// This will print an extra '+' at the beginning but that is fine in glsl
for index in 0..size {
let component = back::COMPONENTS[index];
write!(
self.out,
" + {}.{} * {}.{}",
arg0_name, component, arg1_name, component
)?;
}

write!(self.out, ")")?;
Ok(())
}

/// Helper method used to write structs
///
/// # Notes
Expand Down Expand Up @@ -1490,13 +1560,10 @@ impl<'a, W: Write> Writer<'a, W> {
// Otherwise, we could accidentally write variable name instead of full expression.
// Also, we use sanitized names! It defense backend from generating variable with name from reserved keywords.
Some(self.namer.call(name))
} else if self.need_bake_expressions.contains(&handle) {
Some(format!("{}{}", super::BAKE_PREFIX, handle.index()))
} else {
let min_ref_count = ctx.expressions[handle].bake_ref_count();
if min_ref_count <= info.ref_count {
Some(format!("{}{}", super::BAKE_PREFIX, handle.index()))
} else {
None
}
None
};

if let Some(name) = expr_name {
Expand Down Expand Up @@ -2538,7 +2605,18 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Dot => match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.write_dot_product(arg, arg1.unwrap(), size as usize)
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",
Expand Down
101 changes: 94 additions & 7 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ pub struct Writer<W> {
out: W,
names: FastHashMap<NameKey, String>,
named_expressions: crate::NamedExpressions,
/// Set of expressions that need to be baked to avoid unnecessary repetition in output
need_bake_expressions: crate::NeedBakeExpressions,
namer: proc::Namer,
#[cfg(test)]
put_expression_stack_pointers: FastHashSet<*const ()>,
Expand Down Expand Up @@ -526,6 +528,7 @@ impl<W: Write> Writer<W> {
out,
names: FastHashMap::default(),
named_expressions: crate::NamedExpressions::default(),
need_bake_expressions: crate::NeedBakeExpressions::default(),
namer: proc::Namer::default(),
#[cfg(test)]
put_expression_stack_pointers: Default::default(),
Expand Down Expand Up @@ -827,6 +830,33 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Emit code for the arithmetic expression of the dot product.
///
fn put_dot_product(
&mut self,
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
) -> BackendResult {
write!(self.out, "(")?;

let arg0_name = &self.named_expressions[&arg];
let arg1_name = &self.named_expressions[&arg1];

// This will print an extra '+' at the beginning but that is fine in msl
for index in 0..size {
let component = back::COMPONENTS[index];
write!(
self.out,
" + {}.{} * {}.{}",
arg0_name, component, arg1_name, component
)?;
}

write!(self.out, ")")?;
Ok(())
}

/// Emit code for the expression `expr_handle`.
///
/// The `is_scoped` argument is true if the surrounding operators have the
Expand Down Expand Up @@ -1216,7 +1246,18 @@ impl<W: Write> Writer<W> {
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Dot => match *context.resolve_type(arg) {
crate::TypeInner::Vector {
kind: crate::ScalarKind::Float,
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(arg, arg1.unwrap(), size as usize)
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
),
},
Mf::Outer => return Err(Error::UnsupportedCall(format!("{:?}", fun))),
Mf::Cross => "cross",
Mf::Distance => "distance",
Expand Down Expand Up @@ -1810,6 +1851,55 @@ impl<W: Write> Writer<W> {
Ok(())
}

/// Helper method used to find which expressions of a given function require baking
///
/// # Notes
/// This function overwrites the contents of `self.need_bake_expressions`
fn update_expressions_to_bake(
&mut self,
func: &crate::Function,
info: &valid::FunctionInfo,
context: &ExpressionContext,
) {
use crate::Expression;
self.need_bake_expressions.clear();
for expr in func.expressions.iter() {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let expr_info = &info[expr.0];
let min_ref_count = func.expressions[expr.0].bake_ref_count();
if min_ref_count <= expr_info.ref_count {
self.need_bake_expressions.insert(expr.0);
}
// if the expression is a Dot product with integer arguments,
// then the args needs baking as well
if let (
fun_handle,
&Expression::Math {
fun: crate::MathFunction::Dot,
arg,
arg1,
..
},
) = expr
{
use crate::TypeInner;
// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(fun_handle);
if let TypeInner::Scalar { kind, .. } = *inner {
match kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
}
}

fn start_baking_expression(
&mut self,
handle: Handle<crate::Expression>,
Expand Down Expand Up @@ -1913,12 +2003,7 @@ impl<W: Write> Writer<W> {
if context.expression.guarded_indices.contains(handle.index()) {
true
} else {
// Expressions whose reference count is above the
// threshold should always be stored in temporaries.
let min_ref_count = context.expression.function.expressions
[handle]
.bake_ref_count();
min_ref_count <= info.ref_count
self.need_bake_expressions.contains(&handle)
};

if bake {
Expand Down Expand Up @@ -2763,6 +2848,7 @@ impl<W: Write> Writer<W> {
result_struct: None,
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
}
Expand Down Expand Up @@ -3226,6 +3312,7 @@ impl<W: Write> Writer<W> {
result_struct: Some(&stage_out_name),
};
self.named_expressions.clear();
self.update_expressions_to_bake(fun, fun_info, &context.expression);
self.put_block(back::Level(1), &fun.body, &context)?;
writeln!(self.out, "}}")?;
if ep_index + 1 != module.entry_points.len() {
Expand Down
Loading

0 comments on commit b235973

Please sign in to comment.