From dd06d2752743c7e5ceb35f6a82601d6bc7bd1f8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Sat, 28 May 2022 00:08:02 +0100 Subject: [PATCH] glsl-in: Fix matrix multiplication check The previous check compared rows to rows and columns to columns but multiplication of matrices only needs the columns of the left matrix to be equal to the rows of the right matrix. --- src/front/glsl/context.rs | 8 ++++++-- tests/in/glsl/expressions.frag | 4 ++++ tests/out/wgsl/expressions-frag.wgsl | 13 +++++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index fe20df4101..7a2fcc0747 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -617,9 +617,13 @@ impl Context { width: right_width, }, ) => { + let additive_check = + left_columns != right_columns || left_rows != right_rows; + let multiplicative_check = left_columns != right_rows; + // Check that the two arguments have the same dimensions - if left_columns != right_columns - || left_rows != right_rows + if (multiplicative_check && op == BinaryOperator::Multiply) + || (additive_check && op != BinaryOperator::Multiply) || left_width != right_width { parser.errors.push(Error { diff --git a/tests/in/glsl/expressions.frag b/tests/in/glsl/expressions.frag index 8dd07c6525..acf0ea9213 100644 --- a/tests/in/glsl/expressions.frag +++ b/tests/in/glsl/expressions.frag @@ -128,6 +128,10 @@ void ternary(bool a) { uint nested = a ? (a ? (a ? 2u : 3) : 4u) : 5; } +void testMatrixMultiplication(mat4x3 a, mat4x4 b) { + mat4x3 c = a * b; +} + out vec4 o_color; void main() { privatePointer(global); diff --git a/tests/out/wgsl/expressions-frag.wgsl b/tests/out/wgsl/expressions-frag.wgsl index b4364e597c..7cf1eace0e 100644 --- a/tests/out/wgsl/expressions-frag.wgsl +++ b/tests/out/wgsl/expressions-frag.wgsl @@ -356,6 +356,19 @@ fn ternary(a_20: bool) { return; } +fn testMatrixMultiplication(a_22: mat4x3, b_18: mat4x4) { + var a_23: mat4x3; + var b_19: mat4x4; + var c_2: mat4x3; + + a_23 = a_22; + b_19 = b_18; + let _e5 = a_23; + let _e6 = b_19; + c_2 = (_e5 * _e6); + return; +} + fn main_1() { var local_5: f32;