diff --git a/CHANGELOG.md b/CHANGELOG.md index f12dd95f7f..00450c43a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,8 +102,44 @@ Passing an owned value `window` to `Surface` will return a `Surface<'static>`. S - Introduce a new `Scalar` struct type for use in Naga's IR, and update all frontend, middle, and backend code appropriately. By @jimblandy in [#4673](https://github.com/gfx-rs/wgpu/pull/4673). - Add more metal keywords. By @fornwall in [#4707](https://github.com/gfx-rs/wgpu/pull/4707). -- Implement WGSL abstract types (by @jimblandy): - - Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) +- Add partial support for WGSL abstract types (@jimblandy in [#4743](https://github.com/gfx-rs/wgpu/pull/4743)). + + Abstract types make numeric literals easier to use, by + automatically converting literals and other constant expressions + from abstract numeric types to concrete types when safe and + necessary. For example, to build a vector of floating-point + numbers, Naga previously made you write: + + vec3(1.0, 2.0, 3.0) + + With this change, you can now simply write: + + vec3(1, 2, 3) + + Even though the literals are abstract integers, Naga recognizes + that it is safe and necessary to convert them to `f32` values in + order to build the vector. You can also use abstract values as + initializers for global constants, like this: + + const unit_x: vec2 = vec2(1, 0); + + The literals `1` and `0` are abstract integers, and the expression + `vec2(1, 0)` is an abstract vector. However, Naga recognizes that + it can convert that to the concrete type `vec2` to satisfy + the given type of `unit_x`. + + The WGSL specification permits abstract integers and + floating-point values in almost all contexts, but Naga's support + for this is still incomplete. Many WGSL operators and builtin + functions are specified to produce abstract results when applied + to abstract inputs, but for now Naga simply concretizes them all + before applying the operation. We will expand Naga's abstract type + support in subsequent pull requests. + + As part of this work, the public types `naga::ScalarKind` and + `naga::Literal` now have new variants, `AbstractInt` and `AbstractFloat`. + +- Add a new `naga::Literal` variant, `I64`, for signed 64-bit literals. [#4711](https://github.com/gfx-rs/wgpu/pull/4711) - Emit and init `struct` member padding always. By @ErichDonGubler in [#4701](https://github.com/gfx-rs/wgpu/pull/4701). diff --git a/naga/Cargo.toml b/naga/Cargo.toml index 20b5bd5d25..0cacc493ac 100644 --- a/naga/Cargo.toml +++ b/naga/Cargo.toml @@ -31,7 +31,7 @@ deserialize = ["serde", "bitflags/serde", "indexmap/serde"] arbitrary = ["dep:arbitrary", "bitflags/arbitrary", "indexmap/arbitrary"] spv-in = ["petgraph", "spirv"] spv-out = ["spirv"] -wgsl-in = ["hexf-parse", "unicode-xid"] +wgsl-in = ["hexf-parse", "unicode-xid", "compact"] wgsl-out = [] hlsl-out = [] compact = [] diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index cd3075f70a..e1dc906630 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2451,6 +2451,11 @@ impl<'a, W: Write> Writer<'a, W> { crate::Literal::I64(_) => { return Err(Error::Custom("GLSL has no 64-bit integer type".into())); } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } } } Expression::Constant(handle) => { @@ -3555,6 +3560,9 @@ impl<'a, W: Write> Writer<'a, W> { (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { write!(self.out, "bool")? } + + (Sk::AbstractInt | Sk::AbstractFloat, _, _) + | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), }; write!(self.out, "(")?; @@ -4117,6 +4125,11 @@ impl<'a, W: Write> Writer<'a, W> { crate::ScalarKind::Uint => write!(self.out, "0u")?, crate::ScalarKind::Float => write!(self.out, "0.0")?, crate::ScalarKind::Sint => write!(self.out, "0")?, + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".to_string(), + )) + } } Ok(()) @@ -4345,6 +4358,9 @@ const fn glsl_scalar(scalar: crate::Scalar) -> Result, Err prefix: "b", full: "bool", }, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::UnsupportedScalar(scalar)); + } }) } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index da17c35704..b6918ddc42 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -10,7 +10,7 @@ impl crate::ScalarKind { Self::Float => "asfloat", Self::Sint => "asint", Self::Uint => "asuint", - Self::Bool => unreachable!(), + Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), } } } @@ -30,6 +30,9 @@ impl crate::Scalar { _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Bool => Ok("bool"), + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + Err(Error::UnsupportedScalar(self)) + } } } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 24d54fc0e5..0dd60c6ad7 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2040,6 +2040,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::I64(value) => write!(self.out, "{}L", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } }, Expression::Constant(handle) => { let constant = &module.constants[handle]; diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 17154c3cd5..f900add71e 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -338,6 +338,10 @@ impl crate::Scalar { kind: Sk::Bool, width: _, } => "bool", + Self { + kind: Sk::AbstractInt | Sk::AbstractFloat, + width: _, + } => unreachable!(), } } } @@ -1275,6 +1279,9 @@ impl Writer { crate::Literal::Bool(value) => { write!(self.out, "{value}")?; } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Validation); + } }, crate::Expression::Constant(handle) => { let constant = &module.constants[handle]; diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index df6ecd00ff..84f8581521 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1175,7 +1175,7 @@ impl<'w> BlockContext<'w> { let op = match src_scalar.kind { Sk::Sint | Sk::Uint => spirv::Op::INotEqual, Sk::Float => spirv::Op::FUnordNotEqual, - Sk::Bool => unreachable!(), + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), }; let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?; diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index fb9d44e7f0..460c906d47 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -334,6 +334,10 @@ impl<'w> BlockContext<'w> { (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { unreachable!("we don't allow bool or float for array index") } + (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) + | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { + unreachable!("abstract types should never reach backends") + } }; let reconciled_array_index_id = if let Some(cast) = cast { let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index ef0532b2ea..4db86c93a7 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -824,6 +824,9 @@ impl Writer { Instruction::type_float(id, bits) } Sk::Bool => Instruction::type_bool(id), + Sk::AbstractInt | Sk::AbstractFloat => { + unreachable!("abstract types should never reach the backend"); + } } } @@ -1184,6 +1187,9 @@ impl Writer { } crate::Literal::Bool(true) => Instruction::constant_true(type_id, id), crate::Literal::Bool(false) => Instruction::constant_false(type_id, id), + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + unreachable!("Abstract types should not appear in IR presented to backends"); + } }; instruction.to_words(&mut self.logical_layout.declarations); @@ -1591,6 +1597,11 @@ impl Writer { | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Bool => true, Sk::Float => false, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::Validation( + "Abstract types should not appear in IR presented to backends", + )) + } }, _ => false, }; diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 10da339968..a356c7e2ad 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1099,6 +1099,11 @@ impl Writer { crate::Literal::I64(_) => { return Err(Error::Custom("unsupported i64 literal".to_string())); } + crate::Literal::AbstractInt(_) | crate::Literal::AbstractFloat(_) => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".into(), + )); + } } } Expression::Constant(handle) => { diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index f0a2705ad2..e87d76fffc 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -205,7 +205,7 @@ pub const fn type_power(scalar: Scalar) -> Option { ScalarKind::Uint => 1, ScalarKind::Float if scalar.width == 4 => 2, ScalarKind::Float => 3, - ScalarKind::Bool => return None, + ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, }) } diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index f5acbe2d65..dc10124680 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -251,6 +251,12 @@ pub enum Error<'a> { ExpectedPositiveArrayLength(Span), MissingWorkgroupSize(Span), ConstantEvaluatorError(ConstantEvaluatorError, Span), + AutoConversion { + dest_span: Span, + dest_type: String, + source_span: Span, + source_type: String, + }, } impl<'a> Error<'a> { @@ -712,6 +718,20 @@ impl<'a> Error<'a> { )], notes: vec![], }, + Error::AutoConversion { dest_span, ref dest_type, source_span, ref source_type } => ParseError { + message: format!("automatic conversions cannot convert `{source_type}` to `{dest_type}`"), + labels: vec![ + ( + dest_span, + format!("a value of type {dest_type} is required here").into(), + ), + ( + source_span, + format!("this expression has type {source_type}").into(), + ) + ], + notes: vec![], + } } } } diff --git a/naga/src/front/wgsl/lower/construction.rs b/naga/src/front/wgsl/lower/construction.rs index fafda793c0..c7e4106460 100644 --- a/naga/src/front/wgsl/lower/construction.rs +++ b/naga/src/front/wgsl/lower/construction.rs @@ -5,6 +5,7 @@ use crate::{Handle, Span}; use crate::front::wgsl::error::Error; use crate::front::wgsl::lower::{ExpressionContext, Lowerer}; +use crate::front::wgsl::Scalar; /// A cooked form of `ast::ConstructorType` that uses Naga types whenever /// possible. @@ -80,7 +81,6 @@ enum Components<'a> { Many { components: Vec>, spans: Vec, - first_component_ty_inner: &'a crate::TypeInner, }, } @@ -116,13 +116,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { components: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { + use crate::proc::TypeResolution as Tr; + let constructor_h = self.constructor(constructor, ctx)?; let components = match *components { [] => Components::None, [component] => { let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; + let component = self.expression_for_abstract(component, ctx)?; let ty_inner = super::resolve_inner!(ctx, component); Components::One { @@ -131,30 +133,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ty_inner, } } - [component, ref rest @ ..] => { - let span = ctx.ast_expressions.get_span(component); - let component = self.expression(component, ctx)?; - - let components = std::iter::once(Ok(component)) - .chain( - rest.iter() - .map(|&component| self.expression(component, ctx)), - ) + ref ast_components @ [_, _, ..] => { + let components = ast_components + .iter() + .map(|&expr| self.expression_for_abstract(expr, ctx)) .collect::>()?; - let spans = std::iter::once(span) - .chain( - rest.iter() - .map(|&component| ctx.ast_expressions.get_span(component)), - ) + let spans = ast_components + .iter() + .map(|&expr| ctx.ast_expressions.get_span(expr)) .collect(); - let first_component_ty_inner = super::resolve_inner!(ctx, component); - - Components::Many { - components, - spans, - first_component_ty_inner, + for &component in &components { + ctx.grow_types(component)?; } + + Components::Many { components, spans } } }; @@ -163,7 +156,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { // above can have mutable access to the type arena. let constructor = constructor_h.borrow_inner(ctx.module); - let expr = match (components, constructor) { + let expr; + match (components, constructor) { // Empty constructor (Components::None, dst_ty) => match dst_ty { Constructor::Type((result_ty, _)) => { @@ -186,11 +180,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .. }, Constructor::Type((_, &crate::TypeInner::Scalar(scalar))), - ) => crate::Expression::As { - expr: component, - kind: scalar.kind, - convert: Some(scalar.width), - }, + ) => { + expr = crate::Expression::As { + expr: component, + kind: scalar.kind, + convert: Some(scalar.width), + }; + } // Vector conversion (vector -> vector) ( @@ -206,11 +202,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar: dst_scalar, }, )), - ) if dst_size == src_size => crate::Expression::As { - expr: component, - kind: dst_scalar.kind, - convert: Some(dst_scalar.width), - }, + ) if dst_size == src_size => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } // Vector conversion (vector -> vector) - partial ( @@ -247,11 +245,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar: dst_scalar, }, )), - ) if dst_columns == src_columns && dst_rows == src_rows => crate::Expression::As { - expr: component, - kind: crate::ScalarKind::Float, - convert: Some(dst_scalar.width), - }, + ) if dst_columns == src_columns && dst_rows == src_rows => { + expr = crate::Expression::As { + expr: component, + kind: dst_scalar.kind, + convert: Some(dst_scalar.width), + }; + } // Matrix conversion (matrix -> matrix) - partial ( @@ -284,69 +284,99 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .. }, Constructor::PartialVector { size }, - ) => crate::Expression::Splat { - size, - value: component, - }, + ) => { + expr = crate::Expression::Splat { + size, + value: component, + }; + } // Vector constructor (splat) ( Components::One { - component, - ty_inner: &crate::TypeInner::Scalar(src_scalar), + mut component, + ty_inner: &crate::TypeInner::Scalar(_), .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { - size, - scalar: dst_scalar, - }, - )), - ) if dst_scalar == src_scalar => crate::Expression::Splat { - size, - value: component, - }, + Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), + ) => { + ctx.convert_slice_to_common_scalar(std::slice::from_mut(&mut component), scalar)?; + expr = crate::Expression::Splat { + size, + value: component, + }; + } - // Vector constructor (by elements) + // Vector constructor (by elements), partial ( Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar(scalar) | &crate::TypeInner::Vector { scalar, .. }, - .. + mut components, + spans, }, Constructor::PartialVector { size }, - ) - | ( - Components::Many { - components, - first_component_ty_inner: - &crate::TypeInner::Scalar { .. } | &crate::TypeInner::Vector { .. }, - .. - }, - Constructor::Type((_, &crate::TypeInner::Vector { size, scalar })), ) => { - let inner = crate::TypeInner::Vector { size, scalar }; + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let inner = consensus_scalar.to_inner_vector(size); let ty = ctx.ensure_type_exists(inner); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; + } + + // Vector constructor (by elements), full type given + ( + Components::Many { mut components, .. }, + Constructor::Type((ty, &crate::TypeInner::Vector { scalar, .. })), + ) => { + ctx.try_automatic_conversions_for_vector(&mut components, scalar, ty_span)?; + expr = crate::Expression::Compose { ty, components }; } - // Matrix constructor (by elements) + // Matrix constructor (by elements), partial ( Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar(scalar), - .. + mut components, + spans, }, Constructor::PartialMatrix { columns, rows }, - ) - | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Scalar { .. }, - .. - }, + ) if components.len() == columns as usize * rows as usize => { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + // We actually only accept floating-point elements. + let consensus_scalar = consensus_scalar + .automatic_conversion_combine(crate::Scalar::ABSTRACT_FLOAT) + .unwrap_or(consensus_scalar); + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let vec_ty = ctx.ensure_type_exists(consensus_scalar.to_inner_vector(rows)); + + let components = components + .chunks(rows as usize) + .map(|vec_components| { + ctx.append_expression( + crate::Expression::Compose { + ty: vec_ty, + components: Vec::from(vec_components), + }, + Default::default(), + ) + }) + .collect::, _>>()?; + + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by elements), type given + ( + Components::Many { mut components, .. }, Constructor::Type(( _, &crate::TypeInner::Matrix { @@ -355,9 +385,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { scalar, }, )), - ) => { - let vec_ty = - ctx.ensure_type_exists(crate::TypeInner::Vector { scalar, size: rows }); + ) if components.len() == columns as usize * rows as usize => { + let element = Tr::Value(crate::TypeInner::Scalar(scalar)); + ctx.try_automatic_conversions_slice(&mut components, &element, ty_span)?; + let vec_ty = ctx.ensure_type_exists(scalar.to_inner_vector(rows)); let components = components .chunks(rows as usize) @@ -377,44 +408,74 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rows, scalar, }); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; } - // Matrix constructor (by columns) + // Matrix constructor (by columns), partial ( Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { scalar, .. }, - .. + mut components, + spans, }, Constructor::PartialMatrix { columns, rows }, - ) - | ( - Components::Many { - components, - first_component_ty_inner: &crate::TypeInner::Vector { .. }, - .. - }, + ) => { + let consensus_scalar = + automatic_conversion_consensus(&components, ctx).map_err(|index| { + Error::InvalidConstructorComponentType(spans[index], index as i32) + })?; + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { + columns, + rows, + scalar: consensus_scalar, + }); + expr = crate::Expression::Compose { ty, components }; + } + + // Matrix constructor (by columns), type given + ( + Components::Many { mut components, .. }, Constructor::Type(( - _, + ty, &crate::TypeInner::Matrix { - columns, + columns: _, rows, scalar, }, )), ) => { - let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix { - columns, - rows, - scalar, - }); - crate::Expression::Compose { ty, components } + let component_ty = crate::TypeInner::Vector { size: rows, scalar }; + ctx.try_automatic_conversions_slice( + &mut components, + &Tr::Value(component_ty), + span, + )?; + expr = crate::Expression::Compose { ty, components }; } // Array constructor - infer type (components, Constructor::PartialArray) => { - let components = components.into_components_vec(); + let mut components = components.into_components_vec(); + if let Ok(consensus_scalar) = automatic_conversion_consensus(&components, ctx) { + // Note that this will *not* necessarily convert all the + // components to the same type! The `automatic_conversion_consensus` + // function only considers the parameters' leaf scalar + // types; the parameters themselves could be any mix of + // vectors, matrices, and scalars. + // + // But *if* it is possible for this array construction + // expression to be well-typed at all, then all the + // parameters must have the same type constructors (vec, + // matrix, scalar) applied to their leaf scalars, so + // reconciling their scalars is always the right thing to + // do. And if this array construction is not well-typed, + // these conversions will not make it so, and we can let + // validation catch the error. + ctx.convert_slice_to_common_scalar(&mut components, consensus_scalar)?; + } else { + // There's no consensus scalar. Emit the `Compose` + // expression anyway, and let validation catch the problem. + } let base = ctx.register_type(components[0])?; @@ -430,19 +491,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let ty = ctx.ensure_type_exists(inner); - crate::Expression::Compose { ty, components } + expr = crate::Expression::Compose { ty, components }; + } + + // Array constructor, explicit type + (components, Constructor::Type((ty, &crate::TypeInner::Array { base, .. }))) => { + let mut components = components.into_components_vec(); + ctx.try_automatic_conversions_slice(&mut components, &Tr::Handle(base), span)?; + expr = crate::Expression::Compose { ty, components }; } - // Array or Struct constructor + // Struct constructor ( components, - Constructor::Type(( - ty, - &crate::TypeInner::Array { .. } | &crate::TypeInner::Struct { .. }, - )), + Constructor::Type((ty, &crate::TypeInner::Struct { ref members, .. })), ) => { - let components = components.into_components_vec(); - crate::Expression::Compose { ty, components } + let mut components = components.into_components_vec(); + let struct_ty_span = ctx.module.types.get_span(ty); + + // Make a vector of the members' type handles in advance, to + // avoid borrowing `members` from `ctx` while we generate + // new code. + let members: Vec> = members.iter().map(|m| m.ty).collect(); + + for (component, &ty) in components.iter_mut().zip(&members) { + *component = + ctx.try_automatic_conversions(*component, &Tr::Handle(ty), struct_ty_span)?; + } + expr = crate::Expression::Compose { ty, components }; } // ERRORS @@ -466,22 +542,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Err(Error::UnexpectedComponents(span)); } - // Parameters are of the wrong type for vector or matrix constructor - ( - Components::Many { spans, .. }, - Constructor::Type(( - _, - &crate::TypeInner::Vector { .. } | &crate::TypeInner::Matrix { .. }, - )) - | Constructor::PartialVector { .. } - | Constructor::PartialMatrix { .. }, - ) => { - return Err(Error::InvalidConstructorComponentType(spans[0], 0)); - } - // Other types can't be constructed _ => return Err(Error::TypeNotConstructible(ty_span)), - }; + } let expr = ctx.append_expression(expr, span)?; Ok(expr) @@ -546,3 +609,50 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(handle) } } + +/// Find the consensus scalar of `components` under WGSL's automatic +/// conversions. +/// +/// If `components` can all be converted to any common scalar via +/// WGSL's automatic conversions, return the best such scalar. +/// +/// The `components` slice must not be empty. All elements' types must +/// have been resolved. +/// +/// If `components` are definitely not acceptable as arguments to such +/// constructors, return `Err(i)`, where `i` is the index in +/// `components` of some problematic argument. +/// +/// This function doesn't fully type-check the arguments - it only +/// considers their leaf scalar types. This means it may return `Ok` +/// even when the Naga validator will reject the resulting +/// construction expression later. +fn automatic_conversion_consensus( + components: &[Handle], + ctx: &ExpressionContext<'_, '_, '_>, +) -> Result { + let types = &ctx.module.types; + let mut inners = components + .iter() + .map(|&c| ctx.typifier()[c].inner_with(types)); + log::debug!( + "wgsl automatic_conversion_consensus: {:?}", + inners + .clone() + .map(|inner| inner.to_wgsl(&ctx.module.to_ctx())) + .collect::>() + ); + let mut best = inners.next().unwrap().scalar().ok_or(0_usize)?; + for (inner, i) in inners.zip(1..) { + let scalar = inner.scalar().ok_or(i)?; + match best.automatic_conversion_combine(scalar) { + Some(new_best) => { + best = new_best; + } + None => return Err(i), + } + } + + log::debug!(" consensus: {:?}", best.to_wgsl()); + Ok(best) +} diff --git a/naga/src/front/wgsl/lower/conversion.rs b/naga/src/front/wgsl/lower/conversion.rs new file mode 100644 index 0000000000..0819c13a66 --- /dev/null +++ b/naga/src/front/wgsl/lower/conversion.rs @@ -0,0 +1,365 @@ +//! WGSL's automatic conversions for abstract types. + +use crate::{Handle, Span}; + +impl<'source, 'temp, 'out> super::ExpressionContext<'source, 'temp, 'out> { + /// Try to use WGSL's automatic conversions to convert `expr` to `goal_ty`. + /// + /// If no conversions are necessary, return `expr` unchanged. + /// + /// If automatic conversions cannot convert `expr` to `goal_ty`, return an + /// [`AutoConversion`] error. + /// + /// Although the Load Rule is one of the automatic conversions, this + /// function assumes it has already been applied if appropriate, as + /// indicated by the fact that the Rust type of `expr` is not `Typed<_>`. + /// + /// [`AutoConversion`]: super::Error::AutoConversion + pub fn try_automatic_conversions( + &mut self, + expr: Handle, + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result, super::Error<'source>> { + let expr_span = self.get_expression_span(expr); + // Keep the TypeResolution so we can get type names for + // structs in error messages. + let expr_resolution = super::resolve!(self, expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + let goal_inner = goal_ty.inner_with(types); + + // If `expr` already has the requested type, we're done. + if expr_inner.equivalent(goal_inner, types) { + return Ok(expr); + } + + let (_expr_scalar, goal_scalar) = + match expr_inner.automatically_converts_to(goal_inner, types) { + Some(scalars) => scalars, + None => { + let gctx = &self.module.to_ctx(); + let source_type = expr_resolution.to_wgsl(gctx); + let dest_type = goal_ty.to_wgsl(gctx); + + return Err(super::Error::AutoConversion { + dest_span: goal_span, + dest_type, + source_span: expr_span, + source_type, + }); + } + }; + + let converted = if let crate::TypeInner::Array { .. } = *goal_inner { + let span = self.get_expression_span(expr); + self.as_const_evaluator() + .cast_array(expr, goal_scalar, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))? + } else { + let cast = crate::Expression::As { + expr, + kind: goal_scalar.kind, + convert: Some(goal_scalar.width), + }; + self.append_expression(cast, expr_span)? + }; + + Ok(converted) + } + + /// Try to convert `exprs` to `goal_ty` using WGSL's automatic conversions. + pub fn try_automatic_conversions_slice( + &mut self, + exprs: &mut [Handle], + goal_ty: &crate::proc::TypeResolution, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + *expr = self.try_automatic_conversions(*expr, goal_ty, goal_span)?; + } + + Ok(()) + } + + /// Apply WGSL's automatic conversions to a vector constructor's arguments. + /// + /// When calling a vector constructor like `vec3(...)`, the parameters + /// can be a mix of scalars and vectors, with the latter being spread out to + /// contribute each of their components as a component of the new value. + /// When the element type is explicit, as with `` in the example above, + /// WGSL's automatic conversions should convert abstract scalar and vector + /// parameters to the constructor's required scalar type. + pub fn try_automatic_conversions_for_vector( + &mut self, + exprs: &mut [Handle], + goal_scalar: crate::Scalar, + goal_span: Span, + ) -> Result<(), super::Error<'source>> { + use crate::proc::TypeResolution as Tr; + use crate::TypeInner as Ti; + let goal_scalar_res = Tr::Value(Ti::Scalar(goal_scalar)); + + for (i, expr) in exprs.iter_mut().enumerate() { + // Keep the TypeResolution so we can get full type names + // in error messages. + let expr_resolution = super::resolve!(self, *expr); + let types = &self.module.types; + let expr_inner = expr_resolution.inner_with(types); + + match *expr_inner { + Ti::Scalar(_) => { + *expr = self.try_automatic_conversions(*expr, &goal_scalar_res, goal_span)?; + } + Ti::Vector { size, scalar: _ } => { + let goal_vector_res = Tr::Value(Ti::Vector { + size, + scalar: goal_scalar, + }); + *expr = self.try_automatic_conversions(*expr, &goal_vector_res, goal_span)?; + } + _ => { + let span = self.get_expression_span(*expr); + return Err(super::Error::InvalidConstructorComponentType( + span, i as i32, + )); + } + } + } + + Ok(()) + } + + /// Convert all expressions in `exprs` to a common scalar type. + /// + /// Note that the caller is responsible for making sure these + /// conversions are actually justified. This function simply + /// generates `As` expressions, regardless of whether they are + /// permitted WGSL automatic conversions. Callers intending to + /// implement automatic conversions need to determine for + /// themselves whether the casts we we generate are justified, + /// perhaps by calling `TypeInner::automatically_converts_to` or + /// `Scalar::automatic_conversion_combine`. + pub fn convert_slice_to_common_scalar( + &mut self, + exprs: &mut [Handle], + goal: crate::Scalar, + ) -> Result<(), super::Error<'source>> { + for expr in exprs.iter_mut() { + let inner = super::resolve_inner!(self, *expr); + // Do nothing if `inner` doesn't even have leaf scalars; + // it's a type error that validation will catch. + if inner.scalar() != Some(goal) { + let cast = crate::Expression::As { + expr: *expr, + kind: goal.kind, + convert: Some(goal.width), + }; + let expr_span = self.get_expression_span(*expr); + *expr = self.append_expression(cast, expr_span)?; + } + } + + Ok(()) + } + + /// Return an expression for the concretized value of `expr`. + /// + /// If `expr` is already concrete, return it unchanged. + pub fn concretize( + &mut self, + mut expr: Handle, + ) -> Result, super::Error<'source>> { + let inner = super::resolve_inner!(self, expr); + if let Some(scalar) = inner.automatically_convertible_scalar(&self.module.types) { + let concretized = scalar.concretize(); + if concretized != scalar { + let span = self.get_expression_span(expr); + expr = self + .as_const_evaluator() + .cast_array(expr, concretized, span) + .map_err(|err| super::Error::ConstantEvaluatorError(err, span))?; + } + } + + Ok(expr) + } +} + +impl crate::TypeInner { + /// Determine whether `self` automatically converts to `goal`. + /// + /// If WGSL's automatic conversions (excluding the Load Rule) will + /// convert `self` to `goal`, then return a pair `(from, to)`, + /// where `from` and `to` are the scalar types of the leaf values + /// of `self` and `goal`. + /// + /// This function assumes that `self` and `goal` are different + /// types. Callers should first check whether any conversion is + /// needed at all. + /// + /// If the automatic conversions cannot convert `self` to `goal`, + /// return `None`. + fn automatically_converts_to( + &self, + goal: &Self, + types: &crate::UniqueArena, + ) -> Option<(crate::Scalar, crate::Scalar)> { + use crate::ScalarKind as Sk; + use crate::TypeInner as Ti; + + // Automatic conversions only change the scalar type of a value's leaves + // (e.g., `vec4` to `vec4`), never the type + // constructors applied to those scalar types (e.g., never scalar to + // `vec4`, or `vec2` to `vec3`). So first we check that the type + // constructors match, extracting the leaf scalar types in the process. + let expr_scalar; + let goal_scalar; + match (self, goal) { + (&Ti::Scalar(expr), &Ti::Scalar(goal)) => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Vector { + size: expr_size, + scalar: expr, + }, + &Ti::Vector { + size: goal_size, + scalar: goal, + }, + ) if expr_size == goal_size => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Matrix { + rows: expr_rows, + columns: expr_columns, + scalar: expr, + }, + &Ti::Matrix { + rows: goal_rows, + columns: goal_columns, + scalar: goal, + }, + ) if expr_rows == goal_rows && expr_columns == goal_columns => { + expr_scalar = expr; + goal_scalar = goal; + } + ( + &Ti::Array { + base: expr_base, + size: expr_size, + stride: _, + }, + &Ti::Array { + base: goal_base, + size: goal_size, + stride: _, + }, + ) if expr_size == goal_size => { + return types[expr_base] + .inner + .automatically_converts_to(&types[goal_base].inner, types); + } + _ => return None, + } + + match (expr_scalar.kind, goal_scalar.kind) { + (Sk::AbstractFloat, Sk::Float) => {} + (Sk::AbstractInt, Sk::Sint | Sk::Uint | Sk::AbstractFloat | Sk::Float) => {} + _ => return None, + } + + log::trace!(" okay: expr {expr_scalar:?}, goal {goal_scalar:?}"); + Some((expr_scalar, goal_scalar)) + } + + fn automatically_convertible_scalar( + &self, + types: &crate::UniqueArena, + ) -> Option { + use crate::TypeInner as Ti; + match *self { + Ti::Scalar(scalar) | Ti::Vector { scalar, .. } | Ti::Matrix { scalar, .. } => { + Some(scalar) + } + Ti::Array { base, .. } => types[base].inner.automatically_convertible_scalar(types), + Ti::Atomic(_) + | Ti::Pointer { .. } + | Ti::ValuePointer { .. } + | Ti::Struct { .. } + | Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => None, + } + } +} + +impl crate::Scalar { + /// Find the common type of `self` and `other` under WGSL's + /// automatic conversions. + /// + /// If there are any scalars to which WGSL's automatic conversions + /// will convert both `self` and `other`, return the best such + /// scalar. Otherwise, return `None`. + pub const fn automatic_conversion_combine(self, other: Self) -> Option { + use crate::ScalarKind as Sk; + + match (self.kind, other.kind) { + // When the kinds match... + (Sk::AbstractFloat, Sk::AbstractFloat) + | (Sk::AbstractInt, Sk::AbstractInt) + | (Sk::Sint, Sk::Sint) + | (Sk::Uint, Sk::Uint) + | (Sk::Float, Sk::Float) + | (Sk::Bool, Sk::Bool) => { + if self.width == other.width { + // ... either no conversion is necessary ... + Some(self) + } else { + // ... or no conversion is possible. + // We never convert concrete to concrete, and + // abstract types should have only one size. + None + } + } + + // AbstractInt converts to AbstractFloat. + (Sk::AbstractFloat, Sk::AbstractInt) => Some(self), + (Sk::AbstractInt, Sk::AbstractFloat) => Some(other), + + // AbstractFloat converts to Float. + (Sk::AbstractFloat, Sk::Float) => Some(other), + (Sk::Float, Sk::AbstractFloat) => Some(self), + + // AbstractInt converts to concrete integer or float. + (Sk::AbstractInt, Sk::Uint | Sk::Sint | Sk::Float) => Some(other), + (Sk::Uint | Sk::Sint | Sk::Float, Sk::AbstractInt) => Some(self), + + // AbstractFloat can't be reconciled with concrete integer types. + (Sk::AbstractFloat, Sk::Uint | Sk::Sint) | (Sk::Uint | Sk::Sint, Sk::AbstractFloat) => { + None + } + + // Nothing can be reconciled with `bool`. + (Sk::Bool, _) | (_, Sk::Bool) => None, + + // Different concrete types cannot be reconciled. + (Sk::Sint | Sk::Uint | Sk::Float, Sk::Sint | Sk::Uint | Sk::Float) => None, + } + } + + const fn concretize(self) -> Self { + use crate::ScalarKind as Sk; + match self.kind { + Sk::Sint | Sk::Uint | Sk::Float | Sk::Bool => self, + Sk::AbstractInt => Self::I32, + Sk::AbstractFloat => Self::F32, + } + } +} diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index a727d6379b..b050ffc343 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -11,6 +11,7 @@ use crate::proc::{ use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; mod construction; +mod conversion; /// Resolves the inner type of a given expression. /// @@ -66,6 +67,7 @@ macro_rules! resolve { &$ctx.typifier()[$expr] }}; } +pub(super) use resolve; /// State for constructing a `crate::Module`. pub struct GlobalContext<'source, 'temp, 'out> { @@ -903,29 +905,39 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::GlobalDeclKind::Const(ref c) => { let mut ectx = ctx.as_const(); - let init = self.expression(c.init, &mut ectx)?; - let inferred_type = ectx.register_type(init)?; - - let explicit_ty = - c.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx)) - .transpose()?; - - if let Some(explicit) = explicit_ty { - if explicit != inferred_type { - let gctx = ctx.module.to_ctx(); - return Err(Error::InitializationTypeMismatch { - name: c.name.span, - expected: explicit.to_wgsl(&gctx), - got: inferred_type.to_wgsl(&gctx), - }); - } + let mut init = self.expression_for_abstract(c.init, &mut ectx)?; + + let ty; + if let Some(explicit_ty) = c.ty { + let explicit_ty = + self.resolve_ast_type(explicit_ty, &mut ectx.as_global())?; + let explicit_ty_res = crate::proc::TypeResolution::Handle(explicit_ty); + init = ectx + .try_automatic_conversions(init, &explicit_ty_res, c.name.span) + .map_err(|error| match error { + Error::AutoConversion { + dest_span: _, + dest_type, + source_span: _, + source_type, + } => Error::InitializationTypeMismatch { + name: c.name.span, + expected: dest_type, + got: source_type, + }, + other => other, + })?; + ty = explicit_ty; + } else { + init = ectx.concretize(init)?; + ty = ectx.register_type(init)?; } let handle = ctx.module.constants.append( crate::Constant { name: Some(c.name.name.to_string()), r#override: crate::Override::None, - ty: inferred_type, + ty, init, }, span, @@ -951,6 +963,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } + // Constant evaluation may leave abstract-typed literals and + // compositions in expression arenas, so we need to compact the module + // to remove unused expressions and types. + crate::compact::compact(&mut module); + Ok(module) } @@ -1449,10 +1466,25 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } /// Lower `expr` and apply the Load Rule if possible. + /// + /// For the time being, this concretizes abstract values, to support + /// consumers that haven't been adapted to consume them yet. Consumers + /// prepared for abstract values can call [`expression_for_abstract`]. + /// + /// [`expression_for_abstract`]: Lowerer::expression_for_abstract fn expression( &mut self, expr: Handle>, ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let expr = self.expression_for_abstract(expr, ctx)?; + ctx.concretize(expr) + } + + fn expression_for_abstract( + &mut self, + expr: Handle>, + ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let expr = self.expression_for_reference(expr, ctx)?; ctx.apply_load_rule(expr) @@ -1473,8 +1505,9 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::Literal::Number(Number::I32(i)) => crate::Literal::I32(i), ast::Literal::Number(Number::U32(u)) => crate::Literal::U32(u), ast::Literal::Number(Number::F64(f)) => crate::Literal::F64(f), - ast::Literal::Number(_) => { - unreachable!("got abstract numeric type when not expected"); + ast::Literal::Number(Number::AbstractInt(i)) => crate::Literal::AbstractInt(i), + ast::Literal::Number(Number::AbstractFloat(f)) => { + crate::Literal::AbstractFloat(f) } ast::Literal::Bool(b) => crate::Literal::Bool(b), }; diff --git a/naga/src/front/wgsl/parse/lexer.rs b/naga/src/front/wgsl/parse/lexer.rs index 1cc5a0c438..9fa9416f11 100644 --- a/naga/src/front/wgsl/parse/lexer.rs +++ b/naga/src/front/wgsl/parse/lexer.rs @@ -465,13 +465,13 @@ fn test_numbers() { sub_test( "0x123 0X123u 1u 123 0 0i 0x3f", &[ - Token::Number(Ok(Number::I32(291))), + Token::Number(Ok(Number::AbstractInt(291))), Token::Number(Ok(Number::U32(291))), Token::Number(Ok(Number::U32(1))), - Token::Number(Ok(Number::I32(123))), + Token::Number(Ok(Number::AbstractInt(123))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(0))), - Token::Number(Ok(Number::I32(63))), + Token::Number(Ok(Number::AbstractInt(63))), ], ); // decimal floating point @@ -479,17 +479,17 @@ fn test_numbers() { "0.e+4f 01. .01 12.34 .0f 0h 1e-3 0xa.fp+2 0x1P+4f 0X.3 0x3p+2h 0X1.fp-4 0x3.2p+2h", &[ Token::Number(Ok(Number::F32(0.))), - Token::Number(Ok(Number::F32(1.))), - Token::Number(Ok(Number::F32(0.01))), - Token::Number(Ok(Number::F32(12.34))), + Token::Number(Ok(Number::AbstractFloat(1.))), + Token::Number(Ok(Number::AbstractFloat(0.01))), + Token::Number(Ok(Number::AbstractFloat(12.34))), Token::Number(Ok(Number::F32(0.))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.001))), - Token::Number(Ok(Number::F32(43.75))), + Token::Number(Ok(Number::AbstractFloat(0.001))), + Token::Number(Ok(Number::AbstractFloat(43.75))), Token::Number(Ok(Number::F32(16.))), - Token::Number(Ok(Number::F32(0.1875))), + Token::Number(Ok(Number::AbstractFloat(0.1875))), Token::Number(Err(NumberError::UnimplementedF16)), - Token::Number(Ok(Number::F32(0.12109375))), + Token::Number(Ok(Number::AbstractFloat(0.12109375))), Token::Number(Err(NumberError::UnimplementedF16)), ], ); @@ -635,7 +635,7 @@ fn double_floats() { Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(0.0625))), Token::Number(Ok(Number::F64(10.0))), - Token::Number(Ok(Number::I32(10))), + Token::Number(Ok(Number::AbstractInt(10))), Token::Word("l"), ], ) @@ -646,13 +646,16 @@ fn test_tokens() { sub_test("id123_OK", &[Token::Word("id123_OK")]); sub_test( "92No", - &[Token::Number(Ok(Number::I32(92))), Token::Word("No")], + &[ + Token::Number(Ok(Number::AbstractInt(92))), + Token::Word("No"), + ], ); sub_test( "2u3o", &[ Token::Number(Ok(Number::U32(2))), - Token::Number(Ok(Number::I32(3))), + Token::Number(Ok(Number::AbstractInt(3))), Token::Word("o"), ], ); @@ -660,7 +663,7 @@ fn test_tokens() { "2.4f44po", &[ Token::Number(Ok(Number::F32(2.4))), - Token::Number(Ok(Number::I32(44))), + Token::Number(Ok(Number::AbstractInt(44))), Token::Word("po"), ], ); @@ -699,13 +702,13 @@ fn test_tokens() { &[ // The 'f' suffixes are taken as a hex digit: // the fractional part is 0x2f / 256. - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.0 + 0x2f as f32 / 256.0))), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.0 + 0x2f as f64 / 256.0))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("h"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("H"), - Token::Number(Ok(Number::F32(1.125))), + Token::Number(Ok(Number::AbstractFloat(1.125))), Token::Word("lf"), ], ) @@ -719,7 +722,7 @@ fn test_variable_decl() { Token::Attribute, Token::Word("group"), Token::Paren('('), - Token::Number(Ok(Number::I32(0))), + Token::Number(Ok(Number::AbstractInt(0))), Token::Paren(')'), Token::Word("var"), Token::Paren('<'), diff --git a/naga/src/front/wgsl/parse/number.rs b/naga/src/front/wgsl/parse/number.rs index 3178736990..fde5e3cee6 100644 --- a/naga/src/front/wgsl/parse/number.rs +++ b/naga/src/front/wgsl/parse/number.rs @@ -20,37 +20,11 @@ pub enum Number { F64(f64), } -impl Number { - /// Convert abstract numbers to a plausible concrete counterpart. - /// - /// Return concrete numbers unchanged. If the conversion would be - /// lossy, return an error. - fn abstract_to_concrete(self) -> Result { - match self { - Number::AbstractInt(num) => i32::try_from(num) - .map(Number::I32) - .map_err(|_| NumberError::NotRepresentable), - Number::AbstractFloat(num) => { - let num = num as f32; - if num.is_finite() { - Ok(Number::F32(num)) - } else { - Err(NumberError::NotRepresentable) - } - } - num => Ok(num), - } - } -} - // TODO: when implementing Creation-Time Expressions, remove the ability to match the minus sign pub(in crate::front::wgsl) fn consume_number(input: &str) -> (Token<'_>, &str) { let (result, rest) = parse(input); - ( - Token::Number(result.and_then(Number::abstract_to_concrete)), - rest, - ) + (Token::Number(result), rest) } enum Kind { diff --git a/naga/src/front/wgsl/tests.rs b/naga/src/front/wgsl/tests.rs index 9e3ba2fab6..eb2f8a2eb3 100644 --- a/naga/src/front/wgsl/tests.rs +++ b/naga/src/front/wgsl/tests.rs @@ -76,7 +76,7 @@ fn parse_type_cast() { assert!(parse_str( " fn main() { - let x: vec2 = vec2(0); + let x: vec2 = vec2(0i, 0i); } ", ) @@ -313,7 +313,7 @@ fn parse_texture_load() { " var t: texture_3d; fn foo() { - let r: vec4 = textureLoad(t, vec3(0.0, 1.0, 2.0), 1); + let r: vec4 = textureLoad(t, vec3(0u, 1u, 2u), 1); } ", ) diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index cdfa1f0b1f..c8331ace09 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -140,6 +140,8 @@ impl crate::Scalar { crate::ScalarKind::Uint => "u", crate::ScalarKind::Float => "f", crate::ScalarKind::Bool => return "bool".to_string(), + crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(), + crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(), }; format!("{}{}", prefix, self.width * 8) } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e140ad6aef..b27ebc6764 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -300,6 +300,9 @@ use serde::Serialize; /// Width of a boolean type, in bytes. pub const BOOL_WIDTH: Bytes = 1; +/// Width of abstract types, in bytes. +pub const ABSTRACT_WIDTH: Bytes = 8; + /// Hash map that is faster but not resilient to DoS attacks. pub type FastHashMap = rustc_hash::FxHashMap; /// Hash set that is faster but not resilient to DoS attacks. @@ -470,6 +473,16 @@ pub enum ScalarKind { Float, /// Boolean type. Bool, + + /// WGSL abstract integer type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractInt, + + /// Abstract floating-point type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractFloat, } /// Characteristics of a scalar type. @@ -871,6 +884,8 @@ pub enum Literal { I32(i32), I64(i64), Bool(bool), + AbstractInt(i64), + AbstractFloat(f64), } #[derive(Debug, PartialEq)] diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index e3c07f9e16..51e447847c 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -167,6 +167,15 @@ pub enum ConstantEvaluatorError { NotImplemented(String), #[error("{0} operation overflowed")] Overflow(String), + #[error( + "the concrete type `{to_type}` cannot represent the abstract value `{value}` accurately" + )] + AutomaticConversionLossy { + value: String, + to_type: &'static str, + }, + #[error("abstract floating-point values cannot be automatically converted to integers")] + AutomaticConversionFloatToInt { to_type: &'static str }, #[error("Division by zero")] DivisionByZero, #[error("Remainder by zero")] @@ -268,6 +277,17 @@ impl<'a> ConstantEvaluator<'a> { } } + pub fn to_ctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.types, + constants: self.constants, + const_expressions: match self.function_local_data { + Some(ref data) => data.const_expressions, + None => self.expressions, + }, + } + } + fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { if let Some(ref function_local_data) = self.function_local_data { if !function_local_data.expression_constness.is_const(expr) { @@ -979,6 +999,8 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => i32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => i32::try_from_abstract(v)?, }), Sc::U32 => Literal::U32(match literal { Literal::I32(v) => v as u32, @@ -988,6 +1010,8 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => u32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => u32::try_from_abstract(v)?, }), Sc::F32 => Literal::F32(match literal { Literal::I32(v) => v as f32, @@ -997,25 +1021,46 @@ impl<'a> ConstantEvaluator<'a> { Literal::F64(_) | Literal::I64(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } + Literal::AbstractInt(v) => f32::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f32::try_from_abstract(v)?, }), Sc::F64 => Literal::F64(match literal { Literal::I32(v) => v as f64, Literal::U32(v) => v as f64, Literal::F32(v) => v as f64, - Literal::Bool(v) => v as u32 as f64, Literal::F64(v) => v, + Literal::Bool(v) => v as u32 as f64, Literal::I64(_) => return Err(ConstantEvaluatorError::InvalidCastArg), + Literal::AbstractInt(v) => f64::try_from_abstract(v)?, + Literal::AbstractFloat(v) => f64::try_from_abstract(v)?, }), Sc::BOOL => Literal::Bool(match literal { Literal::I32(v) => v != 0, Literal::U32(v) => v != 0, Literal::F32(v) => v != 0.0, Literal::Bool(v) => v, - Literal::F64(_) | Literal::I64(_) => { + Literal::F64(_) + | Literal::I64(_) + | Literal::AbstractInt(_) + | Literal::AbstractFloat(_) => { return Err(ConstantEvaluatorError::InvalidCastArg) } }), - _ => return Err(ConstantEvaluatorError::InvalidCastArg), + Sc::ABSTRACT_FLOAT => Literal::AbstractFloat(match literal { + Literal::AbstractInt(v) => { + // Overflow is forbidden, but inexact conversions + // are fine. The range of f64 is far larger than + // that of i64, so we don't have to check anything + // here. + v as f64 + } + Literal::AbstractFloat(v) => v, + _ => return Err(ConstantEvaluatorError::InvalidCastArg), + }), + _ => { + log::debug!("Constant evaluator refused to convert value to {target:?}"); + return Err(ConstantEvaluatorError::InvalidCastArg); + } }; Expression::Literal(literal) } @@ -1065,6 +1110,66 @@ impl<'a> ConstantEvaluator<'a> { self.register_evaluated_expr(expr, span) } + /// Convert the scalar leaves of `expr` to `target`, handling arrays. + /// + /// `expr` must be a `Compose` expression whose type is a scalar, vector, + /// matrix, or nested arrays of such. + /// + /// This is basically the same as the [`cast`] method, except that that + /// should only handle Naga [`As`] expressions, which cannot convert arrays. + /// + /// Treat `span` as the location of the resulting expression. + /// + /// [`cast`]: ConstantEvaluator::cast + /// [`As`]: crate::Expression::As + pub fn cast_array( + &mut self, + expr: Handle, + target: crate::Scalar, + span: Span, + ) -> Result, ConstantEvaluatorError> { + let Expression::Compose { ty, ref components } = self.expressions[expr] else { + return self.cast(expr, target, span); + }; + + let crate::TypeInner::Array { base: _, size, stride: _ } = self.types[ty].inner else { + return self.cast(expr, target, span); + }; + + let mut components = components.clone(); + for component in &mut components { + *component = self.cast_array(*component, target, span)?; + } + + let first = components + .first() + .ok_or(ConstantEvaluatorError::InvalidCastArg)?; + let new_base = match self.resolve_type(*first)? { + crate::proc::TypeResolution::Handle(ty) => ty, + crate::proc::TypeResolution::Value(inner) => { + self.types.insert(Type { name: None, inner }, span) + } + }; + let new_base_stride = self.types[new_base].inner.size(self.to_ctx()); + let new_array_ty = self.types.insert( + Type { + name: None, + inner: TypeInner::Array { + base: new_base, + size, + stride: new_base_stride, + }, + }, + span, + ); + + let compose = Expression::Compose { + ty: new_array_ty, + components, + }; + self.register_evaluated_expr(compose, span) + } + fn unary_op( &mut self, op: UnaryOperator, @@ -1319,6 +1424,28 @@ impl<'a> ConstantEvaluator<'a> { Ok(self.expressions.append(expr, span)) } } + + fn resolve_type( + &self, + expr: Handle, + ) -> Result { + use crate::proc::TypeResolution as Tr; + use crate::Expression as Ex; + let resolution = match self.expressions[expr] { + Ex::Literal(ref literal) => Tr::Value(literal.ty_inner()), + Ex::Constant(c) => Tr::Handle(self.constants[c].ty), + Ex::ZeroValue(ty) | Ex::Compose { ty, .. } => Tr::Handle(ty), + Ex::Splat { size, value } => { + let Tr::Value(TypeInner::Scalar(scalar)) = self.resolve_type(value)? else { + return Err(ConstantEvaluatorError::SplatScalarOnly); + }; + Tr::Value(TypeInner::Vector { scalar, size }) + } + _ => return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant), + }; + + Ok(resolution) + } } #[cfg(test)] @@ -1828,3 +1955,92 @@ mod tests { } } } + +/// Trait for conversions of abstract values to concrete types. +trait TryFromAbstract: Sized { + /// Convert an abstract literal `value` to `Self`. + /// + /// Since Naga's `AbstractInt` and `AbstractFloat` exist to support + /// WGSL, we follow WGSL's conversion rules here: + /// + /// - WGSL §6.1.2. Conversion Rank says that automatic conversions + /// to integers are either lossless or an error. + /// + /// - WGSL §14.6.4 Floating Point Conversion says that conversions + /// to floating point in constant expressions and override + /// expressions are errors if the value is out of range for the + /// destination type, but rounding is okay. + /// + /// [`AbstractInt`]: crate::Literal::AbstractInt + /// [`Float`]: crate::Literal::Float + fn try_from_abstract(value: T) -> Result; +} + +impl TryFromAbstract for i32 { + fn try_from_abstract(value: i64) -> Result { + i32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "i32", + }) + } +} + +impl TryFromAbstract for u32 { + fn try_from_abstract(value: i64) -> Result { + u32::try_from(value).map_err(|_| ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "u32", + }) + } +} + +impl TryFromAbstract for f32 { + fn try_from_abstract(value: i64) -> Result { + let f = value as f32; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f32` is roughly ±3.4 × 10³⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract for f32 { + fn try_from_abstract(value: f64) -> Result { + let f = value as f32; + if f.is_infinite() { + return Err(ConstantEvaluatorError::AutomaticConversionLossy { + value: format!("{value:?}"), + to_type: "f32", + }); + } + Ok(f) + } +} + +impl TryFromAbstract for f64 { + fn try_from_abstract(value: i64) -> Result { + let f = value as f64; + // The range of `i64` is roughly ±18 × 10¹⁸, whereas the range of + // `f64` is roughly ±1.8 × 10³⁰⁸, so there's no opportunity for + // overflow here. + Ok(f) + } +} + +impl TryFromAbstract for f64 { + fn try_from_abstract(value: f64) -> Result { + Ok(value) + } +} + +impl TryFromAbstract for i32 { + fn try_from_abstract(_: f64) -> Result { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "i32" }) + } +} + +impl TryFromAbstract for u32 { + fn try_from_abstract(_: f64) -> Result { + Err(ConstantEvaluatorError::AutomaticConversionFloatToInt { to_type: "u32" }) + } +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 4f2f5c705d..e375bb1af3 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -71,7 +71,11 @@ impl From for super::ScalarKind { impl super::ScalarKind { pub const fn is_numeric(self) -> bool { match self { - crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true, + crate::ScalarKind::Sint + | crate::ScalarKind::Uint + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => true, crate::ScalarKind::Bool => false, } } @@ -102,6 +106,14 @@ impl super::Scalar { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, }; + pub const ABSTRACT_INT: Self = Self { + kind: crate::ScalarKind::AbstractInt, + width: crate::ABSTRACT_WIDTH, + }; + pub const ABSTRACT_FLOAT: Self = Self { + kind: crate::ScalarKind::AbstractFloat, + width: crate::ABSTRACT_WIDTH, + }; /// Construct a float `Scalar` with the given width. /// @@ -144,7 +156,7 @@ impl Eq for crate::Literal {} impl std::hash::Hash for crate::Literal { fn hash(&self, hasher: &mut H) { match *self { - Self::F64(v) => { + Self::F64(v) | Self::AbstractFloat(v) => { hasher.write_u8(0); v.to_bits().hash(hasher); } @@ -164,7 +176,7 @@ impl std::hash::Hash for crate::Literal { hasher.write_u8(4); v.hash(hasher); } - Self::I64(v) => { + Self::I64(v) | Self::AbstractInt(v) => { hasher.write_u8(5); v.hash(hasher); } @@ -198,7 +210,8 @@ impl crate::Literal { match *self { Self::F64(_) | Self::I64(_) => 8, Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, - Self::Bool(_) => 1, + Self::Bool(_) => crate::BOOL_WIDTH, + Self::AbstractInt(_) | Self::AbstractFloat(_) => crate::ABSTRACT_WIDTH, } } pub const fn scalar(&self) -> crate::Scalar { @@ -209,6 +222,8 @@ impl crate::Literal { Self::I32(_) => crate::Scalar::I32, Self::I64(_) => crate::Scalar::I64, Self::Bool(_) => crate::Scalar::BOOL, + Self::AbstractInt(_) => crate::Scalar::ABSTRACT_INT, + Self::AbstractFloat(_) => crate::Scalar::ABSTRACT_FLOAT, } } pub const fn scalar_kind(&self) -> crate::ScalarKind { @@ -222,6 +237,10 @@ impl crate::Literal { pub const POINTER_SPAN: u32 = 4; impl super::TypeInner { + /// Return the scalar type of `self`. + /// + /// If `inner` is a scalar, vector, or matrix type, return + /// its scalar type. Otherwise, return `None`. pub const fn scalar(&self) -> Option { use crate::TypeInner as Ti; match *self { diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 1f57c55441..c82d60f062 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -671,7 +671,7 @@ impl super::Validator { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, - Sk::Bool => false, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, Ti::Matrix { .. } => left_inner == right_inner, _ => false, @@ -679,14 +679,14 @@ impl super::Validator { Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, - Sk::Bool => false, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { Some(Sk::Uint | Sk::Sint | Sk::Float) => true, - Some(Sk::Bool) | None => false, + Some(Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat) | None => false, }; let types_match = match (left_inner, right_inner) { // Straight scalar and mixed scalar/vector. @@ -763,7 +763,7 @@ impl super::Validator { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, - Sk::Bool => false, + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -785,7 +785,7 @@ impl super::Validator { Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Float => false, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -795,7 +795,7 @@ impl super::Validator { Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Bool | Sk::Float => false, + Sk::Bool | Sk::Float | Sk::AbstractInt | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -824,7 +824,7 @@ impl super::Validator { }; match base_scalar.kind { Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, - Sk::Float | Sk::Bool => false, + Sk::Float | Sk::AbstractInt | Sk::AbstractFloat | Sk::Bool => false, } } }; diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f5da7d0764..3b12e59067 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -112,7 +112,7 @@ pub enum FunctionError { InvalidStorePointer(Handle), #[error("The value {0:?} can not be stored")] InvalidStoreValue(Handle), - #[error("Store of {value:?} into {pointer:?} doesn't have matching types")] + #[error("The type of {value:?} doesn't match the type stored in {pointer:?}")] InvalidStoreTypes { pointer: Handle, value: Handle, diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 53462fe801..1e3e03fe19 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -143,6 +143,9 @@ pub enum WidthError { #[error("64-bit integers are not yet supported")] Unsupported64Bit, + + #[error("Abstract types may only appear in constant expressions")] + Abstract, } // Only makes sense if `flags.contains(HOST_SHAREABLE)` @@ -248,6 +251,9 @@ impl super::Validator { } scalar.width == 4 } + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(WidthError::Abstract); + } }; if good { Ok(()) @@ -325,7 +331,10 @@ impl super::Validator { } Ti::Atomic(crate::Scalar { kind, width }) => { let good = match kind { - crate::ScalarKind::Bool | crate::ScalarKind::Float => false, + crate::ScalarKind::Bool + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => false, crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, }; if !good { diff --git a/naga/tests/in/abstract-types.wgsl b/naga/tests/in/abstract-types.wgsl new file mode 100644 index 0000000000..0cf20b5d26 --- /dev/null +++ b/naga/tests/in/abstract-types.wgsl @@ -0,0 +1,81 @@ +// i/x: type inferred / explicit +// vX/mX/aX: vector / matrix / array of X +// where X: u/i/f: u32 / i32 / f32 +// s: vector splat +// r: vector spread (vector arg to vector constructor) +// p: "partial" constructor (type parameter inferred) +// u/i/f/ai/af: u32 / i32 / f32 / abstract float / abstract integer as parameter +// _: just for alignment + +// Ensure that: +// - the inferred type is correct. +// - all parameters' types are considered. +// - all parameters are converted to the consensus type. + +const xvupaiai: vec2 = vec2(42, 43); +const xvfpaiai: vec2 = vec2(44, 45); + +const xvupuai: vec2 = vec2(42u, 43); +const xvupaiu: vec2 = vec2(42, 43u); + +const xvuuai: vec2 = vec2(42u, 43); +const xvuaiu: vec2 = vec2(42, 43u); + +const xmfpaiaiaiai: mat2x2 = mat2x2(1, 2, 3, 4); +const xmfpafaiaiai: mat2x2 = mat2x2(1.0, 2, 3, 4); +const xmfpaiafaiai: mat2x2 = mat2x2(1, 2.0, 3, 4); +const xmfpaiaiafai: mat2x2 = mat2x2(1, 2, 3.0, 4); +const xmfpaiaiaiaf: mat2x2 = mat2x2(1, 2, 3, 4.0); + +const imfpaiaiaiai = mat2x2(1, 2, 3, 4); +const imfpafaiaiai = mat2x2(1.0, 2, 3, 4); +const imfpafafafaf = mat2x2(1.0, 2.0, 3.0, 4.0); + +const ivispai = vec2(1); +const ivfspaf = vec2(1.0); +const ivis_ai = vec2(1); +const ivus_ai = vec2(1); +const ivfs_ai = vec2(1); +const ivfs_af = vec2(1.0); + +const iafafaf = array(1.0, 2.0); +const iafaiai = array(1, 2); + +const iafpafaf = array(1.0, 2.0); +const iafpaiaf = array(1, 2.0); +const iafpafai = array(1.0, 2); +const xafpafaf: array = array(1.0, 2.0); + +struct S { + f: f32, + i: i32, + u: u32, +} + +const s_f_i_u: S = S(1.0f, 1i, 1u); +const s_f_iai: S = S(1.0f, 1i, 1); +const s_fai_u: S = S(1.0f, 1, 1u); +const s_faiai: S = S(1.0f, 1, 1); +const saf_i_u: S = S(1.0, 1i, 1u); +const saf_iai: S = S(1.0, 1i, 1); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1); + +// Vector construction with spreads +const ivfr_f__f = vec3(vec2(1.0f, 2.0f), 3.0f); +const ivfr_f_af = vec3(vec2(1.0f, 2.0f), 3.0 ); +const ivfraf__f = vec3(vec2 (1.0 , 2.0 ), 3.0f); +const ivfraf_af = vec3(vec2 (1.0 , 2.0 ), 3.0 ); + +const ivf__fr_f = vec3(1.0f, vec2(2.0f, 3.0f)); +const ivf__fraf = vec3(1.0f, vec2 (2.0 , 3.0 )); +const ivf_afr_f = vec3(1.0 , vec2(2.0f, 3.0f)); +const ivf_afraf = vec3(1.0 , vec2 (2.0 , 3.0 )); + +const ivfr_f_ai = vec3(vec2(1.0f, 2.0f), 3 ); +const ivfrai__f = vec3(vec2 (1 , 2 ), 3.0f); +const ivfrai_ai = vec3(vec2 (1 , 2 ), 3 ); + +const ivf__frai = vec3(1.0f, vec2 (2 , 3 )); +const ivf_air_f = vec3(1 , vec2(2.0f, 3.0f)); +const ivf_airai = vec3(1 , vec2 (2 , 3 )); diff --git a/naga/tests/out/ir/access.compact.ron b/naga/tests/out/ir/access.compact.ron index 70ea0c4bb5..0670534e90 100644 --- a/naga/tests/out/ir/access.compact.ron +++ b/naga/tests/out/ir/access.compact.ron @@ -1629,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -2192,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/access.ron b/naga/tests/out/ir/access.ron index 55d27c97eb..0670534e90 100644 --- a/naga/tests/out/ir/access.ron +++ b/naga/tests/out/ir/access.ron @@ -214,16 +214,6 @@ ), ), ), - ( - name: None, - inner: Vector( - size: Bi, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Matrix( @@ -238,7 +228,7 @@ ( name: None, inner: Array( - base: 19, + base: 18, size: Constant(2), stride: 32, ), @@ -249,7 +239,7 @@ members: [ ( name: Some("am"), - ty: 20, + ty: 19, binding: None, offset: 0, ), @@ -267,14 +257,14 @@ ( name: None, inner: Pointer( - base: 22, + base: 21, space: Function, ), ), ( name: None, inner: Array( - base: 22, + base: 21, size: Constant(10), stride: 4, ), @@ -282,7 +272,7 @@ ( name: None, inner: Array( - base: 24, + base: 23, size: Constant(5), stride: 40, ), @@ -297,15 +287,6 @@ ), ), ), - ( - name: None, - inner: Pointer( - base: 3, - space: Storage( - access: ("LOAD | STORE"), - ), - ), - ), ( name: None, inner: Array( @@ -314,26 +295,6 @@ stride: 4, ), ), - ( - name: None, - inner: Vector( - size: Quad, - scalar: ( - kind: Sint, - width: 4, - ), - ), - ), - ( - name: None, - inner: Vector( - size: Tri, - scalar: ( - kind: Float, - width: 4, - ), - ), - ), ( name: None, inner: Pointer( @@ -344,7 +305,7 @@ ( name: None, inner: Array( - base: 26, + base: 25, size: Constant(2), stride: 16, ), @@ -352,7 +313,7 @@ ( name: None, inner: Pointer( - base: 32, + base: 28, space: Function, ), ), @@ -412,7 +373,7 @@ group: 0, binding: 3, )), - ty: 21, + ty: 20, init: None, ), ], @@ -438,32 +399,6 @@ 6, ], ), - Literal(I32(8)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(0)), - Literal(I32(3)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(5)), - Literal(I32(10)), - Literal(I32(5)), - Literal(I32(0)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(2)), - Literal(I32(1)), ], functions: [ ( @@ -479,7 +414,7 @@ ( name: Some("t"), ty: 16, - init: Some(54), + init: Some(49), ), ], expressions: [ @@ -507,136 +442,131 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(3), AccessIndex( - base: 14, + base: 13, index: 0, ), Load( pointer: 2, ), Access( - base: 15, - index: 16, + base: 14, + index: 15, ), Load( - pointer: 17, + pointer: 16, ), GlobalVariable(3), AccessIndex( - base: 19, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 20, + base: 19, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 22, + base: 20, index: 1, ), Load( - pointer: 24, + pointer: 21, ), GlobalVariable(3), AccessIndex( - base: 26, + base: 23, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 27, + base: 24, index: 0, ), Load( pointer: 2, ), Access( - base: 29, - index: 30, + base: 25, + index: 26, ), Load( - pointer: 31, + pointer: 27, ), GlobalVariable(3), AccessIndex( - base: 33, + base: 29, index: 0, ), Load( pointer: 2, ), Access( - base: 34, - index: 35, + base: 30, + index: 31, ), - Literal(I32(1)), AccessIndex( - base: 36, + base: 32, index: 1, ), Load( - pointer: 38, + pointer: 33, ), GlobalVariable(3), AccessIndex( - base: 40, + base: 35, index: 0, ), Load( pointer: 2, ), Access( - base: 41, - index: 42, + base: 36, + index: 37, ), Load( pointer: 2, ), Access( - base: 43, - index: 44, + base: 38, + index: 39, ), Load( - pointer: 45, + pointer: 40, ), Literal(F32(1.0)), Splat( size: Bi, - value: 47, + value: 42, ), Literal(F32(2.0)), Splat( size: Bi, - value: 49, + value: 44, ), Literal(F32(3.0)), Splat( size: Bi, - value: 51, + value: 46, ), Compose( ty: 15, components: [ - 48, - 50, - 52, + 43, + 45, + 47, ], ), Compose( ty: 16, components: [ - 53, + 48, ], ), LocalVariable(2), @@ -646,143 +576,138 @@ ), Binary( op: Add, - left: 57, - right: 56, + left: 52, + right: 51, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Literal(F32(6.0)), Splat( size: Bi, - value: 60, + value: 55, ), Literal(F32(5.0)), Splat( size: Bi, - value: 62, + value: 57, ), Literal(F32(4.0)), Splat( size: Bi, - value: 64, + value: 59, ), Compose( ty: 15, components: [ - 61, - 63, - 65, + 56, + 58, + 60, ], ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 67, + base: 62, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 70, + value: 64, ), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 72, - index: 73, + base: 66, + index: 67, ), Literal(F32(90.0)), Splat( size: Bi, - value: 75, + value: 69, ), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 77, + base: 71, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 79, + base: 72, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 83, + base: 75, index: 0, ), Load( pointer: 2, ), Access( - base: 85, - index: 86, + base: 76, + index: 77, ), Literal(F32(20.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 89, - index: 90, + base: 80, + index: 81, ), - Literal(I32(1)), AccessIndex( - base: 91, + base: 82, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 55, + base: 50, index: 0, ), Load( pointer: 2, ), Access( - base: 95, - index: 96, + base: 85, + index: 86, ), Load( pointer: 2, ), Access( - base: 97, - index: 98, + base: 87, + index: 88, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 18: "l2", - 25: "l3", - 32: "l4", - 39: "l5", - 46: "l6", + 12: "l1", + 17: "l2", + 22: "l3", + 28: "l4", + 34: "l5", + 41: "l6", }, body: [ Emit(( @@ -802,160 +727,160 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, )), Emit(( - start: 14, - end: 18, + start: 13, + end: 17, + )), + Emit(( + start: 18, + end: 19, )), Emit(( start: 19, end: 20, )), Emit(( - start: 21, + start: 20, end: 22, )), Emit(( start: 23, - end: 25, + end: 24, )), Emit(( - start: 26, - end: 27, + start: 24, + end: 28, )), Emit(( - start: 28, + start: 29, end: 32, )), Emit(( - start: 33, - end: 36, + start: 32, + end: 34, )), Emit(( - start: 37, - end: 39, + start: 35, + end: 41, )), Emit(( - start: 40, - end: 46, + start: 42, + end: 43, )), Emit(( - start: 47, - end: 48, + start: 44, + end: 45, )), Emit(( - start: 49, - end: 50, + start: 46, + end: 49, )), Emit(( start: 51, - end: 54, - )), - Emit(( - start: 56, - end: 58, + end: 53, )), Store( pointer: 2, - value: 58, + value: 53, ), Emit(( - start: 58, - end: 59, + start: 53, + end: 54, )), Emit(( - start: 60, - end: 61, + start: 55, + end: 56, )), Emit(( - start: 62, - end: 63, + start: 57, + end: 58, )), Emit(( - start: 64, - end: 66, + start: 59, + end: 61, )), Store( - pointer: 59, - value: 66, + pointer: 54, + value: 61, ), Emit(( - start: 66, - end: 67, + start: 61, + end: 62, )), Emit(( - start: 68, - end: 69, + start: 62, + end: 63, )), Emit(( - start: 70, - end: 71, + start: 64, + end: 65, )), Store( - pointer: 69, - value: 71, + pointer: 63, + value: 65, ), Emit(( - start: 71, - end: 74, + start: 65, + end: 68, )), Emit(( - start: 75, - end: 76, + start: 69, + end: 70, )), Store( - pointer: 74, - value: 76, + pointer: 68, + value: 70, ), Emit(( - start: 76, - end: 77, + start: 70, + end: 71, )), Emit(( - start: 78, - end: 79, + start: 71, + end: 72, )), Emit(( - start: 80, - end: 81, + start: 72, + end: 73, )), Store( - pointer: 81, - value: 82, + pointer: 73, + value: 74, ), Emit(( - start: 82, - end: 83, + start: 74, + end: 75, )), Emit(( - start: 84, - end: 87, + start: 75, + end: 78, )), Store( - pointer: 87, - value: 88, + pointer: 78, + value: 79, ), Emit(( - start: 88, - end: 91, + start: 79, + end: 82, )), Emit(( - start: 92, - end: 93, + start: 82, + end: 83, )), Store( - pointer: 93, - value: 94, + pointer: 83, + value: 84, ), Emit(( - start: 94, - end: 99, + start: 84, + end: 89, )), Store( - pointer: 99, - value: 100, + pointer: 89, + value: 90, ), Return( value: None, @@ -974,8 +899,8 @@ ), ( name: Some("t"), - ty: 21, - init: Some(65), + ty: 20, + init: Some(53), ), ], expressions: [ @@ -1003,157 +928,145 @@ base: 9, index: 0, ), - Literal(I32(0)), AccessIndex( base: 10, index: 0, ), Load( - pointer: 12, + pointer: 11, ), GlobalVariable(5), AccessIndex( - base: 14, + base: 13, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 15, + base: 14, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 17, + base: 15, index: 0, ), Load( - pointer: 19, + pointer: 16, ), GlobalVariable(5), AccessIndex( - base: 21, + base: 18, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 22, + base: 19, index: 0, ), Load( pointer: 2, ), Access( - base: 24, - index: 25, + base: 20, + index: 21, ), Load( - pointer: 26, + pointer: 22, ), GlobalVariable(5), AccessIndex( - base: 28, + base: 24, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 29, + base: 25, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 31, + base: 26, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 33, + base: 27, index: 1, ), Load( - pointer: 35, + pointer: 28, ), GlobalVariable(5), AccessIndex( - base: 37, + base: 30, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 38, + base: 31, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 40, + base: 32, index: 0, ), Load( pointer: 2, ), Access( - base: 42, - index: 43, + base: 33, + index: 34, ), Load( - pointer: 44, + pointer: 35, ), GlobalVariable(5), AccessIndex( - base: 46, + base: 37, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 47, + base: 38, index: 0, ), Load( pointer: 2, ), Access( - base: 49, - index: 50, + base: 39, + index: 40, ), - Literal(I32(1)), AccessIndex( - base: 51, + base: 41, index: 1, ), Load( - pointer: 53, + pointer: 42, ), GlobalVariable(5), AccessIndex( - base: 55, + base: 44, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 56, + base: 45, index: 0, ), Load( pointer: 2, ), Access( - base: 58, - index: 59, + base: 46, + index: 47, ), Load( pointer: 2, ), Access( - base: 60, - index: 61, + base: 48, + index: 49, ), Load( - pointer: 62, + pointer: 50, ), - ZeroValue(20), + ZeroValue(19), Compose( - ty: 21, + ty: 20, components: [ - 64, + 52, ], ), LocalVariable(2), @@ -1163,190 +1076,178 @@ ), Binary( op: Add, - left: 68, - right: 67, + left: 56, + right: 55, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - ZeroValue(20), + ZeroValue(19), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 72, + base: 60, index: 0, ), Literal(F32(8.0)), Splat( size: Bi, - value: 75, + value: 62, ), Literal(F32(7.0)), Splat( size: Bi, - value: 77, + value: 64, ), Literal(F32(6.0)), Splat( size: Bi, - value: 79, + value: 66, ), Literal(F32(5.0)), Splat( size: Bi, - value: 81, + value: 68, ), Compose( - ty: 19, + ty: 18, components: [ - 76, - 78, - 80, - 82, + 63, + 65, + 67, + 69, ], ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 84, + base: 71, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 86, + base: 72, index: 0, ), Literal(F32(9.0)), Splat( size: Bi, - value: 89, + value: 74, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 91, + base: 76, index: 0, ), Load( pointer: 2, ), Access( - base: 93, - index: 94, + base: 77, + index: 78, ), Literal(F32(90.0)), Splat( size: Bi, - value: 96, + value: 80, ), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 98, + base: 82, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 100, + base: 83, index: 0, ), - Literal(I32(1)), AccessIndex( - base: 102, + base: 84, index: 1, ), Literal(F32(10.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 106, + base: 87, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 108, + base: 88, index: 0, ), Load( pointer: 2, ), Access( - base: 110, - index: 111, + base: 89, + index: 90, ), Literal(F32(20.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 114, + base: 93, index: 0, ), Load( pointer: 2, ), Access( - base: 116, - index: 117, + base: 94, + index: 95, ), - Literal(I32(1)), AccessIndex( - base: 118, + base: 96, index: 1, ), Literal(F32(30.0)), AccessIndex( - base: 66, + base: 54, index: 0, ), - Literal(I32(0)), AccessIndex( - base: 122, + base: 99, index: 0, ), Load( pointer: 2, ), Access( - base: 124, - index: 125, + base: 100, + index: 101, ), Load( pointer: 2, ), Access( - base: 126, - index: 127, + base: 102, + index: 103, ), Literal(F32(40.0)), ], named_expressions: { 8: "l0", - 13: "l1", - 20: "l2", - 27: "l3", - 36: "l4", - 45: "l5", - 54: "l6", - 63: "l7", + 12: "l1", + 17: "l2", + 23: "l3", + 29: "l4", + 36: "l5", + 43: "l6", + 51: "l7", }, body: [ Emit(( @@ -1366,31 +1267,43 @@ end: 10, )), Emit(( - start: 11, - end: 13, + start: 10, + end: 12, + )), + Emit(( + start: 13, + end: 14, )), Emit(( start: 14, end: 15, )), Emit(( - start: 16, + start: 15, end: 17, )), Emit(( start: 18, - end: 20, + end: 19, )), Emit(( - start: 21, - end: 22, + start: 19, + end: 23, )), Emit(( - start: 23, + start: 24, + end: 25, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 26, end: 27, )), Emit(( - start: 28, + start: 27, end: 29, )), Emit(( @@ -1398,11 +1311,11 @@ end: 31, )), Emit(( - start: 32, - end: 33, + start: 31, + end: 32, )), Emit(( - start: 34, + start: 32, end: 36, )), Emit(( @@ -1410,180 +1323,168 @@ end: 38, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 41, )), Emit(( start: 41, - end: 45, + end: 43, )), Emit(( - start: 46, - end: 47, + start: 44, + end: 45, )), Emit(( - start: 48, + start: 45, end: 51, )), Emit(( start: 52, - end: 54, + end: 53, )), Emit(( start: 55, - end: 56, - )), - Emit(( - start: 57, - end: 63, - )), - Emit(( - start: 64, - end: 65, - )), - Emit(( - start: 67, - end: 69, + end: 57, )), Store( pointer: 2, - value: 69, + value: 57, ), Emit(( - start: 69, - end: 70, + start: 57, + end: 58, )), Store( - pointer: 70, - value: 71, + pointer: 58, + value: 59, ), Emit(( - start: 71, - end: 72, + start: 59, + end: 60, )), Emit(( - start: 73, - end: 74, + start: 60, + end: 61, )), Emit(( - start: 75, - end: 76, + start: 62, + end: 63, )), Emit(( - start: 77, - end: 78, + start: 64, + end: 65, )), Emit(( - start: 79, - end: 80, + start: 66, + end: 67, )), Emit(( - start: 81, - end: 83, + start: 68, + end: 70, )), Store( - pointer: 74, - value: 83, + pointer: 61, + value: 70, ), Emit(( - start: 83, - end: 84, + start: 70, + end: 71, )), Emit(( - start: 85, - end: 86, + start: 71, + end: 72, )), Emit(( - start: 87, - end: 88, + start: 72, + end: 73, )), Emit(( - start: 89, - end: 90, + start: 74, + end: 75, )), Store( - pointer: 88, - value: 90, + pointer: 73, + value: 75, ), Emit(( - start: 90, - end: 91, + start: 75, + end: 76, )), Emit(( - start: 92, - end: 95, + start: 76, + end: 79, )), Emit(( - start: 96, - end: 97, + start: 80, + end: 81, )), Store( - pointer: 95, - value: 97, + pointer: 79, + value: 81, ), Emit(( - start: 97, - end: 98, + start: 81, + end: 82, )), Emit(( - start: 99, - end: 100, + start: 82, + end: 83, )), Emit(( - start: 101, - end: 102, + start: 83, + end: 84, )), Emit(( - start: 103, - end: 104, + start: 84, + end: 85, )), Store( - pointer: 104, - value: 105, + pointer: 85, + value: 86, ), Emit(( - start: 105, - end: 106, + start: 86, + end: 87, )), Emit(( - start: 107, - end: 108, + start: 87, + end: 88, )), Emit(( - start: 109, - end: 112, + start: 88, + end: 91, )), Store( - pointer: 112, - value: 113, + pointer: 91, + value: 92, ), Emit(( - start: 113, - end: 114, + start: 92, + end: 93, )), Emit(( - start: 115, - end: 118, + start: 93, + end: 96, )), Emit(( - start: 119, - end: 120, + start: 96, + end: 97, )), Store( - pointer: 120, - value: 121, + pointer: 97, + value: 98, ), Emit(( - start: 121, - end: 122, + start: 98, + end: 99, )), Emit(( - start: 123, - end: 128, + start: 99, + end: 104, )), Store( - pointer: 128, - value: 129, + pointer: 104, + value: 105, ), Return( value: None, @@ -1595,12 +1496,12 @@ arguments: [ ( name: Some("foo"), - ty: 23, + ty: 22, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], @@ -1628,25 +1529,23 @@ arguments: [ ( name: Some("a"), - ty: 25, + ty: 24, binding: None, ), ], result: Some(( - ty: 22, + ty: 21, binding: None, )), local_variables: [], expressions: [ FunctionArgument(0), - Literal(I32(4)), AccessIndex( base: 1, index: 4, ), - Literal(I32(9)), AccessIndex( - base: 3, + base: 2, index: 9, ), ], @@ -1655,15 +1554,15 @@ }, body: [ Emit(( - start: 2, - end: 3, + start: 1, + end: 2, )), Emit(( - start: 4, - end: 5, + start: 2, + end: 3, )), Return( - value: Some(5), + value: Some(3), ), ], ), @@ -1672,7 +1571,7 @@ arguments: [ ( name: Some("p"), - ty: 31, + ty: 27, binding: None, ), ], @@ -1700,7 +1599,7 @@ arguments: [ ( name: Some("foo"), - ty: 33, + ty: 29, binding: None, ), ], @@ -1719,7 +1618,7 @@ value: 4, ), Compose( - ty: 32, + ty: 28, components: [ 3, 5, @@ -1730,6 +1629,14 @@ 1: "foo", }, body: [ + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 2, end: 3, @@ -1764,7 +1671,7 @@ ), ], result: Some(( - ty: 26, + ty: 25, binding: Some(BuiltIn(Position( invariant: false, ))), @@ -1772,12 +1679,12 @@ local_variables: [ ( name: Some("foo"), - ty: 22, + ty: 21, init: Some(2), ), ( name: Some("c2"), - ty: 28, + ty: 26, init: None, ), ], @@ -1859,13 +1766,12 @@ base: 30, index: 5, ), - Literal(I32(0)), AccessIndex( base: 31, index: 0, ), AccessIndex( - base: 33, + base: 32, index: 0, ), CallResult(3), @@ -1878,13 +1784,13 @@ Literal(I32(4)), Literal(I32(5)), Compose( - ty: 28, + ty: 26, components: [ 27, + 35, 36, 37, 38, - 39, ], ), LocalVariable(2), @@ -1892,42 +1798,42 @@ Binary( op: Add, left: 1, - right: 42, + right: 41, ), Access( - base: 41, - index: 43, + base: 40, + index: 42, ), Literal(I32(42)), Access( - base: 41, + base: 40, index: 1, ), Load( - pointer: 46, + pointer: 45, ), - ZeroValue(25), + ZeroValue(24), CallResult(4), Splat( size: Quad, - value: 47, + value: 46, ), As( - expr: 50, + expr: 49, kind: Float, convert: Some(4), ), Binary( op: Multiply, left: 8, - right: 51, + right: 50, ), Literal(F32(2.0)), Compose( - ty: 26, + ty: 25, components: [ + 51, 52, - 53, ], ), ], @@ -1940,9 +1846,9 @@ 17: "b", 27: "a", 29: "c", - 34: "data_pointer", - 35: "foo_value", - 47: "value", + 33: "data_pointer", + 34: "foo_value", + 46: "value", }, body: [ Emit(( @@ -1996,57 +1902,57 @@ end: 31, )), Emit(( - start: 32, - end: 34, + start: 31, + end: 33, )), Call( function: 3, arguments: [ 3, ], - result: Some(35), + result: Some(34), ), Emit(( - start: 35, - end: 36, + start: 34, + end: 35, )), Emit(( - start: 39, - end: 40, + start: 38, + end: 39, )), Store( - pointer: 41, - value: 40, + pointer: 40, + value: 39, ), Emit(( - start: 42, - end: 44, + start: 41, + end: 43, )), Store( - pointer: 44, - value: 45, + pointer: 43, + value: 44, ), Emit(( - start: 45, - end: 47, + start: 44, + end: 46, )), Call( function: 4, arguments: [ - 48, + 47, ], - result: Some(49), + result: Some(48), ), Emit(( - start: 49, - end: 52, + start: 48, + end: 51, )), Emit(( - start: 53, - end: 54, + start: 52, + end: 53, )), Return( - value: Some(54), + value: Some(53), ), ], ), @@ -2060,7 +1966,7 @@ name: Some("foo_frag"), arguments: [], result: Some(( - ty: 26, + ty: 25, binding: Some(Location( location: 0, second_blend_source: false, @@ -2075,84 +1981,82 @@ base: 1, index: 0, ), - Literal(I32(1)), AccessIndex( base: 2, index: 1, ), AccessIndex( - base: 4, + base: 3, index: 2, ), Literal(F32(1.0)), GlobalVariable(2), AccessIndex( - base: 7, + base: 6, index: 0, ), Literal(F32(0.0)), Splat( size: Tri, - value: 9, + value: 8, ), Literal(F32(1.0)), Splat( size: Tri, - value: 11, + value: 10, ), Literal(F32(2.0)), Splat( size: Tri, - value: 13, + value: 12, ), Literal(F32(3.0)), Splat( size: Tri, - value: 15, + value: 14, ), Compose( ty: 6, components: [ - 10, - 12, - 14, - 16, + 9, + 11, + 13, + 15, ], ), GlobalVariable(2), AccessIndex( - base: 18, + base: 17, index: 4, ), Literal(U32(0)), Splat( size: Bi, - value: 20, + value: 19, ), Literal(U32(1)), Splat( size: Bi, - value: 22, + value: 21, ), Compose( ty: 12, components: [ - 21, - 23, + 20, + 22, ], ), GlobalVariable(2), AccessIndex( - base: 25, + base: 24, index: 5, ), - Literal(I32(1)), AccessIndex( - base: 26, + base: 25, index: 1, ), AccessIndex( - base: 28, + base: 26, index: 0, ), Literal(I32(1)), @@ -2161,7 +2065,7 @@ Literal(F32(0.0)), Splat( size: Quad, - value: 33, + value: 31, ), ], named_expressions: {}, @@ -2171,75 +2075,75 @@ end: 2, )), Emit(( - start: 3, - end: 5, + start: 2, + end: 4, )), Store( - pointer: 5, - value: 6, + pointer: 4, + value: 5, ), Emit(( - start: 7, - end: 8, + start: 6, + end: 7, )), Emit(( - start: 9, - end: 10, + start: 8, + end: 9, )), Emit(( - start: 11, - end: 12, + start: 10, + end: 11, )), Emit(( - start: 13, - end: 14, + start: 12, + end: 13, )), Emit(( - start: 15, - end: 17, + start: 14, + end: 16, )), Store( - pointer: 8, - value: 17, + pointer: 7, + value: 16, ), Emit(( - start: 18, - end: 19, + start: 17, + end: 18, )), Emit(( - start: 20, - end: 21, + start: 19, + end: 20, )), Emit(( - start: 22, - end: 24, + start: 21, + end: 23, )), Store( - pointer: 19, - value: 24, + pointer: 18, + value: 23, ), Emit(( - start: 25, - end: 26, + start: 24, + end: 25, )), Emit(( - start: 27, - end: 29, + start: 25, + end: 27, )), Store( - pointer: 29, - value: 30, + pointer: 27, + value: 28, ), Store( - pointer: 31, - value: 32, + pointer: 29, + value: 30, ), Emit(( - start: 33, - end: 34, + start: 31, + end: 32, )), Return( - value: Some(34), + value: Some(32), ), ], ), @@ -2261,7 +2165,7 @@ ), ( name: Some("arr"), - ty: 32, + ty: 28, init: Some(7), ), ], @@ -2279,7 +2183,7 @@ value: 5, ), Compose( - ty: 32, + ty: 28, components: [ 4, 6, @@ -2296,6 +2200,14 @@ ], result: None, ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), Emit(( start: 3, end: 4, diff --git a/naga/tests/out/ir/collatz.ron b/naga/tests/out/ir/collatz.ron index effde120a5..cfc3bfa0ee 100644 --- a/naga/tests/out/ir/collatz.ron +++ b/naga/tests/out/ir/collatz.ron @@ -60,11 +60,7 @@ init: None, ), ], - const_expressions: [ - Literal(I32(0)), - Literal(I32(0)), - Literal(I32(1)), - ], + const_expressions: [], functions: [ ( name: Some("collatz_iterations"), diff --git a/naga/tests/out/msl/abstract-types.msl b/naga/tests/out/msl/abstract-types.msl new file mode 100644 index 0000000000..af4de25cc6 --- /dev/null +++ b/naga/tests/out/msl/abstract-types.msl @@ -0,0 +1,62 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_5 { + float inner[2]; +}; +struct S { + float f; + int i; + uint u; +}; +constant metal::uint2 xvupaiai = metal::uint2(42u, 43u); +constant metal::float2 xvfpaiai = metal::float2(44.0, 45.0); +constant metal::uint2 xvupuai = metal::uint2(42u, 43u); +constant metal::uint2 xvupaiu = metal::uint2(42u, 43u); +constant metal::uint2 xvuuai = metal::uint2(42u, 43u); +constant metal::uint2 xvuaiu = metal::uint2(42u, 43u); +constant metal::float2x2 xmfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiafaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiafai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 xmfpaiaiaiaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpaiaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpafaiaiai = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::float2x2 imfpafafafaf = metal::float2x2(metal::float2(1.0, 2.0), metal::float2(3.0, 4.0)); +constant metal::int2 ivispai = metal::int2(1); +constant metal::float2 ivfspaf = metal::float2(1.0); +constant metal::int2 ivis_ai = metal::int2(1); +constant metal::uint2 ivus_ai = metal::uint2(1u); +constant metal::float2 ivfs_ai = metal::float2(1.0); +constant metal::float2 ivfs_af = metal::float2(1.0); +constant type_5 iafafaf = type_5 {1.0, 2.0}; +constant type_5 iafaiai = type_5 {1.0, 2.0}; +constant type_5 iafpafaf = type_5 {1.0, 2.0}; +constant type_5 iafpaiaf = type_5 {1.0, 2.0}; +constant type_5 iafpafai = type_5 {1.0, 2.0}; +constant type_5 xafpafaf = type_5 {1.0, 2.0}; +constant S s_f_i_u = S {1.0, 1, 1u}; +constant S s_f_iai = S {1.0, 1, 1u}; +constant S s_fai_u = S {1.0, 1, 1u}; +constant S s_faiai = S {1.0, 1, 1u}; +constant S saf_i_u = S {1.0, 1, 1u}; +constant S saf_iai = S {1.0, 1, 1u}; +constant S safai_u = S {1.0, 1, 1u}; +constant S safaiai = S {1.0, 1, 1u}; +constant metal::float3 ivfr_f_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfr_f_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfraf_af = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_fr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_fraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afr_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_afraf = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivfr_f_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_f = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivfrai_ai = metal::float3(metal::float2(1.0, 2.0), 3.0); +constant metal::float3 ivf_frai = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_air_f = metal::float3(1.0, metal::float2(2.0, 3.0)); +constant metal::float3 ivf_airai = metal::float3(1.0, metal::float2(2.0, 3.0)); diff --git a/naga/tests/out/spv/abstract-types.spvasm b/naga/tests/out/spv/abstract-types.spvasm new file mode 100644 index 0000000000..207a04f564 --- /dev/null +++ b/naga/tests/out/spv/abstract-types.spvasm @@ -0,0 +1,46 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 36 +OpCapability Shader +OpCapability Linkage +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpDecorate %10 ArrayStride 4 +OpMemberDecorate %12 0 Offset 0 +OpMemberDecorate %12 1 Offset 4 +OpMemberDecorate %12 2 Offset 8 +%2 = OpTypeVoid +%4 = OpTypeInt 32 0 +%3 = OpTypeVector %4 2 +%6 = OpTypeFloat 32 +%5 = OpTypeVector %6 2 +%7 = OpTypeMatrix %5 2 +%9 = OpTypeInt 32 1 +%8 = OpTypeVector %9 2 +%11 = OpConstant %4 2 +%10 = OpTypeArray %6 %11 +%12 = OpTypeStruct %6 %9 %4 +%13 = OpTypeVector %6 3 +%14 = OpConstant %4 42 +%15 = OpConstant %4 43 +%16 = OpConstantComposite %3 %14 %15 +%17 = OpConstant %6 44.0 +%18 = OpConstant %6 45.0 +%19 = OpConstantComposite %5 %17 %18 +%20 = OpConstant %6 1.0 +%21 = OpConstant %6 2.0 +%22 = OpConstantComposite %5 %20 %21 +%23 = OpConstant %6 3.0 +%24 = OpConstant %6 4.0 +%25 = OpConstantComposite %5 %23 %24 +%26 = OpConstantComposite %7 %22 %25 +%27 = OpConstant %9 1 +%28 = OpConstantComposite %8 %27 %27 +%29 = OpConstantComposite %5 %20 %20 +%30 = OpConstant %4 1 +%31 = OpConstantComposite %3 %30 %30 +%32 = OpConstantComposite %10 %20 %21 +%33 = OpConstantComposite %12 %20 %27 %30 +%34 = OpConstantComposite %13 %20 %21 %23 +%35 = OpConstantComposite %5 %21 %23 \ No newline at end of file diff --git a/naga/tests/out/spv/ray-query.spvasm b/naga/tests/out/spv/ray-query.spvasm index d96dbb315b..23d5dd1baa 100644 --- a/naga/tests/out/spv/ray-query.spvasm +++ b/naga/tests/out/spv/ray-query.spvasm @@ -66,10 +66,10 @@ OpMemberDecorate %18 0 Offset 0 %47 = OpConstantComposite %5 %27 %25 %27 %48 = OpConstant %4 4 %49 = OpConstant %4 255 -%50 = OpConstant %6 0.1 -%51 = OpConstant %6 100.0 -%52 = OpConstantComposite %5 %27 %27 %27 -%53 = OpConstantComposite %14 %48 %49 %50 %51 %52 %47 +%50 = OpConstantComposite %5 %27 %27 %27 +%51 = OpConstant %6 0.1 +%52 = OpConstant %6 100.0 +%53 = OpConstantComposite %14 %48 %49 %51 %52 %50 %47 %55 = OpTypePointer Function %13 %72 = OpConstant %4 1 %85 = OpTypePointer StorageBuffer %4 diff --git a/naga/tests/out/wgsl/abstract-types.wgsl b/naga/tests/out/wgsl/abstract-types.wgsl new file mode 100644 index 0000000000..538be44df7 --- /dev/null +++ b/naga/tests/out/wgsl/abstract-types.wgsl @@ -0,0 +1,55 @@ +struct S { + f: f32, + i: i32, + u: u32, +} + +const xvupaiai: vec2 = vec2(42u, 43u); +const xvfpaiai: vec2 = vec2(44.0, 45.0); +const xvupuai: vec2 = vec2(42u, 43u); +const xvupaiu: vec2 = vec2(42u, 43u); +const xvuuai: vec2 = vec2(42u, 43u); +const xvuaiu: vec2 = vec2(42u, 43u); +const xmfpaiaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiafaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiafai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const xmfpaiaiaiaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpaiaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpafaiaiai: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const imfpafafafaf: mat2x2 = mat2x2(vec2(1.0, 2.0), vec2(3.0, 4.0)); +const ivispai: vec2 = vec2(1); +const ivfspaf: vec2 = vec2(1.0); +const ivis_ai: vec2 = vec2(1); +const ivus_ai: vec2 = vec2(1u); +const ivfs_ai: vec2 = vec2(1.0); +const ivfs_af: vec2 = vec2(1.0); +const iafafaf: array = array(1.0, 2.0); +const iafaiai: array = array(1.0, 2.0); +const iafpafaf: array = array(1.0, 2.0); +const iafpaiaf: array = array(1.0, 2.0); +const iafpafai: array = array(1.0, 2.0); +const xafpafaf: array = array(1.0, 2.0); +const s_f_i_u: S = S(1.0, 1, 1u); +const s_f_iai: S = S(1.0, 1, 1u); +const s_fai_u: S = S(1.0, 1, 1u); +const s_faiai: S = S(1.0, 1, 1u); +const saf_i_u: S = S(1.0, 1, 1u); +const saf_iai: S = S(1.0, 1, 1u); +const safai_u: S = S(1.0, 1, 1u); +const safaiai: S = S(1.0, 1, 1u); +const ivfr_f_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfr_f_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfraf_af: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_fr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_fraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afr_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_afraf: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivfr_f_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_f: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivfrai_ai: vec3 = vec3(vec2(1.0, 2.0), 3.0); +const ivf_frai: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_air_f: vec3 = vec3(1.0, vec2(2.0, 3.0)); +const ivf_airai: vec3 = vec3(1.0, vec2(2.0, 3.0)); + diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 4ad17f1a2a..35370e7f07 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -275,7 +275,13 @@ fn check_targets( let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) .validate(module) - .unwrap_or_else(|_| panic!("Naga module validation failed on test '{}'", name.display())); + .unwrap_or_else(|err| { + panic!( + "Naga module validation failed on test `{}`:\n{:?}", + name.display(), + err + ); + }); #[cfg(feature = "compact")] let info = { @@ -292,10 +298,11 @@ fn check_targets( naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) .validate(module) - .unwrap_or_else(|_| { + .unwrap_or_else(|err| { panic!( - "Post-compaction module validation failed on test '{}'", - name.display() + "Post-compaction module validation failed on test '{}':\n<{:?}", + name.display(), + err, ) }) }; @@ -783,6 +790,10 @@ fn convert_wgsl() { "f64", Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "abstract-types", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index 99257457fb..56ca313464 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -209,11 +209,14 @@ fn constructor_parameter_type_mismatch() { _ = mat2x2(array(0, 1), vec2(2, 3)); } "#, - r#"error: invalid type for constructor component at index [0] - ┌─ wgsl:3:33 + r#"error: automatic conversions cannot convert `array<{AbstractInt}, 2>` to `vec2` + ┌─ wgsl:3:21 │ 3 │ _ = mat2x2(array(0, 1), vec2(2, 3)); - │ ^^^^^^^^^^^ invalid component type + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + │ │ │ + │ │ this expression has type array<{AbstractInt}, 2> + │ a value of type vec2 is required here "#, ); @@ -502,7 +505,7 @@ fn let_type_mismatch() { r#" const x: i32 = 1.0; "#, - r#"error: the type of `x` is expected to be `i32`, but got `f32` + r#"error: the type of `x` is expected to be `i32`, but got `{AbstractFloat}` ┌─ wgsl:2:19 │ 2 │ const x: i32 = 1.0; @@ -829,6 +832,22 @@ fn matrix_with_bad_type() { ); } +#[test] +fn matrix_constructor_inferred() { + check( + r#" + const m: mat2x2 = mat2x2(vec2(0), vec2(1)); + "#, + r#"error: the type of `m` is expected to be `mat2x2`, but got `mat2x2` + ┌─ wgsl:2:19 + │ +2 │ const m: mat2x2 = mat2x2(vec2(0), vec2(1)); + │ ^ definition of `m` + +"#, + ); +} + /// Check the result of validating a WGSL program against a pattern. /// /// Unless you are generating code programmatically, the @@ -1996,9 +2015,12 @@ fn binding_array_non_struct() { #[test] fn compaction_preserves_spans() { let source = r#" - const a: i32 = -(-(-(-42i))); - const b: vec2 = vec2(42u, 43i); - "#; // ^^^^^^^^^^^^^^^^^^^ correct error span: 68..87 + fn f() { + var a: i32 = -(-(-(-42i))); + var x: i32; + x = 42u; + } + "#; // ^^^ correct error span: 95..98 let mut module = naga::front::wgsl::parse_str(source).expect("source ought to parse"); naga::compact::compact(&mut module); let err = naga::valid::Validator::new( @@ -2011,10 +2033,18 @@ fn compaction_preserves_spans() { // Ideally this would all just be a `matches!` with a big pattern, // but the `Span` API is full of opaque structs. let mut spans = err.spans(); - let first_span = spans.next().expect("error should have at least one span").0; + + // The first span is the whole function. + let _ = spans.next().expect("error should have at least one span"); + + // The second span is the assignment destination. + let dest_span = spans + .next() + .expect("error should have at least two spans") + .0; if !matches!( - first_span.to_range(), - Some(std::ops::Range { start: 68, end: 87 }) + dest_span.to_range(), + Some(std::ops::Range { start: 95, end: 98 }) ) { panic!("Error message has wrong span:\n\n{err:#?}"); } diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index b523bf9a73..caa23592c7 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -560,7 +560,9 @@ impl Resource { } naga::ScalarKind::Sint => wgt::TextureSampleType::Sint, naga::ScalarKind::Uint => wgt::TextureSampleType::Uint, - naga::ScalarKind::Bool => unreachable!(), + naga::ScalarKind::AbstractInt + | naga::ScalarKind::AbstractFloat + | naga::ScalarKind::Bool => unreachable!(), }, view_dimension, multisampled: multi,