Skip to content

Commit

Permalink
Add *(::Diagonal, ::Diagonal, ::Diagonal) (#49005) (#49007)
Browse files Browse the repository at this point in the history
(cherry picked from commit c37fc27)
  • Loading branch information
dlfivefifty authored and KristofferC committed Mar 30, 2023
1 parent 3f04640 commit 5ba487d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
6 changes: 6 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,12 @@ function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
return broadcast(*, Da.diag, A, permutedims(Db.diag))
end

function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
_muldiag_size_check(Da, Db)
_muldiag_size_check(Db, Dc)
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
end

# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
_muldiag!(out, D, V, alpha, beta)
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1133,4 +1133,15 @@ Base.size(::SMatrix1) = (1, 1)
@test C isa Matrix{SMatrix1{String}}
end

@testset "diagonal triple multiplication (#49005)" begin
n = 10
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n))) isa Diagonal
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n+1))))
@test_throws DimensionMismatch (*(Diagonal(ones(n)), Diagonal(1:n+1), Diagonal(ones(n+1))))
@test_throws DimensionMismatch (*(Diagonal(ones(n+1)), Diagonal(1:n), Diagonal(ones(n))))

# currently falls back to two-term *
@test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal
end

end # module TestDiagonal

0 comments on commit 5ba487d

Please sign in to comment.