Skip to content

Commit

Permalink
Fix (l/r)mul! with Diagonal/Bidiagonal (#55052)
Browse files Browse the repository at this point in the history
Currently, `rmul!(A::AbstractMatirx, D::Diagonal)` calls `mul!(A, A,
D)`, but this isn't a valid call, as `mul!` assumes no aliasing between
the destination and the matrices to be multiplied. As a consequence,
```julia
julia> B = Bidiagonal(rand(4), rand(3), :L)
4×4 Bidiagonal{Float64, Vector{Float64}}:
 0.476892   ⋅         ⋅         ⋅ 
 0.353756  0.139188   ⋅         ⋅ 
  ⋅        0.685839  0.309336   ⋅ 
  ⋅         ⋅        0.369038  0.304273

julia> D = Diagonal(rand(size(B,2)));

julia> rmul!(B, D)
4×4 Bidiagonal{Float64, Vector{Float64}}:
 0.0   ⋅    ⋅    ⋅ 
 0.0  0.0   ⋅    ⋅ 
  ⋅   0.0  0.0   ⋅ 
  ⋅    ⋅   0.0  0.0

julia> B
4×4 Bidiagonal{Float64, Vector{Float64}}:
 0.0   ⋅    ⋅    ⋅ 
 0.0  0.0   ⋅    ⋅ 
  ⋅   0.0  0.0   ⋅ 
  ⋅    ⋅   0.0  0.0
```
This is clearly nonsense, and happens because the internal `_mul!`
function assumes that it can safely overwrite the destination with zeros
before carrying out the multiplication. This is fixed in this PR by
using broadcasting instead. The current implementation is generally
equally performant, albeit occasionally with a minor allocation arising
from `reshape`ing an `Array`.

A similar problem also exists in `l/rmul!` with `Bidiaognal`, but that's
a little harder to fix while remaining equally performant.
  • Loading branch information
jishnub authored Jul 11, 2024
1 parent faf17eb commit 262b40a
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 4 deletions.
72 changes: 70 additions & 2 deletions stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,76 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) =
@stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta))

lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul())
rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul())
# B .= A * B
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in axes(B,2)
for i in axes(ev,1)
B[i,k] = dv[i] * B[i,k] + ev[i] * B[i+1,k]
end
B[end,k] = dv[end] * B[end,k]
end
else
for k in axes(B,2)
for i in reverse(axes(dv,1)[2:end])
B[i,k] = dv[i] * B[i,k] + ev[i-1] * B[i-1,k]
end
B[1,k] = dv[1] * B[1,k]
end
end
return B
end
# B .= D * B
function lmul!(D::Diagonal, B::Bidiagonal)
_muldiag_size_check(D, B)
(; dv, ev) = B
isL = B.uplo == 'L'
dv[1] = D.diag[1] * dv[1]
for i in axes(ev,1)
ev[i] = D.diag[i + isL] * ev[i]
dv[i+1] = D.diag[i+1] * dv[i+1]
end
return B
end
# B .= B * A
function rmul!(B::AbstractMatrix, A::Bidiagonal)
_muldiag_size_check(A, B)
(; dv, ev) = A
if A.uplo == 'U'
for k in reverse(axes(dv,1)[2:end])
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k-1] * ev[k-1]
end
end
for i in axes(B,1)
B[i,1] *= dv[1]
end
else
for k in axes(ev,1)
for i in axes(B,1)
B[i,k] = B[i,k] * dv[k] + B[i,k+1] * ev[k]
end
end
for i in axes(B,1)
B[i,end] *= dv[end]
end
end
return B
end
# B .= B * D
function rmul!(B::Bidiagonal, D::Diagonal)
_muldiag_size_check(B, D)
(; dv, ev) = B
isU = B.uplo == 'U'
dv[1] *= D.diag[1]
for i in axes(ev,1)
ev[i] *= D.diag[i + isU]
dv[i+1] *= D.diag[i+1]
end
return B
end

function check_A_mul_B!_sizes(C, A, B)
mA, nA = size(A)
Expand Down
45 changes: 43 additions & 2 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,49 @@ function (*)(D::Diagonal, V::AbstractVector)
return D.diag .* V
end

rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D)
lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B)
function rmul!(A::AbstractMatrix, D::Diagonal)
_muldiag_size_check(A, D)
for I in CartesianIndices(A)
row, col = Tuple(I)
@inbounds A[row, col] *= D.diag[col]
end
return A
end
# T .= T * D
function rmul!(T::Tridiagonal, D::Diagonal)
_muldiag_size_check(T, D)
(; dl, d, du) = T
d[1] *= D.diag[1]
for i in axes(dl,1)
dl[i] *= D.diag[i]
du[i] *= D.diag[i+1]
d[i+1] *= D.diag[i+1]
end
return T
end

