Skip to content

Commit

Permalink
[naga] Add Literal::I64, for signed 64-bit integer literals.
Browse files Browse the repository at this point in the history
Add an `I64` variant to `crate::Literal`, making `crate::Expression`
suitable for representing `AbstractFloat` and `AbstractInt` values in
the WGSL front end.

Make validation reject uses of `Literal::I64` in constant and function
expression arenas unconditionally. Add tests for this.

Let the frontends and backends for languages that have 64-bit integers
read/write them.
  • Loading branch information
jimblandy committed Nov 17, 2023
1 parent a5c93ca commit 05a12e8
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 9 deletions.
3 changes: 3 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,9 @@ impl<'a, W: Write> Writer<'a, W> {
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::I64(_) => {
return Err(Error::Custom("GLSL has no 64-bit integer type".into()));
}
}
}
Expression::Constant(handle) => {
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2041,6 +2041,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::I64(value) => write!(self.out, "{}L", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
},
Expression::Constant(handle) => {
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,9 @@ impl<W: Write> Writer<W> {
crate::Literal::I32(value) => {
write!(self.out, "{value}")?;
}
crate::Literal::I64(value) => {
write!(self.out, "{value}L")?;
}
crate::Literal::Bool(value) => {
write!(self.out, "{value}")?;
}
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,9 @@ impl Writer {
crate::Literal::F32(value) => Instruction::constant_32bit(type_id, id, value.to_bits()),
crate::Literal::U32(value) => Instruction::constant_32bit(type_id, id, value),
crate::Literal::I32(value) => Instruction::constant_32bit(type_id, id, value as u32),
crate::Literal::I64(value) => {
Instruction::constant_64bit(type_id, id, value as u32, (value >> 32) as u32)
}
crate::Literal::Bool(true) => Instruction::constant_true(type_id, id),
crate::Literal::Bool(false) => Instruction::constant_false(type_id, id),
};
Expand Down
9 changes: 6 additions & 3 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1091,13 +1091,16 @@ impl<W: Write> Writer<W> {
match literal {
// Floats are written using `Debug` instead of `Display` because it always appends the
// decimal part even it's zero
crate::Literal::F64(_) => {
return Err(Error::Custom("unsupported f64 literal".to_string()));
}
crate::Literal::F32(value) => write!(self.out, "{:?}", value)?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
crate::Literal::F64(_) => {
return Err(Error::Custom("unsupported f64 literal".to_string()));
}
crate::Literal::I64(_) => {
return Err(Error::Custom("unsupported i64 literal".to_string()));
}
}
}
Expression::Constant(handle) => {
Expand Down
5 changes: 5 additions & 0 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4879,6 +4879,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let low = self.next()?;
match width {
4 => crate::Literal::I32(low as i32),
8 => {
inst.expect(5)?;
let high = self.next()?;
crate::Literal::I64((u64::from(high) << 32 | u64::from(low)) as i64)
}
_ => return Err(Error::InvalidTypeWidth(width as u32)),
}
}
Expand Down
1 change: 1 addition & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ pub enum Literal {
F32(f32),
U32(u32),
I32(i32),
I64(i64),
Bool(bool),
}

Expand Down
16 changes: 12 additions & 4 deletions naga/src/proc/constant_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,28 +963,36 @@ impl<'a> ConstantEvaluator<'a> {
Literal::U32(v) => v as i32,
Literal::F32(v) => v as i32,
Literal::Bool(v) => v as i32,
Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg),
Literal::F64(_) | Literal::I64(_) => {
return Err(ConstantEvaluatorError::InvalidCastArg)
}
}),
Sc::U32 => Literal::U32(match literal {
Literal::I32(v) => v as u32,
Literal::U32(v) => v,
Literal::F32(v) => v as u32,
Literal::Bool(v) => v as u32,
Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg),
Literal::F64(_) | Literal::I64(_) => {
return Err(ConstantEvaluatorError::InvalidCastArg)
}
}),
Sc::F32 => Literal::F32(match literal {
Literal::I32(v) => v as f32,
Literal::U32(v) => v as f32,
Literal::F32(v) => v,
Literal::Bool(v) => v as u32 as f32,
Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg),
Literal::F64(_) | Literal::I64(_) => {
return Err(ConstantEvaluatorError::InvalidCastArg)
}
}),
Sc::BOOL => Literal::Bool(match literal {
Literal::I32(v) => v != 0,
Literal::U32(v) => v != 0,
Literal::F32(v) => v != 0.0,
Literal::Bool(v) => v,
Literal::F64(_) => return Err(ConstantEvaluatorError::InvalidCastArg),
Literal::F64(_) | Literal::I64(_) => {
return Err(ConstantEvaluatorError::InvalidCastArg)
}
}),
_ => return Err(ConstantEvaluatorError::InvalidCastArg),
};
Expand Down
13 changes: 12 additions & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ impl super::Scalar {
kind: crate::ScalarKind::Float,
width: 8,
};
pub const I64: Self = Self {
kind: crate::ScalarKind::Sint,
width: 8,
};
pub const BOOL: Self = Self {
kind: crate::ScalarKind::Bool,
width: crate::BOOL_WIDTH,
Expand Down Expand Up @@ -130,6 +134,7 @@ impl PartialEq for crate::Literal {
(Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(),
(Self::U32(a), Self::U32(b)) => a == b,
(Self::I32(a), Self::I32(b)) => a == b,
(Self::I64(a), Self::I64(b)) => a == b,
(Self::Bool(a), Self::Bool(b)) => a == b,
_ => false,
}
Expand Down Expand Up @@ -159,6 +164,10 @@ impl std::hash::Hash for crate::Literal {
hasher.write_u8(4);
v.hash(hasher);
}
Self::I64(v) => {
hasher.write_u8(5);
v.hash(hasher);
}
}
}
}
Expand All @@ -170,6 +179,7 @@ impl crate::Literal {
(value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)),
(value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)),
(value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)),
(value, crate::ScalarKind::Sint, 8) => Some(Self::I64(value as _)),
(1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)),
(0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)),
_ => None,
Expand All @@ -186,7 +196,7 @@ impl crate::Literal {

pub const fn width(&self) -> crate::Bytes {
match *self {
Self::F64(_) => 8,
Self::F64(_) | Self::I64(_) => 8,
Self::F32(_) | Self::U32(_) | Self::I32(_) => 4,
Self::Bool(_) => 1,
}
Expand All @@ -197,6 +207,7 @@ impl crate::Literal {
Self::F32(_) => crate::Scalar::F32,
Self::U32(_) => crate::Scalar::U32,
Self::I32(_) => crate::Scalar::I32,
Self::I64(_) => crate::Scalar::I64,
Self::Bool(_) => crate::Scalar::BOOL,
}
}
Expand Down
145 changes: 145 additions & 0 deletions naga/src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1655,3 +1655,148 @@ pub fn check_literal_value(literal: crate::Literal) -> Result<(), LiteralError>

Ok(())
}

