From 5ba487da5d4895139d1b2b49b436bf62fa93fb13 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 16 Mar 2023 10:53:06 +0000 Subject: [PATCH] Add *(::Diagonal, ::Diagonal, ::Diagonal) (#49005) (#49007) (cherry picked from commit c37fc2798d9bb7349ff8eadb350ae68cf17cee61) --- stdlib/LinearAlgebra/src/diagonal.jl | 6 ++++++ stdlib/LinearAlgebra/test/diagonal.jl | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 291233ebe2e6a..6364169b2ba8a 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -355,6 +355,12 @@ function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal) return broadcast(*, Da.diag, A, permutedims(Db.diag)) end +function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal) + _muldiag_size_check(Da, Db) + _muldiag_size_check(Db, Dc) + return Diagonal(Da.diag .* Db.diag .* Dc.diag) +end + # Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat @inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) = _muldiag!(out, D, V, alpha, beta) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index 83a2a896e736c..130a66ea0a1d5 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1133,4 +1133,15 @@ Base.size(::SMatrix1) = (1, 1) @test C isa Matrix{SMatrix1{String}} end +@testset "diagonal triple multiplication (#49005)" begin + n = 10 + @test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal + @test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1)))) + @test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1)))) + @test_throws DimensionMismatch (*(Diagonal(ones(n+1)), Diagonal(1:n), Diagonal(ones(n)))) + + # currently falls back to two-term * + @test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal +end + end # module TestDiagonal