From 42058cf24f45d34a066639815de43abbbc9a990b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 26 Oct 2023 16:20:03 -0700 Subject: [PATCH] [naga] Add `Literal::I64`, for signed 64-bit integer literals. Add an `I64` variant to `crate::Literal`, making `crate::Expression` suitable for representing `AbstractFloat` and `AbstractInt` values in the WGSL front end. Make validation reject uses of `Literal::I64` in constant and function expression arenas unconditionally. Add tests for this. Let the frontends and backends for languages that have 64-bit integers read/write them. --- CHANGELOG.md | 3 + naga/src/back/glsl/mod.rs | 3 + naga/src/back/hlsl/writer.rs | 1 + naga/src/back/msl/writer.rs | 3 + naga/src/back/spv/writer.rs | 3 + naga/src/back/wgsl/writer.rs | 9 +- naga/src/front/spv/mod.rs | 5 + naga/src/lib.rs | 1 + naga/src/proc/constant_evaluator.rs | 16 ++- naga/src/proc/mod.rs | 13 ++- naga/src/valid/expression.rs | 145 ++++++++++++++++++++++++++++ naga/src/valid/type.rs | 7 +- 12 files changed, 200 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2faf1cef6..c3d82a9f56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S - Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673). - Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707). +- Implement WGSL abstract types (by @jimblandy): + - Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) + ### Bug Fixes #### WGL diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 2ecb421d4d..967ffc468c 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2410,6 +2410,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(_) => { + return Err(Error::Custom("GLSL has no 64-bit integer type".into())); + } } } Expression::Constant(handle) => { diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index dc61805568..40e1e2db39 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2041,6 +2041,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::Literal::F32(value) => write!(self.out, "{value:?}")?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::I64(value) => write!(self.out, "{}L", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, }, Expression::Constant(handle) => { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 07077a5c16..7836f0c3b9 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1274,6 +1274,9 @@ impl Writer { crate::Literal::I32(value) => { write!(self.out, "{value}")?; } + crate::Literal::I64(value) => { + write!(self.out, "{value}L")?; + } crate::Literal::Bool(value) => { write!(self.out, "{value}")?; } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index b977b757a1..da3fb3e786 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1179,6 +1179,9 @@ impl Writer { crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()), crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value), crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32), + crate::Literal::I64(value) => { + Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32) + } crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), }; diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index cd9ae076f3..0595a966bf 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1091,13 +1091,16 @@ impl Writer { match literal { // Floats are written using `Debug` instead of `Display` because it always appends the // decimal part even it's zero - crate::Literal::F64(_) => { - return Err(Error::Custom("unsupported f64 literal".to_string())); - } crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, crate::Literal::U32(value) => write!(self.out, "{}u", value)?, crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::F64(_) => { + return Err(Error::Custom("unsupported f64 literal".to_string())); + } + crate::Literal::I64(_) => { + return Err(Error::Custom("unsupported i64 literal".to_string())); + } } } Expression::Constant(handle) => { diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 7a8ba3c1cf..74f58033e5 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4879,6 +4879,11 @@ impl> Frontend { let low = self.next()?; match width { 4 => crate::Literal::I32(low as i32), + 8 => { + inst.expect(5)?; + let high = self.next()?; + crate::Literal::I64((u64::from(high) << 32 | u64::from(low)) as i64) + } _ => return Err(Error::InvalidTypeWidth(width as u32)), } } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index b23630d999..792f3066c7 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -869,6 +869,7 @@ pub enum Literal { F32(f32), U32(u32), I32(i32), + I64(i64), Bool(bool), } diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 6adc97ac3e..b0ef371c84 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -979,28 +979,36 @@ impl<'a> ConstantEvaluator<'a> { Literal::U32(v) => v as i32, Literal::F32(v) => v as i32, Literal::Bool(v) => v as i32, - Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::F64(_) | Literal::I64(_) => { + return Err(ConstantEvaluatorError::InvalidCastArg) + } }), Sc::U32 => Literal::U32(match literal { Literal::I32(v) => v as u32, Literal::U32(v) => v, Literal::F32(v) => v as u32, Literal::Bool(v) => v as u32, - Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::F64(_) | Literal::I64(_) => { + return Err(ConstantEvaluatorError::InvalidCastArg) + } }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, Literal::U32(v) => v as f32, Literal::F32(v) => v, Literal::Bool(v) => v as u32 as f32, - Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::F64(_) | Literal::I64(_) => { + return Err(ConstantEvaluatorError::InvalidCastArg) + } }), Sc::BOOL => Literal::Bool(match literal { Literal::I32(v) => v != 0, Literal::U32(v) => v != 0, Literal::F32(v) => v != 0.0, Literal::Bool(v) => v, - Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::F64(_) | Literal::I64(_) => { + return Err(ConstantEvaluatorError::InvalidCastArg) + } }), _ => return Err(ConstantEvaluatorError::InvalidCastArg), }; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 5adf271133..897063aca8 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -94,6 +94,10 @@ impl super::Scalar { kind: crate::ScalarKind::Float, width: 8, }; + pub const I64: Self = Self { + kind: crate::ScalarKind::Sint, + width: 8, + }; pub const BOOL: Self = Self { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, @@ -130,6 +134,7 @@ impl PartialEq for crate::Literal { (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), (Self::U32(a), Self::U32(b)) => a == b, (Self::I32(a), Self::I32(b)) => a == b, + (Self::I64(a), Self::I64(b)) => a == b, (Self::Bool(a), Self::Bool(b)) => a == b, _ => false, } @@ -159,6 +164,10 @@ impl std::hash::Hash for crate::Literal { hasher.write_u8(4); v.hash(hasher); } + Self::I64(v) => { + hasher.write_u8(5); + v.hash(hasher); + } } } } @@ -170,6 +179,7 @@ impl crate::Literal { (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), + (value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)), (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), _ => None, @@ -186,7 +196,7 @@ impl crate::Literal { pub const fn width(&self) -> crate::Bytes { match *self { - Self::F64(_) => 8, + Self::F64(_) | Self::I64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, Self::Bool(_) => 1, } @@ -197,6 +207,7 @@ impl crate::Literal { Self::F32(_) => crate::Scalar::F32, Self::U32(_) => crate::Scalar::U32, Self::I32(_) => crate::Scalar::I32, + Self::I64(_) => crate::Scalar::I64, Self::Bool(_) => crate::Scalar::BOOL, } } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 840ba90d01..c0cad69f21 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1650,3 +1650,148 @@ pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError> Ok(()) } + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given expression, expecting an error. +fn validate_with_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result> { + use crate::span::Span; + + let mut function = crate::Function::default(); + function.expressions.append(expr, Span::default()); + function.body.push( + crate::Statement::Emit(function.expressions.range_from(0)), + Span::default(), + ); + + let mut module = crate::Module::default(); + module.functions.append(function, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps); + + validator.validate(&module) +} + +#[cfg(all(test, feature = "validate"))] +/// Validate a module containing the given constant expression, expecting an error. +fn validate_with_const_expression( + expr: crate::Expression, + caps: super::Capabilities, +) -> Result> { + use crate::span::Span; + + let mut module = crate::Module::default(); + module.const_expressions.append(expr, Span::default()); + + let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps); + + validator.validate(&module) +} + +/// Using F64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + ),), + .. + }, + .. + } + )); + + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using F64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn f64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::MissingCapability { + name: "f64", + flag: "FLOAT64", + } + )), + .. + } + )); + + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::F64(0.57721_56649)), + super::Capabilities::default() | super::Capabilities::FLOAT64, + ); + assert!(result.is_ok()); +} + +/// Using I64 in a function's expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_runtime_literals() { + let result = validate_with_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::Function { + source: super::FunctionError::Expression { + source: super::ExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit + ),), + .. + }, + .. + } + )); +} + +/// Using I64 in a module's constant expression arena is forbidden. +#[cfg(feature = "validate")] +#[test] +fn i64_const_literals() { + let result = validate_with_const_expression( + crate::Expression::Literal(crate::Literal::I64(1729)), + // There is no capability that enables this. + super::Capabilities::all(), + ); + let error = result.unwrap_err().into_inner(); + assert!(matches!( + error, + crate::valid::ValidationError::ConstExpression { + source: super::ConstExpressionError::Literal(super::LiteralError::Width( + super::r#type::WidthError::Unsupported64Bit, + ),), + .. + } + )); +} diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 12d5663f40..217a02a90d 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -240,7 +240,12 @@ impl super::Validator { scalar.width == 4 } } - crate::ScalarKind::Sint | crate::ScalarKind::Uint => scalar.width == 4, + crate::ScalarKind::Sint | crate::ScalarKind::Uint => { + if scalar.width == 8 { + return Err(WidthError::Unsupported64Bit); + } + scalar.width == 4 + } }; if good { Ok(())