#[cfg(all(test, feature = "validate"))]
/// Validate a module containing the given expression, expecting an error.
fn validate_with_expression(
expr: crate::Expression,
caps: super::Capabilities,
) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
use crate::span::Span;

let mut function = crate::Function::default();
function.expressions.append(expr, Span::default());
function.body.push(
crate::Statement::Emit(function.expressions.range_from(0)),
Span::default(),
);

let mut module = crate::Module::default();
module.functions.append(function, Span::default());

let mut validator = super::Validator::new(super::ValidationFlags::EXPRESSIONS, caps);

validator.validate(&module)
}

#[cfg(all(test, feature = "validate"))]
/// Validate a module containing the given constant expression, expecting an error.
fn validate_with_const_expression(
expr: crate::Expression,
caps: super::Capabilities,
) -> Result<ModuleInfo, crate::span::WithSpan<super::ValidationError>> {
use crate::span::Span;

let mut module = crate::Module::default();
module.const_expressions.append(expr, Span::default());

let mut validator = super::Validator::new(super::ValidationFlags::CONSTANTS, caps);

validator.validate(&module)
}

/// Using F64 in a function's expression arena is forbidden.
#[cfg(feature = "validate")]
#[test]
fn f64_runtime_literals() {
let result = validate_with_expression(
crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
super::Capabilities::default(),
);
let error = result.unwrap_err().into_inner();
assert!(matches!(
error,
crate::valid::ValidationError::Function {
source: super::FunctionError::Expression {
source: super::ExpressionError::Literal(super::LiteralError::Width(
super::r#type::WidthError::MissingCapability {
name: "f64",
flag: "FLOAT64",
}
),),
..
},
..
}
));

let result = validate_with_expression(
crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
super::Capabilities::default() | super::Capabilities::FLOAT64,
);
assert!(result.is_ok());
}

/// Using F64 in a module's constant expression arena is forbidden.
#[cfg(feature = "validate")]
#[test]
fn f64_const_literals() {
let result = validate_with_const_expression(
crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
super::Capabilities::default(),
);
let error = result.unwrap_err().into_inner();
assert!(matches!(
error,
crate::valid::ValidationError::ConstExpression {
source: super::ConstExpressionError::Literal(super::LiteralError::Width(
super::r#type::WidthError::MissingCapability {
name: "f64",
flag: "FLOAT64",
}
)),
..
}
));

let result = validate_with_const_expression(
crate::Expression::Literal(crate::Literal::F64(0.57721_56649)),
super::Capabilities::default() | super::Capabilities::FLOAT64,
);
assert!(result.is_ok());
}

/// Using I64 in a function's expression arena is forbidden.
#[cfg(feature = "validate")]
#[test]
fn i64_runtime_literals() {
let result = validate_with_expression(
crate::Expression::Literal(crate::Literal::I64(1729)),
// There is no capability that enables this.
super::Capabilities::all(),
);
let error = result.unwrap_err().into_inner();
assert!(matches!(
error,
crate::valid::ValidationError::Function {
source: super::FunctionError::Expression {
source: super::ExpressionError::Literal(super::LiteralError::Width(
super::r#type::WidthError::Unsupported64Bit
),),
..
},
..
}
));
}

/// Using I64 in a module's constant expression arena is forbidden.
#[cfg(feature = "validate")]
#[test]
fn i64_const_literals() {
let result = validate_with_const_expression(
crate::Expression::Literal(crate::Literal::I64(1729)),
// There is no capability that enables this.
super::Capabilities::all(),
);
let error = result.unwrap_err().into_inner();
assert!(matches!(
error,
crate::valid::ValidationError::ConstExpression {
source: super::ConstExpressionError::Literal(super::LiteralError::Width(
super::r#type::WidthError::Unsupported64Bit,
),),
..
}
));
}
7 changes: 6 additions & 1 deletion naga/src/valid/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ impl super::Validator {
scalar.width == 4
}
}
crate::ScalarKind::Sint | crate::ScalarKind::Uint => scalar.width == 4,
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
if scalar.width == 8 {
return Err(WidthError::Unsupported64Bit);
}
scalar.width == 4
}
};
if good {
Ok(())
Expand Down

0 comments on commit 05a12e8

Please sign in to comment.