From 3a506bc9bef41d66819c8bac7eb13d89597ee089 Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Wed, 6 Dec 2017 14:36:57 -0800 Subject: [PATCH] Replace A[ct]_(mul|ldiv|rdiv)_B[ct][!] defs in base/sparse/sparsevector.jl with de-jazzed passthroughs. --- base/deprecated.jl | 35 ++++++++++++++++ base/sparse/sparsevector.jl | 81 ++++++++++++++++++++++--------------- 2 files changed, 83 insertions(+), 33 deletions(-) diff --git a/base/deprecated.jl b/base/deprecated.jl index ec7ff0a8035300..92ba401eb48041 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -2734,6 +2734,41 @@ end Ac_mul_Bc(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Adjoint(A), Adjoint(B)) end +# A[ct]_(mul|ldiv|rdiv)_B[ct][!] methods from base/sparse/sparsevector.jl, to deprecate +for isunittri in (true, false), islowertri in (true, false) + unitstr = isunittri ? "Unit" : "" + halfstr = islowertri ? "Lower" : "Upper" + tritype = :(Base.LinAlg.$(Symbol(unitstr, halfstr, "Triangular"))) + @eval Base.SparseArrays begin + using Base.LinAlg: Adjoint, Transpose + At_ldiv_B(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Transpose(A), b) + At_ldiv_B(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Transpose(A), b) + At_ldiv_B(A::$tritype, b::SparseVector) = \(Transpose(A), b) + Ac_ldiv_B(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Adjoint(A), b) + Ac_ldiv_B(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Adjoint(A), b) + Ac_ldiv_B(A::$tritype, b::SparseVector) = \(Adjoint(A), b) + A_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(A, b) + At_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(Transpose(A), b) + Ac_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(Adjoint(A), b) + end +end +@eval Base.SparseArrays begin + using Base.LinAlg: Adjoint, Transpose + Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = *(Adjoint(A), x) + At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = *(Transpose(A), x) + Ac_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Adjoint(A), x, β, y) + Ac_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Adjoint(A), x) + At_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Transpose(A), x, β, y) + At_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Transpose(A), x) + A_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, A, x, β, y) + A_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, A, x) + At_mul_B!(α::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Transpose(A), x, β, y) + At_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Transpose(A), x) + At_mul_B(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx} = *(Transpose(A), x) + A_mul_B!(α::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, A, x, β, y) + A_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, A, x) +end + # issue #24822 @deprecate_binding Display AbstractDisplay diff --git a/base/sparse/sparsevector.jl b/base/sparse/sparsevector.jl index 76e09d34ce39a4..74b8ed30c0bf67 100644 --- a/base/sparse/sparsevector.jl +++ b/base/sparse/sparsevector.jl @@ -2,6 +2,7 @@ ### Common definitions +using Base.LinAlg: Adjoint, Transpose import Base: scalarmax, scalarmin, sort, find, findnz import Base.LinAlg: promote_to_array_type, promote_to_arrays_ @@ -1571,10 +1572,10 @@ function (*)(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx} A_mul_B!(y, A, x) end -A_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = +mul!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = A_mul_B!(one(Tx), A, x, zero(Ty), y) -function A_mul_B!(α::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) +function mul!(α::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) m, n = size(A) length(x) == n && length(y) == m || throw(DimensionMismatch()) m == 0 && return y @@ -1600,7 +1601,8 @@ end # At_mul_B -function At_mul_B(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx} +function *(transA::Transpose{<:Any,<:StridedMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx} + A = transA.parent m, n = size(A) length(x) == m || throw(DimensionMismatch()) Ty = promote_type(Ta, Tx) @@ -1608,10 +1610,11 @@ function At_mul_B(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,T At_mul_B!(y, A, x) end -At_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = - At_mul_B!(one(Tx), A, x, zero(Ty), y) +mul!(y::StridedVector{Ty}, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector{Tx}) where {Tx,Ty} = + (A = transA.parent; At_mul_B!(one(Tx), A, x, zero(Ty), y)) -function At_mul_B!(α::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) +function mul!(α::Number, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector, β::Number, y::StridedVector) + A = transA.parent m, n = size(A) length(x) == m && length(y) == n || throw(DimensionMismatch()) n == 0 && return y @@ -1666,10 +1669,10 @@ end # A_mul_B -A_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = +mul!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = A_mul_B!(one(Tx), A, x, zero(Ty), y) -function A_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) +function mul!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) m, n = size(A) length(x) == n && length(y) == m || throw(DimensionMismatch()) m == 0 && return y @@ -1699,17 +1702,17 @@ end # At_mul_B -At_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = - At_mul_B!(one(Tx), A, x, zero(Ty), y) +mul!(y::StridedVector{Ty}, transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector{Tx}) where {Tx,Ty} = + (A = transA.parent; At_mul_B!(one(Tx), A, x, zero(Ty), y)) -At_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = - _At_or_Ac_mul_B!(*, α, A, x, β, y) +mul!(α::Number, transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector, β::Number, y::StridedVector) = + (A = transA.parent; _At_or_Ac_mul_B!(*, α, A, x, β, y)) -Ac_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = - Ac_mul_B!(one(Tx), A, x, zero(Ty), y) +mul!(y::StridedVector{Ty}, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector{Tx}) where {Tx,Ty} = + (A = adjA.parent; Ac_mul_B!(one(Tx), A, x, zero(Ty), y)) -Ac_mul_B!(α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = - _At_or_Ac_mul_B!(dot, α, A, x, β, y) +mul!(α::Number, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector, β::Number, y::StridedVector) = + (A = adjA.parent; _At_or_Ac_mul_B!(dot, α, A, x, β, y)) function _At_or_Ac_mul_B!(tfun::Function, α::Number, A::SparseMatrixCSC, x::AbstractSparseVector, @@ -1747,11 +1750,11 @@ function *(A::SparseMatrixCSC, x::AbstractSparseVector) _dense2sparsevec(y, initcap) end -At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = - _At_or_Ac_mul_B(*, A, x) +*(transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector) = + (A = transA.parent; _At_or_Ac_mul_B(*, A, x)) -Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = - _At_or_Ac_mul_B(dot, A, x) +*(adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector) = + (A = adjA.parent; _At_or_Ac_mul_B(dot, A, x)) function _At_or_Ac_mul_B(tfun::Function, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX}) where {TvA,TiA,TvX,TiX} m, n = size(A) @@ -1797,13 +1800,16 @@ for isunittri in (true, false), islowertri in (true, false) tritype = :(Base.LinAlg.$(Symbol(unitstr, halfstr, "Triangular"))) # build out-of-place left-division operations - for (istrans, func, ipfunc) in ( - (false, :(\), :(A_ldiv_B!)), - (true, :(At_ldiv_B), :(At_ldiv_B!)), - (true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) ) + for (istrans, func, ipfunc, applyxform, xform) in ( + (false, :(\), :(A_ldiv_B!), false, :None), + (true, :(At_ldiv_B), :(At_ldiv_B!), true, :Transpose), + (true, :(Ac_ldiv_B), :(Ac_ldiv_B!), true, :Adjoint) ) # broad method where elements are Numbers - @eval function ($func)(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} + xformtritype = applyxform ? :($xform{<:TA,<:$tritype{<:Any,<:AbstractMatrix}}) : + :($tritype{<:TA,<:AbstractMatrix}) + @eval function \(xformA::$xformtritype, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} + A = $(applyxform ? :(xformA.parent) : :(xformA) ) TAb = $(isunittri ? :(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) : :(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) ) @@ -1812,7 +1818,10 @@ for isunittri in (true, false), islowertri in (true, false) # faster method requiring good view support of the # triangular matrix type. hence the StridedMatrix restriction. - @eval function ($func)(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} + xformtritype = applyxform ? :($xform{<:TA,<:$tritype{<:Any,<:StridedMatrix}}) : + :($tritype{<:TA,<:StridedMatrix}) + @eval function \(xformA::$xformtritype, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} + A = $(applyxform ? :(xformA.parent) : :(xformA) ) TAb = $(isunittri ? :(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) : :(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) ) @@ -1832,19 +1841,26 @@ for isunittri in (true, false), islowertri in (true, false) end # fallback where elements are not Numbers - @eval ($func)(A::$tritype, b::SparseVector) = ($ipfunc)(A, copy(b)) + xformtritype = applyxform ? :($xform{<:Any,<:$tritype}) : :($tritype) + @eval function \(xformA::$xformtritype, b::SparseVector) + A = $(applyxform ? :(xformA.parent) : :(xformA) ) + ($ipfunc)(A, copy(b)) + end end # build in-place left-division operations - for (istrans, func) in ( - (false, :(A_ldiv_B!)), - (true, :(At_ldiv_B!)), - (true, :(Ac_ldiv_B!)) ) + for (istrans, func, applyxform, xform) in ( + (false, :(A_ldiv_B!), false, :None), + (true, :(At_ldiv_B!), true, :Transpose), + (true, :(Ac_ldiv_B!), true, :Adjoint) ) + xformtritype = applyxform ? :($xform{<:Any,<:$tritype{<:Any,<:StridedMatrix}}) : + :($tritype{<:Any,<:StridedMatrix}) # the generic in-place left-division methods handle these cases, but # we can achieve greater efficiency where the triangular matrix provides # good view support. hence the StridedMatrix restriction. - @eval function ($func)(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) + @eval function ldiv!(xformA::$xformtritype, b::SparseVector) + A = $(applyxform ? :(xformA.parent) : :(xformA) ) # If b has no nonzero entries, the result is necessarily zero and this call # reduces to a no-op. If b has nonzero entries, then... if nnz(b) != 0 @@ -1863,7 +1879,6 @@ for isunittri in (true, false), islowertri in (true, false) nzrangeviewbnz = view(b.nzval, nzrange .- (b.nzind[1] - 1)) nzrangeviewA = $tritype(view(A.data, nzrange, nzrange)) ($func)(nzrangeviewA, nzrangeviewbnz) - # could strip any miraculous zeros here perhaps end b end