Skip to content

Commit

Permalink
Replace A[ct]_(mul|ldiv|rdiv)_B[ct][!] defs in base/sparse/sparsevect…
Browse files Browse the repository at this point in the history
…or.jl with de-jazzed passthroughs.
  • Loading branch information
Sacha0 committed Dec 7, 2017
1 parent dc06d0d commit 3a506bc
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 33 deletions.
35 changes: 35 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,41 @@ end
Ac_mul_Bc(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Adjoint(A), Adjoint(B))
end

# A[ct]_(mul|ldiv|rdiv)_B[ct][!] methods from base/sparse/sparsevector.jl, to deprecate
for isunittri in (true, false), islowertri in (true, false)
unitstr = isunittri ? "Unit" : ""
halfstr = islowertri ? "Lower" : "Upper"
tritype = :(Base.LinAlg.$(Symbol(unitstr, halfstr, "Triangular")))
@eval Base.SparseArrays begin
using Base.LinAlg: Adjoint, Transpose
At_ldiv_B(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Transpose(A), b)
At_ldiv_B(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Transpose(A), b)
At_ldiv_B(A::$tritype, b::SparseVector) = \(Transpose(A), b)
Ac_ldiv_B(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Adjoint(A), b)
Ac_ldiv_B(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number} = \(Adjoint(A), b)
Ac_ldiv_B(A::$tritype, b::SparseVector) = \(Adjoint(A), b)
A_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(A, b)
At_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(Transpose(A), b)
Ac_ldiv_B!(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector) = ldiv!(Adjoint(A), b)
end
end
@eval Base.SparseArrays begin
using Base.LinAlg: Adjoint, Transpose
Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = *(Adjoint(A), x)
At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) = *(Transpose(A), x)
Ac_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Adjoint(A), x, β, y)
Ac_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Adjoint(A), x)
At_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Transpose(A), x, β, y)
At_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Transpose(A), x)
A_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, A, x, β, y)
A_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, A, x)
At_mul_B!::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, Transpose(A), x, β, y)
At_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, Transpose(A), x)
At_mul_B(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx} = *(Transpose(A), x)
A_mul_B!::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector) = mul!(α, A, x, β, y)
A_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} = mul!(y, A, x)
end

# issue #24822
@deprecate_binding Display AbstractDisplay

Expand Down
81 changes: 48 additions & 33 deletions base/sparse/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

### Common definitions

using Base.LinAlg: Adjoint, Transpose
import Base: scalarmax, scalarmin, sort, find, findnz
import Base.LinAlg: promote_to_array_type, promote_to_arrays_

Expand Down Expand Up @@ -1571,10 +1572,10 @@ function (*)(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
A_mul_B!(y, A, x)
end

A_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
mul!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
A_mul_B!(one(Tx), A, x, zero(Ty), y)

function A_mul_B!::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector)
function mul!::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector)
m, n = size(A)
length(x) == n && length(y) == m || throw(DimensionMismatch())
m == 0 && return y
Expand All @@ -1600,18 +1601,20 @@ end

# At_mul_B

function At_mul_B(A::StridedMatrix{Ta}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
function *(transA::Transpose{<:Any,<:StridedMatrix{Ta}}, x::AbstractSparseVector{Tx}) where {Ta,Tx}
A = transA.parent
m, n = size(A)
length(x) == m || throw(DimensionMismatch())
Ty = promote_type(Ta, Tx)
y = Vector{Ty}(uninitialized, n)
At_mul_B!(y, A, x)
end

At_mul_B!(y::StridedVector{Ty}, A::StridedMatrix, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
At_mul_B!(one(Tx), A, x, zero(Ty), y)
mul!(y::StridedVector{Ty}, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
(A = transA.parent; At_mul_B!(one(Tx), A, x, zero(Ty), y))

function At_mul_B!::Number, A::StridedMatrix, x::AbstractSparseVector, β::Number, y::StridedVector)
function mul!::Number, transA::Transpose{<:Any,<:StridedMatrix}, x::AbstractSparseVector, β::Number, y::StridedVector)
A = transA.parent
m, n = size(A)
length(x) == m && length(y) == n || throw(DimensionMismatch())
n == 0 && return y
Expand Down Expand Up @@ -1666,10 +1669,10 @@ end

# A_mul_B

A_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
mul!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
A_mul_B!(one(Tx), A, x, zero(Ty), y)

function A_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector)
function mul!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector)
m, n = size(A)
length(x) == n && length(y) == m || throw(DimensionMismatch())
m == 0 && return y
Expand Down Expand Up @@ -1699,17 +1702,17 @@ end

# At_mul_B

At_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
At_mul_B!(one(Tx), A, x, zero(Ty), y)
mul!(y::StridedVector{Ty}, transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
(A = transA.parent; At_mul_B!(one(Tx), A, x, zero(Ty), y))

At_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) =
_At_or_Ac_mul_B!(*, α, A, x, β, y)
mul!::Number, transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector, β::Number, y::StridedVector) =
(A = transA.parent; _At_or_Ac_mul_B!(*, α, A, x, β, y))

Ac_mul_B!(y::StridedVector{Ty}, A::SparseMatrixCSC, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
Ac_mul_B!(one(Tx), A, x, zero(Ty), y)
mul!(y::StridedVector{Ty}, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector{Tx}) where {Tx,Ty} =
(A = adjA.parent; Ac_mul_B!(one(Tx), A, x, zero(Ty), y))

Ac_mul_B!::Number, A::SparseMatrixCSC, x::AbstractSparseVector, β::Number, y::StridedVector) =
_At_or_Ac_mul_B!(dot, α, A, x, β, y)
mul!::Number, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector, β::Number, y::StridedVector) =
(A = adjA.parent; _At_or_Ac_mul_B!(dot, α, A, x, β, y))

