From 205ae7d129a3d707199b7411b45fb945d1fefaee Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 7 Jul 2022 10:37:40 +0200 Subject: [PATCH] Complete size checks in `BLAS.[sy/he]mm!` (#45605) (cherry picked from commit da13d78f9f689e7d761e3c149462c0a2b0dad54f) --- stdlib/LinearAlgebra/src/blas.jl | 53 ++++++++++++++++++++------ stdlib/LinearAlgebra/test/blas.jl | 8 ++++ stdlib/LinearAlgebra/test/symmetric.jl | 6 +++ 3 files changed, 56 insertions(+), 11 deletions(-) diff --git a/stdlib/LinearAlgebra/src/blas.jl b/stdlib/LinearAlgebra/src/blas.jl index 272f0f57bdb29d..69bfb08b9430c0 100644 --- a/stdlib/LinearAlgebra/src/blas.jl +++ b/stdlib/LinearAlgebra/src/blas.jl @@ -1,5 +1,4 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license - """ Interface to BLAS subroutines. """ @@ -1509,11 +1508,27 @@ for (mfname, elty) in ((:dsymm_,:Float64), require_one_based_indexing(A, B, C) m, n = size(C) j = checksquare(A) - if j != (side == 'L' ? m : n) - throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)")) - end - if size(B,2) != n - throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) + M, N = size(B) + if side == 'L' + if j != m + throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m")) + end + if N != n + throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n")) + end + if j != M + throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M")) + end + else + if j != n + throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n")) + end + if N != j + throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j")) + end + if M != m + throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m")) + end end chkstride1(A) chkstride1(B) @@ -1582,11 +1597,27 @@ for (mfname, elty) in ((:zhemm_,:ComplexF64), require_one_based_indexing(A, B, C) m, n = size(C) j = checksquare(A) - if j != (side == 'L' ? m : n) - throw(DimensionMismatch("A has size $(size(A)), C has size ($m,$n)")) - end - if size(B,2) != n - throw(DimensionMismatch("B has second dimension $(size(B,2)) but needs to match second dimension of C, $n")) + M, N = size(B) + if side == 'L' + if j != m + throw(DimensionMismatch("A has first dimension $j but needs to match first dimension of C, $m")) + end + if N != n + throw(DimensionMismatch("B has second dimension $N but needs to match second dimension of C, $n")) + end + if j != M + throw(DimensionMismatch("A has second dimension $j but needs to match first dimension of B, $M")) + end + else + if j != n + throw(DimensionMismatch("B has second dimension $j but needs to match second dimension of C, $n")) + end + if N != j + throw(DimensionMismatch("A has second dimension $N but needs to match first dimension of B, $j")) + end + if M != m + throw(DimensionMismatch("A has first dimension $M but needs to match first dimension of C, $m")) + end end chkstride1(A) chkstride1(B) diff --git a/stdlib/LinearAlgebra/test/blas.jl b/stdlib/LinearAlgebra/test/blas.jl index aaebc50546ac3f..e144b64e1b7d3b 100644 --- a/stdlib/LinearAlgebra/test/blas.jl +++ b/stdlib/LinearAlgebra/test/blas.jl @@ -223,11 +223,19 @@ Random.seed!(100) @test_throws DimensionMismatch BLAS.symm('R','U',Cmn,Cnn) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.symm!('L','U',one(elty),Asymm,Cmn,one(elty),Cnn) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnm,one(elty),Cmn) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.symm!('R','U',one(elty),Asymm,Cmn,one(elty),Cnn) if elty <: BlasComplex @test_throws DimensionMismatch BLAS.hemm('L','U',Cnm,Cnn) @test_throws DimensionMismatch BLAS.hemm('R','U',Cmn,Cnn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cmn) @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.hemm!('L','U',one(elty),Aherm,Cmn,one(elty),Cnn) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnm,one(elty),Cmn) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cnn,one(elty),Cnm) + @test_throws DimensionMismatch BLAS.hemm!('R','U',one(elty),Aherm,Cmn,one(elty),Cnn) end end end diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index 7d99dd32889fd1..9da68355d9b736 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -340,6 +340,9 @@ end C = zeros(eltya,n,n) @test Hermitian(aherm) * a ≈ aherm * a @test a * Hermitian(aherm) ≈ a * aherm + # rectangular multiplication + @test [a; a] * Hermitian(aherm) ≈ [a; a] * aherm + @test Hermitian(aherm) * [a a] ≈ aherm * [a a] @test Hermitian(aherm) * Hermitian(aherm) ≈ aherm*aherm @test_throws DimensionMismatch Hermitian(aherm) * Vector{eltya}(undef, n+1) LinearAlgebra.mul!(C,a,Hermitian(aherm)) @@ -348,6 +351,9 @@ end @test Symmetric(asym) * Symmetric(asym) ≈ asym*asym @test Symmetric(asym) * a ≈ asym * a @test a * Symmetric(asym) ≈ a * asym + # rectangular multiplication + @test Symmetric(asym) * [a a] ≈ asym * [a a] + @test [a; a] * Symmetric(asym) ≈ [a; a] * asym @test_throws DimensionMismatch Symmetric(asym) * Vector{eltya}(undef, n+1) LinearAlgebra.mul!(C,a,Symmetric(asym)) @test C ≈ a*asym