Skip to content

Commit

Permalink
Fix size-1 StructuredMatrix's broadcast.
Browse files Browse the repository at this point in the history
1. size-1 StructuredMatrix should behave like scalar during broadcast. Thus their `fzero` should return the only element.
(fix #54087)

2. But for simple broadcast with only one StructuredMatrix, we can keep stability as the structure is "preserved" even for size-1 case. Thus `count_structedmatrix` is added to check that.

3. `count_structedmatrix` is fused to keep `Bidiagonal` stability.
(replace JuliaLang#54067)
  • Loading branch information
N5N3 committed Apr 22, 2024
1 parent b5bfd83 commit 9153b77
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 38 deletions.
38 changes: 25 additions & 13 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nest
function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
uplo = n > 0 ? find_uplo(bc) : 'U'
n1 = max(n - 1, 0)
if uplo == 'T'
if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
Expand Down Expand Up @@ -135,24 +135,36 @@ iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)

fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)

function fzeropreserving(bc)
n = count_structedmatrix(StructuredMatrix, bc)
v = fzero(bc, Val(n==1))
!ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)
end
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
# expression is stable. We can test the zero-preservability by applying the function
# in cases where all other arguments are known scalars against a zero from the structured
# matrix. If any non-structured matrix argument is not a known scalar, we give up.
fzero(x::Number) = x
fzero(::Type{T}) where T = T
fzero(r::Ref) = r[]
fzero(t::Tuple{Any}) = t[1]
fzero(S::StructuredMatrix) = zero(eltype(S))
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing
fzero(x) = missing
function fzero(bc::Broadcast.Broadcasted)
args = map(fzero, bc.args)
return any(ismissing, args) ? missing : bc.f(args...)
fzero(x::Number, ::Val) = x
fzero(::Type{T}, ::Val) where T = T
fzero(r::Union{Ref,AbstractArray{<:Any,0}}, ::Val) = r[]
fzero(t::Tuple{Any}, ::Val) = t[1]
# The check below is tricky as size-1 `StructuredMatrix`s behave like scalar during broadcast.
# So we have to check their size if there are more than 1 broadcasted arguments which <: StructuredMatrix.
fzero(S::StructuredMatrix, ::Val{O}) where {O} = !O && isone(size(S, 1)) ? S[1, 1] : zero(eltype(S))
fzero(S::StructuredMatrix{<:AbstractMatrix{T}}, ::Val{O}) where {T<:Number,O} = !O && isone(size(S, 1)) ? S[1, 1] : haszero(T) ? zero(T)*I : missing
fzero(x, ::Val) = missing
@inline function fzero(bc::Broadcast.Broadcasted, v::Val)
args = map(Base.Fix2(fzero, v), bc.args)
return anymissing(args) ? missing : bc.f(args...)
end
# force unroll to keep stability
anymissing(x::Tuple{Any,Vararg}) = anymissing(Base.tail(x))
anymissing(::Tuple{Missing,Vararg}) = true
anymissing(::Tuple{}) = false

function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
Base.@constprop :aggressive function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
inds = axes(bc)
fzerobc = fzeropreserving(bc)
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
Expand Down
63 changes: 38 additions & 25 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ using Test, LinearAlgebra
M = Matrix(rand(N,N))
structuredarrays = (D, B, T, U, L, M)
fstructuredarrays = map(Array, structuredarrays)

# help functions used to ensure simple structured broadcast is stable
mul2(X) = (g(x) = X .* 2.0; @inferred(g(X)))
mult2(X) = (g(X) = X .* (2.0,); @inferred(g(X)))
mulinf(X) = (g(x) = x .* Inf; @inferred(g(X)))
lpow2(X) = (g(X) = X .^ 2; @inferred(g(X)))
lpow0(X) = (g(X) = X .^ 0; @inferred(g(X)))
pow2(X) = (g(x) = (two = 2; x.^two); @inferred(g(X)))
lpow_1(X) = (g(X) = X .^ -1; @inferred(g(X)))
powt2(X) = (g(X) = X .^ (2,); @inferred(g(X)))

