From 316ec449593c49e826b14b0f6f2ed6ef8569e23b Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Fri, 27 Dec 2024 10:04:23 -0800 Subject: [PATCH] compiles --- naga/src/proc/builtins.rs | 433 ++++++++++++++++++++++ naga/src/proc/mod.rs | 2 + naga/src/valid/expression.rs | 675 ++--------------------------------- 3 files changed, 465 insertions(+), 645 deletions(-) create mode 100644 naga/src/proc/builtins.rs diff --git a/naga/src/proc/builtins.rs b/naga/src/proc/builtins.rs new file mode 100644 index 0000000000..798371b5df --- /dev/null +++ b/naga/src/proc/builtins.rs @@ -0,0 +1,433 @@ +//! A database of built-in functions, with argument and return types. + +/// TODO: What about overloads? +pub struct Builtin { + pub arguments: &'static [Argument], + pub return_type: ReturnType, +} + +#[derive(Copy, Clone, Debug)] +pub enum Argument { + /// The set of types permitted for this argument. + Set(TypeSet), + + /// This argument must have the same type as the prior argument at the given index. + Match(usize), +} + +impl Argument { + pub fn constrain(&self, i: usize, actual: &crate::TypeInner, typesets: &mut [TypeSet]) { + let convertible_from = TypeSet::convertible_from(actual); + match *self { + Argument::Set(set) => { + typesets[i] = set.intersection(convertible_from); + } + Argument::Match(j) => { + typesets[j] = typesets[j].intersection(convertible_from); + typesets[i] = typesets[j]; + } + } + } +} + +#[derive(Copy, Clone, Debug)] +pub enum ReturnType { + /// The same type as the argument at the given index, + Match(usize), + + /// The scalar element type of the argument at the given index. + MatchScalar(usize), +} + +/// A set of types, for describing builtin function arguments. +/// +/// This represents the set of types formed by applying every constructor in +/// `constructors` to every scalar in `scalars`. +/// +/// This representation can't implement `union` on itself, which is +/// too bad, but it can implement `intersection`, which is what we +/// actually need for progressively constraining types as we check a +/// given call. +#[derive(Copy, Clone, Debug)] +pub struct TypeSet { + /// The set of type constructors to apply. + constructors: ConstructorSet, + + /// The set of scalars to apply them to. + scalars: ScalarSet, +} + +impl TypeSet { + pub fn intersection(self, right: Self) -> Self { + TypeSet { + constructors: self.constructors & right.constructors, + scalars: self.scalars & right.scalars, + } + } + + pub const fn empty() -> Self { + TypeSet { + constructors: ConstructorSet::empty(), + scalars: ScalarSet::empty(), + } + } + + pub fn contains(self, right: Self) -> bool { + self.constructors.contains(right.constructors) + && self.scalars.contains(right.scalars) + } + + /// Return the type set containing only `inner`. + pub fn singleton(inner: &crate::TypeInner) -> Self { + let Some(scalar) = inner.scalar() else { + return Self::empty(); + }; + + TypeSet { + constructors: ConstructorSet::singleton(inner), + scalars: ScalarSet::singleton(scalar), + } + } + + /// Return the set of types that `inner` can be automatically converted to. + pub fn convertible_from(inner: &crate::TypeInner) -> Self { + let Some(scalar) = inner.scalar() else { + return Self::empty(); + }; + + TypeSet { + constructors: ConstructorSet::singleton(inner), + scalars: ScalarSet::convertible_from(scalar), + } + } +} + +bitflags::bitflags! { + /// A set of scalar types. + /// + /// This represents a set of [`Scalar`] types. + /// + /// [`Scalar`]: crate::Scalar + #[derive(Copy, Clone, Debug)] + struct ScalarSet: u16 { + const ABSTRACT_FLOAT = 1 << 0; + const ABSTRACT_INT = 1 << 1; + const SINT_32 = 1 << 2; + const SINT_64 = 1 << 3; + const UINT_32 = 1 << 4; + const UINT_64 = 1 << 5; + const FLOAT_16 = 1 << 6; + const FLOAT_32 = 1 << 7; + const FLOAT_64 = 1 << 8; + const BOOL = 1 << 9; + + const NUMERIC = Self::ABSTRACT_FLOAT.bits() + | Self::ABSTRACT_INT.bits() + | Self::SINT_32.bits() + | Self::SINT_64.bits() + | Self::UINT_32.bits() + | Self::UINT_64.bits() + | Self::FLOAT_16.bits() + | Self::FLOAT_32.bits() + | Self::FLOAT_64.bits() + ; + + const ANY_FLOAT = Self::ABSTRACT_FLOAT.bits() + | Self::FLOAT_16.bits() + | Self::FLOAT_32.bits() + | Self::FLOAT_64.bits() + ; + + const ANY_INTEGER = Self::ABSTRACT_INT.bits() + | Self::SINT_32.bits() + | Self::SINT_64.bits() + | Self::UINT_32.bits() + | Self::UINT_64.bits() + ; + + const CONCRETE_INTEGER = Self::SINT_32.bits() + | Self::SINT_64.bits() + | Self::UINT_32.bits() + | Self::UINT_64.bits() + ; + } + + /// A set of type constructors. + #[derive(Copy, Clone, Debug)] + struct ConstructorSet: u16 { + const SCALAR = 1 << 0; + const VEC2 = 1 << 1; + const VEC3 = 1 << 2; + const VEC4 = 1 << 3; + const MAT2X2 = 1 << 4; + const MAT2X3 = 1 << 5; + const MAT2X4 = 1 << 6; + const MAT3X2 = 1 << 7; + const MAT3X3 = 1 << 8; + const MAT3X4 = 1 << 9; + const MAT4X2 = 1 << 10; + const MAT4X3 = 1 << 11; + const MAT4X4 = 1 << 12; + + const VECN = Self::VEC2.bits() + | Self::VEC3.bits() + | Self::VEC4.bits(); + } +} + +impl ScalarSet { + /// Return the set of scalars containing only `scalar`. + fn singleton(scalar: crate::Scalar) -> Self { + use crate::Scalar as Sc; + use crate::ScalarKind as Sk; + match scalar { + Sc { kind : Sk::Sint, width: 4 } => Self::SINT_32, + Sc { kind : Sk::Sint, width: 8 } => Self::SINT_64, + Sc { kind : Sk::Uint, width: 4 } => Self::UINT_32, + Sc { kind : Sk::Uint, width: 8 } => Self::UINT_64, + Sc { kind : Sk::Float, width: 2 } => Self::FLOAT_16, + Sc { kind : Sk::Float, width: 4 } => Self::FLOAT_32, + Sc { kind : Sk::Float, width: 8 } => Self::FLOAT_64, + Sc::BOOL => Self::BOOL, + Sc::ABSTRACT_INT => Self::ABSTRACT_INT, + Sc::ABSTRACT_FLOAT => Self::ABSTRACT_FLOAT, + _ => Self::empty(), + } + } + + /// Return the set of scalars to which `scalar` can be automatically + /// converted. + fn convertible_from(scalar: crate::Scalar) -> Self { + use crate::Scalar as Sc; + use crate::ScalarKind as Sk; + match scalar { + Sc { kind : Sk::Sint, width: 4 } => Self::SINT_32, + Sc { kind : Sk::Sint, width: 8 } => Self::SINT_64, + Sc { kind : Sk::Uint, width: 4 } => Self::UINT_32, + Sc { kind : Sk::Uint, width: 8 } => Self::UINT_64, + Sc { kind : Sk::Float, width: 2 } => Self::FLOAT_16, + Sc { kind : Sk::Float, width: 4 } => Self::FLOAT_32, + Sc { kind : Sk::Float, width: 8 } => Self::FLOAT_64, + Sc::BOOL => Self::BOOL, + Sc::ABSTRACT_INT => Self::ANY_INTEGER | Self::ANY_FLOAT, + Sc::ABSTRACT_FLOAT => Self::ANY_FLOAT, + _ => Self::empty(), + } + } +} + +impl ConstructorSet { + /// Return the single-member set containing `inner`'s constructor. + fn singleton(inner: &crate::TypeInner) -> ConstructorSet { + use crate::TypeInner as Ti; + use crate::VectorSize as Vs; + match *inner { + Ti::Scalar(_) => Self::SCALAR, + Ti::Vector { size: Vs::Bi, scalar: _ } => Self::VEC2, + Ti::Vector { size: Vs::Tri, scalar: _ } => Self::VEC3, + Ti::Vector { size: Vs::Quad, scalar: _ } => Self::VEC4, + Ti::Matrix { columns: Vs::Bi, rows: Vs::Bi, scalar: _ } => Self::MAT2X2, + Ti::Matrix { columns: Vs::Bi, rows: Vs::Tri, scalar: _ } => Self::MAT2X3, + Ti::Matrix { columns: Vs::Bi, rows: Vs::Quad, scalar: _ } => Self::MAT2X4, + Ti::Matrix { columns: Vs::Tri, rows: Vs::Bi, scalar: _ } => Self::MAT3X2, + Ti::Matrix { columns: Vs::Tri, rows: Vs::Tri, scalar: _ } => Self::MAT3X3, + Ti::Matrix { columns: Vs::Tri, rows: Vs::Quad, scalar: _ } => Self::MAT3X4, + Ti::Matrix { columns: Vs::Quad, rows: Vs::Bi, scalar: _ } => Self::MAT4X2, + Ti::Matrix { columns: Vs::Quad, rows: Vs::Tri, scalar: _ } => Self::MAT4X3, + Ti::Matrix { columns: Vs::Quad, rows: Vs::Quad, scalar: _ } => Self::MAT4X4, + _ => Self::empty(), + } + } +} + +/// Nicer notation for [`Builtin`] values. +macro_rules! b { + ( ( $( $arg:tt ),* ) -> $ret:tt ) => { + { + // I don't know why it's necessary to declare the static. + // It seems like Rust should figure this out for me. + static B: &'static Builtin = &Builtin { + arguments: &[ $( argument! ( $arg ) ),* ], + return_type: return_type! ( $ret ), + }; + B + } + }; +} + +macro_rules! argument { + ( arg0 ) => { Argument::Match(0) }; + ( $ts:tt ) => { Argument::Set( typeset! ( $ts ) ) }; +} + +/// Nicer notation for [`TypeSet`] values. +macro_rules! typeset { + ( ( vecN or scalar < $scalar:ident > ) ) => ( + TypeSet { + constructors: ConstructorSet::VECN.union(ConstructorSet::SCALAR), + scalars: s!($scalar), + } + ); + ( ( vecN < $scalar:ident > ) ) => ( + TypeSet { + constructors: ConstructorSet::VECN, + scalars: s!($scalar), + } + ); + ( $scalar:ident ) => ( + TypeSet { + constructors: ConstructorSet::SCALAR, + scalars: s!($scalar), + } + ); +} + +/// Nicer notation for [`ScalarSet`] values. +macro_rules! s { + ( numeric ) => ( ScalarSet::NUMERIC ); + ( float ) => ( ScalarSet::ANY_FLOAT ); +} + +/// Nicer notation for [`ReturnType`] values. +macro_rules! return_type { + ( scalar arg0 ) => { ReturnType::MatchScalar(0) }; + ( arg0 ) => { ReturnType::Match(0) }; +} + +impl crate::MathFunction { + #[allow(clippy::todo)] + pub fn argument_info(self) -> &'static Builtin { + use crate::MathFunction as Mf; + match self { + // Functions of one numeric argument, extended element-wise to + // vectors. + Mf::Abs => b!(((vecN or scalar) ) -> arg0), + + // Functions of two numeric arguments, extended element-wise to + // vectors. + Mf::Max | Mf::Min => b!(((vecN or scalar), arg0) -> arg0), + + // Functions of one floating-point argument, extended + // element-wise to vectors. + Mf::Acos + | Mf::Acosh + | Mf::Asin + | Mf::Asinh + | Mf::Atan + | Mf::Atanh + | Mf::Ceil + | Mf::Cos + | Mf::Cosh + | Mf::Degrees + | Mf::Exp + | Mf::Exp2 + | Mf::Floor + | Mf::Fract + | Mf::InverseSqrt + | Mf::Length + | Mf::Log + | Mf::Log2 + | Mf::Radians + | Mf::Round + | Mf::Saturate + | Mf::Sin + | Mf::Sinh + | Mf::Sqrt + | Mf::Tan + | Mf::Tanh + | Mf::Trunc => b!(((vecN or scalar)) -> arg0), + + // Functions of two floating-point arguments, extended element-wise + // to vectors. + Mf::Atan2 => b!(((vecN or scalar), arg0) -> arg0), + +/* + // Functions of three floating-point arguments, extended + // element-wise to vectors. + Mf::Clamp => &Builtin { + arguments: &[ + TypeSet::ScalarOrVectorN(ScalarSet::AnyFloat), + TypeSet::Match(0), + TypeSet::Match(0), + ], + return_type: ReturnType::Match(0), + }, + + // Bitwise functions of one 32-bit integer argument, extended + // element-wise to vectors. + Mf::CountLeadingZeros + | Mf::CountOneBits + | Mf::CountTrailingZeros + | Mf::ReverseBits + | Mf::FirstLeadingBit + | Mf::FirstTrailingBit => &Builtin { + arguments: &[TypeSet::ScalarOrVectorN(ScalarSet::ConcreteInteger(4))], + return_type: ReturnType::Match(0), + }, + Mf::Cross => &Builtin { + arguments: &[ + TypeSet::Vector { + size: crate::VectorSize::Tri, + scalar: ScalarSet::AnyFloat, + }, + TypeSet::Match(0), + ], + return_type: ReturnType::Match(0), + }, + Mf::Determinant => &Builtin { + arguments: &[TypeSet::Matrix(ScalarSet::AnyFloat)], + return_type: ReturnType::MatchScalar(0), + }, + Mf::Distance => &Builtin { + arguments: &[ + TypeSet::ScalarOrVectorN(ScalarSet::AnyFloat), + TypeSet::Match(0), + ], + return_type: ReturnType::MatchScalar(0), + }, + Mf::Dot => todo!(), + Mf::ExtractBits => todo!(), + Mf::FaceForward => todo!(), + Mf::Fma => todo!(), + Mf::Frexp => todo!(), + Mf::InsertBits => todo!(), + Mf::Inverse => todo!(), + Mf::Ldexp => todo!(), + Mf::Mix => todo!(), + Mf::Modf => todo!(), + Mf::Normalize => todo!(), + Mf::Outer => todo!(), + Mf::Pack2x16float => todo!(), + Mf::Pack2x16snorm => todo!(), + Mf::Pack2x16unorm => todo!(), + Mf::Pack4x8snorm => todo!(), + Mf::Pack4x8unorm => todo!(), + Mf::Pack4xI8 => todo!(), + Mf::Pack4xU8 => todo!(), + Mf::Pow => todo!(), + Mf::QuantizeToF16 => todo!(), + Mf::Radians => todo!(), + Mf::Reflect => todo!(), + Mf::Refract => todo!(), + Mf::ReverseBits => todo!(), + Mf::Round => todo!(), + Mf::Sign => todo!(), + Mf::SmoothStep => todo!(), + Mf::Sqrt => todo!(), + Mf::Step => todo!(), + Mf::Transpose => todo!(), + Mf::Trunc => todo!(), + Mf::Unpack2x16float => todo!(), + Mf::Unpack2x16snorm => todo!(), + Mf::Unpack2x16unorm => todo!(), + Mf::Unpack4x8snorm => todo!(), + Mf::Unpack4x8unorm => todo!(), + Mf::Unpack4xI8 => todo!(), + Mf::Unpack4xU8 => todo!(), + */ + _ => todo!(), + } + } +} + + diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 76698fd102..f5324ccf32 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -2,6 +2,7 @@ [`Module`](super::Module) processing functionality. */ +mod builtins; mod constant_evaluator; mod emitter; pub mod index; @@ -10,6 +11,7 @@ mod namer; mod terminator; mod typifier; +pub use builtins::{Argument, Builtin, ReturnType, TypeSet}; pub use constant_evaluator::{ ConstantEvaluator, ConstantEvaluatorError, ExpressionKind, ExpressionKindTracker, }; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 5f3c0a819c..f30e09e39b 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -997,657 +997,42 @@ impl super::Validator { } => { use crate::MathFunction as Mf; + let info = fun.argument_info(); + let actuals: &[_] = match (info.arguments.len(), arg1, arg2, arg3) { + (1, None, None, None) => &[arg], + (2, Some(arg1), None, None) => &[arg, arg1], + (3, Some(arg1), Some(arg2), None) => &[arg, arg1, arg2], + (4, Some(arg1), Some(arg2), Some(arg3)) => &[arg, arg1, arg2, arg3], + _ => return Err(ExpressionError::WrongArgumentCount(fun)), + }; + let resolve = |arg| &resolver[arg]; - let arg_ty = resolve(arg); - let arg1_ty = arg1.map(resolve); - let arg2_ty = arg2.map(resolve); - let arg3_ty = arg3.map(resolve); - match fun { - Mf::Abs => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - let good = match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { - scalar.kind != Sk::Bool - } - _ => false, - }; - if !good { - return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); - } - } - Mf::Min | Mf::Max => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - let good = match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { - scalar.kind != Sk::Bool - } - _ => false, - }; - if !good { - return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Clamp => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), None) => (ty1, ty2), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - let good = match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { - scalar.kind != Sk::Bool - } - _ => false, - }; - if !good { - return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - if arg2_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )); - } - } - Mf::Saturate - | Mf::Cos - | Mf::Cosh - | Mf::Sin - | Mf::Sinh - | Mf::Tan - | Mf::Tanh - | Mf::Acos - | Mf::Asin - | Mf::Atan - | Mf::Asinh - | Mf::Acosh - | Mf::Atanh - | Mf::Radians - | Mf::Degrees - | Mf::Ceil - | Mf::Floor - | Mf::Round - | Mf::Fract - | Mf::Trunc - | Mf::Exp - | Mf::Exp2 - | Mf::Log - | Mf::Log2 - | Mf::Length - | Mf::Sqrt - | Mf::InverseSqrt => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } - if scalar.kind == Sk::Float => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } + let actual_types: &[_] = match actuals { + &[arg0] => &[resolve(arg0)], + &[arg0, arg1] => &[resolve(arg0), resolve(arg1)], + &[arg0, arg1, arg2] => &[resolve(arg0), resolve(arg1), resolve(arg2)], + &[arg0, arg1, arg2, arg3] => { + &[resolve(arg0), resolve(arg1), resolve(arg2), resolve(arg3)] } - Mf::Sign => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Float | Sk::Sint, - .. - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Float | Sk::Sint, - .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::Atan2 | Mf::Pow | Mf::Distance | Mf::Step => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } - if scalar.kind == Sk::Float => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Modf | Mf::Frexp => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - if !matches!(*arg_ty, - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } - if scalar.kind == Sk::Float) - { - return Err(ExpressionError::InvalidArgumentType(fun, 1, arg)); - } - } - Mf::Ldexp => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - let size0 = match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Float, .. - }) => None, - Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - size, - } => Some(size), - _ => { - return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); - } - }; - let good = match *arg1_ty { - Ti::Scalar(Sc { kind: Sk::Sint, .. }) if size0.is_none() => true, - Ti::Vector { - size, - scalar: Sc { kind: Sk::Sint, .. }, - } if Some(size) == size0 => true, - _ => false, - }; - if !good { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Dot => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Vector { - scalar: - Sc { - kind: Sk::Float | Sk::Sint | Sk::Uint, - .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Outer | Mf::Reflect => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Cross => { - let arg1_ty = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), None, None) => ty1, - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - size: crate::VectorSize::Tri, - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - } - Mf::Refract => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), None) => (ty1, ty2), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - - match *arg_ty { - Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } + _ => unreachable!(), + }; - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } + // Look at all the arguments and determine what set of + // types would be permitted for each one. + let mut arg_typesets = &[crate::proc::TypeSet::empty(); 4][.. info.arguments.len()]; + for (i, (formal, &actual)) in info.arguments.iter().zip(actual_types).enumerate() { + formal.constrain(i, actual.inner_with(&module.types), &mut arg_typesets); + } - match (arg_ty, arg2_ty) { - ( - &Ti::Vector { - scalar: - Sc { - width: vector_width, - .. - }, - .. - }, - &Ti::Scalar(Sc { - width: scalar_width, - kind: Sk::Float, - }), - ) if vector_width == scalar_width => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )) - } - } - } - Mf::Normalize => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::FaceForward | Mf::Fma | Mf::SmoothStep => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), None) => (ty1, ty2), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Float, .. - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Float, .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - if arg2_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )); - } - } - Mf::Mix => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), None) => (ty1, ty2), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - let arg_width = match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Float, - width, - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Float, - width, - }, - .. - } => width, - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - }; - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - // the last argument can always be a scalar - match *arg2_ty { - Ti::Scalar(Sc { - kind: Sk::Float, - width, - }) if width == arg_width => {} - _ if arg2_ty == arg_ty => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )); - } - } - } - Mf::Inverse | Mf::Determinant => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - let good = match *arg_ty { - Ti::Matrix { columns, rows, .. } => columns == rows, - _ => false, - }; - if !good { - return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)); - } - } - Mf::Transpose => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Matrix { .. } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::QuantizeToF16 => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Float, - width: 4, - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Float, - width: 4, - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 - Mf::CountLeadingZeros - | Mf::CountTrailingZeros - | Mf::CountOneBits - | Mf::ReverseBits - | Mf::FirstLeadingBit - | Mf::FirstTrailingBit => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Sint | Sk::Uint => { - if scalar.width != 4 { - return Err(ExpressionError::UnsupportedWidth( - fun, - scalar.kind, - scalar.width, - )); - } - } - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - }, - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::InsertBits => { - let (arg1_ty, arg2_ty, arg3_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), Some(ty3)) => (ty1, ty2, ty3), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Sint | Sk::Uint, - .. - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Sint | Sk::Uint, - .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - if arg1_ty != arg_ty { - return Err(ExpressionError::InvalidArgumentType( - fun, - 1, - arg1.unwrap(), - )); - } - match *arg2_ty { - Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )) - } - } - match *arg3_ty { - Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg3.unwrap(), - )) - } - } - // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 - for &arg in [arg_ty, arg1_ty, arg2_ty, arg3_ty].iter() { - match *arg { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { - if scalar.width != 4 { - return Err(ExpressionError::UnsupportedWidth( - fun, - scalar.kind, - scalar.width, - )); - } - } - _ => {} - } - } - } - Mf::ExtractBits => { - let (arg1_ty, arg2_ty) = match (arg1_ty, arg2_ty, arg3_ty) { - (Some(ty1), Some(ty2), None) => (ty1, ty2), - _ => return Err(ExpressionError::WrongArgumentCount(fun)), - }; - match *arg_ty { - Ti::Scalar(Sc { - kind: Sk::Sint | Sk::Uint, - .. - }) - | Ti::Vector { - scalar: - Sc { - kind: Sk::Sint | Sk::Uint, - .. - }, - .. - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - match *arg1_ty { - Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg1.unwrap(), - )) - } - } - match *arg2_ty { - Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} - _ => { - return Err(ExpressionError::InvalidArgumentType( - fun, - 2, - arg2.unwrap(), - )) - } - } - // Remove once fixed https://github.com/gfx-rs/wgpu/issues/5276 - for &arg in [arg_ty, arg1_ty, arg2_ty].iter() { - match *arg { - Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => { - if scalar.width != 4 { - return Err(ExpressionError::UnsupportedWidth( - fun, - scalar.kind, - scalar.width, - )); - } - } - _ => {} - } - } - } - Mf::Pack2x16unorm | Mf::Pack2x16snorm | Mf::Pack2x16float => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Vector { - size: crate::VectorSize::Bi, - scalar: - Sc { - kind: Sk::Float, .. - }, - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::Pack4x8snorm | Mf::Pack4x8unorm => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Vector { - size: crate::VectorSize::Quad, - scalar: - Sc { - kind: Sk::Float, .. - }, - } => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - mf @ (Mf::Pack4xI8 | Mf::Pack4xU8) => { - let scalar_kind = match mf { - Mf::Pack4xI8 => Sk::Sint, - Mf::Pack4xU8 => Sk::Uint, - _ => unreachable!(), - }; - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Vector { - size: crate::VectorSize::Quad, - scalar: Sc { kind, .. }, - } if kind == scalar_kind => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } - } - Mf::Unpack2x16float - | Mf::Unpack2x16snorm - | Mf::Unpack2x16unorm - | Mf::Unpack4x8snorm - | Mf::Unpack4x8unorm - | Mf::Unpack4xI8 - | Mf::Unpack4xU8 => { - if arg1_ty.is_some() || arg2_ty.is_some() || arg3_ty.is_some() { - return Err(ExpressionError::WrongArgumentCount(fun)); - } - match *arg_ty { - Ti::Scalar(Sc { kind: Sk::Uint, .. }) => {} - _ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)), - } + for (i, (allowed, actual)) in arg_typesets.iter().zip(actual_types).enumerate() { + let actual = crate::proc::TypeSet::singleton(actual.inner_with(&module.types)); + if !allowed.contains(actual) { + return Err(ExpressionError::InvalidArgumentType( + fun, i as u32, actuals[i], + )); } } + ShaderStages::all() } E::As {