From 4b4773c005dfb24abb17e1a88f68dab30379e176 Mon Sep 17 00:00:00 2001 From: Connor Fitzgerald Date: Sun, 15 Oct 2023 16:02:18 -0400 Subject: [PATCH] Add unpacking builtins --- tests/tests/shader/mod.rs | 112 +++++++------ tests/tests/shader/numeric_builtins.rs | 211 ++++++++++++++++++++++--- tests/tests/shader/struct_layout.rs | 15 +- wgpu-types/src/math.rs | 25 +++ 4 files changed, 290 insertions(+), 73 deletions(-) diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 5701e981e2..509f55bb7c 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -4,8 +4,9 @@ //! shader is run on the input buffer which generates an output buffer. This //! buffer is then read and compared to a given output. -use std::{borrow::Cow, fmt::Debug}; +use std::{borrow::Cow, fmt::Debug, iter::zip, ops::Range}; +use bytemuck::Pod; use wgpu::{ Backends, BindGroupDescriptor, BindGroupEntry, BindGroupLayoutDescriptor, BindGroupLayoutEntry, BindingType, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, ComputePassDescriptor, @@ -36,6 +37,33 @@ impl InputStorageType { } } +trait ComparisonValue: Debug { + fn compare(&self, actual: u32) -> bool; +} + +impl ComparisonValue for u32 { + fn compare(&self, actual: u32) -> bool { + actual == *self + } +} + +impl ComparisonValue for f32 { + fn compare(&self, actual: u32) -> bool { + let value: &f32 = bytemuck::cast_ref(&actual); + value == self + } +} + +impl ComparisonValue for Range +where + T: Debug + PartialOrd + Pod, +{ + fn compare(&self, actual: u32) -> bool { + let value: &T = bytemuck::cast_ref(&actual); + self.contains(value) + } +} + /// Describes a single test of a shader. struct ShaderTest { /// Human readable name @@ -57,12 +85,7 @@ struct ShaderTest { /// List of values will be written to the input buffer. input_values: Vec, /// List of lists of valid expected outputs from the shader. - output_values: Vec>, - /// Function which compares the output values to the resulting values and - /// prints a message on failure. - /// - /// Defaults [`Self::default_comparison_function`]. - output_comparison_fn: fn(&str, &[u32], &[Vec]) -> bool, + output_values: Vec>>, /// Value to pre-initialize the output buffer to. Often u32::MAX so /// that writing a 0 looks different than not writing a value at all. /// @@ -75,45 +98,43 @@ struct ShaderTest { failures: Backends, } impl ShaderTest { - fn default_comparison_function( - test_name: &str, - actual_values: &[u32], - expected_values: &[Vec], - ) -> bool { - let cast_actual = bytemuck::cast_slice::(actual_values); - + fn default_comparison_function(&self, test_name: &str, actual_values: &[u32]) -> bool { // When printing the error message, we want to trim `cast_actual` to the length // of the longest set of expected values. This tracks that value. let mut max_relevant_value_count = 0; - for expected in expected_values { - let cast_expected = bytemuck::cast_slice::(expected); + let mut succeeded = false; + 's: for expected in &self.output_values { + max_relevant_value_count = max_relevant_value_count.max(expected.len()); - // We shorten the actual to the length of the expected. - if &cast_actual[0..cast_expected.len()] == cast_expected { - return true; + for (actual, expected) in zip(actual_values, expected) { + if !expected.compare(*actual) { + continue 's; + } } - - max_relevant_value_count = max_relevant_value_count.max(cast_expected.len()); + succeeded = true; + break 's; + } + if succeeded { + return true; } // We haven't found a match, lets print an error. eprint!( "Inner test failure. Actual {:?}. Expected", - &cast_actual[0..max_relevant_value_count] + &actual_values[0..max_relevant_value_count] ); - if expected_values.len() != 1 { + if self.output_values.len() != 1 { eprint!(" one of: "); } else { eprint!(": "); } - for (idx, expected) in expected_values.iter().enumerate() { - let cast_expected = bytemuck::cast_slice::(expected); - eprint!("{cast_expected:?}"); - if idx + 1 != expected_values.len() { + for (idx, expected) in self.output_values.iter().enumerate() { + eprint!("{expected:?}"); + if idx + 1 != expected.len() { eprint!(" "); } } @@ -123,16 +144,15 @@ impl ShaderTest { false } - fn new( + fn new( name: String, custom_struct_members: String, body: String, input_values: &[I], - output_values: &[O], + output_values: &[&[impl ComparisonValue + Clone + 'static]], ) -> Self where I: bytemuck::Pod, - O: bytemuck::Pod + Debug + PartialEq, { Self { name, @@ -141,27 +161,23 @@ impl ShaderTest { input_type: String::from("CustomStruct"), output_type: String::from("array"), input_values: bytemuck::cast_slice(input_values).to_vec(), - output_values: vec![bytemuck::cast_slice(output_values).to_vec()], - output_comparison_fn: Self::default_comparison_function::, + output_values: output_values + .iter() + .map(|values| { + values + .iter() + .map(|v| { + let v: Box = Box::new(v.clone()); + v + }) + .collect() + }) + .collect(), output_initialization: u32::MAX, failures: Backends::empty(), } } - /// Add another set of possible outputs. If any of the given - /// output values are seen it's considered a success (i.e. this is OR, not AND). - /// - /// Assumes that this type O is the same as the O provided to new. - fn extra_output_values( - mut self, - output_values: &[O], - ) -> Self { - self.output_values - .push(bytemuck::cast_slice(output_values).to_vec()); - - self - } - fn failures(mut self, failures: Backends) -> Self { self.failures = failures; @@ -269,7 +285,7 @@ fn shader_input_output_test( assert!(test.input_values.len() <= MAX_BUFFER_SIZE as usize / 4); assert!(test.output_values.len() <= MAX_BUFFER_SIZE as usize / 4); - let test_name = test.name; + let test_name = &test.name; // -- Building shader + pipeline -- @@ -357,7 +373,7 @@ fn shader_input_output_test( // -- Check results -- - let failure = !(test.output_comparison_fn)(&test_name, typed, &test.output_values); + let failure = !test.default_comparison_function(&test_name, typed); // We don't immediately panic to let all tests execute if failure != test diff --git a/tests/tests/shader/numeric_builtins.rs b/tests/tests/shader/numeric_builtins.rs index 263a3db99c..05c46cecbc 100644 --- a/tests/tests/shader/numeric_builtins.rs +++ b/tests/tests/shader/numeric_builtins.rs @@ -1,4 +1,7 @@ +use std::ops::Range; + use wgpu::{DownlevelFlags, Limits}; +use wgt::math::f32_next; use crate::shader::{shader_input_output_test, InputStorageType, ShaderTest}; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; @@ -19,19 +22,15 @@ fn create_clamp_builtin_test() -> Vec { ( 3.0, 2.0, 1.0, &[1.0, 2.0]), ]; - for &(input, low, high, output) in clamp_values { - let mut test = ShaderTest::new( - format!("clamp({input}, 0.0, 10.0) == {output:?})"), + for &(input, low, high, outputs) in clamp_values { + let nested_outputs: Vec<_> = outputs.iter().map(|v| std::slice::from_ref(v)).collect(); + tests.push(ShaderTest::new( + format!("clamp({input}, 0.0, 10.0) == {outputs:?})"), String::from("value: f32, low: f32, high: f32"), String::from("output[0] = bitcast(clamp(input.value, input.low, input.high));"), &[input, low, high], - &[output[0]], - ); - for &extra in &output[1..] { - test = test.extra_output_values(&[extra]); - } - - tests.push(test); + &nested_outputs, + )); } tests @@ -48,6 +47,7 @@ static CLAMP_BUILTIN: GpuTestConfiguration = GpuTestConfiguration::new() shader_input_output_test(ctx, InputStorageType::Storage, create_clamp_builtin_test()); }); +#[allow(clippy::excessive_precision)] fn create_pack_builtin_test() -> Vec { let mut tests = Vec::new(); @@ -178,18 +178,15 @@ fn create_pack_builtin_test() -> Vec { Function::Pack4x8Unorm => "output[0] = pack4x8unorm(input.value);", }; - let mut test = ShaderTest::new( + let outputs: Vec<_> = outputs.iter().map(|v| std::slice::from_ref(v)).collect(); + + tests.push(ShaderTest::new( name, String::from(members), String::from(body), inputs, - &[outputs[0]], - ); - - for &output in &outputs[1..] { - test = test.extra_output_values(&[output]); - } - tests.push(test); + &outputs, + )); } tests @@ -205,3 +202,181 @@ static PACKING_BUILTINS: GpuTestConfiguration = GpuTestConfiguration::new() .run_sync(|ctx| { shader_input_output_test(ctx, InputStorageType::Storage, create_pack_builtin_test()); }); + +#[allow(clippy::excessive_precision)] +fn create_unpack_builtin_test() -> Vec { + let mut tests = Vec::new(); + + // Magic numbers from the spec + // https://github.com/gpuweb/cts/blob/main/src/unittests/floating_point.spec.ts + + pub const ZERO_BOUNDS: Range = f32::MIN_POSITIVE * -1.0..f32::MIN_POSITIVE; + + pub const ONE_BOUNDS_SNORM: Range = 0.999999821186065673828125..1.0000002384185791015625; + + pub const ONE_BOUNDS_UNORM: Range = + 0.9999998509883880615234375..1.0000001490116119384765625; + + pub const NEG_ONE_BOUNDS_SNORM: Range = -1.0 - f32::EPSILON..-0.999999821186065673828125; + + pub const HALF_BOUNDS_2X16_SNORM: Range = + 0.500015079975128173828125..0.5000154972076416015625; + + pub const NEG_HALF_BOUNDS_2X16_SNORM: Range = + -0.4999848306179046630859375..-0.49998462200164794921875; + + pub const HALF_BOUNDS_2X16_UNORM: Range = + 0.5000074803829193115234375..0.5000078380107879638671875; + + pub const HALF_BOUNDS_4X8_SNORM: Range = + 0.503936827182769775390625..0.503937244415283203125; + + pub const NEG_HALF_BOUNDS_4X8_SNORM: Range = + -0.4960630834102630615234375..-0.49606287479400634765625; + + pub const HALF_BOUNDS_4X8_UNORM: Range = + 0.5019606053829193115234375..0.5019609630107879638671875; + + fn range(value: f32) -> Range { + value..f32_next(value) + } + + #[derive(Clone, Copy)] + enum Function { + Unpack2x16Float, + Unpack2x16Unorm, + Unpack2x16Snorm, + Unpack4x8Snorm, + Unpack4x8Unorm, + } + + #[rustfmt::skip] + let values: &[(Function, u32, &[Range])] = &[ + (Function::Unpack2x16Snorm, 0x00000000, &[ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack2x16Snorm, 0x00007fff, &[ONE_BOUNDS_SNORM, ZERO_BOUNDS]), + (Function::Unpack2x16Snorm, 0x7fff0000, &[ZERO_BOUNDS, ONE_BOUNDS_SNORM]), + (Function::Unpack2x16Snorm, 0x7fff7fff, &[ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM]), + (Function::Unpack2x16Snorm, 0x80018001, &[NEG_ONE_BOUNDS_SNORM, NEG_ONE_BOUNDS_SNORM]), + (Function::Unpack2x16Snorm, 0x40004000, &[HALF_BOUNDS_2X16_SNORM, HALF_BOUNDS_2X16_SNORM]), + (Function::Unpack2x16Snorm, 0xc001c001, &[NEG_HALF_BOUNDS_2X16_SNORM, NEG_HALF_BOUNDS_2X16_SNORM]), + (Function::Unpack2x16Snorm, 0x0000c001, &[NEG_HALF_BOUNDS_2X16_SNORM, ZERO_BOUNDS]), + (Function::Unpack2x16Snorm, 0xc0010000, &[ZERO_BOUNDS, NEG_HALF_BOUNDS_2X16_SNORM]), + + (Function::Unpack2x16Unorm, 0x00000000, &[ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack2x16Unorm, 0x0000ffff, &[ONE_BOUNDS_UNORM, ZERO_BOUNDS]), + (Function::Unpack2x16Unorm, 0xffff0000, &[ZERO_BOUNDS, ONE_BOUNDS_UNORM]), + (Function::Unpack2x16Unorm, 0xffffffff, &[ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM]), + (Function::Unpack2x16Unorm, 0x80008000, &[HALF_BOUNDS_2X16_UNORM, HALF_BOUNDS_2X16_UNORM]), + + (Function::Unpack4x8Snorm, 0x00000000, &[ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x0000007f, &[ONE_BOUNDS_SNORM, ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x00007f00, &[ZERO_BOUNDS, ONE_BOUNDS_SNORM, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x007f0000, &[ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_SNORM, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x7f000000, &[ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_SNORM]), + (Function::Unpack4x8Snorm, 0x00007f7f, &[ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x7f7f0000, &[ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM]), + (Function::Unpack4x8Snorm, 0x7f007f00, &[ZERO_BOUNDS, ONE_BOUNDS_SNORM, ZERO_BOUNDS, ONE_BOUNDS_SNORM]), + (Function::Unpack4x8Snorm, 0x007f007f, &[ONE_BOUNDS_SNORM, ZERO_BOUNDS, ONE_BOUNDS_SNORM, ZERO_BOUNDS]), + (Function::Unpack4x8Snorm, 0x7f7f7f7f, &[ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM, ONE_BOUNDS_SNORM]), + (Function::Unpack4x8Snorm, 0x81818181, &[NEG_ONE_BOUNDS_SNORM, NEG_ONE_BOUNDS_SNORM, NEG_ONE_BOUNDS_SNORM, NEG_ONE_BOUNDS_SNORM]), + (Function::Unpack4x8Snorm, 0x40404040, &[HALF_BOUNDS_4X8_SNORM, HALF_BOUNDS_4X8_SNORM, HALF_BOUNDS_4X8_SNORM, HALF_BOUNDS_4X8_SNORM]), + (Function::Unpack4x8Snorm, 0xc1c1c1c1, &[NEG_HALF_BOUNDS_4X8_SNORM, NEG_HALF_BOUNDS_4X8_SNORM, NEG_HALF_BOUNDS_4X8_SNORM, NEG_HALF_BOUNDS_4X8_SNORM]), + + (Function::Unpack4x8Unorm, 0x00000000, &[ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0x000000ff, &[ONE_BOUNDS_UNORM, ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0x0000ff00, &[ZERO_BOUNDS, ONE_BOUNDS_UNORM, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0x00ff0000, &[ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_UNORM, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0xff000000, &[ZERO_BOUNDS, ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_UNORM]), + (Function::Unpack4x8Unorm, 0x0000ffff, &[ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM, ZERO_BOUNDS, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0xffff0000, &[ZERO_BOUNDS, ZERO_BOUNDS, ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM]), + (Function::Unpack4x8Unorm, 0xff00ff00, &[ZERO_BOUNDS, ONE_BOUNDS_UNORM, ZERO_BOUNDS, ONE_BOUNDS_UNORM]), + (Function::Unpack4x8Unorm, 0x00ff00ff, &[ONE_BOUNDS_UNORM, ZERO_BOUNDS, ONE_BOUNDS_UNORM, ZERO_BOUNDS]), + (Function::Unpack4x8Unorm, 0xffffffff, &[ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM, ONE_BOUNDS_UNORM]), + (Function::Unpack4x8Unorm, 0x80808080, &[HALF_BOUNDS_4X8_UNORM, HALF_BOUNDS_4X8_UNORM, HALF_BOUNDS_4X8_UNORM, HALF_BOUNDS_4X8_UNORM]), + + + (Function::Unpack2x16Float, 0x00000000, &[range(0.0), range(0.0)]), + (Function::Unpack2x16Float, 0x80000000, &[range(0.0), range(0.0)]), + (Function::Unpack2x16Float, 0x00008000, &[range(0.0), range(0.0)]), + (Function::Unpack2x16Float, 0x80008000, &[range(0.0), range(0.0)]), + (Function::Unpack2x16Float, 0x00003c00, &[range(1.0), range(0.0)]), + (Function::Unpack2x16Float, 0x3c000000, &[range(0.0), range(1.0)]), + (Function::Unpack2x16Float, 0x3c003c00, &[range(1.0), range(1.0)]), + (Function::Unpack2x16Float, 0xbc00bc00, &[range(-1.0), range(-1.0)]), + (Function::Unpack2x16Float, 0x49004900, &[range(10.0), range(10.0)]), + (Function::Unpack2x16Float, 0xc900c900, &[range(-10.0), range(-10.0)]), + ]; + + for &(function, input, outputs) in values { + let name = match function { + Function::Unpack2x16Float => format!("unpack2x16float({input:#x?}) == {outputs:?}"), + Function::Unpack2x16Unorm => format!("unpack2x16unorm({input:#x?}) == {outputs:?}"), + Function::Unpack2x16Snorm => format!("unpack2x16snorm({input:#x?}) == {outputs:?}"), + Function::Unpack4x8Snorm => format!("unpack4x8snorm({input:#x?}) == {outputs:?}"), + Function::Unpack4x8Unorm => format!("unpack4x8unorm({input:#x?}) == {outputs:?}"), + }; + + let body = match function { + Function::Unpack2x16Float => { + " + let value = unpack2x16float(input.value); + output[0] = bitcast(value.x); + output[1] = bitcast(value.y); + " + } + Function::Unpack2x16Unorm => { + " + let value = unpack2x16unorm(input.value); + output[0] = bitcast(value.x); + output[1] = bitcast(value.y); + " + } + Function::Unpack2x16Snorm => { + " + let value = unpack2x16snorm(input.value); + output[0] = bitcast(value.x); + output[1] = bitcast(value.y); + " + } + Function::Unpack4x8Snorm => { + " + let value = unpack4x8snorm(input.value); + output[0] = bitcast(value.x); + output[1] = bitcast(value.y); + output[2] = bitcast(value.z); + output[3] = bitcast(value.w); + " + } + Function::Unpack4x8Unorm => { + " + let value = unpack4x8unorm(input.value); + output[0] = bitcast(value.x); + output[1] = bitcast(value.y); + output[2] = bitcast(value.z); + output[3] = bitcast(value.w); + " + } + }; + + tests.push(ShaderTest::new( + name, + String::from("value: u32"), + String::from(body), + &[input], + &[outputs], + )); + } + + tests +} + +#[gpu_test] +static UNPACKING_BUILTINS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .downlevel_flags(DownlevelFlags::COMPUTE_SHADERS) + .limits(Limits::downlevel_defaults()), + ) + .run_sync(|ctx| { + shader_input_output_test(ctx, InputStorageType::Storage, create_unpack_builtin_test()); + }); diff --git a/tests/tests/shader/struct_layout.rs b/tests/tests/shader/struct_layout.rs index f17dceac08..16360a5e24 100644 --- a/tests/tests/shader/struct_layout.rs +++ b/tests/tests/shader/struct_layout.rs @@ -32,12 +32,13 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec writeln!(loaded, "output[{idx}] = bitcast(loaded.{component});").unwrap(); } + let output_values = (0..components as u32).collect::>(); tests.push(ShaderTest::new( format!("vec{components}<{ty}> - direct"), input_members.clone(), direct, &input_values, - &(0..components as u32).collect::>(), + &[&output_values], )); tests.push(ShaderTest::new( @@ -45,7 +46,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec input_members.clone(), loaded, &input_values, - &(0..components as u32).collect::>(), + &[&output_values], )); } } @@ -112,7 +113,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec input_members.clone(), direct, &input_values, - &output_values, + &[&output_values], ) .failures(failures), ); @@ -123,7 +124,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec input_members.clone(), vector_loaded, &input_values, - &output_values, + &[&output_values], ) .failures(failures), ); @@ -134,7 +135,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec input_members.clone(), fully_loaded, &input_values, - &output_values, + &[&output_values], ) .failures(failures), ); @@ -151,7 +152,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec members, direct, &input_values, - &[3], + &[&[3]], )); } @@ -166,7 +167,7 @@ fn create_struct_layout_tests(storage_type: InputStorageType) -> Vec members, direct, &input_values, - &[columns * 4], + &[&[columns * 4]], )); } } diff --git a/wgpu-types/src/math.rs b/wgpu-types/src/math.rs index 593e5d810c..686fbe1913 100644 --- a/wgpu-types/src/math.rs +++ b/wgpu-types/src/math.rs @@ -29,3 +29,28 @@ where value + alignment - remainder } } + +/// Returns the next representable f32 value after `value`. +/// +/// Copied from the unstable https://doc.rust-lang.org/src/core/num/f32.rs.html#710-730 +pub fn f32_next(value: f32) -> f32 { + // We must use strictly integer arithmetic to prevent denormals from + // flushing to zero after an arithmetic operation on some platforms. + const TINY_BITS: u32 = 0x1; // Smallest positive f32. + const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff; + + let bits = value.to_bits(); + if value.is_nan() || bits == f32::INFINITY.to_bits() { + return value; + } + + let abs = bits & CLEAR_SIGN_MASK; + let next_bits = if abs == 0 { + TINY_BITS + } else if bits == abs { + bits + 1 + } else { + bits - 1 + }; + f32::from_bits(next_bits) +}