From 506463060463c52405d1951b78dd09782a77119b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 9 Nov 2023 09:23:15 -0800 Subject: [PATCH 1/9] [naga] Introduce `ScalarKind::AbstractInt` and `AbstractFloat`. Introduce new variants of `naga::ScalarKind`, `AbstractInt` and `AbstractFloat`, for representing WGSL abstract types. --- naga/src/back/glsl/mod.rs | 11 +++++++++ naga/src/back/hlsl/conv.rs | 5 +++- naga/src/back/msl/writer.rs | 4 ++++ naga/src/back/spv/block.rs | 2 +- naga/src/back/spv/image.rs | 4 ++++ naga/src/back/spv/writer.rs | 8 +++++++ naga/src/front/glsl/types.rs | 2 +- naga/src/front/wgsl/to_wgsl.rs | 2 ++ naga/src/lib.rs | 10 ++++++++ naga/src/proc/mod.rs | 6 ++++- naga/src/valid/expression.rs | 42 ++++++++++++++++++++++++++-------- naga/src/valid/type.rs | 11 ++++++++- wgpu-core/src/validation.rs | 4 +++- 13 files changed, 95 insertions(+), 16 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index cd3075f70a..d08a0c02c2 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -3555,6 +3555,9 @@ impl<'a, W: Write> Writer<'a, W> { (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { write!(self.out, "bool")? } + + (Sk::AbstractInt | Sk::AbstractFloat, _, _) + | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), }; write!(self.out, "(")?; @@ -4117,6 +4120,11 @@ impl<'a, W: Write> Writer<'a, W> { crate::ScalarKind::Uint => write!(self.out, "0u")?, crate::ScalarKind::Float => write!(self.out, "0.0")?, crate::ScalarKind::Sint => write!(self.out, "0")?, + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".to_string(), + )) + } } Ok(()) @@ -4345,6 +4353,9 @@ const fn glsl_scalar(scalar: crate::Scalar) -> Result, Err prefix: "b", full: "bool", }, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::UnsupportedScalar(scalar)); + } }) } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index da17c35704..b6918ddc42 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -10,7 +10,7 @@ impl crate::ScalarKind { Self::Float => "asfloat", Self::Sint => "asint", Self::Uint => "asuint", - Self::Bool => unreachable!(), + Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), } } } @@ -30,6 +30,9 @@ impl crate::Scalar { _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Bool => Ok("bool"), + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + Err(Error::UnsupportedScalar(self)) + } } } } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 17154c3cd5..de226af87b 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -338,6 +338,10 @@ impl crate::Scalar { kind: Sk::Bool, width: _, } => "bool", + Self { + kind: Sk::AbstractInt | Sk::AbstractFloat, + width: _, + } => unreachable!(), } } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index df6ecd00ff..84f8581521 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1175,7 +1175,7 @@ impl<'w> BlockContext<'w> { let op = match src_scalar.kind { Sk::Sint | Sk::Uint => spirv::Op::INotEqual, Sk::Float => spirv::Op::FUnordNotEqual, - Sk::Bool => unreachable!(), + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), }; let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?; diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index fb9d44e7f0..460c906d47 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -334,6 +334,10 @@ impl<'w> BlockContext<'w> { (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { unreachable!("we don't allow bool or float for array index") } + (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) + | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { + unreachable!("abstract types should never reach backends") + } }; let reconciled_array_index_id = if let Some(cast) = cast { let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index ef0532b2ea..48fc64bf24 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -824,6 +824,9 @@ impl Writer { Instruction::type_float(id, bits) } Sk::Bool => Instruction::type_bool(id), + Sk::AbstractInt | Sk::AbstractFloat => { + unreachable!("abstract types should never reach the backend"); + } } } @@ -1591,6 +1594,11 @@ impl Writer { | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Bool => true, Sk::Float => false, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::Validation( + "Abstract types should not appear in IR presented to backends", + )) + } }, _ => false, }; diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index f0a2705ad2..e87d76fffc 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -205,7 +205,7 @@ pub const fn type_power(scalar: Scalar) -> Option { ScalarKind::Uint => 1, ScalarKind::Float if scalar.width == 4 => 2, ScalarKind::Float => 3, - ScalarKind::Bool => return None, + ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, }) } diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index cdfa1f0b1f..c8331ace09 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -140,6 +140,8 @@ impl crate::Scalar { crate::ScalarKind::Uint => "u", crate::ScalarKind::Float => "f", crate::ScalarKind::Bool => return "bool".to_string(), + crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(), + crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(), }; format!("{}{}", prefix, self.width * 8) } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e140ad6aef..e45f463bd9 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -470,6 +470,16 @@ pub enum ScalarKind { Float, /// Boolean type. Bool, + + /// WGSL abstract integer type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractInt, + + /// Abstract floating-point type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractFloat, } /// Characteristics of a scalar type. diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 4f2f5c705d..687527049e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -71,7 +71,11 @@ impl From for super::ScalarKind { impl super::ScalarKind { pub const fn is_numeric(self) -> bool { match self { - crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true, + crate::ScalarKind::Sint + | crate::ScalarKind::Uint + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => true, crate::ScalarKind::Bool => false, } } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 1f57c55441..ba427dfda2 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -670,7 +670,11 @@ impl super::Validator { let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, Ti::Matrix { .. } => left_inner == right_inner, @@ -678,14 +682,24 @@ impl super::Validator { }, Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { - Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some( + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat, + ) => true, Some(Sk::Bool) | None => false, }; let types_match = match (left_inner, right_inner) { @@ -762,7 +776,11 @@ impl super::Validator { Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, ref other => { @@ -784,8 +802,10 @@ impl super::Validator { }, Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Float => false, + Sk::Bool | Sk::Sint | Sk::Uint | Sk::AbstractInt => { + left_inner == right_inner + } + Sk::Float | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -794,8 +814,8 @@ impl super::Validator { }, Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Bool | Sk::Float => false, + Sk::Sint | Sk::Uint | Sk::AbstractInt => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -823,8 +843,10 @@ impl super::Validator { } }; match base_scalar.kind { - Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, - Sk::Float | Sk::Bool => false, + Sk::Sint | Sk::Uint | Sk::AbstractInt => { + base_size.is_ok() && base_size == shift_size + } + Sk::Float | Sk::AbstractFloat | Sk::Bool => false, } } }; diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 53462fe801..1e3e03fe19 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -143,6 +143,9 @@ pub enum WidthError { #[error("64-bit integers are not yet supported")] Unsupported64Bit, + + #[error("Abstract types may only appear in constant expressions")] + Abstract, } // Only makes sense if `flags.contains(HOST_SHAREABLE)` @@ -248,6 +251,9 @@ impl super::Validator { } scalar.width == 4 } + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(WidthError::Abstract); + } }; if good { Ok(()) @@ -325,7 +331,10 @@ impl super::Validator { } Ti::Atomic(crate::Scalar { kind, width }) => { let good = match kind { - crate::ScalarKind::Bool | crate::ScalarKind::Float => false, + crate::ScalarKind::Bool + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => false, crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, }; if !good { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index b523bf9a73..caa23592c7 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -560,7 +560,9 @@ impl Resource { } naga::ScalarKind::Sint => wgt::TextureSampleType::Sint, naga::ScalarKind::Uint => wgt::TextureSampleType::Uint, - naga::ScalarKind::Bool => unreachable!(), + naga::ScalarKind::AbstractInt + | naga::ScalarKind::AbstractFloat + | naga::ScalarKind::Bool => unreachable!(), }, view_dimension, multisampled: multi, From 221185a7390cdddeb3a17cad4a61ac348106eda5 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 9 Nov 2023 13:21:33 -0800 Subject: [PATCH 2/9] [naga] Introduce `Literal::AbstractInt` and `AbstractFloat`. Introduce new variants of `naga::Literal`, `AbstractInt` and `AbstractFloat`, for representing WGSL abstract values. --- naga/src/back/glsl/mod.rs | 5 ++ naga/src/back/hlsl/writer.rs | 5 ++ naga/src/back/msl/writer.rs | 3 + naga/src/back/spv/writer.rs | 3 + naga/src/back/wgsl/writer.rs | 5 ++ naga/src/lib.rs | 5 ++ naga/src/proc/constant_evaluator.rs | 113 +++++++++++++++++++++++++++- naga/src/proc/mod.rs | 17 ++++- 8 files changed, 151 insertions(+), 5 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index d08a0c02c2..e1dc906630 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2451,6 +2451,11 @@ impl<'a, W: Write> Writer<'a, W> { crate::Literal::I64(_) => { return Err(Error::Custom("GLSL has no 64-bit integer type".into())); } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } } } Expression::Constant(handle) => { diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 24d54fc0e5..0dd60c6ad7 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2040,6 +2040,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::I64(value) => write!(self.out, "{}L", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } }, Expression::Constant(handle) => { let constant = &module.constants[handle]; diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index de226af87b..f900add71e 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1279,6 +1279,9 @@ impl Writer { crate::Literal::Bool(value) => { write!(self.out, "{value}")?; } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Validation); + } }, crate::Expression::Constant(handle) => { let constant = &module.constants[handle]; diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 48fc64bf24..4db86c93a7 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1187,6 +1187,9 @@ impl Writer { } crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + unreachable!("Abstract types should not appear in IR presented to backends"); + } }; instruction.to_words(&mut self.logical_layout.declarations); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 10da339968..a356c7e2ad 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1099,6 +1099,11 @@ impl Writer { crate::Literal::I64(_) => { return Err(Error::Custom("unsupported i64 literal".to_string())); } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } } } Expression::Constant(handle) => { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e45f463bd9..b27ebc6764 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -300,6 +300,9 @@ use serde::Serialize; /// Width of a boolean type, in bytes. pub const BOOL_WIDTH: Bytes = 1; +/// Width of abstract types, in bytes. +pub const ABSTRACT_WIDTH: Bytes = 8; + /// Hash map that is faster but not resilient to DoS attacks. pub type FastHashMap = rustc_hash::FxHashMap; /// Hash set that is faster but not resilient to DoS attacks. @@ -881,6 +884,8 @@ pub enum Literal { I32(i32), I64(i64), Bool(bool), + AbstractInt(i64), + AbstractFloat(f64), } #[derive(Debug, PartialEq)] diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index e3c07f9e16..6901ada20c 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -167,6 +167,15 @@ pub enum ConstantEvaluatorError { NotImplemented(String), #[error("{0} operation overflowed")] Overflow(String), + #[error( + "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately" + )] + AutomaticConversionLossy { + value: String, + to_type: &'static str, + }, + #[error("abstract floating-point values cannot be automatically converted to integers")] + AutomaticConversionFloatToInt { to_type: &'static str }, #[error("Division by zero")] DivisionByZero, #[error("Remainder by zero")] @@ -979,6 +988,8 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => i32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => i32::try_from_abstract(v)?, }), Sc::U32 => Literal::U32(match literal { Literal::I32(v) => v as u32, @@ -988,6 +999,8 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => u32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => u32::try_from_abstract(v)?, }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, @@ -997,21 +1010,28 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => f32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f32::try_from_abstract(v)?, }), Sc::F64 => Literal::F64(match literal { Literal::I32(v) => v as f64, Literal::U32(v) => v as f64, Literal::F32(v) => v as f64, - Literal::Bool(v) => v as u32 as f64, Literal::F64(v) => v, + Literal::Bool(v) => v as u32 as f64, Literal::I64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::AbstractInt(v) => f64::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f64::try_from_abstract(v)?, }), Sc::BOOL => Literal::Bool(match literal { Literal::I32(v) => v != 0, Literal::U32(v) => v != 0, Literal::F32(v) => v != 0.0, Literal::Bool(v) => v, - Literal::F64(_) | Literal::I64(_) => { + Literal::F64(_) + | Literal::I64(_) + | Literal::AbstractInt(_) + | Literal::AbstractFloat(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } }), @@ -1828,3 +1848,92 @@ mod tests { } } } + +/// Trait for conversions of abstract values to concrete types. +trait TryFromAbstract: Sized { + /// Convert an abstract literal `value` to `Self`. + /// + /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support + /// WGSL, we follow WGSL's conversion rules here: + /// + /// - WGSL §6.1.2. Conversion Rank says that automatic conversions + /// to integers are either lossless or an error. + /// + /// - WGSL §14.6.4 Floating Point Conversion says that conversions + /// to floating point in constant expressions and override + /// expressions are errors if the value is out of range for the + /// destination type, but rounding is okay. + /// + /// [`AbstractInt`]: crate::Literal::AbstractInt + /// [`Float`]: crate::Literal::Float + fn try_from_abstract(value: T) -> Result; +} + +impl TryFromAbstract for i32 { + fn try_from_abstract(value: i64) -> Result { + i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "i32", + }) + } +} + +impl TryFromAbstract for u32 { + fn try_from_abstract(value: i64) -> Result { + u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "u32", + }) + } +} + +impl TryFromAbstract for f32 { + fn try_from_abstract(value: i64) -> Result { + let f = value as f32; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract for f32 { + fn try_from_abstract(value: f64) -> Result { + let f = value as f32; + if f.is_infinite() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f32", + }); + } + Ok(f) + } +} + +impl TryFromAbstract for f64 { + fn try_from_abstract(value: i64) -> Result { + let f = value as f64; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract for f64 { + fn try_from_abstract(value: f64) -> Result { + Ok(value) + } +} + +impl TryFromAbstract for i32 { + fn try_from_abstract(_: f64) -> Result { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" }) + } +} + +impl TryFromAbstract for u32 { + fn try_from_abstract(_: f64) -> Result { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" }) + } +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 687527049e..40a342f6ce 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -106,6 +106,14 @@ impl super::Scalar { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, }; + pub const ABSTRACT_INT: Self = Self { + kind: crate::ScalarKind::AbstractInt, + width: crate::ABSTRACT_WIDTH, + }; + pub const ABSTRACT_FLOAT: Self = Self { + kind: crate::ScalarKind::AbstractFloat, + width: crate::ABSTRACT_WIDTH, + }; /// Construct a float `Scalar` with the given width. /// @@ -148,7 +156,7 @@ impl Eq for crate::Literal {} impl std::hash::Hash for crate::Literal { fn hash(&self, hasher: &mut H) { match *self { - Self::F64(v) => { + Self::F64(v) | Self::AbstractFloat(v) => { hasher.write_u8(0); v.to_bits().hash(hasher); } @@ -168,7 +176,7 @@ impl std::hash::Hash for crate::Literal { hasher.write_u8(4); v.hash(hasher); } - Self::I64(v) => { + Self::I64(v) | Self::AbstractInt(v) => { hasher.write_u8(5); v.hash(hasher); } @@ -202,7 +210,8 @@ impl crate::Literal { match *self { Self::F64(_) | Self::I64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, - Self::Bool(_) => 1, + Self::Bool(_) => crate::BOOL_WIDTH, + Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, } } pub const fn scalar(&self) -> crate::Scalar { @@ -213,6 +222,8 @@ impl crate::Literal { Self::I32(_) => crate::Scalar::I32, Self::I64(_) => crate::Scalar::I64, Self::Bool(_) => crate::Scalar::BOOL, + Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT, + Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT, } } pub const fn scalar_kind(&self) -> crate::ScalarKind { From 9d3b367a161688943d7def10088fd6695400ab0d Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 22 Nov 2023 13:41:47 -0800 Subject: [PATCH 3/9] [naga wgsl-in] Reformat match statement for better patch stability. The large `match` statement in `Lowerer::construct` seems to flop back and forth between two indentation levels as it's edited, making the diffs hard to read. Rewrite it to use deferred initialization of `expr`, so that `cargo fmt` doesn't have to decide whether or not to put the `match` on the same line as `let expr`. This makes subsequent diffs easier to read. --- naga/src/front/wgsl/lower/construction.rs | 69 +++++++++++++---------- 1 file changed, 40 insertions(+), 29 deletions(-) diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index fafda793c0..c2791fe92b 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -163,7 +163,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // above can have mutable access to the type arena. let constructor = constructor_h.borrow_inner(ctx.module); - let expr = match (components, constructor) { + let expr; + match (components, constructor) { // Empty constructor (Components::None, dst_ty) => match dst_ty { Constructor::Type((result_ty, _)) => { @@ -186,11 +187,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .. }, Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), - ) => crate::Expression::As { - expr: component, - kind: scalar.kind, - convert: Some(scalar.width), - }, + ) => { + expr = crate::Expression::As { + expr: component, + kind: scalar.kind, + convert: Some(scalar.width), + }; + } // Vector conversion (vector -> vector) ( @@ -206,11 +209,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar: dst_scalar, }, )), - ) if dst_size == src_size => crate::Expression::As { - expr: component, - kind: dst_scalar.kind, - convert: Some(dst_scalar.width), - }, + ) if dst_size == src_size => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } // Vector conversion (vector -> vector) - partial ( @@ -247,11 +252,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar: dst_scalar, }, )), - ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As { - expr: component, - kind: crate::ScalarKind::Float, - convert: Some(dst_scalar.width), - }, + ) if dst_columns == src_columns && dst_rows == src_rows => { + expr = crate::Expression::As { + expr: component, + kind: crate::ScalarKind::Float, + convert: Some(dst_scalar.width), + }; + } // Matrix conversion (matrix -> matrix) - partial ( @@ -284,10 +291,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .. }, Constructor::PartialVector { size }, - ) => crate::Expression::Splat { - size, - value: component, - }, + ) => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } // Vector constructor (splat) ( @@ -303,10 +312,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar: dst_scalar, }, )), - ) if dst_scalar == src_scalar => crate::Expression::Splat { - size, - value: component, - }, + ) if dst_scalar == src_scalar => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } // Vector constructor (by elements) ( @@ -329,7 +340,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) => { let inner = crate::TypeInner::Vector { size, scalar }; let ty = ctx.ensure_type_exists(inner); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by elements) @@ -377,7 +388,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, scalar, }); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by columns) @@ -409,7 +420,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, scalar, }); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } // Array constructor - infer type @@ -430,7 +441,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let ty = ctx.ensure_type_exists(inner); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } // Array or Struct constructor @@ -442,7 +453,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )), ) => { let components = components.into_components_vec(); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } // ERRORS From bda1c9efb175f11875ce265b15ec87949b75718f Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 22 Nov 2023 13:45:35 -0800 Subject: [PATCH 4/9] [wgsl-in]: Remove `Components::Many::first_component_ty_inner`. Delete the `first_component_ty_inner` field from `front::wgsl::lower::construction::Components::Many`. With the introduction of abstract types, it will no longer be possible to infer the type of the vector being constructed by looking at the type of its first constructor argument alone: automatic conversion rules might need to be applied to that argument. --- naga/src/front/wgsl/lower/construction.rs | 168 ++++++++++------------ naga/src/proc/mod.rs | 4 + 2 files changed, 77 insertions(+), 95 deletions(-) diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index c2791fe92b..378ee66474 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -5,6 +5,7 @@ use crate::{Handle, Span}; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; +use crate::front::wgsl::Scalar; /// A cooked form of `ast::ConstructorType` that uses Naga types whenever /// possible. @@ -80,7 +81,6 @@ enum Components<'a> { Many { components: Vec>, spans: Vec, - first_component_ty_inner: &'a crate::TypeInner, }, } @@ -131,30 +131,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner, } } - [component, ref rest @ ..] => { - let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; - - let components = std::iter::once(Ok(component)) - .chain( - rest.iter() - .map(|&component| self.expression(component, ctx)), - ) + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression(expr, ctx)) .collect::>()?; - let spans = std::iter::once(span) - .chain( - rest.iter() - .map(|&component| ctx.ast_expressions.get_span(component)), - ) + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); - let first_component_ty_inner = super::resolve_inner!(ctx, component); - - Components::Many { - components, - spans, - first_component_ty_inner, - } + Components::Many { components, spans } } }; @@ -255,7 +242,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ) if dst_columns == src_columns && dst_rows == src_rows => { expr = crate::Expression::As { expr: component, - kind: crate::ScalarKind::Float, + kind: dst_scalar.kind, convert: Some(dst_scalar.width), }; } @@ -319,56 +306,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; } - // Vector constructor (by elements) + // Vector constructor (by elements), partial + (Components::Many { components, spans }, Constructor::PartialVector { size }) => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + let inner = scalar.to_inner_vector(size); + let ty = ctx.ensure_type_exists(inner); + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar(scalar) | &crate::TypeInner::Vector { scalar, .. }, - .. - }, - Constructor::PartialVector { size }, - ) - | ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, - .. - }, - Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + Components::Many { components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { .. })), ) => { - let inner = crate::TypeInner::Vector { size, scalar }; - let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Matrix constructor (by elements) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar(scalar), - .. - }, + Components::Many { components, spans }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, - Constructor::Type(( - _, - &crate::TypeInner::Matrix { - columns, - rows, - scalar, - }, - )), - ) => { - let vec_ty = - ctx.ensure_type_exists(crate::TypeInner::Vector { scalar, size: rows }); + Components::Many { components, spans }, + Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), + ) if components.len() == columns as usize * rows as usize => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components .chunks(rows as usize) @@ -393,28 +363,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix constructor (by columns) ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { scalar, .. }, - .. - }, + Components::Many { components, spans }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { .. }, - .. - }, - Constructor::Type(( - _, - &crate::TypeInner::Matrix { - columns, - rows, - scalar, - }, - )), + Components::Many { components, spans }, + Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) => { + let scalar = + component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, @@ -477,22 +436,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Err(Error::UnexpectedComponents(span)); } - // Parameters are of the wrong type for vector or matrix constructor - ( - Components::Many { spans, .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, - )) - | Constructor::PartialVector { .. } - | Constructor::PartialMatrix { .. }, - ) => { - return Err(Error::InvalidConstructorComponentType(spans[0], 0)); - } - // Other types can't be constructed _ => return Err(Error::TypeNotConstructible(ty_span)), - }; + } let expr = ctx.append_expression(expr, span)?; Ok(expr) @@ -557,3 +503,35 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(handle) } } + +/// Compute a vector or matrix's scalar type from those of its +/// constructor arguments. +/// +/// Given `components`, the arguments given to a vector or matrix +/// constructor, return the scalar type of the vector or matrix's +/// elements. +/// +/// The `components` slice must not be empty. All elements' types must +/// have been resolved. +/// +/// If `components` are definitely not acceptable as arguments to such +/// constructors, return `Err(i)`, where `i` is the index in +/// `components` of some problematic argument. +/// +/// This function doesn't fully type-check the arguments, so it may +/// return `Ok` even when the Naga validator will reject the resulting +/// construction expression later. +fn component_scalar_from_constructor_args( + components: &[Handle], + ctx: &mut ExpressionContext<'_, '_, '_>, +) -> Result { + // Since we don't yet implement abstract types, we can settle for + // just inspecting the first element. + let first = components[0]; + ctx.grow_types(first).map_err(|_| 0_usize)?; + let inner = ctx.typifier()[first].inner_with(&ctx.module.types); + match inner.scalar() { + Some(scalar) => Ok(scalar), + None => Err(0), + } +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 40a342f6ce..e375bb1af3 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -237,6 +237,10 @@ impl crate::Literal { pub const POINTER_SPAN: u32 = 4; impl super::TypeInner { + /// Return the scalar type of `self`. + /// + /// If `inner` is a scalar, vector, or matrix type, return + /// its scalar type. Otherwise, return `None`. pub const fn scalar(&self) -> Option { use crate::TypeInner as Ti; match *self { From d1519b5c05f99dff6944036ea56a22a4e9e517f9 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 22 Nov 2023 13:50:47 -0800 Subject: [PATCH 5/9] [naga wgsl-in] Implement abstract types for consts, constructors. --- CHANGELOG.md | 40 +- naga/Cargo.toml | 2 +- naga/src/front/wgsl/error.rs | 20 + naga/src/front/wgsl/lower/construction.rs | 213 +++-- naga/src/front/wgsl/lower/conversion.rs | 375 ++++++++ naga/src/front/wgsl/lower/mod.rs | 71 +- naga/src/front/wgsl/parse/lexer.rs | 45 +- naga/src/front/wgsl/parse/number.rs | 28 +- naga/src/front/wgsl/tests.rs | 4 +- naga/src/proc/constant_evaluator.rs | 109 ++- naga/src/valid/expression.rs | 50 +- naga/src/valid/function.rs | 2 +- naga/tests/in/abstract-types.wgsl | 77 ++ naga/tests/out/ir/access.compact.ron | 16 + naga/tests/out/ir/access.ron | 1028 ++++++++++----------- naga/tests/out/ir/collatz.ron | 6 +- naga/tests/out/msl/abstract-types.msl | 59 ++ naga/tests/out/spv/abstract-types.spvasm | 46 + naga/tests/out/spv/ray-query.spvasm | 8 +- naga/tests/out/wgsl/abstract-types.wgsl | 52 ++ naga/tests/snapshots.rs | 4 + naga/tests/wgsl_errors.rs | 25 +- 22 files changed, 1541 insertions(+), 739 deletions(-) create mode 100644 naga/src/front/wgsl/lower/conversion.rs create mode 100644 naga/tests/in/abstract-types.wgsl create mode 100644 naga/tests/out/msl/abstract-types.msl create mode 100644 naga/tests/out/spv/abstract-types.spvasm create mode 100644 naga/tests/out/wgsl/abstract-types.wgsl diff --git a/CHANGELOG.md b/CHANGELOG.md index f12dd95f7f..00450c43a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,8 +102,44 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S - Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673). - Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707). -- Implement WGSL abstract types (by @jimblandy): - - Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) +- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)). + + Abstract types make numeric literals easier to use, by + automatically converting literals and other constant expressions + from abstract numeric types to concrete types when safe and + necessary. For example, to build a vector of floating-point + numbers, Naga previously made you write: + + vec3(1.0, 2.0, 3.0) + + With this change, you can now simply write: + + vec3(1, 2, 3) + + Even though the literals are abstract integers, Naga recognizes + that it is safe and necessary to convert them to `f32` values in + order to build the vector. You can also use abstract values as + initializers for global constants, like this: + + const unit_x: vec2 = vec2(1, 0); + + The literals `1` and `0` are abstract integers, and the expression + `vec2(1, 0)` is an abstract vector. However, Naga recognizes that + it can convert that to the concrete type `vec2` to satisfy + the given type of `unit_x`. + + The WGSL specification permits abstract integers and + floating-point values in almost all contexts, but Naga's support + for this is still incomplete. Many WGSL operators and builtin + functions are specified to produce abstract results when applied + to abstract inputs, but for now Naga simply concretizes them all + before applying the operation. We will expand Naga's abstract type + support in subsequent pull requests. + + As part of this work, the public types `naga::ScalarKind` and + `naga::Literal` now have new variants, `AbstractInt` and `AbstractFloat`. + +- Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) - Emit and init `struct` member padding always. By @ErichDonGubler in [#4701](https://github.com/gfx-rs/wgpu/pull/4701). diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 20b5bd5d25..0cacc493ac 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -31,7 +31,7 @@ deserialize = ["serde", "bitflags/serde", "indexmap/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] spv-in = ["petgraph", "spirv"] spv-out = ["spirv"] -wgsl-in = ["hexf-parse", "unicode-xid"] +wgsl-in = ["hexf-parse", "unicode-xid", "compact"] wgsl-out = [] hlsl-out = [] compact = [] diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f5acbe2d65..dc10124680 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -251,6 +251,12 @@ pub enum Error<'a> { ExpectedPositiveArrayLength(Span), MissingWorkgroupSize(Span), ConstantEvaluatorError(ConstantEvaluatorError, Span), + AutoConversion { + dest_span: Span, + dest_type: String, + source_span: Span, + source_type: String, + }, } impl<'a> Error<'a> { @@ -712,6 +718,20 @@ impl<'a> Error<'a> { )], notes: vec![], }, + Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"), + labels: vec![ + ( + dest_span, + format!("a value of type {dest_type} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + } } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 378ee66474..5e58c93892 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -116,13 +116,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { components: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { + use crate::proc::TypeResolution as Tr; + let constructor_h = self.constructor(constructor, ctx)?; let components = match *components { [] => Components::None, [component] => { let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; + let component = self.expression_for_abstract(component, ctx)?; let ty_inner = super::resolve_inner!(ctx, component); Components::One { @@ -134,13 +136,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref ast_components @ [_, _, ..] => { let components = ast_components .iter() - .map(|&expr| self.expression(expr, ctx)) + .map(|&expr| self.expression_for_abstract(expr, ctx)) .collect::>()?; let spans = ast_components .iter() .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); + for &component in &components { + ctx.grow_types(component)?; + } + Components::Many { components, spans } } }; @@ -288,18 +294,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector constructor (splat) ( Components::One { - component, - ty_inner: &crate::TypeInner::Scalar(src_scalar), + mut component, + ty_inner: &crate::TypeInner::Scalar(_), .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { - size, - scalar: dst_scalar, - }, - )), - ) if dst_scalar == src_scalar => { + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_scalar(std::slice::from_mut(&mut component), scalar)?; expr = crate::Expression::Splat { size, value: component, @@ -307,37 +308,82 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } // Vector constructor (by elements), partial - (Components::Many { components, spans }, Constructor::PartialVector { size }) => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialVector { size }, + ) => { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; - let inner = scalar.to_inner_vector(size); + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Vector constructor (by elements), full type given ( - Components::Many { components, .. }, - Constructor::Type((ty, &crate::TypeInner::Vector { .. })), + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; expr = crate::Expression::Compose { ty, components }; } - // Matrix constructor (by elements) + // Matrix constructor (by elements), partial ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::PartialMatrix { columns, rows }, - ) - | ( - Components::Many { components, spans }, - Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) if components.len() == columns as usize * rows as usize => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns, + rows, + scalar, + }, + )), + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components @@ -363,28 +409,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix constructor (by columns) ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, - scalar, + scalar: consensus_scalar, }); expr = crate::Expression::Compose { ty, components }; } // Array constructor - infer type (components, Constructor::PartialArray) => { - let components = components.into_components_vec(); + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = automatic_conversion_consensus(&components, ctx) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // function only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } let base = ctx.register_type(components[0])?; @@ -403,15 +476,30 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { expr = crate::Expression::Compose { ty, components }; } - // Array or Struct constructor + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Struct constructor ( components, - Constructor::Type(( - ty, - &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, - )), + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), ) => { - let components = components.into_components_vec(); + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } expr = crate::Expression::Compose { ty, components }; } @@ -504,12 +592,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } -/// Compute a vector or matrix's scalar type from those of its -/// constructor arguments. +/// Find the consensus scalar of `components` under WGSL's automatic +/// conversions. /// -/// Given `components`, the arguments given to a vector or matrix -/// constructor, return the scalar type of the vector or matrix's -/// elements. +/// If `components` can all be converted to any common scalar via +/// WGSL's automatic conversions, return the best such scalar. /// /// The `components` slice must not be empty. All elements' types must /// have been resolved. @@ -518,20 +605,36 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { /// constructors, return `Err(i)`, where `i` is the index in /// `components` of some problematic argument. /// -/// This function doesn't fully type-check the arguments, so it may -/// return `Ok` even when the Naga validator will reject the resulting +/// This function doesn't fully type-check the arguments - it only +/// considers their leaf scalar types. This means it may return `Ok` +/// even when the Naga validator will reject the resulting /// construction expression later. -fn component_scalar_from_constructor_args( +fn automatic_conversion_consensus( components: &[Handle], - ctx: &mut ExpressionContext<'_, '_, '_>, + ctx: &ExpressionContext<'_, '_, '_>, ) -> Result { - // Since we don't yet implement abstract types, we can settle for - // just inspecting the first element. - let first = components[0]; - ctx.grow_types(first).map_err(|_| 0_usize)?; - let inner = ctx.typifier()[first].inner_with(&ctx.module.types); - match inner.scalar() { - Some(scalar) => Ok(scalar), - None => Err(0), + let types = &ctx.module.types; + let mut inners = components + .iter() + .map(|&c| ctx.typifier()[c].inner_with(types)); + log::debug!( + "wgsl automatic_conversion_consensus: {:?}", + inners + .clone() + .map(|inner| inner.to_wgsl(&ctx.module.to_ctx())) + .collect::>() + ); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_combine(scalar) { + Some(new_best) => { + best = new_best; + } + None => return Err(i), + } } + + log::debug!(" consensus: {:?}", best.to_wgsl()); + Ok(best) } diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 0000000000..893c425433 --- /dev/null +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,375 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + let converted = if let crate::TypeInner::Array { .. } = *goal_inner { + let span = self.get_expression_span(expr); + self.as_const_evaluator() + .cast_array(expr, goal_scalar, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))? + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + self.append_expression(cast, expr_span)? + }; + + Ok(converted) + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_combine`. + pub fn convert_slice_to_common_scalar( + &mut self, + exprs: &mut [Handle], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle, + ) -> Result, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + let span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))?; + } + } + + Ok(expr) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4` to `vec4`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena, + ) -> Option { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_combine(self, other: Self) -> Option { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat | Sk::AbstractInt, Sk::AbstractFloat | Sk::AbstractInt) => { + Some(Self { + kind: Sk::AbstractFloat, + width: crate::ABSTRACT_WIDTH, + }) + } + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(Self::float(other.width)), + (Sk::Float, Sk::AbstractFloat) => Some(Self::float(self.width)), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, kind @ (Sk::Uint | Sk::Sint | Sk::Float)) => Some(Self { + kind, + width: other.width, + }), + (kind @ (Sk::Uint | Sk::Sint | Sk::Float), Sk::AbstractInt) => Some(Self { + kind, + width: self.width, + }), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index a727d6379b..b050ffc343 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -11,6 +11,7 @@ use crate::proc::{ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; mod construction; +mod conversion; /// Resolves the inner type of a given expression. /// @@ -66,6 +67,7 @@ macro_rules! resolve { &$ctx.typifier()[$expr] }}; } +pub(super) use resolve; /// State for constructing a `crate::Module`. pub struct GlobalContext<'source, 'temp, 'out> { @@ -903,29 +905,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Const(ref c) => { let mut ectx = ctx.as_const(); - let init = self.expression(c.init, &mut ectx)?; - let inferred_type = ectx.register_type(init)?; - - let explicit_ty = - c.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) - .transpose()?; - - if let Some(explicit) = explicit_ty { - if explicit != inferred_type { - let gctx = ctx.module.to_ctx(); - return Err(Error::InitializationTypeMismatch { - name: c.name.span, - expected: explicit.to_wgsl(&gctx), - got: inferred_type.to_wgsl(&gctx), - }); - } + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; } let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), r#override: crate::Override::None, - ty: inferred_type, + ty, init, }, span, @@ -951,6 +963,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + Ok(module) } @@ -1449,10 +1466,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract fn expression( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let expr = self.expression_for_reference(expr, ctx)?; ctx.apply_load_rule(expr) @@ -1473,8 +1505,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), - ast::Literal::Number(_) => { - unreachable!("got abstract numeric type when not expected"); + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) } ast::Literal::Bool(b) => crate::Literal::Bool(b), }; diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index 1cc5a0c438..9fa9416f11 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -465,13 +465,13 @@ fn test_numbers() { sub_test( "0x123 0X123u 1u 123 0 0i 0x3f", &[ - Token::Number(Ok(Number::I32(291))), + Token::Number(Ok(Number::AbstractInt(291))), Token::Number(Ok(Number::U32(291))), Token::Number(Ok(Number::U32(1))), - Token::Number(Ok(Number::I32(123))), + Token::Number(Ok(Number::AbstractInt(123))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(63))), + Token::Number(Ok(Number::AbstractInt(63))), ], ); // decimal floating point @@ -479,17 +479,17 @@ fn test_numbers() { "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", &[ Token::Number(Ok(Number::F32(0.))), - Token::Number(Ok(Number::F32(1.))), - Token::Number(Ok(Number::F32(0.01))), - Token::Number(Ok(Number::F32(12.34))), + Token::Number(Ok(Number::AbstractFloat(1.))), + Token::Number(Ok(Number::AbstractFloat(0.01))), + Token::Number(Ok(Number::AbstractFloat(12.34))), Token::Number(Ok(Number::F32(0.))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.001))), - Token::Number(Ok(Number::F32(43.75))), + Token::Number(Ok(Number::AbstractFloat(0.001))), + Token::Number(Ok(Number::AbstractFloat(43.75))), Token::Number(Ok(Number::F32(16.))), - Token::Number(Ok(Number::F32(0.1875))), + Token::Number(Ok(Number::AbstractFloat(0.1875))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.12109375))), + Token::Number(Ok(Number::AbstractFloat(0.12109375))), Token::Number(Err(NumberError::UnimplementedF16)), ], ); @@ -635,7 +635,7 @@ fn double_floats() { Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(10.0))), - Token::Number(Ok(Number::I32(10))), + Token::Number(Ok(Number::AbstractInt(10))), Token::Word("l"), ], ) @@ -646,13 +646,16 @@ fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); sub_test( "92No", - &[Token::Number(Ok(Number::I32(92))), Token::Word("No")], + &[ + Token::Number(Ok(Number::AbstractInt(92))), + Token::Word("No"), + ], ); sub_test( "2u3o", &[ Token::Number(Ok(Number::U32(2))), - Token::Number(Ok(Number::I32(3))), + Token::Number(Ok(Number::AbstractInt(3))), Token::Word("o"), ], ); @@ -660,7 +663,7 @@ fn test_tokens() { "2.4f44po", &[ Token::Number(Ok(Number::F32(2.4))), - Token::Number(Ok(Number::I32(44))), + Token::Number(Ok(Number::AbstractInt(44))), Token::Word("po"), ], ); @@ -699,13 +702,13 @@ fn test_tokens() { &[ // The 'f' suffixes are taken as a hex digit: // the fractional part is 0x2f / 256. - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("h"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("H"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("lf"), ], ) @@ -719,7 +722,7 @@ fn test_variable_decl() { Token::Attribute, Token::Word("group"), Token::Paren('('), - Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Paren(')'), Token::Word("var"), Token::Paren('<'), diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index 3178736990..fde5e3cee6 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -20,37 +20,11 @@ pub enum Number { F64(f64), } -impl Number { - /// Convert abstract numbers to a plausible concrete counterpart. - /// - /// Return concrete numbers unchanged. If the conversion would be - /// lossy, return an error. - fn abstract_to_concrete(self) -> Result { - match self { - Number::AbstractInt(num) => i32::try_from(num) - .map(Number::I32) - .map_err(|_| NumberError::NotRepresentable), - Number::AbstractFloat(num) => { - let num = num as f32; - if num.is_finite() { - Ok(Number::F32(num)) - } else { - Err(NumberError::NotRepresentable) - } - } - num => Ok(num), - } - } -} - // TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { let (result, rest) = parse(input); - ( - Token::Number(result.and_then(Number::abstract_to_concrete)), - rest, - ) + (Token::Number(result), rest) } enum Kind { diff --git a/naga/src/front/wgsl/tests.rs b/naga/src/front/wgsl/tests.rs index 9e3ba2fab6..eb2f8a2eb3 100644 --- a/naga/src/front/wgsl/tests.rs +++ b/naga/src/front/wgsl/tests.rs @@ -76,7 +76,7 @@ fn parse_type_cast() { assert!(parse_str( " fn main() { - let x: vec2 = vec2(0); + let x: vec2 = vec2(0i, 0i); } ", ) @@ -313,7 +313,7 @@ fn parse_texture_load() { " var t: texture_3d; fn foo() { - let r: vec4 = textureLoad(t, vec3(0.0, 1.0, 2.0), 1); + let r: vec4 = textureLoad(t, vec3(0u, 1u, 2u), 1); } ", ) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 6901ada20c..51e447847c 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -277,6 +277,17 @@ impl<'a> ConstantEvaluator<'a> { } } + pub fn to_ctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.types, + constants: self.constants, + const_expressions: match self.function_local_data { + Some(ref data) => data.const_expressions, + None => self.expressions, + }, + } + } + fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { if let Some(ref function_local_data) = self.function_local_data { if !function_local_data.expression_constness.is_const(expr) { @@ -1035,7 +1046,21 @@ impl<'a> ConstantEvaluator<'a> { return Err(ConstantEvaluatorError::InvalidCastArg) } }), - _ => return Err(ConstantEvaluatorError::InvalidCastArg), + Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal { + Literal::AbstractInt(v) => { + // Overflow is forbidden, but inexact conversions + // are fine. The range of f64 is far larger than + // that of i64, so we don't have to check anything + // here. + v as f64 + } + Literal::AbstractFloat(v) => v, + _ => return Err(ConstantEvaluatorError::InvalidCastArg), + }), + _ => { + log::debug!("Constant evaluator refused to convert value to {target:?}"); + return Err(ConstantEvaluatorError::InvalidCastArg); + } }; Expression::Literal(literal) } @@ -1085,6 +1110,66 @@ impl<'a> ConstantEvaluator<'a> { self.register_evaluated_expr(expr, span) } + /// Convert the scalar leaves of `expr` to `target`, handling arrays. + /// + /// `expr` must be a `Compose` expression whose type is a scalar, vector, + /// matrix, or nested arrays of such. + /// + /// This is basically the same as the [`cast`] method, except that that + /// should only handle Naga [`As`] expressions, which cannot convert arrays. + /// + /// Treat `span` as the location of the resulting expression. + /// + /// [`cast`]: ConstantEvaluator::cast + /// [`As`]: crate::Expression::As + pub fn cast_array( + &mut self, + expr: Handle, + target: crate::Scalar, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let Expression::Compose { ty, ref components } = self.expressions[expr] else { + return self.cast(expr, target, span); + }; + + let crate::TypeInner::Array { base: _, size, stride: _ } = self.types[ty].inner else { + return self.cast(expr, target, span); + }; + + let mut components = components.clone(); + for component in &mut components { + *component = self.cast_array(*component, target, span)?; + } + + let first = components + .first() + .ok_or(ConstantEvaluatorError::InvalidCastArg)?; + let new_base = match self.resolve_type(*first)? { + crate::proc::TypeResolution::Handle(ty) => ty, + crate::proc::TypeResolution::Value(inner) => { + self.types.insert(Type { name: None, inner }, span) + } + }; + let new_base_stride = self.types[new_base].inner.size(self.to_ctx()); + let new_array_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: new_base, + size, + stride: new_base_stride, + }, + }, + span, + ); + + let compose = Expression::Compose { + ty: new_array_ty, + components, + }; + self.register_evaluated_expr(compose, span) + } + fn unary_op( &mut self, op: UnaryOperator, @@ -1339,6 +1424,28 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.expressions.append(expr, span)) } } + + fn resolve_type( + &self, + expr: Handle, + ) -> Result { + use crate::proc::TypeResolution as Tr; + use crate::Expression as Ex; + let resolution = match self.expressions[expr] { + Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()), + Ex::Constant(c) => Tr::Handle(self.constants[c].ty), + Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty), + Ex::Splat { size, value } => { + let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else { + return Err(ConstantEvaluatorError::SplatScalarOnly); + }; + Tr::Value(TypeInner::Vector { scalar, size }) + } + _ => return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant), + }; + + Ok(resolution) + } } #[cfg(test)] diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index ba427dfda2..c82d60f062 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -670,37 +670,23 @@ impl super::Validator { let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, Ti::Matrix { .. } => left_inner == right_inner, _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { - Some( - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat, - ) => true, - Some(Sk::Bool) | None => false, + Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, }; let types_match = match (left_inner, right_inner) { // Straight scalar and mixed scalar/vector. @@ -776,12 +762,8 @@ impl super::Validator { Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -802,10 +784,8 @@ impl super::Validator { }, Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Bool | Sk::Sint | Sk::Uint | Sk::AbstractInt => { - left_inner == right_inner - } - Sk::Float | Sk::AbstractFloat => false, + Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -814,8 +794,8 @@ impl super::Validator { }, Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Sint | Sk::Uint | Sk::AbstractInt => left_inner == right_inner, - Sk::Bool | Sk::Float | Sk::AbstractFloat => false, + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -843,10 +823,8 @@ impl super::Validator { } }; match base_scalar.kind { - Sk::Sint | Sk::Uint | Sk::AbstractInt => { - base_size.is_ok() && base_size == shift_size - } - Sk::Float | Sk::AbstractFloat | Sk::Bool => false, + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, } } }; diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f5da7d0764..3b12e59067 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -112,7 +112,7 @@ pub enum FunctionError { InvalidStorePointer(Handle), #[error("The value {0:?} can not be stored")] InvalidStoreValue(Handle), - #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] + #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")] InvalidStoreTypes { pointer: Handle, value: Handle, diff --git a/naga/tests/in/abstract-types.wgsl b/naga/tests/in/abstract-types.wgsl new file mode 100644 index 0000000000..f9f718bf01 --- /dev/null +++ b/naga/tests/in/abstract-types.wgsl @@ -0,0 +1,77 @@ +// i/x: type inferred / explicit +// vX/mX/aX: vector / matrix / array of X +// where X: u/i/f: u32 / i32 / f32 +// s: vector splat +// r: vector spread (vector arg to vector constructor) +// p: "partial" constructor (type parameter inferred) +// u/i/f/ai/af: u32 / i32 / f32 / abstract float / abstract integer as parameter +// _: just for alignment + +// Ensure that: +// - the inferred type is correct. +// - all parameters' types are considered. +// - all parameters are converted to the consensus type. + +const xvupaiai: vec2 = vec2(42, 43); +const xvfpaiai: vec2 = vec2(44, 45); + +const xvupuai: vec2 = vec2(42u, 43); +const xvupaiu: vec2 = vec2(42, 43u); + +const xvuuai: vec2 = vec2(42u, 43); +const xvuaiu: vec2 = vec2(42, 43u); + +const xmfpaiaiaiai: mat2x2 = mat2x2(1, 2, 3, 4); +const xmfpafaiaiai: mat2x2 = mat2x2(1.0, 2, 3, 4); +const xmfpaiafaiai: mat2x2 = mat2x2(1, 2.0, 3, 4); +const xmfpaiaiafai: mat2x2 = mat2x2(1, 2, 3.0, 4); +const xmfpaiaiaiaf: mat2x2 = mat2x2(1, 2, 3, 4.0); + +const ivispai = vec2(1); +const ivfspaf = vec2(1.0); +const ivis_ai = vec2(1); +const ivus_ai = vec2(1); +const ivfs_ai = vec2(1); +const ivfs_af = vec2(1.0); + +const iafafaf = array(1.0, 2.0); +const iafaiai = array(1, 2); + +const iafpafaf = array(1.0, 2.0); +const iafpaiaf = array(1, 2.0); +const iafpafai = array(1.0, 2); +const xafpafaf: array = array(1.0, 2.0); + +struct S { + f: f32, + i: i32, + u: u32, +} + +const s_f_i_u: S = S(1.0f, 1i, 1u); +const s_f_iai: S = S(1.0f, 1i, 1); +const s_fai_u: S = S(1.0f, 1, 1u); +const s_faiai: S = S(1.0f, 1, 1); +const saf_i_u: S = S(1.0, 1i, 1u); +const saf_iai: S = S(1.0, 1i, 1); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1); + +// Vector construction with spreads +const ivfr_f__f = vec3(vec2(1.0f, 2.0f), 3.0f); +const ivfr_f_af = vec3(vec2(1.0f, 2.0f), 3.0 ); +const ivfraf__f = vec3(vec2 (1.0 , 2.0 ), 3.0f); +const ivfraf_af = vec3(vec2 (1.0 , 2.0 ), 3.0 ); + +const ivf__fr_f = vec3(1.0f, vec2(2.0f, 3.0f)); +const ivf__fraf = vec3(1.0f, vec2 (2.0 , 3.0 )); +const ivf_afr_f = vec3(1.0 , vec2(2.0f, 3.0f)); +const ivf_afraf = vec3(1.0 , vec2 (2.0 , 3.0 )); + +const ivfr_f_ai = vec3(vec2(1.0f, 2.0f), 3 ); +const ivfrai__f = vec3(vec2 (1 , 2 ), 3.0f); +const ivfrai_ai = vec3(vec2 (1 , 2 ), 3 ); + +const ivf__frai = vec3(1.0f, vec2 (2 , 3 )); +const ivf_air_f = vec3(1 , vec2(2.0f, 3.0f)); +const ivf_airai = vec3(1 , vec2 (2 , 3 )); diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 70ea0c4bb5..0670534e90 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1629,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -2192,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 55d27c97eb..0670534e90 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -214,16 +214,6 @@ ), ), ), - ( - name: None, - inner: Vector( - size: Bi, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Matrix( @@ -238,7 +228,7 @@ ( name: None, inner: Array( - base: 19, + base: 18, size: Constant(2), stride: 32, ), @@ -249,7 +239,7 @@ members: [ ( name: Some("am"), - ty: 20, + ty: 19, binding: None, offset: 0, ), @@ -267,14 +257,14 @@ ( name: None, inner: Pointer( - base: 22, + base: 21, space: Function, ), ), ( name: None, inner: Array( - base: 22, + base: 21, size: Constant(10), stride: 4, ), @@ -282,7 +272,7 @@ ( name: None, inner: Array( - base: 24, + base: 23, size: Constant(5), stride: 40, ), @@ -297,15 +287,6 @@ ), ), ), - ( - name: None, - inner: Pointer( - base: 3, - space: Storage( - access: ("LOAD | STORE"), - ), - ), - ), ( name: None, inner: Array( @@ -314,26 +295,6 @@ stride: 4, ), ), - ( - name: None, - inner: Vector( - size: Quad, - scalar: ( - kind: Sint, - width: 4, - ), - ), - ), - ( - name: None, - inner: Vector( - size: Tri, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Pointer( @@ -344,7 +305,7 @@ ( name: None, inner: Array( - base: 26, + base: 25, size: Constant(2), stride: 16, ), @@ -352,7 +313,7 @@ ( name: None, inner: Pointer( - base: 32, + base: 28, space: Function, ), ), @@ -412,7 +373,7 @@ group: 0, binding: 3, )), - ty: 21, + ty: 20, init: None, ), ], @@ -438,32 +399,6 @@ 6, ], ), - Literal(I32(8)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(3)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(5)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(1)), ], functions: [ ( @@ -479,7 +414,7 @@ ( name: Some("t"), ty: 16, - init: Some(54), + init: Some(49), ), ], expressions: [ @@ -507,136 +442,131 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(3), AccessIndex( - base: 14, + base: 13, index: 0, ), Load( pointer: 2, ), Access( - base: 15, - index: 16, + base: 14, + index: 15, ), Load( - pointer: 17, + pointer: 16, ), GlobalVariable(3), AccessIndex( - base: 19, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 20, + base: 19, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 22, + base: 20, index: 1, ), Load( - pointer: 24, + pointer: 21, ), GlobalVariable(3), AccessIndex( - base: 26, + base: 23, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 27, + base: 24, index: 0, ), Load( pointer: 2, ), Access( - base: 29, - index: 30, + base: 25, + index: 26, ), Load( - pointer: 31, + pointer: 27, ), GlobalVariable(3), AccessIndex( - base: 33, + base: 29, index: 0, ), Load( pointer: 2, ), Access( - base: 34, - index: 35, + base: 30, + index: 31, ), - Literal(I32(1)), AccessIndex( - base: 36, + base: 32, index: 1, ), Load( - pointer: 38, + pointer: 33, ), GlobalVariable(3), AccessIndex( - base: 40, + base: 35, index: 0, ), Load( pointer: 2, ), Access( - base: 41, - index: 42, + base: 36, + index: 37, ), Load( pointer: 2, ), Access( - base: 43, - index: 44, + base: 38, + index: 39, ), Load( - pointer: 45, + pointer: 40, ), Literal(F32(1.0)), Splat( size: Bi, - value: 47, + value: 42, ), Literal(F32(2.0)), Splat( size: Bi, - value: 49, + value: 44, ), Literal(F32(3.0)), Splat( size: Bi, - value: 51, + value: 46, ), Compose( ty: 15, components: [ - 48, - 50, - 52, + 43, + 45, + 47, ], ), Compose( ty: 16, components: [ - 53, + 48, ], ), LocalVariable(2), @@ -646,143 +576,138 @@ ), Binary( op: Add, - left: 57, - right: 56, + left: 52, + right: 51, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Literal(F32(6.0)), Splat( size: Bi, - value: 60, + value: 55, ), Literal(F32(5.0)), Splat( size: Bi, - value: 62, + value: 57, ), Literal(F32(4.0)), Splat( size: Bi, - value: 64, + value: 59, ), Compose( ty: 15, components: [ - 61, - 63, - 65, + 56, + 58, + 60, ], ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 67, + base: 62, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 70, + value: 64, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 72, - index: 73, + base: 66, + index: 67, ), Literal(F32(90.0)), Splat( size: Bi, - value: 75, + value: 69, ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 77, + base: 71, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 79, + base: 72, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 83, + base: 75, index: 0, ), Load( pointer: 2, ), Access( - base: 85, - index: 86, + base: 76, + index: 77, ), Literal(F32(20.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 89, - index: 90, + base: 80, + index: 81, ), - Literal(I32(1)), AccessIndex( - base: 91, + base: 82, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 95, - index: 96, + base: 85, + index: 86, ), Load( pointer: 2, ), Access( - base: 97, - index: 98, + base: 87, + index: 88, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 18: "l2", - 25: "l3", - 32: "l4", - 39: "l5", - 46: "l6", + 12: "l1", + 17: "l2", + 22: "l3", + 28: "l4", + 34: "l5", + 41: "l6", }, body: [ Emit(( @@ -802,160 +727,160 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, )), Emit(( - start: 14, - end: 18, + start: 13, + end: 17, + )), + Emit(( + start: 18, + end: 19, )), Emit(( start: 19, end: 20, )), Emit(( - start: 21, + start: 20, end: 22, )), Emit(( start: 23, - end: 25, + end: 24, )), Emit(( - start: 26, - end: 27, + start: 24, + end: 28, )), Emit(( - start: 28, + start: 29, end: 32, )), Emit(( - start: 33, - end: 36, + start: 32, + end: 34, )), Emit(( - start: 37, - end: 39, + start: 35, + end: 41, )), Emit(( - start: 40, - end: 46, + start: 42, + end: 43, )), Emit(( - start: 47, - end: 48, + start: 44, + end: 45, )), Emit(( - start: 49, - end: 50, + start: 46, + end: 49, )), Emit(( start: 51, - end: 54, - )), - Emit(( - start: 56, - end: 58, + end: 53, )), Store( pointer: 2, - value: 58, + value: 53, ), Emit(( - start: 58, - end: 59, + start: 53, + end: 54, )), Emit(( - start: 60, - end: 61, + start: 55, + end: 56, )), Emit(( - start: 62, - end: 63, + start: 57, + end: 58, )), Emit(( - start: 64, - end: 66, + start: 59, + end: 61, )), Store( - pointer: 59, - value: 66, + pointer: 54, + value: 61, ), Emit(( - start: 66, - end: 67, + start: 61, + end: 62, )), Emit(( - start: 68, - end: 69, + start: 62, + end: 63, )), Emit(( - start: 70, - end: 71, + start: 64, + end: 65, )), Store( - pointer: 69, - value: 71, + pointer: 63, + value: 65, ), Emit(( - start: 71, - end: 74, + start: 65, + end: 68, )), Emit(( - start: 75, - end: 76, + start: 69, + end: 70, )), Store( - pointer: 74, - value: 76, + pointer: 68, + value: 70, ), Emit(( - start: 76, - end: 77, + start: 70, + end: 71, )), Emit(( - start: 78, - end: 79, + start: 71, + end: 72, )), Emit(( - start: 80, - end: 81, + start: 72, + end: 73, )), Store( - pointer: 81, - value: 82, + pointer: 73, + value: 74, ), Emit(( - start: 82, - end: 83, + start: 74, + end: 75, )), Emit(( - start: 84, - end: 87, + start: 75, + end: 78, )), Store( - pointer: 87, - value: 88, + pointer: 78, + value: 79, ), Emit(( - start: 88, - end: 91, + start: 79, + end: 82, )), Emit(( - start: 92, - end: 93, + start: 82, + end: 83, )), Store( - pointer: 93, - value: 94, + pointer: 83, + value: 84, ), Emit(( - start: 94, - end: 99, + start: 84, + end: 89, )), Store( - pointer: 99, - value: 100, + pointer: 89, + value: 90, ), Return( value: None, @@ -974,8 +899,8 @@ ), ( name: Some("t"), - ty: 21, - init: Some(65), + ty: 20, + init: Some(53), ), ], expressions: [ @@ -1003,157 +928,145 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(5), AccessIndex( - base: 14, + base: 13, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 15, + base: 14, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 17, + base: 15, index: 0, ), Load( - pointer: 19, + pointer: 16, ), GlobalVariable(5), AccessIndex( - base: 21, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 22, + base: 19, index: 0, ), Load( pointer: 2, ), Access( - base: 24, - index: 25, + base: 20, + index: 21, ), Load( - pointer: 26, + pointer: 22, ), GlobalVariable(5), AccessIndex( - base: 28, + base: 24, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 29, + base: 25, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 31, + base: 26, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 33, + base: 27, index: 1, ), Load( - pointer: 35, + pointer: 28, ), GlobalVariable(5), AccessIndex( - base: 37, + base: 30, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 38, + base: 31, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 40, + base: 32, index: 0, ), Load( pointer: 2, ), Access( - base: 42, - index: 43, + base: 33, + index: 34, ), Load( - pointer: 44, + pointer: 35, ), GlobalVariable(5), AccessIndex( - base: 46, + base: 37, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 47, + base: 38, index: 0, ), Load( pointer: 2, ), Access( - base: 49, - index: 50, + base: 39, + index: 40, ), - Literal(I32(1)), AccessIndex( - base: 51, + base: 41, index: 1, ), Load( - pointer: 53, + pointer: 42, ), GlobalVariable(5), AccessIndex( - base: 55, + base: 44, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 56, + base: 45, index: 0, ), Load( pointer: 2, ), Access( - base: 58, - index: 59, + base: 46, + index: 47, ), Load( pointer: 2, ), Access( - base: 60, - index: 61, + base: 48, + index: 49, ), Load( - pointer: 62, + pointer: 50, ), - ZeroValue(20), + ZeroValue(19), Compose( - ty: 21, + ty: 20, components: [ - 64, + 52, ], ), LocalVariable(2), @@ -1163,190 +1076,178 @@ ), Binary( op: Add, - left: 68, - right: 67, + left: 56, + right: 55, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - ZeroValue(20), + ZeroValue(19), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 72, + base: 60, index: 0, ), Literal(F32(8.0)), Splat( size: Bi, - value: 75, + value: 62, ), Literal(F32(7.0)), Splat( size: Bi, - value: 77, + value: 64, ), Literal(F32(6.0)), Splat( size: Bi, - value: 79, + value: 66, ), Literal(F32(5.0)), Splat( size: Bi, - value: 81, + value: 68, ), Compose( - ty: 19, + ty: 18, components: [ - 76, - 78, - 80, - 82, + 63, + 65, + 67, + 69, ], ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 84, + base: 71, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 86, + base: 72, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 89, + value: 74, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 91, + base: 76, index: 0, ), Load( pointer: 2, ), Access( - base: 93, - index: 94, + base: 77, + index: 78, ), Literal(F32(90.0)), Splat( size: Bi, - value: 96, + value: 80, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 98, + base: 82, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 100, + base: 83, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 102, + base: 84, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 106, + base: 87, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 108, + base: 88, index: 0, ), Load( pointer: 2, ), Access( - base: 110, - index: 111, + base: 89, + index: 90, ), Literal(F32(20.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 114, + base: 93, index: 0, ), Load( pointer: 2, ), Access( - base: 116, - index: 117, + base: 94, + index: 95, ), - Literal(I32(1)), AccessIndex( - base: 118, + base: 96, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 122, + base: 99, index: 0, ), Load( pointer: 2, ), Access( - base: 124, - index: 125, + base: 100, + index: 101, ), Load( pointer: 2, ), Access( - base: 126, - index: 127, + base: 102, + index: 103, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 20: "l2", - 27: "l3", - 36: "l4", - 45: "l5", - 54: "l6", - 63: "l7", + 12: "l1", + 17: "l2", + 23: "l3", + 29: "l4", + 36: "l5", + 43: "l6", + 51: "l7", }, body: [ Emit(( @@ -1366,31 +1267,43 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, + )), + Emit(( + start: 13, + end: 14, )), Emit(( start: 14, end: 15, )), Emit(( - start: 16, + start: 15, end: 17, )), Emit(( start: 18, - end: 20, + end: 19, )), Emit(( - start: 21, - end: 22, + start: 19, + end: 23, )), Emit(( - start: 23, + start: 24, + end: 25, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 26, end: 27, )), Emit(( - start: 28, + start: 27, end: 29, )), Emit(( @@ -1398,11 +1311,11 @@ end: 31, )), Emit(( - start: 32, - end: 33, + start: 31, + end: 32, )), Emit(( - start: 34, + start: 32, end: 36, )), Emit(( @@ -1410,180 +1323,168 @@ end: 38, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 41, )), Emit(( start: 41, - end: 45, + end: 43, )), Emit(( - start: 46, - end: 47, + start: 44, + end: 45, )), Emit(( - start: 48, + start: 45, end: 51, )), Emit(( start: 52, - end: 54, + end: 53, )), Emit(( start: 55, - end: 56, - )), - Emit(( - start: 57, - end: 63, - )), - Emit(( - start: 64, - end: 65, - )), - Emit(( - start: 67, - end: 69, + end: 57, )), Store( pointer: 2, - value: 69, + value: 57, ), Emit(( - start: 69, - end: 70, + start: 57, + end: 58, )), Store( - pointer: 70, - value: 71, + pointer: 58, + value: 59, ), Emit(( - start: 71, - end: 72, + start: 59, + end: 60, )), Emit(( - start: 73, - end: 74, + start: 60, + end: 61, )), Emit(( - start: 75, - end: 76, + start: 62, + end: 63, )), Emit(( - start: 77, - end: 78, + start: 64, + end: 65, )), Emit(( - start: 79, - end: 80, + start: 66, + end: 67, )), Emit(( - start: 81, - end: 83, + start: 68, + end: 70, )), Store( - pointer: 74, - value: 83, + pointer: 61, + value: 70, ), Emit(( - start: 83, - end: 84, + start: 70, + end: 71, )), Emit(( - start: 85, - end: 86, + start: 71, + end: 72, )), Emit(( - start: 87, - end: 88, + start: 72, + end: 73, )), Emit(( - start: 89, - end: 90, + start: 74, + end: 75, )), Store( - pointer: 88, - value: 90, + pointer: 73, + value: 75, ), Emit(( - start: 90, - end: 91, + start: 75, + end: 76, )), Emit(( - start: 92, - end: 95, + start: 76, + end: 79, )), Emit(( - start: 96, - end: 97, + start: 80, + end: 81, )), Store( - pointer: 95, - value: 97, + pointer: 79, + value: 81, ), Emit(( - start: 97, - end: 98, + start: 81, + end: 82, )), Emit(( - start: 99, - end: 100, + start: 82, + end: 83, )), Emit(( - start: 101, - end: 102, + start: 83, + end: 84, )), Emit(( - start: 103, - end: 104, + start: 84, + end: 85, )), Store( - pointer: 104, - value: 105, + pointer: 85, + value: 86, ), Emit(( - start: 105, - end: 106, + start: 86, + end: 87, )), Emit(( - start: 107, - end: 108, + start: 87, + end: 88, )), Emit(( - start: 109, - end: 112, + start: 88, + end: 91, )), Store( - pointer: 112, - value: 113, + pointer: 91, + value: 92, ), Emit(( - start: 113, - end: 114, + start: 92, + end: 93, )), Emit(( - start: 115, - end: 118, + start: 93, + end: 96, )), Emit(( - start: 119, - end: 120, + start: 96, + end: 97, )), Store( - pointer: 120, - value: 121, + pointer: 97, + value: 98, ), Emit(( - start: 121, - end: 122, + start: 98, + end: 99, )), Emit(( - start: 123, - end: 128, + start: 99, + end: 104, )), Store( - pointer: 128, - value: 129, + pointer: 104, + value: 105, ), Return( value: None, @@ -1595,12 +1496,12 @@ arguments: [ ( name: Some("foo"), - ty: 23, + ty: 22, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], @@ -1628,25 +1529,23 @@ arguments: [ ( name: Some("a"), - ty: 25, + ty: 24, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], expressions: [ FunctionArgument(0), - Literal(I32(4)), AccessIndex( base: 1, index: 4, ), - Literal(I32(9)), AccessIndex( - base: 3, + base: 2, index: 9, ), ], @@ -1655,15 +1554,15 @@ }, body: [ Emit(( - start: 2, - end: 3, + start: 1, + end: 2, )), Emit(( - start: 4, - end: 5, + start: 2, + end: 3, )), Return( - value: Some(5), + value: Some(3), ), ], ), @@ -1672,7 +1571,7 @@ arguments: [ ( name: Some("p"), - ty: 31, + ty: 27, binding: None, ), ], @@ -1700,7 +1599,7 @@ arguments: [ ( name: Some("foo"), - ty: 33, + ty: 29, binding: None, ), ], @@ -1719,7 +1618,7 @@ value: 4, ), Compose( - ty: 32, + ty: 28, components: [ 3, 5, @@ -1730,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -1764,7 +1671,7 @@ ), ], result: Some(( - ty: 26, + ty: 25, binding: Some(BuiltIn(Position( invariant: false, ))), @@ -1772,12 +1679,12 @@ local_variables: [ ( name: Some("foo"), - ty: 22, + ty: 21, init: Some(2), ), ( name: Some("c2"), - ty: 28, + ty: 26, init: None, ), ], @@ -1859,13 +1766,12 @@ base: 30, index: 5, ), - Literal(I32(0)), AccessIndex( base: 31, index: 0, ), AccessIndex( - base: 33, + base: 32, index: 0, ), CallResult(3), @@ -1878,13 +1784,13 @@ Literal(I32(4)), Literal(I32(5)), Compose( - ty: 28, + ty: 26, components: [ 27, + 35, 36, 37, 38, - 39, ], ), LocalVariable(2), @@ -1892,42 +1798,42 @@ Binary( op: Add, left: 1, - right: 42, + right: 41, ), Access( - base: 41, - index: 43, + base: 40, + index: 42, ), Literal(I32(42)), Access( - base: 41, + base: 40, index: 1, ), Load( - pointer: 46, + pointer: 45, ), - ZeroValue(25), + ZeroValue(24), CallResult(4), Splat( size: Quad, - value: 47, + value: 46, ), As( - expr: 50, + expr: 49, kind: Float, convert: Some(4), ), Binary( op: Multiply, left: 8, - right: 51, + right: 50, ), Literal(F32(2.0)), Compose( - ty: 26, + ty: 25, components: [ + 51, 52, - 53, ], ), ], @@ -1940,9 +1846,9 @@ 17: "b", 27: "a", 29: "c", - 34: "data_pointer", - 35: "foo_value", - 47: "value", + 33: "data_pointer", + 34: "foo_value", + 46: "value", }, body: [ Emit(( @@ -1996,57 +1902,57 @@ end: 31, )), Emit(( - start: 32, - end: 34, + start: 31, + end: 33, )), Call( function: 3, arguments: [ 3, ], - result: Some(35), + result: Some(34), ), Emit(( - start: 35, - end: 36, + start: 34, + end: 35, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 39, )), Store( - pointer: 41, - value: 40, + pointer: 40, + value: 39, ), Emit(( - start: 42, - end: 44, + start: 41, + end: 43, )), Store( - pointer: 44, - value: 45, + pointer: 43, + value: 44, ), Emit(( - start: 45, - end: 47, + start: 44, + end: 46, )), Call( function: 4, arguments: [ - 48, + 47, ], - result: Some(49), + result: Some(48), ), Emit(( - start: 49, - end: 52, + start: 48, + end: 51, )), Emit(( - start: 53, - end: 54, + start: 52, + end: 53, )), Return( - value: Some(54), + value: Some(53), ), ], ), @@ -2060,7 +1966,7 @@ name: Some("foo_frag"), arguments: [], result: Some(( - ty: 26, + ty: 25, binding: Some(Location( location: 0, second_blend_source: false, @@ -2075,84 +1981,82 @@ base: 1, index: 0, ), - Literal(I32(1)), AccessIndex( base: 2, index: 1, ), AccessIndex( - base: 4, + base: 3, index: 2, ), Literal(F32(1.0)), GlobalVariable(2), AccessIndex( - base: 7, + base: 6, index: 0, ), Literal(F32(0.0)), Splat( size: Tri, - value: 9, + value: 8, ), Literal(F32(1.0)), Splat( size: Tri, - value: 11, + value: 10, ), Literal(F32(2.0)), Splat( size: Tri, - value: 13, + value: 12, ), Literal(F32(3.0)), Splat( size: Tri, - value: 15, + value: 14, ), Compose( ty: 6, components: [ - 10, - 12, - 14, - 16, + 9, + 11, + 13, + 15, ], ), GlobalVariable(2), AccessIndex( - base: 18, + base: 17, index: 4, ), Literal(U32(0)), Splat( size: Bi, - value: 20, + value: 19, ), Literal(U32(1)), Splat( size: Bi, - value: 22, + value: 21, ), Compose( ty: 12, components: [ - 21, - 23, + 20, + 22, ], ), GlobalVariable(2), AccessIndex( - base: 25, + base: 24, index: 5, ), - Literal(I32(1)), AccessIndex( - base: 26, + base: 25, index: 1, ), AccessIndex( - base: 28, + base: 26, index: 0, ), Literal(I32(1)), @@ -2161,7 +2065,7 @@ Literal(F32(0.0)), Splat( size: Quad, - value: 33, + value: 31, ), ], named_expressions: {}, @@ -2171,75 +2075,75 @@ end: 2, )), Emit(( - start: 3, - end: 5, + start: 2, + end: 4, )), Store( - pointer: 5, - value: 6, + pointer: 4, + value: 5, ), Emit(( - start: 7, - end: 8, + start: 6, + end: 7, )), Emit(( - start: 9, - end: 10, + start: 8, + end: 9, )), Emit(( - start: 11, - end: 12, + start: 10, + end: 11, )), Emit(( - start: 13, - end: 14, + start: 12, + end: 13, )), Emit(( - start: 15, - end: 17, + start: 14, + end: 16, )), Store( - pointer: 8, - value: 17, + pointer: 7, + value: 16, ), Emit(( - start: 18, - end: 19, + start: 17, + end: 18, )), Emit(( - start: 20, - end: 21, + start: 19, + end: 20, )), Emit(( - start: 22, - end: 24, + start: 21, + end: 23, )), Store( - pointer: 19, - value: 24, + pointer: 18, + value: 23, ), Emit(( - start: 25, - end: 26, + start: 24, + end: 25, )), Emit(( - start: 27, - end: 29, + start: 25, + end: 27, )), Store( - pointer: 29, - value: 30, + pointer: 27, + value: 28, ), Store( - pointer: 31, - value: 32, + pointer: 29, + value: 30, ), Emit(( - start: 33, - end: 34, + start: 31, + end: 32, )), Return( - value: Some(34), + value: Some(32), ), ], ), @@ -2261,7 +2165,7 @@ ), ( name: Some("arr"), - ty: 32, + ty: 28, init: Some(7), ), ], @@ -2279,7 +2183,7 @@ value: 5, ), Compose( - ty: 32, + ty: 28, components: [ 4, 6, @@ -2296,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index effde120a5..cfc3bfa0ee 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -60,11 +60,7 @@ init: None, ), ], - const_expressions: [ - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - ], + const_expressions: [], functions: [ ( name: Some("collatz_iterations"), diff --git a/naga/tests/out/msl/abstract-types.msl b/naga/tests/out/msl/abstract-types.msl new file mode 100644 index 0000000000..a9c89f7a9c --- /dev/null +++ b/naga/tests/out/msl/abstract-types.msl @@ -0,0 +1,59 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_5 { + float inner[2]; +}; +struct S { + float f; + int i; + uint u; +}; +constant metal::uint2 xvupaiai = metal::uint2(42u, 43u); +constant metal::float2 xvfpaiai = metal::float2(44.0, 45.0); +constant metal::uint2 xvupuai = metal::uint2(42u, 43u); +constant metal::uint2 xvupaiu = metal::uint2(42u, 43u); +constant metal::uint2 xvuuai = metal::uint2(42u, 43u); +constant metal::uint2 xvuaiu = metal::uint2(42u, 43u); +constant metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::int2 ivispai = metal::int2(1); +constant metal::float2 ivfspaf = metal::float2(1.0); +constant metal::int2 ivis_ai = metal::int2(1); +constant metal::uint2 ivus_ai = metal::uint2(1u); +constant metal::float2 ivfs_ai = metal::float2(1.0); +constant metal::float2 ivfs_af = metal::float2(1.0); +constant type_5 iafafaf = type_5 {1.0, 2.0}; +constant type_5 iafaiai = type_5 {1.0, 2.0}; +constant type_5 iafpafaf = type_5 {1.0, 2.0}; +constant type_5 iafpaiaf = type_5 {1.0, 2.0}; +constant type_5 iafpafai = type_5 {1.0, 2.0}; +constant type_5 xafpafaf = type_5 {1.0, 2.0}; +constant S s_f_i_u = S {1.0, 1, 1u}; +constant S s_f_iai = S {1.0, 1, 1u}; +constant S s_fai_u = S {1.0, 1, 1u}; +constant S s_faiai = S {1.0, 1, 1u}; +constant S saf_i_u = S {1.0, 1, 1u}; +constant S saf_iai = S {1.0, 1, 1u}; +constant S safai_u = S {1.0, 1, 1u}; +constant S safaiai = S {1.0, 1, 1u}; +constant metal::float3 ivfr_f_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfr_f_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_fr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_fraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivfr_f_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_frai = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_air_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_airai = metal::float3(1.0, metal::float2(2.0, 3.0)); diff --git a/naga/tests/out/spv/abstract-types.spvasm b/naga/tests/out/spv/abstract-types.spvasm new file mode 100644 index 0000000000..207a04f564 --- /dev/null +++ b/naga/tests/out/spv/abstract-types.spvasm @@ -0,0 +1,46 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 36 +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpDecorate %10 ArrayStride 4 +OpMemberDecorate %12 0 Offset 0 +OpMemberDecorate %12 1 Offset 4 +OpMemberDecorate %12 2 Offset 8 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 2 +%6 = OpTypeFloat 32 +%5 = OpTypeVector %6 2 +%7 = OpTypeMatrix %5 2 +%9 = OpTypeInt 32 1 +%8 = OpTypeVector %9 2 +%11 = OpConstant %4 2 +%10 = OpTypeArray %6 %11 +%12 = OpTypeStruct %6 %9 %4 +%13 = OpTypeVector %6 3 +%14 = OpConstant %4 42 +%15 = OpConstant %4 43 +%16 = OpConstantComposite %3 %14 %15 +%17 = OpConstant %6 44.0 +%18 = OpConstant %6 45.0 +%19 = OpConstantComposite %5 %17 %18 +%20 = OpConstant %6 1.0 +%21 = OpConstant %6 2.0 +%22 = OpConstantComposite %5 %20 %21 +%23 = OpConstant %6 3.0 +%24 = OpConstant %6 4.0 +%25 = OpConstantComposite %5 %23 %24 +%26 = OpConstantComposite %7 %22 %25 +%27 = OpConstant %9 1 +%28 = OpConstantComposite %8 %27 %27 +%29 = OpConstantComposite %5 %20 %20 +%30 = OpConstant %4 1 +%31 = OpConstantComposite %3 %30 %30 +%32 = OpConstantComposite %10 %20 %21 +%33 = OpConstantComposite %12 %20 %27 %30 +%34 = OpConstantComposite %13 %20 %21 %23 +%35 = OpConstantComposite %5 %21 %23 \ No newline at end of file diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm index d96dbb315b..23d5dd1baa 100644 --- a/naga/tests/out/spv/ray-query.spvasm +++ b/naga/tests/out/spv/ray-query.spvasm @@ -66,10 +66,10 @@ OpMemberDecorate %18 0 Offset 0 %47 = OpConstantComposite %5 %27 %25 %27 %48 = OpConstant %4 4 %49 = OpConstant %4 255 -%50 = OpConstant %6 0.1 -%51 = OpConstant %6 100.0 -%52 = OpConstantComposite %5 %27 %27 %27 -%53 = OpConstantComposite %14 %48 %49 %50 %51 %52 %47 +%50 = OpConstantComposite %5 %27 %27 %27 +%51 = OpConstant %6 0.1 +%52 = OpConstant %6 100.0 +%53 = OpConstantComposite %14 %48 %49 %51 %52 %50 %47 %55 = OpTypePointer Function %13 %72 = OpConstant %4 1 %85 = OpTypePointer StorageBuffer %4 diff --git a/naga/tests/out/wgsl/abstract-types.wgsl b/naga/tests/out/wgsl/abstract-types.wgsl new file mode 100644 index 0000000000..4096f9ff02 --- /dev/null +++ b/naga/tests/out/wgsl/abstract-types.wgsl @@ -0,0 +1,52 @@ +struct S { + f: f32, + i: i32, + u: u32, +} + +const xvupaiai: vec2 = vec2(42u, 43u); +const xvfpaiai: vec2 = vec2(44.0, 45.0); +const xvupuai: vec2 = vec2(42u, 43u); +const xvupaiu: vec2 = vec2(42u, 43u); +const xvuuai: vec2 = vec2(42u, 43u); +const xvuaiu: vec2 = vec2(42u, 43u); +const xmfpaiaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiafaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiafai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiaiaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const ivispai: vec2 = vec2(1); +const ivfspaf: vec2 = vec2(1.0); +const ivis_ai: vec2 = vec2(1); +const ivus_ai: vec2 = vec2(1u); +const ivfs_ai: vec2 = vec2(1.0); +const ivfs_af: vec2 = vec2(1.0); +const iafafaf: array = array(1.0, 2.0); +const iafaiai: array = array(1.0, 2.0); +const iafpafaf: array = array(1.0, 2.0); +const iafpaiaf: array = array(1.0, 2.0); +const iafpafai: array = array(1.0, 2.0); +const xafpafaf: array = array(1.0, 2.0); +const s_f_i_u: S = S(1.0, 1, 1u); +const s_f_iai: S = S(1.0, 1, 1u); +const s_fai_u: S = S(1.0, 1, 1u); +const s_faiai: S = S(1.0, 1, 1u); +const saf_i_u: S = S(1.0, 1, 1u); +const saf_iai: S = S(1.0, 1, 1u); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1u); +const ivfr_f_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfr_f_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_fr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_fraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivfr_f_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_frai: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_air_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_airai: vec3 = vec3(1.0, vec2(2.0, 3.0)); + diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 4ad17f1a2a..80cf87dc1e 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -783,6 +783,10 @@ fn convert_wgsl() { "f64", Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "abstract-types", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 99257457fb..21785aaf3b 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -502,7 +502,7 @@ fn let_type_mismatch() { r#" const x: i32 = 1.0; "#, - r#"error: the type of `x` is expected to be `i32`, but got `f32` + r#"error: the type of `x` is expected to be `i32`, but got `{AbstractFloat}` ┌─ wgsl:2:19 │ 2 │ const x: i32 = 1.0; @@ -1996,9 +1996,12 @@ fn binding_array_non_struct() { #[test] fn compaction_preserves_spans() { let source = r#" - const a: i32 = -(-(-(-42i))); - const b: vec2 = vec2(42u, 43i); - "#; // ^^^^^^^^^^^^^^^^^^^ correct error span: 68..87 + fn f() { + var a: i32 = -(-(-(-42i))); + var x: i32; + x = 42u; + } + "#; // ^^^ correct error span: 95..98 let mut module = naga::front::wgsl::parse_str(source).expect("source ought to parse"); naga::compact::compact(&mut module); let err = naga::valid::Validator::new( @@ -2011,10 +2014,18 @@ fn compaction_preserves_spans() { // Ideally this would all just be a `matches!` with a big pattern, // but the `Span` API is full of opaque structs. let mut spans = err.spans(); - let first_span = spans.next().expect("error should have at least one span").0; + + // The first span is the whole function. + let _ = spans.next().expect("error should have at least one span"); + + // The second span is the assignment destination. + let dest_span = spans + .next() + .expect("error should have at least two spans") + .0; if !matches!( - first_span.to_range(), - Some(std::ops::Range { start: 68, end: 87 }) + dest_span.to_range(), + Some(std::ops::Range { start: 95, end: 98 }) ) { panic!("Error message has wrong span:\n\n{err:#?}"); } From a3911b50e36c4fdca48e72c72d75075dd5cf881e Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 27 Nov 2023 13:33:38 -0800 Subject: [PATCH 6/9] [naga wgsl-in] Clarify match in `automatic_conversion_join`. Co-authored-by: Teodor Tanasoaia <28601907+teoxoy@users.noreply.github.com> --- naga/src/front/wgsl/lower/conversion.rs | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs index 893c425433..0819c13a66 100644 --- a/naga/src/front/wgsl/lower/conversion.rs +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -330,26 +330,16 @@ impl crate::Scalar { } // AbstractInt converts to AbstractFloat. - (Sk::AbstractFloat | Sk::AbstractInt, Sk::AbstractFloat | Sk::AbstractInt) => { - Some(Self { - kind: Sk::AbstractFloat, - width: crate::ABSTRACT_WIDTH, - }) - } + (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), + (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), // AbstractFloat converts to Float. - (Sk::AbstractFloat, Sk::Float) => Some(Self::float(other.width)), - (Sk::Float, Sk::AbstractFloat) => Some(Self::float(self.width)), + (Sk::AbstractFloat, Sk::Float) => Some(other), + (Sk::Float, Sk::AbstractFloat) => Some(self), // AbstractInt converts to concrete integer or float. - (Sk::AbstractInt, kind @ (Sk::Uint | Sk::Sint | Sk::Float)) => Some(Self { - kind, - width: other.width, - }), - (kind @ (Sk::Uint | Sk::Sint | Sk::Float), Sk::AbstractInt) => Some(Self { - kind, - width: self.width, - }), + (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), + (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), // AbstractFloat can't be reconciled with concrete integer types. (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { From 081028d6ddef4efdd0e830696d5f46bfc840519b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 27 Nov 2023 18:16:22 -0800 Subject: [PATCH 7/9] [naga] Improve snapshot output when validation fails. --- naga/tests/snapshots.rs | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 80cf87dc1e..35370e7f07 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -275,7 +275,13 @@ fn check_targets( let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) .validate(module) - .unwrap_or_else(|_| panic!("Naga module validation failed on test '{}'", name.display())); + .unwrap_or_else(|err| { + panic!( + "Naga module validation failed on test `{}`:\n{:?}", + name.display(), + err + ); + }); #[cfg(feature = "compact")] let info = { @@ -292,10 +298,11 @@ fn check_targets( naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) .validate(module) - .unwrap_or_else(|_| { + .unwrap_or_else(|err| { panic!( - "Post-compaction module validation failed on test '{}'", - name.display() + "Post-compaction module validation failed on test '{}':\n<{:?}", + name.display(), + err, ) }) }; From 45dc411018c60dc8ece48f1319155b29aec67f59 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Mon, 27 Nov 2023 18:17:38 -0800 Subject: [PATCH 8/9] [naga wgsl-in] Matrix constructors have only float overloads. --- naga/src/front/wgsl/lower/construction.rs | 4 ++++ naga/tests/in/abstract-types.wgsl | 4 ++++ naga/tests/out/msl/abstract-types.msl | 3 +++ naga/tests/out/wgsl/abstract-types.wgsl | 3 +++ 4 files changed, 14 insertions(+) diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 5e58c93892..2cd030dadc 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -346,6 +346,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; + // We actually only accept floating-point elements. + let consensus_scalar = consensus_scalar + .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT) + .unwrap_or(consensus_scalar); ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); diff --git a/naga/tests/in/abstract-types.wgsl b/naga/tests/in/abstract-types.wgsl index f9f718bf01..0cf20b5d26 100644 --- a/naga/tests/in/abstract-types.wgsl +++ b/naga/tests/in/abstract-types.wgsl @@ -27,6 +27,10 @@ const xmfpaiafaiai: mat2x2 = mat2x2(1, 2.0, 3, 4); const xmfpaiaiafai: mat2x2 = mat2x2(1, 2, 3.0, 4); const xmfpaiaiaiaf: mat2x2 = mat2x2(1, 2, 3, 4.0); +const imfpaiaiaiai = mat2x2(1, 2, 3, 4); +const imfpafaiaiai = mat2x2(1.0, 2, 3, 4); +const imfpafafafaf = mat2x2(1.0, 2.0, 3.0, 4.0); + const ivispai = vec2(1); const ivfspaf = vec2(1.0); const ivis_ai = vec2(1); diff --git a/naga/tests/out/msl/abstract-types.msl b/naga/tests/out/msl/abstract-types.msl index a9c89f7a9c..af4de25cc6 100644 --- a/naga/tests/out/msl/abstract-types.msl +++ b/naga/tests/out/msl/abstract-types.msl @@ -23,6 +23,9 @@ constant metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), constant metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); constant metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); constant metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpafafafaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); constant metal::int2 ivispai = metal::int2(1); constant metal::float2 ivfspaf = metal::float2(1.0); constant metal::int2 ivis_ai = metal::int2(1); diff --git a/naga/tests/out/wgsl/abstract-types.wgsl b/naga/tests/out/wgsl/abstract-types.wgsl index 4096f9ff02..538be44df7 100644 --- a/naga/tests/out/wgsl/abstract-types.wgsl +++ b/naga/tests/out/wgsl/abstract-types.wgsl @@ -15,6 +15,9 @@ const xmfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0 const xmfpaiafaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); const xmfpaiaiafai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); const xmfpaiaiaiaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpaiaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpafafafaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); const ivispai: vec2 = vec2(1); const ivfspaf: vec2 = vec2(1.0); const ivis_ai: vec2 = vec2(1); From ffa8d16f1c10a98d12156ac4db096c31ead7a8ac Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 28 Nov 2023 21:15:07 -0800 Subject: [PATCH 9/9] [naga wgsl-in] Constructors with types don't make abstract values. When a constructor builtin has an explicit type parameter, like `mat2x2`, it should not produce an abstract matrix, even if its arguments are abstract. --- naga/src/front/wgsl/lower/construction.rs | 30 +++++++++++++++++------ naga/tests/wgsl_errors.rs | 25 ++++++++++++++++--- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 2cd030dadc..c7e4106460 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -411,20 +411,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { expr = crate::Expression::Compose { ty, components }; } - // Matrix constructor (by columns) + // Matrix constructor (by columns), partial ( Components::Many { mut components, spans, }, Constructor::PartialMatrix { columns, rows }, - ) - | ( - Components::Many { - mut components, - spans, - }, - Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) => { let consensus_scalar = automatic_conversion_consensus(&components, ctx).map_err(|index| { @@ -439,6 +432,27 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { expr = crate::Expression::Compose { ty, components }; } + // Matrix constructor (by columns), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + ty, + &crate::TypeInner::Matrix { + columns: _, + rows, + scalar, + }, + )), + ) => { + let component_ty = crate::TypeInner::Vector { size: rows, scalar }; + ctx.try_automatic_conversions_slice( + &mut components, + &Tr::Value(component_ty), + span, + )?; + expr = crate::Expression::Compose { ty, components }; + } + // Array constructor - infer type (components, Constructor::PartialArray) => { let mut components = components.into_components_vec(); diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 21785aaf3b..56ca313464 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -209,11 +209,14 @@ fn constructor_parameter_type_mismatch() { _ = mat2x2(array(0, 1), vec2(2, 3)); } "#, - r#"error: invalid type for constructor component at index [0] - ┌─ wgsl:3:33 + r#"error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2` + ┌─ wgsl:3:21 │ 3 │ _ = mat2x2(array(0, 1), vec2(2, 3)); - │ ^^^^^^^^^^^ invalid component type + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + │ │ │ + │ │ this expression has type array<{AbstractInt}, 2> + │ a value of type vec2 is required here "#, ); @@ -829,6 +832,22 @@ fn matrix_with_bad_type() { ); } +#[test] +fn matrix_constructor_inferred() { + check( + r#" + const m: mat2x2 = mat2x2(vec2(0), vec2(1)); + "#, + r#"error: the type of `m` is expected to be `mat2x2`, but got `mat2x2` + ┌─ wgsl:2:19 + │ +2 │ const m: mat2x2 = mat2x2(vec2(0), vec2(1)); + │ ^ definition of `m` + +"#, + ); +} + /// Check the result of validating a WGSL program against a pattern. /// /// Unless you are generating code programmatically, the