From 9153b77acfa9755d71e036368385f5c557094ef3 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Mon, 22 Apr 2024 21:10:49 +0800 Subject: [PATCH] Fix size-1 `StructuredMatrix`'s broadcast. 1. size-1 StructuredMatrix should behave like scalar during broadcast. Thus their `fzero` should return the only element. (fix #54087) 2. But for simple broadcast with only one StructuredMatrix, we can keep stability as the structure is "preserved" even for size-1 case. Thus `count_structedmatrix` is added to check that. 3. `count_structedmatrix` is fused to keep `Bidiagonal` stability. (replace #54067) --- .../LinearAlgebra/src/structuredbroadcast.jl | 38 +++++++---- .../LinearAlgebra/test/structuredbroadcast.jl | 63 +++++++++++-------- 2 files changed, 63 insertions(+), 38 deletions(-) diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index f2c35c8edcce4..aeb40ae34432c 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -78,7 +78,7 @@ find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nest function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType} uplo = n > 0 ? find_uplo(bc) : 'U' n1 = max(n - 1, 0) - if uplo == 'T' + if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T' return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1)) end return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo) @@ -135,24 +135,36 @@ iszerodefined(::Type{<:Number}) = true iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T) iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T) -fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)) +count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0) + +function fzeropreserving(bc) + n = count_structedmatrix(StructuredMatrix, bc) + v = fzero(bc, Val(n==1)) + !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0) +end # Like sparse matrices, we assume that the zero-preservation property of a broadcasted # expression is stable. We can test the zero-preservability by applying the function # in cases where all other arguments are known scalars against a zero from the structured # matrix. If any non-structured matrix argument is not a known scalar, we give up. -fzero(x::Number) = x -fzero(::Type{T}) where T = T -fzero(r::Ref) = r[] -fzero(t::Tuple{Any}) = t[1] -fzero(S::StructuredMatrix) = zero(eltype(S)) -fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing -fzero(x) = missing -function fzero(bc::Broadcast.Broadcasted) - args = map(fzero, bc.args) - return any(ismissing, args) ? missing : bc.f(args...) +fzero(x::Number, ::Val) = x +fzero(::Type{T}, ::Val) where T = T +fzero(r::Union{Ref,AbstractArray{<:Any,0}}, ::Val) = r[] +fzero(t::Tuple{Any}, ::Val) = t[1] +# The check below is tricky as size-1 `StructuredMatrix`s behave like scalar during broadcast. +# So we have to check their size if there are more than 1 broadcasted arguments which <: StructuredMatrix. +fzero(S::StructuredMatrix, ::Val{O}) where {O} = !O && isone(size(S, 1)) ? S[1, 1] : zero(eltype(S)) +fzero(S::StructuredMatrix{<:AbstractMatrix{T}}, ::Val{O}) where {T<:Number,O} = !O && isone(size(S, 1)) ? S[1, 1] : haszero(T) ? zero(T)*I : missing +fzero(x, ::Val) = missing +@inline function fzero(bc::Broadcast.Broadcasted, v::Val) + args = map(Base.Fix2(fzero, v), bc.args) + return anymissing(args) ? missing : bc.f(args...) end +# force unroll to keep stability +anymissing(x::Tuple{Any,Vararg}) = anymissing(Base.tail(x)) +anymissing(::Tuple{Missing,Vararg}) = true +anymissing(::Tuple{}) = false -function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType} +Base.@constprop :aggressive function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType} inds = axes(bc) fzerobc = fzeropreserving(bc) if isstructurepreserving(bc) || (fzerobc && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular})) diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 3767fc10055f2..3dfced8b2dd3e 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -17,10 +17,21 @@ using Test, LinearAlgebra M = Matrix(rand(N,N)) structuredarrays = (D, B, T, U, L, M) fstructuredarrays = map(Array, structuredarrays) + + # help functions used to ensure simple structured broadcast is stable + mul2(X) = (g(x) = X .* 2.0; @inferred(g(X))) + mult2(X) = (g(X) = X .* (2.0,); @inferred(g(X))) + mulinf(X) = (g(x) = x .* Inf; @inferred(g(X))) + lpow2(X) = (g(X) = X .^ 2; @inferred(g(X))) + lpow0(X) = (g(X) = X .^ 0; @inferred(g(X))) + pow2(X) = (g(x) = (two = 2; x.^two); @inferred(g(X))) + lpow_1(X) = (g(X) = X .^ -1; @inferred(g(X))) + powt2(X) = (g(X) = X .^ (2,); @inferred(g(X))) + for (X, fX) in zip(structuredarrays, fstructuredarrays) - @test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX)) + @test (Q = @inferred(broadcast(sin, X)); typeof(Q) == typeof(X) && Q == broadcast(sin, fX)) @test broadcast!(sin, Z, X) == broadcast(sin, fX) - @test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX)) + @test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX)) @test broadcast!(cos, Z, X) == broadcast(cos, fX) @test (Q = broadcast(*, s, X); typeof(Q) == typeof(X) && Q == broadcast(*, s, fX)) @test broadcast!(*, Z, s, X) == broadcast(*, s, fX) @@ -29,18 +40,12 @@ using Test, LinearAlgebra @test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX)) @test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX) - @test X .* 2.0 == X .* (2.0,) == fX .* 2.0 - @test X .* 2.0 isa typeof(X) - @test X .* (2.0,) isa typeof(X) - @test isequal(X .* Inf, fX .* Inf) + @test mul2(X)::typeof(X) == mult2(X)::typeof(X) == mult2(fX) + @test isequal(mulinf(X), mulinf(fX)) - two = 2 - @test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two - @test X .^ 2 isa typeof(X) - @test X .^ (2,) isa typeof(X) - @test X .^ two isa typeof(X) - @test X .^ 0 == fX .^ 0 - @test X .^ -1 == fX .^ -1 + @test lpow2(X)::typeof(X) == powt2(X)::typeof(X) == pow2(X)::typeof(X) == lpow2(fX) + @test lpow0(X) == lpow0(fX) + @test lpow_1(X) == lpow_1(fX) for (Y, fY) in zip(structuredarrays, fstructuredarrays) @test broadcast(+, X, Y) == broadcast(+, fX, fY) @@ -65,9 +70,9 @@ using Test, LinearAlgebra Ttris = typeof.((UpperTriangular(parent(UU)), LowerTriangular(parent(UU)))) funittriangulars = map(Array, unittriangulars) for (X, fX, Ttri) in zip(unittriangulars, funittriangulars, Ttris) - @test (Q = broadcast(sin, X); typeof(Q) == Ttri && Q == broadcast(sin, fX)) + @test (Q = @inferred(broadcast(sin, X)); typeof(Q) == Ttri && Q == broadcast(sin, fX)) @test broadcast!(sin, Z, X) == broadcast(sin, fX) - @test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX)) + @test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX)) @test broadcast!(cos, Z, X) == broadcast(cos, fX) @test (Q = broadcast(*, s, X); typeof(Q) == Ttri && Q == broadcast(*, s, fX)) @test broadcast!(*, Z, s, X) == broadcast(*, s, fX) @@ -76,18 +81,14 @@ using Test, LinearAlgebra @test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX)) @test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX) - @test X .* 2.0 == X .* (2.0,) == fX .* 2.0 - @test X .* 2.0 isa Ttri - @test X .* (2.0,) isa Ttri - @test isequal(X .* Inf, fX .* Inf) + @test mul2(X)::Ttri == mult2(X)::Ttri == mul2(fX) + @test isequal(mulinf(X), mulinf(fX)) two = 2 - @test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two - @test X .^ 2 isa typeof(X) # special cased, as isstructurepreserving - @test X .^ (2,) isa Ttri - @test X .^ two isa Ttri - @test X .^ 0 == fX .^ 0 - @test X .^ -1 == fX .^ -1 + @test lpow2(X)::typeof(X) == # special cased, as isstructurepreserving + powt2(X)::Ttri == pow2(X)::Ttri == lpow2(fX) + @test lpow0(X) == lpow0(fX) + @test lpow_1(X) == lpow_1(fX) for (Y, fY) in zip(unittriangulars, funittriangulars) @test broadcast(+, X, Y) == broadcast(+, fX, fY) @@ -338,4 +339,16 @@ end end end +@testset "Issue 54087: size-1 structured matrix's broadcast" begin + Ns = 1, 3 + D1, D2 = map(N->Diagonal(rand(N)), Ns) + B1, B2 = map(N->Bidiagonal(rand(N), rand(N - 1), :U), Ns) + T1, T2 = map(N->Tridiagonal(rand(N - 1), rand(N), rand(N - 1)), Ns) + Ss = [D1, D2, B1, B2, T1, T2] + MSs = Matrix.(Ss) + for ((S1, M1), (S2, M2)) in Iterators.product(zip(Ss, MSs), zip(Ss, MSs)) + @test S1 .+ S2 == M1 .+ M2 + end +end + end