Skip to content

Commit

Permalink
Add unpacking builtins
Browse files Browse the repository at this point in the history
  • Loading branch information
cwfitzgerald committed Oct 15, 2023
1 parent f98d6ae commit 4b4773c
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 73 deletions.
112 changes: 64 additions & 48 deletions tests/tests/shader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
//! shader is run on the input buffer which generates an output buffer. This
//! buffer is then read and compared to a given output.
use std::{borrow::Cow, fmt::Debug};
use std::{borrow::Cow, fmt::Debug, iter::zip, ops::Range};

use bytemuck::Pod;
use wgpu::{
Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry,
BindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor,
Expand Down Expand Up @@ -36,6 +37,33 @@ impl InputStorageType {
}
}

trait ComparisonValue: Debug {
fn compare(&self, actual: u32) -> bool;
}

impl ComparisonValue for u32 {
fn compare(&self, actual: u32) -> bool {
actual == *self
}
}

impl ComparisonValue for f32 {
fn compare(&self, actual: u32) -> bool {
let value: &f32 = bytemuck::cast_ref(&actual);
value == self
}
}

impl<T> ComparisonValue for Range<T>
where
T: Debug + PartialOrd + Pod,
{
fn compare(&self, actual: u32) -> bool {
let value: &T = bytemuck::cast_ref(&actual);
self.contains(value)
}
}

/// Describes a single test of a shader.
struct ShaderTest {
/// Human readable name
Expand All @@ -57,12 +85,7 @@ struct ShaderTest {
/// List of values will be written to the input buffer.
input_values: Vec<u32>,
/// List of lists of valid expected outputs from the shader.
output_values: Vec<Vec<u32>>,
/// Function which compares the output values to the resulting values and
/// prints a message on failure.
///
/// Defaults [`Self::default_comparison_function`].
output_comparison_fn: fn(&str, &[u32], &[Vec<u32>]) -> bool,
output_values: Vec<Vec<Box<dyn ComparisonValue>>>,
/// Value to pre-initialize the output buffer to. Often u32::MAX so
/// that writing a 0 looks different than not writing a value at all.
///
Expand All @@ -75,45 +98,43 @@ struct ShaderTest {
failures: Backends,
}
impl ShaderTest {
fn default_comparison_function<O: bytemuck::Pod + Debug + PartialEq>(
test_name: &str,
actual_values: &[u32],
expected_values: &[Vec<u32>],
) -> bool {
let cast_actual = bytemuck::cast_slice::<u32, O>(actual_values);

fn default_comparison_function(&self, test_name: &str, actual_values: &[u32]) -> bool {
// When printing the error message, we want to trim `cast_actual` to the length
// of the longest set of expected values. This tracks that value.
let mut max_relevant_value_count = 0;

for expected in expected_values {
let cast_expected = bytemuck::cast_slice::<u32, O>(expected);
let mut succeeded = false;
's: for expected in &self.output_values {
max_relevant_value_count = max_relevant_value_count.max(expected.len());

// We shorten the actual to the length of the expected.
if &cast_actual[0..cast_expected.len()] == cast_expected {
return true;
for (actual, expected) in zip(actual_values, expected) {
if !expected.compare(*actual) {
continue 's;
}
}

max_relevant_value_count = max_relevant_value_count.max(cast_expected.len());
succeeded = true;
break 's;
}
if succeeded {
return true;
}

// We haven't found a match, lets print an error.

eprint!(
"Inner test failure. Actual {:?}. Expected",
&cast_actual[0..max_relevant_value_count]
&actual_values[0..max_relevant_value_count]
);

if expected_values.len() != 1 {
if self.output_values.len() != 1 {
eprint!(" one of: ");
} else {
eprint!(": ");
}

for (idx, expected) in expected_values.iter().enumerate() {
let cast_expected = bytemuck::cast_slice::<u32, O>(expected);
eprint!("{cast_expected:?}");
if idx + 1 != expected_values.len() {
for (idx, expected) in self.output_values.iter().enumerate() {
eprint!("{expected:?}");
if idx + 1 != expected.len() {
eprint!(" ");
}
}
Expand All @@ -123,16 +144,15 @@ impl ShaderTest {
false
}

fn new<I, O>(
fn new<I>(
name: String,
custom_struct_members: String,
body: String,
input_values: &[I],
output_values: &[O],
output_values: &[&[impl ComparisonValue + Clone + 'static]],
) -> Self
where
I: bytemuck::Pod,
O: bytemuck::Pod + Debug + PartialEq,
{
Self {
name,
Expand All @@ -141,27 +161,23 @@ impl ShaderTest {
input_type: String::from("CustomStruct"),
output_type: String::from("array<u32>"),
input_values: bytemuck::cast_slice(input_values).to_vec(),
output_values: vec![bytemuck::cast_slice(output_values).to_vec()],
output_comparison_fn: Self::default_comparison_function::<O>,
output_values: output_values
.iter()
.map(|values| {
values
.iter()
.map(|v| {
let v: Box<dyn ComparisonValue> = Box::new(v.clone());
v
})
.collect()
})
.collect(),
output_initialization: u32::MAX,
failures: Backends::empty(),
}
}

/// Add another set of possible outputs. If any of the given
/// output values are seen it's considered a success (i.e. this is OR, not AND).
///
/// Assumes that this type O is the same as the O provided to new.
fn extra_output_values<O: bytemuck::Pod + Debug + PartialEq>(
mut self,
output_values: &[O],
) -> Self {
self.output_values
.push(bytemuck::cast_slice(output_values).to_vec());

self
}

fn failures(mut self, failures: Backends) -> Self {
self.failures = failures;

Expand Down Expand Up @@ -269,7 +285,7 @@ fn shader_input_output_test(
assert!(test.input_values.len() <= MAX_BUFFER_SIZE as usize / 4);
assert!(test.output_values.len() <= MAX_BUFFER_SIZE as usize / 4);

let test_name = test.name;
let test_name = &test.name;

// -- Building shader + pipeline --

Expand Down Expand Up @@ -357,7 +373,7 @@ fn shader_input_output_test(

// -- Check results --

let failure = !(test.output_comparison_fn)(&test_name, typed, &test.output_values);
let failure = !test.default_comparison_function(&test_name, typed);
// We don't immediately panic to let all tests execute
if failure
!= test
Expand Down
Loading

0 comments on commit 4b4773c

Please sign in to comment.