From 5789746eadf9839c9bcfb0a2b9dfec90ba46713b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 22 Nov 2023 13:50:47 -0800 Subject: [PATCH] [naga wgsl-in] Implement abstract types for consts, constructors. --- CHANGELOG.md | 40 +- naga/Cargo.toml | 2 +- naga/src/front/wgsl/error.rs | 20 + naga/src/front/wgsl/lower/construction.rs | 208 +++-- naga/src/front/wgsl/lower/conversion.rs | 378 ++++++++ naga/src/front/wgsl/lower/mod.rs | 71 +- naga/src/front/wgsl/parse/lexer.rs | 45 +- naga/src/front/wgsl/parse/number.rs | 28 +- naga/src/front/wgsl/tests.rs | 4 +- naga/src/proc/constant_evaluator.rs | 109 ++- naga/src/valid/expression.rs | 50 +- naga/src/valid/function.rs | 2 +- naga/tests/in/abstract-types.wgsl | 77 ++ naga/tests/out/ir/access.compact.ron | 16 + naga/tests/out/ir/access.ron | 1028 ++++++++++----------- naga/tests/out/ir/collatz.ron | 6 +- naga/tests/out/msl/abstract-types.msl | 59 ++ naga/tests/out/spv/abstract-types.spvasm | 46 + naga/tests/out/spv/ray-query.spvasm | 8 +- naga/tests/out/wgsl/abstract-types.wgsl | 52 ++ naga/tests/snapshots.rs | 4 + naga/tests/wgsl_errors.rs | 25 +- 22 files changed, 1539 insertions(+), 739 deletions(-) create mode 100644 naga/src/front/wgsl/lower/conversion.rs create mode 100644 naga/tests/in/abstract-types.wgsl create mode 100644 naga/tests/out/msl/abstract-types.msl create mode 100644 naga/tests/out/spv/abstract-types.spvasm create mode 100644 naga/tests/out/wgsl/abstract-types.wgsl diff --git a/CHANGELOG.md b/CHANGELOG.md index f12dd95f7fe..00450c43a1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,8 +102,44 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S - Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673). - Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707). -- Implement WGSL abstract types (by @jimblandy): - - Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) +- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)). + + Abstract types make numeric literals easier to use, by + automatically converting literals and other constant expressions + from abstract numeric types to concrete types when safe and + necessary. For example, to build a vector of floating-point + numbers, Naga previously made you write: + + vec3(1.0, 2.0, 3.0) + + With this change, you can now simply write: + + vec3(1, 2, 3) + + Even though the literals are abstract integers, Naga recognizes + that it is safe and necessary to convert them to `f32` values in + order to build the vector. You can also use abstract values as + initializers for global constants, like this: + + const unit_x: vec2 = vec2(1, 0); + + The literals `1` and `0` are abstract integers, and the expression + `vec2(1, 0)` is an abstract vector. However, Naga recognizes that + it can convert that to the concrete type `vec2` to satisfy + the given type of `unit_x`. + + The WGSL specification permits abstract integers and + floating-point values in almost all contexts, but Naga's support + for this is still incomplete. Many WGSL operators and builtin + functions are specified to produce abstract results when applied + to abstract inputs, but for now Naga simply concretizes them all + before applying the operation. We will expand Naga's abstract type + support in subsequent pull requests. + + As part of this work, the public types `naga::ScalarKind` and + `naga::Literal` now have new variants, `AbstractInt` and `AbstractFloat`. + +- Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) - Emit and init `struct` member padding always. By @ErichDonGubler in [#4701](https://github.com/gfx-rs/wgpu/pull/4701). diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 20b5bd5d255..0cacc493ac6 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -31,7 +31,7 @@ deserialize = ["serde", "bitflags/serde", "indexmap/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] spv-in = ["petgraph", "spirv"] spv-out = ["spirv"] -wgsl-in = ["hexf-parse", "unicode-xid"] +wgsl-in = ["hexf-parse", "unicode-xid", "compact"] wgsl-out = [] hlsl-out = [] compact = [] diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f5acbe2d655..dc101246800 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -251,6 +251,12 @@ pub enum Error<'a> { ExpectedPositiveArrayLength(Span), MissingWorkgroupSize(Span), ConstantEvaluatorError(ConstantEvaluatorError, Span), + AutoConversion { + dest_span: Span, + dest_type: String, + source_span: Span, + source_type: String, + }, } impl<'a> Error<'a> { @@ -712,6 +718,20 @@ impl<'a> Error<'a> { )], notes: vec![], }, + Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"), + labels: vec![ + ( + dest_span, + format!("a value of type {dest_type} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + } } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index 378ee664749..1708dce89e8 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -116,13 +116,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { components: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { + use crate::proc::TypeResolution as Tr; + let constructor_h = self.constructor(constructor, ctx)?; let components = match *components { [] => Components::None, [component] => { let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; + let component = self.expression_for_abstract(component, ctx)?; let ty_inner = super::resolve_inner!(ctx, component); Components::One { @@ -134,13 +136,17 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ref ast_components @ [_, _, ..] => { let components = ast_components .iter() - .map(|&expr| self.expression(expr, ctx)) + .map(|&expr| self.expression_for_abstract(expr, ctx)) .collect::>()?; let spans = ast_components .iter() .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); + for &component in &components { + ctx.grow_types(component)?; + } + Components::Many { components, spans } } }; @@ -288,18 +294,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Vector constructor (splat) ( Components::One { - component, - ty_inner: &crate::TypeInner::Scalar(src_scalar), + mut component, + ty_inner: &crate::TypeInner::Scalar(_), .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { - size, - scalar: dst_scalar, - }, - )), - ) if dst_scalar == src_scalar => { + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_scalar(std::slice::from_mut(&mut component), scalar)?; expr = crate::Expression::Splat { size, value: component, @@ -307,37 +308,82 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } // Vector constructor (by elements), partial - (Components::Many { components, spans }, Constructor::PartialVector { size }) => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + ( + Components::Many { + mut components, + spans, + }, + Constructor::PartialVector { size }, + ) => { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; - let inner = scalar.to_inner_vector(size); + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); let ty = ctx.ensure_type_exists(inner); expr = crate::Expression::Compose { ty, components }; } // Vector constructor (by elements), full type given ( - Components::Many { components, .. }, - Constructor::Type((ty, &crate::TypeInner::Vector { .. })), + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; expr = crate::Expression::Compose { ty, components }; } - // Matrix constructor (by elements) + // Matrix constructor (by elements), partial ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::PartialMatrix { columns, rows }, - ) - | ( - Components::Many { components, spans }, - Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) if components.len() == columns as usize * rows as usize => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, + Constructor::Type(( + _, + &crate::TypeInner::Matrix { + columns, + rows, + scalar, + }, + )), + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components @@ -363,28 +409,55 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // Matrix constructor (by columns) ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::PartialMatrix { columns, rows }, ) | ( - Components::Many { components, spans }, + Components::Many { + mut components, + spans, + }, Constructor::Type((_, &crate::TypeInner::Matrix { columns, rows, .. })), ) => { - let scalar = - component_scalar_from_constructor_args(&components, ctx).map_err(|index| { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { Error::InvalidConstructorComponentType(spans[index], index as i32) })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { columns, rows, - scalar, + scalar: consensus_scalar, }); expr = crate::Expression::Compose { ty, components }; } // Array constructor - infer type (components, Constructor::PartialArray) => { - let components = components.into_components_vec(); + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = automatic_conversion_consensus(&components, ctx) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // function only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } let base = ctx.register_type(components[0])?; @@ -403,15 +476,30 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { expr = crate::Expression::Compose { ty, components }; } - // Array or Struct constructor + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), span)?; + expr = crate::Expression::Compose { ty, components }; + } + + // Struct constructor ( components, - Constructor::Type(( - ty, - &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, - )), + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), ) => { - let components = components.into_components_vec(); + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } expr = crate::Expression::Compose { ty, components }; } @@ -504,12 +592,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } -/// Compute a vector or matrix's scalar type from those of its -/// constructor arguments. +/// Find the consensus scalar of `components` under WGSL's automatic +/// conversions. /// -/// Given `components`, the arguments given to a vector or matrix -/// constructor, return the scalar type of the vector or matrix's -/// elements. +/// If `components` can all be converted to any common scalar via +/// WGSL's automatic conversions, return the best such scalar. /// /// The `components` slice must not be empty. All elements' types must /// have been resolved. @@ -518,20 +605,31 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { /// constructors, return `Err(i)`, where `i` is the index in /// `components` of some problematic argument. /// -/// This function doesn't fully type-check the arguments, so it may -/// return `Ok` even when the Naga validator will reject the resulting +/// This function doesn't fully type-check the arguments - it only +/// considers their leaf scalar types. This means it may return `Ok` +/// even when the Naga validator will reject the resulting /// construction expression later. -fn component_scalar_from_constructor_args( +fn automatic_conversion_consensus( components: &[Handle], - ctx: &mut ExpressionContext<'_, '_, '_>, + ctx: &ExpressionContext<'_, '_, '_>, ) -> Result { - // Since we don't yet implement abstract types, we can settle for - // just inspecting the first element. - let first = components[0]; - ctx.grow_types(first).map_err(|_| 0_usize)?; - let inner = ctx.typifier()[first].inner_with(&ctx.module.types); - match inner.scalar() { - Some(scalar) => Ok(scalar), - None => Err(0), + log::trace!("JIMB: automatic_conversion_consensus"); + let types = &ctx.module.types; + let mut inners = components + .iter() + .map(|&c| ctx.typifier()[c].inner_with(types)); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + log::trace!(" start: {best:?}"); + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_join(scalar) { + Some(new_best) => { + best = new_best; + log::trace!(" new: {best:?}"); + } + None => return Err(i), + } } + + Ok(best) } diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 00000000000..d332c4fcc52 --- /dev/null +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,378 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + let converted = if let crate::TypeInner::Array { .. } = *goal_inner { + let span = self.get_expression_span(expr); + self.as_const_evaluator() + .cast_array(expr, goal_scalar, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))? + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + log::trace!("JIMB: Emitting {cast:?}"); + self.append_expression(cast, expr_span)? + }; + + Ok(converted) + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_join`. + pub fn convert_slice_to_common_scalar( + &mut self, + exprs: &mut [Handle], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle, + ) -> Result, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + let span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))?; + } + } + + Ok(expr) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + log::trace!("JIMB: automatically_converts_to: {self:?} -> {goal:?}"); + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4` to `vec4`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena, + ) -> Option { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_join(self, other: Self) -> Option { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat | Sk::AbstractInt, Sk::AbstractFloat | Sk::AbstractInt) => { + Some(Self { + kind: Sk::AbstractFloat, + width: crate::ABSTRACT_WIDTH, + }) + } + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(Self::float(other.width)), + (Sk::Float, Sk::AbstractFloat) => Some(Self::float(self.width)), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, kind @ (Sk::Uint | Sk::Sint | Sk::Float)) => Some(Self { + kind, + width: other.width, + }), + (kind @ (Sk::Uint | Sk::Sint | Sk::Float), Sk::AbstractInt) => Some(Self { + kind, + width: self.width, + }), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index a727d6379b5..b050ffc343d 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -11,6 +11,7 @@ use crate::proc::{ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; mod construction; +mod conversion; /// Resolves the inner type of a given expression. /// @@ -66,6 +67,7 @@ macro_rules! resolve { &$ctx.typifier()[$expr] }}; } +pub(super) use resolve; /// State for constructing a `crate::Module`. pub struct GlobalContext<'source, 'temp, 'out> { @@ -903,29 +905,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Const(ref c) => { let mut ectx = ctx.as_const(); - let init = self.expression(c.init, &mut ectx)?; - let inferred_type = ectx.register_type(init)?; - - let explicit_ty = - c.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) - .transpose()?; - - if let Some(explicit) = explicit_ty { - if explicit != inferred_type { - let gctx = ctx.module.to_ctx(); - return Err(Error::InitializationTypeMismatch { - name: c.name.span, - expected: explicit.to_wgsl(&gctx), - got: inferred_type.to_wgsl(&gctx), - }); - } + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; } let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), r#override: crate::Override::None, - ty: inferred_type, + ty, init, }, span, @@ -951,6 +963,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + Ok(module) } @@ -1449,10 +1466,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract fn expression( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let expr = self.expression_for_reference(expr, ctx)?; ctx.apply_load_rule(expr) @@ -1473,8 +1505,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), - ast::Literal::Number(_) => { - unreachable!("got abstract numeric type when not expected"); + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) } ast::Literal::Bool(b) => crate::Literal::Bool(b), }; diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index 1cc5a0c438f..9fa9416f111 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -465,13 +465,13 @@ fn test_numbers() { sub_test( "0x123 0X123u 1u 123 0 0i 0x3f", &[ - Token::Number(Ok(Number::I32(291))), + Token::Number(Ok(Number::AbstractInt(291))), Token::Number(Ok(Number::U32(291))), Token::Number(Ok(Number::U32(1))), - Token::Number(Ok(Number::I32(123))), + Token::Number(Ok(Number::AbstractInt(123))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(63))), + Token::Number(Ok(Number::AbstractInt(63))), ], ); // decimal floating point @@ -479,17 +479,17 @@ fn test_numbers() { "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", &[ Token::Number(Ok(Number::F32(0.))), - Token::Number(Ok(Number::F32(1.))), - Token::Number(Ok(Number::F32(0.01))), - Token::Number(Ok(Number::F32(12.34))), + Token::Number(Ok(Number::AbstractFloat(1.))), + Token::Number(Ok(Number::AbstractFloat(0.01))), + Token::Number(Ok(Number::AbstractFloat(12.34))), Token::Number(Ok(Number::F32(0.))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.001))), - Token::Number(Ok(Number::F32(43.75))), + Token::Number(Ok(Number::AbstractFloat(0.001))), + Token::Number(Ok(Number::AbstractFloat(43.75))), Token::Number(Ok(Number::F32(16.))), - Token::Number(Ok(Number::F32(0.1875))), + Token::Number(Ok(Number::AbstractFloat(0.1875))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.12109375))), + Token::Number(Ok(Number::AbstractFloat(0.12109375))), Token::Number(Err(NumberError::UnimplementedF16)), ], ); @@ -635,7 +635,7 @@ fn double_floats() { Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(10.0))), - Token::Number(Ok(Number::I32(10))), + Token::Number(Ok(Number::AbstractInt(10))), Token::Word("l"), ], ) @@ -646,13 +646,16 @@ fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); sub_test( "92No", - &[Token::Number(Ok(Number::I32(92))), Token::Word("No")], + &[ + Token::Number(Ok(Number::AbstractInt(92))), + Token::Word("No"), + ], ); sub_test( "2u3o", &[ Token::Number(Ok(Number::U32(2))), - Token::Number(Ok(Number::I32(3))), + Token::Number(Ok(Number::AbstractInt(3))), Token::Word("o"), ], ); @@ -660,7 +663,7 @@ fn test_tokens() { "2.4f44po", &[ Token::Number(Ok(Number::F32(2.4))), - Token::Number(Ok(Number::I32(44))), + Token::Number(Ok(Number::AbstractInt(44))), Token::Word("po"), ], ); @@ -699,13 +702,13 @@ fn test_tokens() { &[ // The 'f' suffixes are taken as a hex digit: // the fractional part is 0x2f / 256. - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("h"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("H"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("lf"), ], ) @@ -719,7 +722,7 @@ fn test_variable_decl() { Token::Attribute, Token::Word("group"), Token::Paren('('), - Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Paren(')'), Token::Word("var"), Token::Paren('<'), diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index 3178736990f..fde5e3cee6d 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -20,37 +20,11 @@ pub enum Number { F64(f64), } -impl Number { - /// Convert abstract numbers to a plausible concrete counterpart. - /// - /// Return concrete numbers unchanged. If the conversion would be - /// lossy, return an error. - fn abstract_to_concrete(self) -> Result { - match self { - Number::AbstractInt(num) => i32::try_from(num) - .map(Number::I32) - .map_err(|_| NumberError::NotRepresentable), - Number::AbstractFloat(num) => { - let num = num as f32; - if num.is_finite() { - Ok(Number::F32(num)) - } else { - Err(NumberError::NotRepresentable) - } - } - num => Ok(num), - } - } -} - // TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { let (result, rest) = parse(input); - ( - Token::Number(result.and_then(Number::abstract_to_concrete)), - rest, - ) + (Token::Number(result), rest) } enum Kind { diff --git a/naga/src/front/wgsl/tests.rs b/naga/src/front/wgsl/tests.rs index 9e3ba2fab69..eb2f8a2eb36 100644 --- a/naga/src/front/wgsl/tests.rs +++ b/naga/src/front/wgsl/tests.rs @@ -76,7 +76,7 @@ fn parse_type_cast() { assert!(parse_str( " fn main() { - let x: vec2 = vec2(0); + let x: vec2 = vec2(0i, 0i); } ", ) @@ -313,7 +313,7 @@ fn parse_texture_load() { " var t: texture_3d; fn foo() { - let r: vec4 = textureLoad(t, vec3(0.0, 1.0, 2.0), 1); + let r: vec4 = textureLoad(t, vec3(0u, 1u, 2u), 1); } ", ) diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 6901ada20c8..51e447847c8 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -277,6 +277,17 @@ impl<'a> ConstantEvaluator<'a> { } } + pub fn to_ctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.types, + constants: self.constants, + const_expressions: match self.function_local_data { + Some(ref data) => data.const_expressions, + None => self.expressions, + }, + } + } + fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { if let Some(ref function_local_data) = self.function_local_data { if !function_local_data.expression_constness.is_const(expr) { @@ -1035,7 +1046,21 @@ impl<'a> ConstantEvaluator<'a> { return Err(ConstantEvaluatorError::InvalidCastArg) } }), - _ => return Err(ConstantEvaluatorError::InvalidCastArg), + Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal { + Literal::AbstractInt(v) => { + // Overflow is forbidden, but inexact conversions + // are fine. The range of f64 is far larger than + // that of i64, so we don't have to check anything + // here. + v as f64 + } + Literal::AbstractFloat(v) => v, + _ => return Err(ConstantEvaluatorError::InvalidCastArg), + }), + _ => { + log::debug!("Constant evaluator refused to convert value to {target:?}"); + return Err(ConstantEvaluatorError::InvalidCastArg); + } }; Expression::Literal(literal) } @@ -1085,6 +1110,66 @@ impl<'a> ConstantEvaluator<'a> { self.register_evaluated_expr(expr, span) } + /// Convert the scalar leaves of `expr` to `target`, handling arrays. + /// + /// `expr` must be a `Compose` expression whose type is a scalar, vector, + /// matrix, or nested arrays of such. + /// + /// This is basically the same as the [`cast`] method, except that that + /// should only handle Naga [`As`] expressions, which cannot convert arrays. + /// + /// Treat `span` as the location of the resulting expression. + /// + /// [`cast`]: ConstantEvaluator::cast + /// [`As`]: crate::Expression::As + pub fn cast_array( + &mut self, + expr: Handle, + target: crate::Scalar, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let Expression::Compose { ty, ref components } = self.expressions[expr] else { + return self.cast(expr, target, span); + }; + + let crate::TypeInner::Array { base: _, size, stride: _ } = self.types[ty].inner else { + return self.cast(expr, target, span); + }; + + let mut components = components.clone(); + for component in &mut components { + *component = self.cast_array(*component, target, span)?; + } + + let first = components + .first() + .ok_or(ConstantEvaluatorError::InvalidCastArg)?; + let new_base = match self.resolve_type(*first)? { + crate::proc::TypeResolution::Handle(ty) => ty, + crate::proc::TypeResolution::Value(inner) => { + self.types.insert(Type { name: None, inner }, span) + } + }; + let new_base_stride = self.types[new_base].inner.size(self.to_ctx()); + let new_array_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: new_base, + size, + stride: new_base_stride, + }, + }, + span, + ); + + let compose = Expression::Compose { + ty: new_array_ty, + components, + }; + self.register_evaluated_expr(compose, span) + } + fn unary_op( &mut self, op: UnaryOperator, @@ -1339,6 +1424,28 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.expressions.append(expr, span)) } } + + fn resolve_type( + &self, + expr: Handle, + ) -> Result { + use crate::proc::TypeResolution as Tr; + use crate::Expression as Ex; + let resolution = match self.expressions[expr] { + Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()), + Ex::Constant(c) => Tr::Handle(self.constants[c].ty), + Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty), + Ex::Splat { size, value } => { + let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else { + return Err(ConstantEvaluatorError::SplatScalarOnly); + }; + Tr::Value(TypeInner::Vector { scalar, size }) + } + _ => return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant), + }; + + Ok(resolution) + } } #[cfg(test)] diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index ba427dfda26..c82d60f0628 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -670,37 +670,23 @@ impl super::Validator { let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, Ti::Matrix { .. } => left_inner == right_inner, _ => false, }, Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { - Some( - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat, - ) => true, - Some(Sk::Bool) | None => false, + Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, }; let types_match = match (left_inner, right_inner) { // Straight scalar and mixed scalar/vector. @@ -776,12 +762,8 @@ impl super::Validator { Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint - | Sk::Sint - | Sk::Float - | Sk::AbstractInt - | Sk::AbstractFloat => left_inner == right_inner, - Sk::Bool => false, + Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -802,10 +784,8 @@ impl super::Validator { }, Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Bool | Sk::Sint | Sk::Uint | Sk::AbstractInt => { - left_inner == right_inner - } - Sk::Float | Sk::AbstractFloat => false, + Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -814,8 +794,8 @@ impl super::Validator { }, Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Sint | Sk::Uint | Sk::AbstractInt => left_inner == right_inner, - Sk::Bool | Sk::Float | Sk::AbstractFloat => false, + Sk::Sint | Sk::Uint => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -843,10 +823,8 @@ impl super::Validator { } }; match base_scalar.kind { - Sk::Sint | Sk::Uint | Sk::AbstractInt => { - base_size.is_ok() && base_size == shift_size - } - Sk::Float | Sk::AbstractFloat | Sk::Bool => false, + Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, } } }; diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f5da7d0764f..3b12e590676 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -112,7 +112,7 @@ pub enum FunctionError { InvalidStorePointer(Handle), #[error("The value {0:?} can not be stored")] InvalidStoreValue(Handle), - #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] + #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")] InvalidStoreTypes { pointer: Handle, value: Handle, diff --git a/naga/tests/in/abstract-types.wgsl b/naga/tests/in/abstract-types.wgsl new file mode 100644 index 00000000000..f9f718bf010 --- /dev/null +++ b/naga/tests/in/abstract-types.wgsl @@ -0,0 +1,77 @@ +// i/x: type inferred / explicit +// vX/mX/aX: vector / matrix / array of X +// where X: u/i/f: u32 / i32 / f32 +// s: vector splat +// r: vector spread (vector arg to vector constructor) +// p: "partial" constructor (type parameter inferred) +// u/i/f/ai/af: u32 / i32 / f32 / abstract float / abstract integer as parameter +// _: just for alignment + +// Ensure that: +// - the inferred type is correct. +// - all parameters' types are considered. +// - all parameters are converted to the consensus type. + +const xvupaiai: vec2 = vec2(42, 43); +const xvfpaiai: vec2 = vec2(44, 45); + +const xvupuai: vec2 = vec2(42u, 43); +const xvupaiu: vec2 = vec2(42, 43u); + +const xvuuai: vec2 = vec2(42u, 43); +const xvuaiu: vec2 = vec2(42, 43u); + +const xmfpaiaiaiai: mat2x2 = mat2x2(1, 2, 3, 4); +const xmfpafaiaiai: mat2x2 = mat2x2(1.0, 2, 3, 4); +const xmfpaiafaiai: mat2x2 = mat2x2(1, 2.0, 3, 4); +const xmfpaiaiafai: mat2x2 = mat2x2(1, 2, 3.0, 4); +const xmfpaiaiaiaf: mat2x2 = mat2x2(1, 2, 3, 4.0); + +const ivispai = vec2(1); +const ivfspaf = vec2(1.0); +const ivis_ai = vec2(1); +const ivus_ai = vec2(1); +const ivfs_ai = vec2(1); +const ivfs_af = vec2(1.0); + +const iafafaf = array(1.0, 2.0); +const iafaiai = array(1, 2); + +const iafpafaf = array(1.0, 2.0); +const iafpaiaf = array(1, 2.0); +const iafpafai = array(1.0, 2); +const xafpafaf: array = array(1.0, 2.0); + +struct S { + f: f32, + i: i32, + u: u32, +} + +const s_f_i_u: S = S(1.0f, 1i, 1u); +const s_f_iai: S = S(1.0f, 1i, 1); +const s_fai_u: S = S(1.0f, 1, 1u); +const s_faiai: S = S(1.0f, 1, 1); +const saf_i_u: S = S(1.0, 1i, 1u); +const saf_iai: S = S(1.0, 1i, 1); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1); + +// Vector construction with spreads +const ivfr_f__f = vec3(vec2(1.0f, 2.0f), 3.0f); +const ivfr_f_af = vec3(vec2(1.0f, 2.0f), 3.0 ); +const ivfraf__f = vec3(vec2 (1.0 , 2.0 ), 3.0f); +const ivfraf_af = vec3(vec2 (1.0 , 2.0 ), 3.0 ); + +const ivf__fr_f = vec3(1.0f, vec2(2.0f, 3.0f)); +const ivf__fraf = vec3(1.0f, vec2 (2.0 , 3.0 )); +const ivf_afr_f = vec3(1.0 , vec2(2.0f, 3.0f)); +const ivf_afraf = vec3(1.0 , vec2 (2.0 , 3.0 )); + +const ivfr_f_ai = vec3(vec2(1.0f, 2.0f), 3 ); +const ivfrai__f = vec3(vec2 (1 , 2 ), 3.0f); +const ivfrai_ai = vec3(vec2 (1 , 2 ), 3 ); + +const ivf__frai = vec3(1.0f, vec2 (2 , 3 )); +const ivf_air_f = vec3(1 , vec2(2.0f, 3.0f)); +const ivf_airai = vec3(1 , vec2 (2 , 3 )); diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 70ea0c4bb54..0670534e90c 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1629,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -2192,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 55d27c97eb2..0670534e90c 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -214,16 +214,6 @@ ), ), ), - ( - name: None, - inner: Vector( - size: Bi, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Matrix( @@ -238,7 +228,7 @@ ( name: None, inner: Array( - base: 19, + base: 18, size: Constant(2), stride: 32, ), @@ -249,7 +239,7 @@ members: [ ( name: Some("am"), - ty: 20, + ty: 19, binding: None, offset: 0, ), @@ -267,14 +257,14 @@ ( name: None, inner: Pointer( - base: 22, + base: 21, space: Function, ), ), ( name: None, inner: Array( - base: 22, + base: 21, size: Constant(10), stride: 4, ), @@ -282,7 +272,7 @@ ( name: None, inner: Array( - base: 24, + base: 23, size: Constant(5), stride: 40, ), @@ -297,15 +287,6 @@ ), ), ), - ( - name: None, - inner: Pointer( - base: 3, - space: Storage( - access: ("LOAD | STORE"), - ), - ), - ), ( name: None, inner: Array( @@ -314,26 +295,6 @@ stride: 4, ), ), - ( - name: None, - inner: Vector( - size: Quad, - scalar: ( - kind: Sint, - width: 4, - ), - ), - ), - ( - name: None, - inner: Vector( - size: Tri, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Pointer( @@ -344,7 +305,7 @@ ( name: None, inner: Array( - base: 26, + base: 25, size: Constant(2), stride: 16, ), @@ -352,7 +313,7 @@ ( name: None, inner: Pointer( - base: 32, + base: 28, space: Function, ), ), @@ -412,7 +373,7 @@ group: 0, binding: 3, )), - ty: 21, + ty: 20, init: None, ), ], @@ -438,32 +399,6 @@ 6, ], ), - Literal(I32(8)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(3)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(5)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(1)), ], functions: [ ( @@ -479,7 +414,7 @@ ( name: Some("t"), ty: 16, - init: Some(54), + init: Some(49), ), ], expressions: [ @@ -507,136 +442,131 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(3), AccessIndex( - base: 14, + base: 13, index: 0, ), Load( pointer: 2, ), Access( - base: 15, - index: 16, + base: 14, + index: 15, ), Load( - pointer: 17, + pointer: 16, ), GlobalVariable(3), AccessIndex( - base: 19, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 20, + base: 19, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 22, + base: 20, index: 1, ), Load( - pointer: 24, + pointer: 21, ), GlobalVariable(3), AccessIndex( - base: 26, + base: 23, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 27, + base: 24, index: 0, ), Load( pointer: 2, ), Access( - base: 29, - index: 30, + base: 25, + index: 26, ), Load( - pointer: 31, + pointer: 27, ), GlobalVariable(3), AccessIndex( - base: 33, + base: 29, index: 0, ), Load( pointer: 2, ), Access( - base: 34, - index: 35, + base: 30, + index: 31, ), - Literal(I32(1)), AccessIndex( - base: 36, + base: 32, index: 1, ), Load( - pointer: 38, + pointer: 33, ), GlobalVariable(3), AccessIndex( - base: 40, + base: 35, index: 0, ), Load( pointer: 2, ), Access( - base: 41, - index: 42, + base: 36, + index: 37, ), Load( pointer: 2, ), Access( - base: 43, - index: 44, + base: 38, + index: 39, ), Load( - pointer: 45, + pointer: 40, ), Literal(F32(1.0)), Splat( size: Bi, - value: 47, + value: 42, ), Literal(F32(2.0)), Splat( size: Bi, - value: 49, + value: 44, ), Literal(F32(3.0)), Splat( size: Bi, - value: 51, + value: 46, ), Compose( ty: 15, components: [ - 48, - 50, - 52, + 43, + 45, + 47, ], ), Compose( ty: 16, components: [ - 53, + 48, ], ), LocalVariable(2), @@ -646,143 +576,138 @@ ), Binary( op: Add, - left: 57, - right: 56, + left: 52, + right: 51, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Literal(F32(6.0)), Splat( size: Bi, - value: 60, + value: 55, ), Literal(F32(5.0)), Splat( size: Bi, - value: 62, + value: 57, ), Literal(F32(4.0)), Splat( size: Bi, - value: 64, + value: 59, ), Compose( ty: 15, components: [ - 61, - 63, - 65, + 56, + 58, + 60, ], ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 67, + base: 62, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 70, + value: 64, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 72, - index: 73, + base: 66, + index: 67, ), Literal(F32(90.0)), Splat( size: Bi, - value: 75, + value: 69, ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 77, + base: 71, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 79, + base: 72, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 83, + base: 75, index: 0, ), Load( pointer: 2, ), Access( - base: 85, - index: 86, + base: 76, + index: 77, ), Literal(F32(20.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 89, - index: 90, + base: 80, + index: 81, ), - Literal(I32(1)), AccessIndex( - base: 91, + base: 82, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 95, - index: 96, + base: 85, + index: 86, ), Load( pointer: 2, ), Access( - base: 97, - index: 98, + base: 87, + index: 88, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 18: "l2", - 25: "l3", - 32: "l4", - 39: "l5", - 46: "l6", + 12: "l1", + 17: "l2", + 22: "l3", + 28: "l4", + 34: "l5", + 41: "l6", }, body: [ Emit(( @@ -802,160 +727,160 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, )), Emit(( - start: 14, - end: 18, + start: 13, + end: 17, + )), + Emit(( + start: 18, + end: 19, )), Emit(( start: 19, end: 20, )), Emit(( - start: 21, + start: 20, end: 22, )), Emit(( start: 23, - end: 25, + end: 24, )), Emit(( - start: 26, - end: 27, + start: 24, + end: 28, )), Emit(( - start: 28, + start: 29, end: 32, )), Emit(( - start: 33, - end: 36, + start: 32, + end: 34, )), Emit(( - start: 37, - end: 39, + start: 35, + end: 41, )), Emit(( - start: 40, - end: 46, + start: 42, + end: 43, )), Emit(( - start: 47, - end: 48, + start: 44, + end: 45, )), Emit(( - start: 49, - end: 50, + start: 46, + end: 49, )), Emit(( start: 51, - end: 54, - )), - Emit(( - start: 56, - end: 58, + end: 53, )), Store( pointer: 2, - value: 58, + value: 53, ), Emit(( - start: 58, - end: 59, + start: 53, + end: 54, )), Emit(( - start: 60, - end: 61, + start: 55, + end: 56, )), Emit(( - start: 62, - end: 63, + start: 57, + end: 58, )), Emit(( - start: 64, - end: 66, + start: 59, + end: 61, )), Store( - pointer: 59, - value: 66, + pointer: 54, + value: 61, ), Emit(( - start: 66, - end: 67, + start: 61, + end: 62, )), Emit(( - start: 68, - end: 69, + start: 62, + end: 63, )), Emit(( - start: 70, - end: 71, + start: 64, + end: 65, )), Store( - pointer: 69, - value: 71, + pointer: 63, + value: 65, ), Emit(( - start: 71, - end: 74, + start: 65, + end: 68, )), Emit(( - start: 75, - end: 76, + start: 69, + end: 70, )), Store( - pointer: 74, - value: 76, + pointer: 68, + value: 70, ), Emit(( - start: 76, - end: 77, + start: 70, + end: 71, )), Emit(( - start: 78, - end: 79, + start: 71, + end: 72, )), Emit(( - start: 80, - end: 81, + start: 72, + end: 73, )), Store( - pointer: 81, - value: 82, + pointer: 73, + value: 74, ), Emit(( - start: 82, - end: 83, + start: 74, + end: 75, )), Emit(( - start: 84, - end: 87, + start: 75, + end: 78, )), Store( - pointer: 87, - value: 88, + pointer: 78, + value: 79, ), Emit(( - start: 88, - end: 91, + start: 79, + end: 82, )), Emit(( - start: 92, - end: 93, + start: 82, + end: 83, )), Store( - pointer: 93, - value: 94, + pointer: 83, + value: 84, ), Emit(( - start: 94, - end: 99, + start: 84, + end: 89, )), Store( - pointer: 99, - value: 100, + pointer: 89, + value: 90, ), Return( value: None, @@ -974,8 +899,8 @@ ), ( name: Some("t"), - ty: 21, - init: Some(65), + ty: 20, + init: Some(53), ), ], expressions: [ @@ -1003,157 +928,145 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(5), AccessIndex( - base: 14, + base: 13, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 15, + base: 14, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 17, + base: 15, index: 0, ), Load( - pointer: 19, + pointer: 16, ), GlobalVariable(5), AccessIndex( - base: 21, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 22, + base: 19, index: 0, ), Load( pointer: 2, ), Access( - base: 24, - index: 25, + base: 20, + index: 21, ), Load( - pointer: 26, + pointer: 22, ), GlobalVariable(5), AccessIndex( - base: 28, + base: 24, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 29, + base: 25, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 31, + base: 26, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 33, + base: 27, index: 1, ), Load( - pointer: 35, + pointer: 28, ), GlobalVariable(5), AccessIndex( - base: 37, + base: 30, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 38, + base: 31, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 40, + base: 32, index: 0, ), Load( pointer: 2, ), Access( - base: 42, - index: 43, + base: 33, + index: 34, ), Load( - pointer: 44, + pointer: 35, ), GlobalVariable(5), AccessIndex( - base: 46, + base: 37, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 47, + base: 38, index: 0, ), Load( pointer: 2, ), Access( - base: 49, - index: 50, + base: 39, + index: 40, ), - Literal(I32(1)), AccessIndex( - base: 51, + base: 41, index: 1, ), Load( - pointer: 53, + pointer: 42, ), GlobalVariable(5), AccessIndex( - base: 55, + base: 44, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 56, + base: 45, index: 0, ), Load( pointer: 2, ), Access( - base: 58, - index: 59, + base: 46, + index: 47, ), Load( pointer: 2, ), Access( - base: 60, - index: 61, + base: 48, + index: 49, ), Load( - pointer: 62, + pointer: 50, ), - ZeroValue(20), + ZeroValue(19), Compose( - ty: 21, + ty: 20, components: [ - 64, + 52, ], ), LocalVariable(2), @@ -1163,190 +1076,178 @@ ), Binary( op: Add, - left: 68, - right: 67, + left: 56, + right: 55, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - ZeroValue(20), + ZeroValue(19), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 72, + base: 60, index: 0, ), Literal(F32(8.0)), Splat( size: Bi, - value: 75, + value: 62, ), Literal(F32(7.0)), Splat( size: Bi, - value: 77, + value: 64, ), Literal(F32(6.0)), Splat( size: Bi, - value: 79, + value: 66, ), Literal(F32(5.0)), Splat( size: Bi, - value: 81, + value: 68, ), Compose( - ty: 19, + ty: 18, components: [ - 76, - 78, - 80, - 82, + 63, + 65, + 67, + 69, ], ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 84, + base: 71, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 86, + base: 72, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 89, + value: 74, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 91, + base: 76, index: 0, ), Load( pointer: 2, ), Access( - base: 93, - index: 94, + base: 77, + index: 78, ), Literal(F32(90.0)), Splat( size: Bi, - value: 96, + value: 80, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 98, + base: 82, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 100, + base: 83, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 102, + base: 84, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 106, + base: 87, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 108, + base: 88, index: 0, ), Load( pointer: 2, ), Access( - base: 110, - index: 111, + base: 89, + index: 90, ), Literal(F32(20.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 114, + base: 93, index: 0, ), Load( pointer: 2, ), Access( - base: 116, - index: 117, + base: 94, + index: 95, ), - Literal(I32(1)), AccessIndex( - base: 118, + base: 96, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 122, + base: 99, index: 0, ), Load( pointer: 2, ), Access( - base: 124, - index: 125, + base: 100, + index: 101, ), Load( pointer: 2, ), Access( - base: 126, - index: 127, + base: 102, + index: 103, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 20: "l2", - 27: "l3", - 36: "l4", - 45: "l5", - 54: "l6", - 63: "l7", + 12: "l1", + 17: "l2", + 23: "l3", + 29: "l4", + 36: "l5", + 43: "l6", + 51: "l7", }, body: [ Emit(( @@ -1366,31 +1267,43 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, + )), + Emit(( + start: 13, + end: 14, )), Emit(( start: 14, end: 15, )), Emit(( - start: 16, + start: 15, end: 17, )), Emit(( start: 18, - end: 20, + end: 19, )), Emit(( - start: 21, - end: 22, + start: 19, + end: 23, )), Emit(( - start: 23, + start: 24, + end: 25, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 26, end: 27, )), Emit(( - start: 28, + start: 27, end: 29, )), Emit(( @@ -1398,11 +1311,11 @@ end: 31, )), Emit(( - start: 32, - end: 33, + start: 31, + end: 32, )), Emit(( - start: 34, + start: 32, end: 36, )), Emit(( @@ -1410,180 +1323,168 @@ end: 38, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 41, )), Emit(( start: 41, - end: 45, + end: 43, )), Emit(( - start: 46, - end: 47, + start: 44, + end: 45, )), Emit(( - start: 48, + start: 45, end: 51, )), Emit(( start: 52, - end: 54, + end: 53, )), Emit(( start: 55, - end: 56, - )), - Emit(( - start: 57, - end: 63, - )), - Emit(( - start: 64, - end: 65, - )), - Emit(( - start: 67, - end: 69, + end: 57, )), Store( pointer: 2, - value: 69, + value: 57, ), Emit(( - start: 69, - end: 70, + start: 57, + end: 58, )), Store( - pointer: 70, - value: 71, + pointer: 58, + value: 59, ), Emit(( - start: 71, - end: 72, + start: 59, + end: 60, )), Emit(( - start: 73, - end: 74, + start: 60, + end: 61, )), Emit(( - start: 75, - end: 76, + start: 62, + end: 63, )), Emit(( - start: 77, - end: 78, + start: 64, + end: 65, )), Emit(( - start: 79, - end: 80, + start: 66, + end: 67, )), Emit(( - start: 81, - end: 83, + start: 68, + end: 70, )), Store( - pointer: 74, - value: 83, + pointer: 61, + value: 70, ), Emit(( - start: 83, - end: 84, + start: 70, + end: 71, )), Emit(( - start: 85, - end: 86, + start: 71, + end: 72, )), Emit(( - start: 87, - end: 88, + start: 72, + end: 73, )), Emit(( - start: 89, - end: 90, + start: 74, + end: 75, )), Store( - pointer: 88, - value: 90, + pointer: 73, + value: 75, ), Emit(( - start: 90, - end: 91, + start: 75, + end: 76, )), Emit(( - start: 92, - end: 95, + start: 76, + end: 79, )), Emit(( - start: 96, - end: 97, + start: 80, + end: 81, )), Store( - pointer: 95, - value: 97, + pointer: 79, + value: 81, ), Emit(( - start: 97, - end: 98, + start: 81, + end: 82, )), Emit(( - start: 99, - end: 100, + start: 82, + end: 83, )), Emit(( - start: 101, - end: 102, + start: 83, + end: 84, )), Emit(( - start: 103, - end: 104, + start: 84, + end: 85, )), Store( - pointer: 104, - value: 105, + pointer: 85, + value: 86, ), Emit(( - start: 105, - end: 106, + start: 86, + end: 87, )), Emit(( - start: 107, - end: 108, + start: 87, + end: 88, )), Emit(( - start: 109, - end: 112, + start: 88, + end: 91, )), Store( - pointer: 112, - value: 113, + pointer: 91, + value: 92, ), Emit(( - start: 113, - end: 114, + start: 92, + end: 93, )), Emit(( - start: 115, - end: 118, + start: 93, + end: 96, )), Emit(( - start: 119, - end: 120, + start: 96, + end: 97, )), Store( - pointer: 120, - value: 121, + pointer: 97, + value: 98, ), Emit(( - start: 121, - end: 122, + start: 98, + end: 99, )), Emit(( - start: 123, - end: 128, + start: 99, + end: 104, )), Store( - pointer: 128, - value: 129, + pointer: 104, + value: 105, ), Return( value: None, @@ -1595,12 +1496,12 @@ arguments: [ ( name: Some("foo"), - ty: 23, + ty: 22, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], @@ -1628,25 +1529,23 @@ arguments: [ ( name: Some("a"), - ty: 25, + ty: 24, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], expressions: [ FunctionArgument(0), - Literal(I32(4)), AccessIndex( base: 1, index: 4, ), - Literal(I32(9)), AccessIndex( - base: 3, + base: 2, index: 9, ), ], @@ -1655,15 +1554,15 @@ }, body: [ Emit(( - start: 2, - end: 3, + start: 1, + end: 2, )), Emit(( - start: 4, - end: 5, + start: 2, + end: 3, )), Return( - value: Some(5), + value: Some(3), ), ], ), @@ -1672,7 +1571,7 @@ arguments: [ ( name: Some("p"), - ty: 31, + ty: 27, binding: None, ), ], @@ -1700,7 +1599,7 @@ arguments: [ ( name: Some("foo"), - ty: 33, + ty: 29, binding: None, ), ], @@ -1719,7 +1618,7 @@ value: 4, ), Compose( - ty: 32, + ty: 28, components: [ 3, 5, @@ -1730,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -1764,7 +1671,7 @@ ), ], result: Some(( - ty: 26, + ty: 25, binding: Some(BuiltIn(Position( invariant: false, ))), @@ -1772,12 +1679,12 @@ local_variables: [ ( name: Some("foo"), - ty: 22, + ty: 21, init: Some(2), ), ( name: Some("c2"), - ty: 28, + ty: 26, init: None, ), ], @@ -1859,13 +1766,12 @@ base: 30, index: 5, ), - Literal(I32(0)), AccessIndex( base: 31, index: 0, ), AccessIndex( - base: 33, + base: 32, index: 0, ), CallResult(3), @@ -1878,13 +1784,13 @@ Literal(I32(4)), Literal(I32(5)), Compose( - ty: 28, + ty: 26, components: [ 27, + 35, 36, 37, 38, - 39, ], ), LocalVariable(2), @@ -1892,42 +1798,42 @@ Binary( op: Add, left: 1, - right: 42, + right: 41, ), Access( - base: 41, - index: 43, + base: 40, + index: 42, ), Literal(I32(42)), Access( - base: 41, + base: 40, index: 1, ), Load( - pointer: 46, + pointer: 45, ), - ZeroValue(25), + ZeroValue(24), CallResult(4), Splat( size: Quad, - value: 47, + value: 46, ), As( - expr: 50, + expr: 49, kind: Float, convert: Some(4), ), Binary( op: Multiply, left: 8, - right: 51, + right: 50, ), Literal(F32(2.0)), Compose( - ty: 26, + ty: 25, components: [ + 51, 52, - 53, ], ), ], @@ -1940,9 +1846,9 @@ 17: "b", 27: "a", 29: "c", - 34: "data_pointer", - 35: "foo_value", - 47: "value", + 33: "data_pointer", + 34: "foo_value", + 46: "value", }, body: [ Emit(( @@ -1996,57 +1902,57 @@ end: 31, )), Emit(( - start: 32, - end: 34, + start: 31, + end: 33, )), Call( function: 3, arguments: [ 3, ], - result: Some(35), + result: Some(34), ), Emit(( - start: 35, - end: 36, + start: 34, + end: 35, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 39, )), Store( - pointer: 41, - value: 40, + pointer: 40, + value: 39, ), Emit(( - start: 42, - end: 44, + start: 41, + end: 43, )), Store( - pointer: 44, - value: 45, + pointer: 43, + value: 44, ), Emit(( - start: 45, - end: 47, + start: 44, + end: 46, )), Call( function: 4, arguments: [ - 48, + 47, ], - result: Some(49), + result: Some(48), ), Emit(( - start: 49, - end: 52, + start: 48, + end: 51, )), Emit(( - start: 53, - end: 54, + start: 52, + end: 53, )), Return( - value: Some(54), + value: Some(53), ), ], ), @@ -2060,7 +1966,7 @@ name: Some("foo_frag"), arguments: [], result: Some(( - ty: 26, + ty: 25, binding: Some(Location( location: 0, second_blend_source: false, @@ -2075,84 +1981,82 @@ base: 1, index: 0, ), - Literal(I32(1)), AccessIndex( base: 2, index: 1, ), AccessIndex( - base: 4, + base: 3, index: 2, ), Literal(F32(1.0)), GlobalVariable(2), AccessIndex( - base: 7, + base: 6, index: 0, ), Literal(F32(0.0)), Splat( size: Tri, - value: 9, + value: 8, ), Literal(F32(1.0)), Splat( size: Tri, - value: 11, + value: 10, ), Literal(F32(2.0)), Splat( size: Tri, - value: 13, + value: 12, ), Literal(F32(3.0)), Splat( size: Tri, - value: 15, + value: 14, ), Compose( ty: 6, components: [ - 10, - 12, - 14, - 16, + 9, + 11, + 13, + 15, ], ), GlobalVariable(2), AccessIndex( - base: 18, + base: 17, index: 4, ), Literal(U32(0)), Splat( size: Bi, - value: 20, + value: 19, ), Literal(U32(1)), Splat( size: Bi, - value: 22, + value: 21, ), Compose( ty: 12, components: [ - 21, - 23, + 20, + 22, ], ), GlobalVariable(2), AccessIndex( - base: 25, + base: 24, index: 5, ), - Literal(I32(1)), AccessIndex( - base: 26, + base: 25, index: 1, ), AccessIndex( - base: 28, + base: 26, index: 0, ), Literal(I32(1)), @@ -2161,7 +2065,7 @@ Literal(F32(0.0)), Splat( size: Quad, - value: 33, + value: 31, ), ], named_expressions: {}, @@ -2171,75 +2075,75 @@ end: 2, )), Emit(( - start: 3, - end: 5, + start: 2, + end: 4, )), Store( - pointer: 5, - value: 6, + pointer: 4, + value: 5, ), Emit(( - start: 7, - end: 8, + start: 6, + end: 7, )), Emit(( - start: 9, - end: 10, + start: 8, + end: 9, )), Emit(( - start: 11, - end: 12, + start: 10, + end: 11, )), Emit(( - start: 13, - end: 14, + start: 12, + end: 13, )), Emit(( - start: 15, - end: 17, + start: 14, + end: 16, )), Store( - pointer: 8, - value: 17, + pointer: 7, + value: 16, ), Emit(( - start: 18, - end: 19, + start: 17, + end: 18, )), Emit(( - start: 20, - end: 21, + start: 19, + end: 20, )), Emit(( - start: 22, - end: 24, + start: 21, + end: 23, )), Store( - pointer: 19, - value: 24, + pointer: 18, + value: 23, ), Emit(( - start: 25, - end: 26, + start: 24, + end: 25, )), Emit(( - start: 27, - end: 29, + start: 25, + end: 27, )), Store( - pointer: 29, - value: 30, + pointer: 27, + value: 28, ), Store( - pointer: 31, - value: 32, + pointer: 29, + value: 30, ), Emit(( - start: 33, - end: 34, + start: 31, + end: 32, )), Return( - value: Some(34), + value: Some(32), ), ], ), @@ -2261,7 +2165,7 @@ ), ( name: Some("arr"), - ty: 32, + ty: 28, init: Some(7), ), ], @@ -2279,7 +2183,7 @@ value: 5, ), Compose( - ty: 32, + ty: 28, components: [ 4, 6, @@ -2296,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index effde120a50..cfc3bfa0ee4 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -60,11 +60,7 @@ init: None, ), ], - const_expressions: [ - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - ], + const_expressions: [], functions: [ ( name: Some("collatz_iterations"), diff --git a/naga/tests/out/msl/abstract-types.msl b/naga/tests/out/msl/abstract-types.msl new file mode 100644 index 00000000000..a9c89f7a9c9 --- /dev/null +++ b/naga/tests/out/msl/abstract-types.msl @@ -0,0 +1,59 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_5 { + float inner[2]; +}; +struct S { + float f; + int i; + uint u; +}; +constant metal::uint2 xvupaiai = metal::uint2(42u, 43u); +constant metal::float2 xvfpaiai = metal::float2(44.0, 45.0); +constant metal::uint2 xvupuai = metal::uint2(42u, 43u); +constant metal::uint2 xvupaiu = metal::uint2(42u, 43u); +constant metal::uint2 xvuuai = metal::uint2(42u, 43u); +constant metal::uint2 xvuaiu = metal::uint2(42u, 43u); +constant metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::int2 ivispai = metal::int2(1); +constant metal::float2 ivfspaf = metal::float2(1.0); +constant metal::int2 ivis_ai = metal::int2(1); +constant metal::uint2 ivus_ai = metal::uint2(1u); +constant metal::float2 ivfs_ai = metal::float2(1.0); +constant metal::float2 ivfs_af = metal::float2(1.0); +constant type_5 iafafaf = type_5 {1.0, 2.0}; +constant type_5 iafaiai = type_5 {1.0, 2.0}; +constant type_5 iafpafaf = type_5 {1.0, 2.0}; +constant type_5 iafpaiaf = type_5 {1.0, 2.0}; +constant type_5 iafpafai = type_5 {1.0, 2.0}; +constant type_5 xafpafaf = type_5 {1.0, 2.0}; +constant S s_f_i_u = S {1.0, 1, 1u}; +constant S s_f_iai = S {1.0, 1, 1u}; +constant S s_fai_u = S {1.0, 1, 1u}; +constant S s_faiai = S {1.0, 1, 1u}; +constant S saf_i_u = S {1.0, 1, 1u}; +constant S saf_iai = S {1.0, 1, 1u}; +constant S safai_u = S {1.0, 1, 1u}; +constant S safaiai = S {1.0, 1, 1u}; +constant metal::float3 ivfr_f_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfr_f_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_fr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_fraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivfr_f_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_frai = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_air_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_airai = metal::float3(1.0, metal::float2(2.0, 3.0)); diff --git a/naga/tests/out/spv/abstract-types.spvasm b/naga/tests/out/spv/abstract-types.spvasm new file mode 100644 index 00000000000..207a04f5646 --- /dev/null +++ b/naga/tests/out/spv/abstract-types.spvasm @@ -0,0 +1,46 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 36 +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpDecorate %10 ArrayStride 4 +OpMemberDecorate %12 0 Offset 0 +OpMemberDecorate %12 1 Offset 4 +OpMemberDecorate %12 2 Offset 8 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 2 +%6 = OpTypeFloat 32 +%5 = OpTypeVector %6 2 +%7 = OpTypeMatrix %5 2 +%9 = OpTypeInt 32 1 +%8 = OpTypeVector %9 2 +%11 = OpConstant %4 2 +%10 = OpTypeArray %6 %11 +%12 = OpTypeStruct %6 %9 %4 +%13 = OpTypeVector %6 3 +%14 = OpConstant %4 42 +%15 = OpConstant %4 43 +%16 = OpConstantComposite %3 %14 %15 +%17 = OpConstant %6 44.0 +%18 = OpConstant %6 45.0 +%19 = OpConstantComposite %5 %17 %18 +%20 = OpConstant %6 1.0 +%21 = OpConstant %6 2.0 +%22 = OpConstantComposite %5 %20 %21 +%23 = OpConstant %6 3.0 +%24 = OpConstant %6 4.0 +%25 = OpConstantComposite %5 %23 %24 +%26 = OpConstantComposite %7 %22 %25 +%27 = OpConstant %9 1 +%28 = OpConstantComposite %8 %27 %27 +%29 = OpConstantComposite %5 %20 %20 +%30 = OpConstant %4 1 +%31 = OpConstantComposite %3 %30 %30 +%32 = OpConstantComposite %10 %20 %21 +%33 = OpConstantComposite %12 %20 %27 %30 +%34 = OpConstantComposite %13 %20 %21 %23 +%35 = OpConstantComposite %5 %21 %23 \ No newline at end of file diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm index d96dbb315bc..23d5dd1baa1 100644 --- a/naga/tests/out/spv/ray-query.spvasm +++ b/naga/tests/out/spv/ray-query.spvasm @@ -66,10 +66,10 @@ OpMemberDecorate %18 0 Offset 0 %47 = OpConstantComposite %5 %27 %25 %27 %48 = OpConstant %4 4 %49 = OpConstant %4 255 -%50 = OpConstant %6 0.1 -%51 = OpConstant %6 100.0 -%52 = OpConstantComposite %5 %27 %27 %27 -%53 = OpConstantComposite %14 %48 %49 %50 %51 %52 %47 +%50 = OpConstantComposite %5 %27 %27 %27 +%51 = OpConstant %6 0.1 +%52 = OpConstant %6 100.0 +%53 = OpConstantComposite %14 %48 %49 %51 %52 %50 %47 %55 = OpTypePointer Function %13 %72 = OpConstant %4 1 %85 = OpTypePointer StorageBuffer %4 diff --git a/naga/tests/out/wgsl/abstract-types.wgsl b/naga/tests/out/wgsl/abstract-types.wgsl new file mode 100644 index 00000000000..4096f9ff025 --- /dev/null +++ b/naga/tests/out/wgsl/abstract-types.wgsl @@ -0,0 +1,52 @@ +struct S { + f: f32, + i: i32, + u: u32, +} + +const xvupaiai: vec2 = vec2(42u, 43u); +const xvfpaiai: vec2 = vec2(44.0, 45.0); +const xvupuai: vec2 = vec2(42u, 43u); +const xvupaiu: vec2 = vec2(42u, 43u); +const xvuuai: vec2 = vec2(42u, 43u); +const xvuaiu: vec2 = vec2(42u, 43u); +const xmfpaiaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiafaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiafai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiaiaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const ivispai: vec2 = vec2(1); +const ivfspaf: vec2 = vec2(1.0); +const ivis_ai: vec2 = vec2(1); +const ivus_ai: vec2 = vec2(1u); +const ivfs_ai: vec2 = vec2(1.0); +const ivfs_af: vec2 = vec2(1.0); +const iafafaf: array = array(1.0, 2.0); +const iafaiai: array = array(1.0, 2.0); +const iafpafaf: array = array(1.0, 2.0); +const iafpaiaf: array = array(1.0, 2.0); +const iafpafai: array = array(1.0, 2.0); +const xafpafaf: array = array(1.0, 2.0); +const s_f_i_u: S = S(1.0, 1, 1u); +const s_f_iai: S = S(1.0, 1, 1u); +const s_fai_u: S = S(1.0, 1, 1u); +const s_faiai: S = S(1.0, 1, 1u); +const saf_i_u: S = S(1.0, 1, 1u); +const saf_iai: S = S(1.0, 1, 1u); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1u); +const ivfr_f_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfr_f_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_fr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_fraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivfr_f_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_frai: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_air_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_airai: vec3 = vec3(1.0, vec2(2.0, 3.0)); + diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 4ad17f1a2a4..80cf87dc1e3 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -783,6 +783,10 @@ fn convert_wgsl() { "f64", Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "abstract-types", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 99257457fb7..21785aaf3be 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -502,7 +502,7 @@ fn let_type_mismatch() { r#" const x: i32 = 1.0; "#, - r#"error: the type of `x` is expected to be `i32`, but got `f32` + r#"error: the type of `x` is expected to be `i32`, but got `{AbstractFloat}` ┌─ wgsl:2:19 │ 2 │ const x: i32 = 1.0; @@ -1996,9 +1996,12 @@ fn binding_array_non_struct() { #[test] fn compaction_preserves_spans() { let source = r#" - const a: i32 = -(-(-(-42i))); - const b: vec2 = vec2(42u, 43i); - "#; // ^^^^^^^^^^^^^^^^^^^ correct error span: 68..87 + fn f() { + var a: i32 = -(-(-(-42i))); + var x: i32; + x = 42u; + } + "#; // ^^^ correct error span: 95..98 let mut module = naga::front::wgsl::parse_str(source).expect("source ought to parse"); naga::compact::compact(&mut module); let err = naga::valid::Validator::new( @@ -2011,10 +2014,18 @@ fn compaction_preserves_spans() { // Ideally this would all just be a `matches!` with a big pattern, // but the `Span` API is full of opaque structs. let mut spans = err.spans(); - let first_span = spans.next().expect("error should have at least one span").0; + + // The first span is the whole function. + let _ = spans.next().expect("error should have at least one span"); + + // The second span is the assignment destination. + let dest_span = spans + .next() + .expect("error should have at least two spans") + .0; if !matches!( - first_span.to_range(), - Some(std::ops::Range { start: 68, end: 87 }) + dest_span.to_range(), + Some(std::ops::Range { start: 95, end: 98 }) ) { panic!("Error message has wrong span:\n\n{err:#?}"); }