function _At_or_Ac_mul_B!(tfun::Function,
α::Number, A::SparseMatrixCSC, x::AbstractSparseVector,
Expand Down Expand Up @@ -1747,11 +1750,11 @@ function *(A::SparseMatrixCSC, x::AbstractSparseVector)
_dense2sparsevec(y, initcap)
end

At_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(*, A, x)
*(transA::Transpose{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector) =
(A = transA.parent; _At_or_Ac_mul_B(*, A, x))

Ac_mul_B(A::SparseMatrixCSC, x::AbstractSparseVector) =
_At_or_Ac_mul_B(dot, A, x)
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC}, x::AbstractSparseVector) =
(A = adjA.parent; _At_or_Ac_mul_B(dot, A, x))

function _At_or_Ac_mul_B(tfun::Function, A::SparseMatrixCSC{TvA,TiA}, x::AbstractSparseVector{TvX,TiX}) where {TvA,TiA,TvX,TiX}
m, n = size(A)
Expand Down Expand Up @@ -1797,13 +1800,16 @@ for isunittri in (true, false), islowertri in (true, false)
tritype = :(Base.LinAlg.$(Symbol(unitstr, halfstr, "Triangular")))

# build out-of-place left-division operations
for (istrans, func, ipfunc) in (
(false, :(\), :(A_ldiv_B!)),
(true, :(At_ldiv_B), :(At_ldiv_B!)),
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!)) )
for (istrans, func, ipfunc, applyxform, xform) in (
(false, :(\), :(A_ldiv_B!), false, :None),
(true, :(At_ldiv_B), :(At_ldiv_B!), true, :Transpose),
(true, :(Ac_ldiv_B), :(Ac_ldiv_B!), true, :Adjoint) )

# broad method where elements are Numbers
@eval function ($func)(A::$tritype{TA,<:AbstractMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number}
xformtritype = applyxform ? :($xform{<:TA,<:$tritype{<:Any,<:AbstractMatrix}}) :
:($tritype{<:TA,<:AbstractMatrix})
@eval function \(xformA::$xformtritype, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number}
A = $(applyxform ? :(xformA.parent) : :(xformA) )
TAb = $(isunittri ?
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
Expand All @@ -1812,7 +1818,10 @@ for isunittri in (true, false), islowertri in (true, false)

# faster method requiring good view support of the
# triangular matrix type. hence the StridedMatrix restriction.
@eval function ($func)(A::$tritype{TA,<:StridedMatrix}, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number}
xformtritype = applyxform ? :($xform{<:TA,<:$tritype{<:Any,<:StridedMatrix}}) :
:($tritype{<:TA,<:StridedMatrix})
@eval function \(xformA::$xformtritype, b::SparseVector{Tb}) where {TA<:Number,Tb<:Number}
A = $(applyxform ? :(xformA.parent) : :(xformA) )
TAb = $(isunittri ?
:(typeof(zero(TA)*zero(Tb) + zero(TA)*zero(Tb))) :
:(typeof((zero(TA)*zero(Tb) + zero(TA)*zero(Tb))/one(TA))) )
Expand All @@ -1832,19 +1841,26 @@ for isunittri in (true, false), islowertri in (true, false)
end

# fallback where elements are not Numbers
@eval ($func)(A::$tritype, b::SparseVector) = ($ipfunc)(A, copy(b))
xformtritype = applyxform ? :($xform{<:Any,<:$tritype}) : :($tritype)
@eval function \(xformA::$xformtritype, b::SparseVector)
A = $(applyxform ? :(xformA.parent) : :(xformA) )
($ipfunc)(A, copy(b))
end
end

# build in-place left-division operations
for (istrans, func) in (
(false, :(A_ldiv_B!)),
(true, :(At_ldiv_B!)),
(true, :(Ac_ldiv_B!)) )
for (istrans, func, applyxform, xform) in (
(false, :(A_ldiv_B!), false, :None),
(true, :(At_ldiv_B!), true, :Transpose),
(true, :(Ac_ldiv_B!), true, :Adjoint) )
xformtritype = applyxform ? :($xform{<:Any,<:$tritype{<:Any,<:StridedMatrix}}) :
:($tritype{<:Any,<:StridedMatrix})

# the generic in-place left-division methods handle these cases, but
# we can achieve greater efficiency where the triangular matrix provides
# good view support. hence the StridedMatrix restriction.
@eval function ($func)(A::$tritype{<:Any,<:StridedMatrix}, b::SparseVector)
@eval function ldiv!(xformA::$xformtritype, b::SparseVector)
A = $(applyxform ? :(xformA.parent) : :(xformA) )
# If b has no nonzero entries, the result is necessarily zero and this call
# reduces to a no-op. If b has nonzero entries, then...
if nnz(b) != 0
Expand All @@ -1863,7 +1879,6 @@ for isunittri in (true, false), islowertri in (true, false)
nzrangeviewbnz = view(b.nzval, nzrange .- (b.nzind[1] - 1))
nzrangeviewA = $tritype(view(A.data, nzrange, nzrange))
($func)(nzrangeviewA, nzrangeviewbnz)
# could strip any miraculous zeros here perhaps
end
b
end
Expand Down

0 comments on commit 3a506bc

Please sign in to comment.