function lmul!(D::Diagonal, B::AbstractVecOrMat)
_muldiag_size_check(D, B)
for I in CartesianIndices(B)
row = I[1]
@inbounds B[I] = D.diag[row] * B[I]
end
return B
end

# in-place multiplication with a diagonal
# T .= D * T
function lmul!(D::Diagonal, T::Tridiagonal)
_muldiag_size_check(D, T)
(; dl, d, du) = T
d[1] = D.diag[1] * d[1]
for i in axes(dl,1)
dl[i] = D.diag[i+1] * dl[i]
du[i] = D.diag[i] * du[i]
d[i+1] = D.diag[i+1] * d[i+1]
end
return T
end

function __muldiag!(out, D::Diagonal, B, _add::MulAddMul{ais1,bis0}) where {ais1,bis0}
require_one_based_indexing(out, B)
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,41 @@ end
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
end

@testset "rmul!/lmul! with banded matrices" begin
dv, ev = rand(4), rand(3)
for A in (Bidiagonal(dv, ev, :U), Bidiagonal(dv, ev, :L))
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(dv, ev, :U),
Bidiagonal(dv, ev, :L),
Diagonal(dv)
)
@test_throws ArgumentError rmul!(B, A)
@test_throws ArgumentError lmul!(A, B)
end
D = Diagonal(dv)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A
end
@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
for uplo in (:L, :U)
B = Bidiagonal(fill(S32, 4), fill(S32, 3), uplo)
D = Diagonal(fill(S22, size(B,2)))
@test rmul!(copy(B), D) B * D
D = Diagonal(fill(S33, size(B,1)))
@test lmul!(D, copy(B)) D * B
end

B = Bidiagonal(fill(S33, 4), fill(S33, 3), :U)
D = Diagonal(fill(S32, 4))
@test lmul!(B, Array(D)) B * D
B = Bidiagonal(fill(S22, 4), fill(S22, 3), :U)
@test rmul!(Array(D), B) D * B
end
end

@testset "conversion to Tridiagonal for immutable bands" begin
n = 4
dv = FillArrays.Fill(3, n)
Expand Down
13 changes: 13 additions & 0 deletions stdlib/LinearAlgebra/test/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1322,4 +1322,17 @@ end
@test M == D
end

@testset "rmul!/lmul! with banded matrices" begin
@testset "$(nameof(typeof(B)))" for B in (
Bidiagonal(rand(4), rand(3), :L),
Tridiagonal(rand(3), rand(4), rand(3))
)
BA = Array(B)
D = Diagonal(rand(size(B,1)))
DA = Array(D)
@test rmul!(copy(B), D) B * D BA * DA
@test lmul!(D, copy(B)) D * B DA * BA
end
end

end # module TestDiagonal
19 changes: 19 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -892,4 +892,23 @@ end
end
end

@testset "rmul!/lmul! with banded matrices" begin
dl, d, du = rand(3), rand(4), rand(3)
A = Tridiagonal(dl, d, du)
D = Diagonal(d)
@test rmul!(copy(A), D) A * D
@test lmul!(D, copy(A)) D * A

@testset "non-commutative" begin
S32 = SizedArrays.SizedArray{(3,2)}(rand(3,2))
S33 = SizedArrays.SizedArray{(3,3)}(rand(3,3))
S22 = SizedArrays.SizedArray{(2,2)}(rand(2,2))
T = Tridiagonal(fill(S32,3), fill(S32, 4), fill(S32, 3))
D = Diagonal(fill(S22, size(T,2)))
@test rmul!(copy(T), D) T * D
D = Diagonal(fill(S33, size(T,1)))
@test lmul!(D, copy(T)) D * T
end
end

end # module TestTridiagonal
3 changes: 3 additions & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ Base.first(::SOneTo) = 1
Base.last(r::SOneTo) = length(r)
Base.show(io::IO, r::SOneTo) = print(io, "SOneTo(", length(r), ")")

Broadcast.axistype(a::Base.OneTo, s::SOneTo) = s
Broadcast.axistype(s::SOneTo, a::Base.OneTo) = s

struct SizedArray{SZ,T,N,A<:AbstractArray} <: AbstractArray{T,N}
data::A
function SizedArray{SZ}(data::AbstractArray{T,N}) where {SZ,T,N}
Expand Down

0 comments on commit 262b40a

Please sign in to comment.