From ed987f2603fd96f5ff07f26189b160dd538b7d6e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Mon, 8 Jul 2024 14:48:38 +0530 Subject: [PATCH] Bidiagonal to Tridiagonal with immutable bands (#55059) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Using `similar` to generate the zero band necessarily allocates a mutable vector, which would lead to an error if the other bands are immutable. This PR changes this to use `zero` instead, which usually produces a vector of the same type. There are occasions where `zero(v)` produces a different type from `v`, so an extra conversion is added to obtain a zero vector of the same type. The following works after this: ```julia julia> using FillArrays, LinearAlgebra julia> n = 4; B = Bidiagonal(Fill(3, n), Fill(2, n-1), :U) 4×4 Bidiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 3 2 ⋅ ⋅ ⋅ 3 2 ⋅ ⋅ ⋅ 3 2 ⋅ ⋅ ⋅ 3 julia> Tridiagonal(B) 4×4 Tridiagonal{Int64, Fill{Int64, 1, Tuple{Base.OneTo{Int64}}}}: 3 2 ⋅ ⋅ 0 3 2 ⋅ ⋅ 0 3 2 ⋅ ⋅ 0 3 julia> Tridiagonal{Float64}(B) 4×4 Tridiagonal{Float64, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}: 3.0 2.0 ⋅ ⋅ 0.0 3.0 2.0 ⋅ ⋅ 0.0 3.0 2.0 ⋅ ⋅ 0.0 3.0 ``` --- stdlib/LinearAlgebra/src/bidiag.jl | 3 ++- stdlib/LinearAlgebra/src/special.jl | 8 +++++--- stdlib/LinearAlgebra/test/bidiag.jl | 16 ++++++++++++++++ test/testhelpers/FillArrays.jl | 4 ++++ 4 files changed, 27 insertions(+), 4 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 5de286a2c335b..f0d04f121d48f 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -231,7 +231,8 @@ promote_rule(::Type{<:Matrix}, ::Type{<:Bidiagonal}) = Matrix function Tridiagonal{T}(A::Bidiagonal) where T dv = convert(AbstractVector{T}, A.dv) ev = convert(AbstractVector{T}, A.ev) - z = fill!(similar(ev), zero(T)) + # ensure that the types are identical, even if zero returns a different type + z = oftype(ev, zero(ev)) A.uplo == 'U' ? Tridiagonal(z, dv, ev) : Tridiagonal(ev, dv, z) end promote_rule(::Type{<:Tridiagonal{T}}, ::Type{<:Bidiagonal{S}}) where {T,S} = diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 6a8dd676bc41b..9633594574055 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -15,9 +15,11 @@ Diagonal(A::Bidiagonal) = Diagonal(A.dv) SymTridiagonal(A::Bidiagonal) = iszero(A.ev) ? SymTridiagonal(A.dv, A.ev) : throw(ArgumentError("matrix cannot be represented as SymTridiagonal")) -Tridiagonal(A::Bidiagonal) = - Tridiagonal(A.uplo == 'U' ? fill!(similar(A.ev), 0) : A.ev, A.dv, - A.uplo == 'U' ? A.ev : fill!(similar(A.ev), 0)) +function Tridiagonal(A::Bidiagonal) + # ensure that the types are identical, even if zero returns a different type + z = oftype(A.ev, zero(A.ev)) + Tridiagonal(A.uplo == 'U' ? z : A.ev, A.dv, A.uplo == 'U' ? A.ev : z) +end # conversions from SymTridiagonal to other special matrix types Diagonal(A::SymTridiagonal) = Diagonal(A.dv) diff --git a/stdlib/LinearAlgebra/test/bidiag.jl b/stdlib/LinearAlgebra/test/bidiag.jl index 416587332c46c..2380a93d90a74 100644 --- a/stdlib/LinearAlgebra/test/bidiag.jl +++ b/stdlib/LinearAlgebra/test/bidiag.jl @@ -933,4 +933,20 @@ end @test B[1,2] == B[Int8(1),UInt16(2)] == B[big(1), Int16(2)] end +@testset "conversion to Tridiagonal for immutable bands" begin + n = 4 + dv = FillArrays.Fill(3, n) + ev = FillArrays.Fill(2, n-1) + z = FillArrays.Fill(0, n-1) + dvf = FillArrays.Fill(Float64(3), n) + evf = FillArrays.Fill(Float64(2), n-1) + zf = FillArrays.Fill(Float64(0), n-1) + B = Bidiagonal(dv, ev, :U) + @test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(z, dv, ev) + @test Tridiagonal{Float64}(B) === Tridiagonal(zf, dvf, evf) + B = Bidiagonal(dv, ev, :L) + @test Tridiagonal{Int}(B) === Tridiagonal(B) === Tridiagonal(ev, dv, z) + @test Tridiagonal{Float64}(B) === Tridiagonal(evf, dvf, zf) +end + end # module TestBidiagonal diff --git a/test/testhelpers/FillArrays.jl b/test/testhelpers/FillArrays.jl index ee988e0f0aa9c..d3b8d74da7148 100644 --- a/test/testhelpers/FillArrays.jl +++ b/test/testhelpers/FillArrays.jl @@ -11,6 +11,8 @@ Base.size(F::Fill) = F.size Base.copy(F::Fill) = F +Base.AbstractArray{T,N}(F::Fill{<:Any,N}) where {T,N} = Fill(T(F.value), F.size) + @inline getindex_value(F::Fill) = F.value @inline function Base.getindex(F::Fill{<:Any,N}, i::Vararg{Int,N}) where {N} @@ -29,6 +31,8 @@ end F end +Base.zero(F::Fill) = Fill(zero(F.value), size(F)) + Base.show(io::IO, F::Fill) = print(io, "Fill($(F.value), $(F.size))") Base.show(io::IO, ::MIME"text/plain", F::Fill) = show(io, F)