Skip to content

Commit

Permalink
[naga wgsl] Experimental 64-bit floating-point literals.
Browse files Browse the repository at this point in the history
In the WGSL front and back ends, support an `lf` suffix on
floating-point literals to yield 64-bit integer literals.
  • Loading branch information
jimblandy committed Nov 22, 2023
1 parent 7246226 commit daf7dee
Show file tree
Hide file tree
Showing 12 changed files with 198 additions and 12 deletions.
4 changes: 1 addition & 3 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,7 @@ impl<W: Write> Writer<W> {
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::F64(_) => {
return Err(Error::Custom("unsupported f64 literal".to_string()));
}
crate::Literal::F64(value) => write!(self.out, "{:?}lf", value)?,
crate::Literal::I64(_) => {
return Err(Error::Custom("unsupported i64 literal".to_string()));
}
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i),
ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u),
ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f),
ast::Literal::Number(_) => {
unreachable!("got abstract numeric type when not expected");
}
Expand Down
17 changes: 17 additions & 0 deletions naga/src/front/wgsl/parse/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ impl<'a> Lexer<'a> {
}

#[cfg(test)]
#[track_caller]
fn sub_test(source: &str, expected_tokens: &[Token]) {
let mut lex = Lexer::new(source);
for &token in expected_tokens {
Expand Down Expand Up @@ -624,6 +625,22 @@ fn test_numbers() {
);
}

#[test]
fn double_floats() {
sub_test(
"0x1.2p4lf 0x1p8lf 0.0625lf 625e-4lf 10lf 10l",
&[
Token::Number(Ok(Number::F64(18.0))),
Token::Number(Ok(Number::F64(256.0))),
Token::Number(Ok(Number::F64(0.0625))),
Token::Number(Ok(Number::F64(0.0625))),
Token::Number(Ok(Number::F64(10.0))),
Token::Number(Ok(Number::I32(10))),
Token::Word("l"),
],
)
}

#[test]
fn test_tokens() {
sub_test("id123_OK", &[Token::Word("id123_OK")]);
Expand Down
44 changes: 35 additions & 9 deletions naga/src/front/wgsl/parse/number.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub enum Number {
U32(u32),
/// Concrete f32
F32(f32),
/// Concrete f64
F64(f64),
}

impl Number {
Expand Down Expand Up @@ -61,9 +63,11 @@ enum IntKind {
U32,
}

#[derive(Debug)]
enum FloatKind {
F32,
F16,
F32,
F64,
}

// The following regexes (from the WGSL spec) will be matched:
Expand Down Expand Up @@ -104,9 +108,9 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
/// if one of the given patterns are found at the start of the buffer
/// returning the corresponding expr for the matched pattern
macro_rules! consume_map {
($bytes:ident, [$($pattern:pat_param => $to:expr),*]) => {
($bytes:ident, [$( $($pattern:pat_param),* => $to:expr),* $(,)?]) => {
match $bytes {
$( &[$pattern, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
$( &[ $($pattern),*, ref rest @ ..] => { $bytes = rest; Some($to) }, )*
_ => None,
}
};
Expand Down Expand Up @@ -136,6 +140,16 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
}};
}

macro_rules! consume_float_suffix {
($bytes:ident) => {
consume_map!($bytes, [
b'h' => FloatKind::F16,
b'f' => FloatKind::F32,
b'l', b'f' => FloatKind::F64,
])
};
}

/// maps the given `&[u8]` (tail of the initial `input: &str`) to a `&str`
macro_rules! rest_to_str {
($bytes:ident) => {
Expand Down Expand Up @@ -190,7 +204,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let number = general_extract.end(bytes);

let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
let kind = consume_float_suffix!(bytes);

(parse_hex_float(number, kind), rest_to_str!(bytes))
} else {
Expand Down Expand Up @@ -219,7 +233,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let exponent = exp_extract.end(bytes);

let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
let kind = consume_float_suffix!(bytes);

(
parse_hex_float_missing_period(significand, exponent, kind),
Expand Down Expand Up @@ -257,7 +271,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let number = general_extract.end(bytes);

let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
let kind = consume_float_suffix!(bytes);

(parse_dec_float(number, kind), rest_to_str!(bytes))
} else {
Expand All @@ -275,7 +289,7 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {

let number = general_extract.end(bytes);

let kind = consume_map!(bytes, [b'f' => FloatKind::F32, b'h' => FloatKind::F16]);
let kind = consume_float_suffix!(bytes);

(parse_dec_float(number, kind), rest_to_str!(bytes))
} else {
Expand All @@ -289,8 +303,9 @@ fn parse(input: &str) -> (Result<Number, NumberError>, &str) {
let kind = consume_map!(bytes, [
b'i' => Kind::Int(IntKind::I32),
b'u' => Kind::Int(IntKind::U32),
b'h' => Kind::Float(FloatKind::F16),
b'f' => Kind::Float(FloatKind::F32),
b'h' => Kind::Float(FloatKind::F16)
b'l', b'f' => Kind::Float(FloatKind::F64),
]);

(
Expand Down Expand Up @@ -382,12 +397,17 @@ fn parse_hex_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
Some(FloatKind::F32) => match hexf_parse::parse_hexf32(input, false) {
Ok(num) => Ok(Number::F32(num)),
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
Some(FloatKind::F64) => match hexf_parse::parse_hexf64(input, false) {
Ok(num) => Ok(Number::F64(num)),
// can only be ParseHexfErrorKind::Inexact but we can't check since it's private
_ => Err(NumberError::NotRepresentable),
},
}
}

Expand All @@ -407,6 +427,12 @@ fn parse_dec_float(input: &str, kind: Option<FloatKind>) -> Result<Number, Numbe
.then_some(Number::F32(num))
.ok_or(NumberError::NotRepresentable)
}
Some(FloatKind::F64) => {
let num = input.parse::<f64>().unwrap(); // will never fail
num.is_finite()
.then_some(Number::F64(num))
.ok_or(NumberError::NotRepresentable)
}
Some(FloatKind::F16) => Err(NumberError::UnimplementedF16),
}
}
Expand Down
12 changes: 12 additions & 0 deletions naga/tests/in/f64.param.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
god_mode: true,
spv: (
version: (1, 0),
),
glsl: (
version: Desktop(420),
writer_flags: (""),
binding_map: { },
zero_initialize_workgroup_memory: true,
),
)
13 changes: 13 additions & 0 deletions naga/tests/in/f64.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
var<private> v: f64 = 1lf;
const k: f64 = 2.0lf;

fn f(x: f64) -> f64 {
let y: f64 = 3e1lf + 4.0e2lf;
var z = y + f64(5);
return x + y + k + 5.0lf;
}

@compute @workgroup_size(1)
fn main() {
f(6.0lf);
}
19 changes: 19 additions & 0 deletions naga/tests/out/glsl/f64.main.Compute.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#version 420 core
#extension GL_ARB_compute_shader : require
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

const double k = 2.0LF;


double f(double x) {
double z = 0.0;
double y = (30.0LF + 400.0LF);
z = (y + double(5));
return (((x + y) + k) + 5.0LF);
}

void main() {
double _e1 = f(6.0LF);
return;
}

19 changes: 19 additions & 0 deletions naga/tests/out/hlsl/f64.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
static const double k = 2.0L;

static double v = 1.0L;

double f(double x)
{
double z = (double)0;

double y = (30.0L + 400.0L);
z = (y + double(5));
return (((x + y) + k) + 5.0L);
}

[numthreads(1, 1, 1)]
void main()
{
const double _e1 = f(6.0L);
return;
}
12 changes: 12 additions & 0 deletions naga/tests/out/hlsl/f64.ron
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)
48 changes: 48 additions & 0 deletions naga/tests/out/spv/f64.spvasm
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 33
OpCapability Shader
OpCapability Float64
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %28 "main"
OpExecutionMode %28 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeFloat 64
%4 = OpConstant %3 1.0
%5 = OpConstant %3 2.0
%7 = OpTypePointer Private %3
%6 = OpVariable %7 Private %4
%11 = OpTypeFunction %3 %3
%12 = OpConstant %3 30.0
%13 = OpConstant %3 400.0
%14 = OpTypeInt 32 1
%15 = OpConstant %14 5
%16 = OpConstant %3 5.0
%18 = OpTypePointer Function %3
%19 = OpConstantNull %3
%29 = OpTypeFunction %2
%30 = OpConstant %3 6.0
%10 = OpFunction %3 None %11
%9 = OpFunctionParameter %3
%8 = OpLabel
%17 = OpVariable %18 Function %19
OpBranch %20
%20 = OpLabel
%21 = OpFAdd %3 %12 %13
%22 = OpConvertSToF %3 %15
%23 = OpFAdd %3 %21 %22
OpStore %17 %23
%24 = OpFAdd %3 %9 %21
%25 = OpFAdd %3 %24 %5
%26 = OpFAdd %3 %25 %16
OpReturnValue %26
OpFunctionEnd
%28 = OpFunction %2 None %29
%27 = OpLabel
OpBranch %31
%31 = OpLabel
%32 = OpFunctionCall %3 %10 %30
OpReturn
OpFunctionEnd
17 changes: 17 additions & 0 deletions naga/tests/out/wgsl/f64.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
const k: f64 = 2.0lf;

var<private> v: f64 = 1.0lf;

fn f(x: f64) -> f64 {
var z: f64;

let y = (30.0lf + 400.0lf);
z = (y + f64(5));
return (((x + y) + k) + 5.0lf);
}

@compute @workgroup_size(1, 1, 1)
fn main() {
let _e1 = f(6.0lf);
return;
}
4 changes: 4 additions & 0 deletions naga/tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,10 @@ fn convert_wgsl() {
Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
("separate-entry-points", Targets::SPIRV | Targets::GLSL),
(
"f64",
Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL,
),
];

for &(name, targets) in inputs.iter() {
Expand Down

0 comments on commit daf7dee

Please sign in to comment.