Skip to content

Commit

Permalink
Bidiagonal to Tridiagonal with immutable bands (#55059)
Browse files Browse the repository at this point in the history
Using `similar` to generate the zero band necessarily allocates a
mutable vector, which would lead to an error if the other bands are
immutable. This PR changes this to use `zero` instead, which usually
produces a vector of the same type. There are occasions where `zero(v)`
produces a different type from `v`, so an extra conversion is added to
obtain a zero vector of the same type.
The following works after this:
```julia
julia> using FillArrays, LinearAlgebra

julia> n = 4; B = Bidiagonal(Fill(3, n), Fill(2, n-1), :U)
4×4 Bidiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}:
 3  2  ⋅  ⋅
 ⋅  3  2  ⋅
 ⋅  ⋅  3  2
 ⋅  ⋅  ⋅  3

julia> Tridiagonal(B)
4×4 Tridiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}:
 3  2  ⋅  ⋅
 0  3  2  ⋅
 ⋅  0  3  2
 ⋅  ⋅  0  3

julia> Tridiagonal{Float64}(B)
4×4 Tridiagonal{Float64, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}:
 3.0  2.0   ⋅    ⋅ 
 0.0  3.0  2.0   ⋅ 
  ⋅   0.0  3.0  2.0
  ⋅    ⋅   0.0  3.0
```
  • Loading branch information
jishnub authored Jul 8, 2024
1 parent 23dabef commit ed987f2
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 4 deletions.
3 changes: 2 additions & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ promote_rule(::Type{<:Matrix}, ::Type{<:Bidiagonal}) = Matrix
function Tridiagonal{T}(A::Bidiagonal) where T
dv = convert(AbstractVector{T}, A.dv)
ev = convert(AbstractVector{T}, A.ev)
z = fill!(similar(ev), zero(T))
# ensure that the types are identical, even if zero returns a different type
z = oftype(ev, zero(ev))
A.uplo == 'U' ? Tridiagonal(z, dv, ev) : Tridiagonal(ev, dv, z)
end
promote_rule(::Type{<:Tridiagonal{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} =
Expand Down
8 changes: 5 additions & 3 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ Diagonal(A::Bidiagonal) = Diagonal(A.dv)
SymTridiagonal(A::Bidiagonal) =
iszero(A.ev) ? SymTridiagonal(A.dv, A.ev) :
throw(ArgumentError("matrix cannot be represented as SymTridiagonal"))
Tridiagonal(A::Bidiagonal) =
Tridiagonal(A.uplo == 'U' ? fill!(similar(A.ev), 0) : A.ev, A.dv,
A.uplo == 'U' ? A.ev : fill!(similar(A.ev), 0))
function Tridiagonal(A::Bidiagonal)
# ensure that the types are identical, even if zero returns a different type
z = oftype(A.ev, zero(A.ev))
Tridiagonal(A.uplo == 'U' ? z : A.ev, A.dv, A.uplo == 'U' ? A.ev : z)
end

# conversions from SymTridiagonal to other special matrix types
Diagonal(A::SymTridiagonal) = Diagonal(A.dv)
Expand Down
16 changes: 16 additions & 0 deletions stdlib/LinearAlgebra/test/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -933,4 +933,20 @@ end
@test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)]
end

@testset "conversion to Tridiagonal for immutable bands" begin
n = 4
dv = FillArrays.Fill(3, n)
ev = FillArrays.Fill(2, n-1)
z = FillArrays.Fill(0, n-1)
dvf = FillArrays.Fill(Float64(3), n)
evf = FillArrays.Fill(Float64(2), n-1)
zf = FillArrays.Fill(Float64(0), n-1)
B = Bidiagonal(dv, ev, :U)
@test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(z, dv, ev)
@test Tridiagonal{Float64}(B) === Tridiagonal(zf, dvf, evf)
B = Bidiagonal(dv, ev, :L)
@test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(ev, dv, z)
@test Tridiagonal{Float64}(B) === Tridiagonal(evf, dvf, zf)
end

end # module TestBidiagonal
4 changes: 4 additions & 0 deletions test/testhelpers/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Base.size(F::Fill) = F.size

Base.copy(F::Fill) = F

Base.AbstractArray{T,N}(F::Fill{<:Any,N}) where {T,N} = Fill(T(F.value), F.size)

@inline getindex_value(F::Fill) = F.value

@inline function Base.getindex(F::Fill{<:Any,N}, i::Vararg{Int,N}) where {N}
Expand All @@ -29,6 +31,8 @@ end
F
end

Base.zero(F::Fill) = Fill(zero(F.value), size(F))

Base.show(io::IO, F::Fill) = print(io, "Fill($(F.value), $(F.size))")
Base.show(io::IO, ::MIME"text/plain", F::Fill) = show(io, F)

Expand Down

0 comments on commit ed987f2

Please sign in to comment.