for (X, fX) in zip(structuredarrays, fstructuredarrays)
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@test (Q = @inferred(broadcast(sin, X)); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@test broadcast!(sin, Z, X) == broadcast(sin, fX)
@test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX))
@test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX))
@test broadcast!(cos, Z, X) == broadcast(cos, fX)
@test (Q = broadcast(*, s, X); typeof(Q) == typeof(X) && Q == broadcast(*, s, fX))
@test broadcast!(*, Z, s, X) == broadcast(*, s, fX)
Expand All @@ -29,18 +40,12 @@ using Test, LinearAlgebra
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)

@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
@test X .* 2.0 isa typeof(X)
@test X .* (2.0,) isa typeof(X)
@test isequal(X .* Inf, fX .* Inf)
@test mul2(X)::typeof(X) == mult2(X)::typeof(X) == mult2(fX)
@test isequal(mulinf(X), mulinf(fX))

two = 2
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
@test X .^ 2 isa typeof(X)
@test X .^ (2,) isa typeof(X)
@test X .^ two isa typeof(X)
@test X .^ 0 == fX .^ 0
@test X .^ -1 == fX .^ -1
@test lpow2(X)::typeof(X) == powt2(X)::typeof(X) == pow2(X)::typeof(X) == lpow2(fX)
@test lpow0(X) == lpow0(fX)
@test lpow_1(X) == lpow_1(fX)

for (Y, fY) in zip(structuredarrays, fstructuredarrays)
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
Expand All @@ -65,9 +70,9 @@ using Test, LinearAlgebra
Ttris = typeof.((UpperTriangular(parent(UU)), LowerTriangular(parent(UU))))
funittriangulars = map(Array, unittriangulars)
for (X, fX, Ttri) in zip(unittriangulars, funittriangulars, Ttris)
@test (Q = broadcast(sin, X); typeof(Q) == Ttri && Q == broadcast(sin, fX))
@test (Q = @inferred(broadcast(sin, X)); typeof(Q) == Ttri && Q == broadcast(sin, fX))
@test broadcast!(sin, Z, X) == broadcast(sin, fX)
@test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX))
@test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX))
@test broadcast!(cos, Z, X) == broadcast(cos, fX)
@test (Q = broadcast(*, s, X); typeof(Q) == Ttri && Q == broadcast(*, s, fX))
@test broadcast!(*, Z, s, X) == broadcast(*, s, fX)
Expand All @@ -76,18 +81,14 @@ using Test, LinearAlgebra
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)

@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
@test X .* 2.0 isa Ttri
@test X .* (2.0,) isa Ttri
@test isequal(X .* Inf, fX .* Inf)
@test mul2(X)::Ttri == mult2(X)::Ttri == mul2(fX)
@test isequal(mulinf(X), mulinf(fX))

two = 2
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
@test X .^ 2 isa typeof(X) # special cased, as isstructurepreserving
@test X .^ (2,) isa Ttri
@test X .^ two isa Ttri
@test X .^ 0 == fX .^ 0
@test X .^ -1 == fX .^ -1
@test lpow2(X)::typeof(X) == # special cased, as isstructurepreserving
powt2(X)::Ttri == pow2(X)::Ttri == lpow2(fX)
@test lpow0(X) == lpow0(fX)
@test lpow_1(X) == lpow_1(fX)

for (Y, fY) in zip(unittriangulars, funittriangulars)
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
Expand Down Expand Up @@ -338,4 +339,16 @@ end
end
end

@testset "Issue 54087: size-1 structured matrix's broadcast" begin
Ns = 1, 3
D1, D2 = map(N->Diagonal(rand(N)), Ns)
B1, B2 = map(N->Bidiagonal(rand(N), rand(N - 1), :U), Ns)
T1, T2 = map(N->Tridiagonal(rand(N - 1), rand(N), rand(N - 1)), Ns)
Ss = [D1, D2, B1, B2, T1, T2]
MSs = Matrix.(Ss)
for ((S1, M1), (S2, M2)) in Iterators.product(zip(Ss, MSs), zip(Ss, MSs))
@test S1 .+ S2 == M1 .+ M2
end
end

end

0 comments on commit 9153b77

Please sign in to comment.