diff --git a/src/valid/function.rs b/src/valid/function.rs index 151952750e..b928ff5172 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -86,6 +86,8 @@ pub enum FunctionError { }, #[error("Argument '{name}' at index {index} has a type that can't be passed into functions.")] InvalidArgumentType { index: usize, name: String }, + #[error("The function's given return type cannot be returned from functions")] + NonConstructibleReturnType, #[error("Argument '{name}' at index {index} is a pointer of space {space:?}, which can't be passed into functions.")] InvalidArgumentPointerSpace { index: usize, @@ -894,6 +896,17 @@ impl super::Validator { } } + #[cfg(feature = "validate")] + if let Some(ref result) = fun.result { + if !self.types[result.ty.index()] + .flags + .contains(super::TypeFlags::CONSTRUCTIBLE) + { + return Err(FunctionError::NonConstructibleReturnType + .with_span_handle(result.ty, &module.types)); + } + } + self.valid_expression_set.clear(); self.valid_expression_list.clear(); for (handle, expr) in fun.expressions.iter() { diff --git a/src/valid/type.rs b/src/valid/type.rs index 71af43e99c..0312c21432 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -52,6 +52,15 @@ bitflags::bitflags! { /// This type can be passed as a function argument. const ARGUMENT = 0x40; + + /// A WGSL [constructible] type. + /// + /// The constructible types are scalars, vectors, matrices, fixed-size + /// arrays of constructible types, and structs whose members are all + /// constructible. + /// + /// [constructible]: https://gpuweb.github.io/gpuweb/wgsl/#constructible + const CONSTRUCTIBLE = 0x80; } } @@ -237,6 +246,7 @@ impl super::Validator { | TypeFlags::SIZED | TypeFlags::COPY | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE | shareable, width as u32, ) @@ -257,6 +267,7 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE | shareable, count * (width as u32), ) @@ -275,7 +286,8 @@ impl super::Validator { | TypeFlags::SIZED | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE - | TypeFlags::ARGUMENT, + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, count * (width as u32), ) } @@ -467,7 +479,7 @@ impl super::Validator { return Err(TypeError::NonPositiveArrayLength(const_handle)); } - TypeFlags::SIZED | TypeFlags::ARGUMENT + TypeFlags::SIZED | TypeFlags::ARGUMENT | TypeFlags::CONSTRUCTIBLE } crate::ArraySize::Dynamic => { // Non-SIZED types may only appear as the last element of a structure. @@ -495,7 +507,8 @@ impl super::Validator { | TypeFlags::COPY | TypeFlags::HOST_SHAREABLE | TypeFlags::IO_SHAREABLE - | TypeFlags::ARGUMENT, + | TypeFlags::ARGUMENT + | TypeFlags::CONSTRUCTIBLE, 1, ); ti.uniform_layout = Ok(Some(UNIFORM_MIN_ALIGNMENT)); diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index db7da613c7..74f9b621f0 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -1029,6 +1029,43 @@ fn invalid_functions() { }) if function_name == "unacceptable_ptr_space" && argument_name == "arg" } + + check_validation! { + " + struct AFloat { + said_float: f32 + }; + @group(0) @binding(0) + var float: AFloat; + + fn return_pointer() -> ptr { + return &float.said_float; + } + ": + Err(naga::valid::ValidationError::Function { + name: function_name, + error: naga::valid::FunctionError::NonConstructibleReturnType, + .. + }) + if function_name == "return_pointer" + } + + check_validation! { + " + @group(0) @binding(0) + var atom: atomic; + + fn return_atomic() -> atomic { + return atom; + } + ": + Err(naga::valid::ValidationError::Function { + name: function_name, + error: naga::valid::FunctionError::NonConstructibleReturnType, + .. + }) + if function_name == "return_atomic" + } } #[test]