From daf7dee8859610be51e74a5554ec10ca689aafa6 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 21 Nov 2023 17:33:28 -0800 Subject: [PATCH] [naga wgsl] Experimental 64-bit floating-point literals. In the WGSL front and back ends, support an `lf` suffix on floating-point literals to yield 64-bit integer literals. --- naga/src/back/wgsl/writer.rs | 4 +- naga/src/front/wgsl/lower/mod.rs | 1 + naga/src/front/wgsl/parse/lexer.rs | 17 ++++++++ naga/src/front/wgsl/parse/number.rs | 44 ++++++++++++++++----- naga/tests/in/f64.param.ron | 12 ++++++ naga/tests/in/f64.wgsl | 13 ++++++ naga/tests/out/glsl/f64.main.Compute.glsl | 19 +++++++++ naga/tests/out/hlsl/f64.hlsl | 19 +++++++++ naga/tests/out/hlsl/f64.ron | 12 ++++++ naga/tests/out/spv/f64.spvasm | 48 +++++++++++++++++++++++ naga/tests/out/wgsl/f64.wgsl | 17 ++++++++ naga/tests/snapshots.rs | 4 ++ 12 files changed, 198 insertions(+), 12 deletions(-) create mode 100644 naga/tests/in/f64.param.ron create mode 100644 naga/tests/in/f64.wgsl create mode 100644 naga/tests/out/glsl/f64.main.Compute.glsl create mode 100644 naga/tests/out/hlsl/f64.hlsl create mode 100644 naga/tests/out/hlsl/f64.ron create mode 100644 naga/tests/out/spv/f64.spvasm create mode 100644 naga/tests/out/wgsl/f64.wgsl diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 1906035e865..10da339968c 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1095,9 +1095,7 @@ impl Writer { crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, - crate::Literal::F64(_) => { - return Err(Error::Custom("unsupported f64 literal".to_string())); - } + crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?, crate::Literal::I64(_) => { return Err(Error::Custom("unsupported i64 literal".to_string())); } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 34940e4515b..f6adcb58e7a 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1464,6 +1464,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f), ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), + ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), ast::Literal::Number(_) => { unreachable!("got abstract numeric type when not expected"); } diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index dc229bb5fa3..0c7a8d537de 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -448,6 +448,7 @@ impl<'a> Lexer<'a> { } #[cfg(test)] +#[track_caller] fn sub_test(source: &str, expected_tokens: &[Token]) { let mut lex = Lexer::new(source); for &token in expected_tokens { @@ -624,6 +625,22 @@ fn test_numbers() { ); } +#[test] +fn double_floats() { + sub_test( + "0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l", + &[ + Token::Number(Ok(Number::F64(18.0))), + Token::Number(Ok(Number::F64(256.0))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(0.0625))), + Token::Number(Ok(Number::F64(10.0))), + Token::Number(Ok(Number::I32(10))), + Token::Word("l"), + ], + ) +} + #[test] fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index 57a2be61422..3178736990f 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -16,6 +16,8 @@ pub enum Number { U32(u32), /// Concrete f32 F32(f32), + /// Concrete f64 + F64(f64), } impl Number { @@ -61,9 +63,11 @@ enum IntKind { U32, } +#[derive(Debug)] enum FloatKind { - F32, F16, + F32, + F64, } // The following regexes (from the WGSL spec) will be matched: @@ -104,9 +108,9 @@ fn parse(input: &str) -> (Result, &str) { /// if one of the given patterns are found at the start of the buffer /// returning the corresponding expr for the matched pattern macro_rules! consume_map { - ($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => { + ($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => { match $bytes { - $( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )* + $( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )* _ => None, } }; @@ -136,6 +140,16 @@ fn parse(input: &str) -> (Result, &str) { }}; } + macro_rules! consume_float_suffix { + ($bytes:ident) => { + consume_map!($bytes, [ + b'h' => FloatKind::F16, + b'f' => FloatKind::F32, + b'l', b'f' => FloatKind::F64, + ]) + }; + } + /// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str` macro_rules! rest_to_str { ($bytes:ident) => { @@ -190,7 +204,7 @@ fn parse(input: &str) -> (Result, &str) { let number = general_extract.end(bytes); - let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + let kind = consume_float_suffix!(bytes); (parse_hex_float(number, kind), rest_to_str!(bytes)) } else { @@ -219,7 +233,7 @@ fn parse(input: &str) -> (Result, &str) { let exponent = exp_extract.end(bytes); - let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + let kind = consume_float_suffix!(bytes); ( parse_hex_float_missing_period(significand, exponent, kind), @@ -257,7 +271,7 @@ fn parse(input: &str) -> (Result, &str) { let number = general_extract.end(bytes); - let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + let kind = consume_float_suffix!(bytes); (parse_dec_float(number, kind), rest_to_str!(bytes)) } else { @@ -275,7 +289,7 @@ fn parse(input: &str) -> (Result, &str) { let number = general_extract.end(bytes); - let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]); + let kind = consume_float_suffix!(bytes); (parse_dec_float(number, kind), rest_to_str!(bytes)) } else { @@ -289,8 +303,9 @@ fn parse(input: &str) -> (Result, &str) { let kind = consume_map!(bytes, [ b'i' => Kind::Int(IntKind::I32), b'u' => Kind::Int(IntKind::U32), + b'h' => Kind::Float(FloatKind::F16), b'f' => Kind::Float(FloatKind::F32), - b'h' => Kind::Float(FloatKind::F16) + b'l', b'f' => Kind::Float(FloatKind::F64), ]); ( @@ -382,12 +397,17 @@ fn parse_hex_float(input: &str, kind: Option) -> Result Err(NumberError::NotRepresentable), }, + Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) { Ok(num) => Ok(Number::F32(num)), // can only be ParseHexfErrorKind::Inexact but we can't check since it's private _ => Err(NumberError::NotRepresentable), }, - Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), + Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) { + Ok(num) => Ok(Number::F64(num)), + // can only be ParseHexfErrorKind::Inexact but we can't check since it's private + _ => Err(NumberError::NotRepresentable), + }, } } @@ -407,6 +427,12 @@ fn parse_dec_float(input: &str, kind: Option) -> Result { + let num = input.parse::().unwrap(); // will never fail + num.is_finite() + .then_some(Number::F64(num)) + .ok_or(NumberError::NotRepresentable) + } Some(FloatKind::F16) => Err(NumberError::UnimplementedF16), } } diff --git a/naga/tests/in/f64.param.ron b/naga/tests/in/f64.param.ron new file mode 100644 index 00000000000..f1f5359da63 --- /dev/null +++ b/naga/tests/in/f64.param.ron @@ -0,0 +1,12 @@ +( + god_mode: true, + spv: ( + version: (1, 0), + ), + glsl: ( + version: Desktop(420), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/f64.wgsl b/naga/tests/in/f64.wgsl new file mode 100644 index 00000000000..268a6184a60 --- /dev/null +++ b/naga/tests/in/f64.wgsl @@ -0,0 +1,13 @@ +var v: f64 = 1lf; +const k: f64 = 2.0lf; + +fn f(x: f64) -> f64 { + let y: f64 = 3e1lf + 4.0e2lf; + var z = y + f64(5); + return x + y + k + 5.0lf; +} + +@compute @workgroup_size(1) +fn main() { + f(6.0lf); +} diff --git a/naga/tests/out/glsl/f64.main.Compute.glsl b/naga/tests/out/glsl/f64.main.Compute.glsl new file mode 100644 index 00000000000..7d20f8d6fcc --- /dev/null +++ b/naga/tests/out/glsl/f64.main.Compute.glsl @@ -0,0 +1,19 @@ +#version 420 core +#extension GL_ARB_compute_shader : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +const double k = 2.0LF; + + +double f(double x) { + double z = 0.0; + double y = (30.0LF + 400.0LF); + z = (y + double(5)); + return (((x + y) + k) + 5.0LF); +} + +void main() { + double _e1 = f(6.0LF); + return; +} + diff --git a/naga/tests/out/hlsl/f64.hlsl b/naga/tests/out/hlsl/f64.hlsl new file mode 100644 index 00000000000..5b87b263412 --- /dev/null +++ b/naga/tests/out/hlsl/f64.hlsl @@ -0,0 +1,19 @@ +static const double k = 2.0L; + +static double v = 1.0L; + +double f(double x) +{ + double z = (double)0; + + double y = (30.0L + 400.0L); + z = (y + double(5)); + return (((x + y) + k) + 5.0L); +} + +[numthreads(1, 1, 1)] +void main() +{ + const double _e1 = f(6.0L); + return; +} diff --git a/naga/tests/out/hlsl/f64.ron b/naga/tests/out/hlsl/f64.ron new file mode 100644 index 00000000000..a07b03300b1 --- /dev/null +++ b/naga/tests/out/hlsl/f64.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/spv/f64.spvasm b/naga/tests/out/spv/f64.spvasm new file mode 100644 index 00000000000..cdf70f326d6 --- /dev/null +++ b/naga/tests/out/spv/f64.spvasm @@ -0,0 +1,48 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 33 +OpCapability Shader +OpCapability Float64 +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %28 "main" +OpExecutionMode %28 LocalSize 1 1 1 +%2 = OpTypeVoid +%3 = OpTypeFloat 64 +%4 = OpConstant %3 1.0 +%5 = OpConstant %3 2.0 +%7 = OpTypePointer Private %3 +%6 = OpVariable %7 Private %4 +%11 = OpTypeFunction %3 %3 +%12 = OpConstant %3 30.0 +%13 = OpConstant %3 400.0 +%14 = OpTypeInt 32 1 +%15 = OpConstant %14 5 +%16 = OpConstant %3 5.0 +%18 = OpTypePointer Function %3 +%19 = OpConstantNull %3 +%29 = OpTypeFunction %2 +%30 = OpConstant %3 6.0 +%10 = OpFunction %3 None %11 +%9 = OpFunctionParameter %3 +%8 = OpLabel +%17 = OpVariable %18 Function %19 +OpBranch %20 +%20 = OpLabel +%21 = OpFAdd %3 %12 %13 +%22 = OpConvertSToF %3 %15 +%23 = OpFAdd %3 %21 %22 +OpStore %17 %23 +%24 = OpFAdd %3 %9 %21 +%25 = OpFAdd %3 %24 %5 +%26 = OpFAdd %3 %25 %16 +OpReturnValue %26 +OpFunctionEnd +%28 = OpFunction %2 None %29 +%27 = OpLabel +OpBranch %31 +%31 = OpLabel +%32 = OpFunctionCall %3 %10 %30 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/f64.wgsl b/naga/tests/out/wgsl/f64.wgsl new file mode 100644 index 00000000000..61f9c867e55 --- /dev/null +++ b/naga/tests/out/wgsl/f64.wgsl @@ -0,0 +1,17 @@ +const k: f64 = 2.0lf; + +var v: f64 = 1.0lf; + +fn f(x: f64) -> f64 { + var z: f64; + + let y = (30.0lf + 400.0lf); + z = (y + f64(5)); + return (((x + y) + k) + 5.0lf); +} + +@compute @workgroup_size(1, 1, 1) +fn main() { + let _e1 = f(6.0lf); + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index f50ba99e3c9..413ed332661 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -775,6 +775,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("separate-entry-points", Targets::SPIRV | Targets::GLSL), + ( + "f64", + Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {