Skip to content

Commit

Permalink
Replace A[ct]_(mul|ldiv|rdiv)_B[ct][!] defs in base/sparse/linalg.jl …
Browse files Browse the repository at this point in the history
…with de-jazzed passthroughs.
  • Loading branch information
Sacha0 committed Dec 7, 2017
1 parent e3d1d2f commit dc06d0d
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 83 deletions.
38 changes: 38 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2696,6 +2696,44 @@ end
At_ldiv_B!
end

# A[ct]_(mul|ldiv|rdiv)_B[ct][!] methods from base/sparse/linalg.jl, to deprecate
@eval Base.SparseArrays begin
using Base.LinAlg: Adjoint, Transpose
Ac_ldiv_B(A::SparseMatrixCSC, B::RowVector) = \(Adjoint(A), B)
At_ldiv_B(A::SparseMatrixCSC, B::RowVector) = \(Transpose(A), B)
Ac_ldiv_B(A::SparseMatrixCSC, B::AbstractVecOrMat) = \(Adjoint(A), B)
At_ldiv_B(A::SparseMatrixCSC, B::AbstractVecOrMat) = \(Transpose(A), B)
A_rdiv_Bc!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where {T} = rdiv!(A, Adjoint(D))
A_rdiv_Bt!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where {T} = rdiv!(A, Transpose(D))
A_rdiv_B!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where {T} = rdiv!(A, D)
A_ldiv_B!(L::LowerTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = ldiv!(L, B)
A_ldiv_B!(U::UpperTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = ldiv!(U, B)
A_mul_Bt(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(A, Transpose(B))
A_mul_Bc(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(A, Adjoint(B))
At_mul_B(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(Transpose(A), B)
Ac_mul_B(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(Adjoint(A), B)
At_mul_Bt(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(Transpose(A), Transpose(B))
Ac_mul_Bc(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = *(Adjoint(A), Adjoint(B))
A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = mul!(C, A, B)
Ac_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = mul!(C, Adjoint(A), B)
At_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = mul!(C, Transpose(A), B)
A_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = mul!(α, A, B, β, C)
A_mul_B(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx} = *(A, x)
A_mul_B(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} = *(A, B)
Ac_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = mul!(α, Adjoint(A), B, β, C)
Ac_mul_B(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx} = *(Adjoint(A), x)
Ac_mul_B(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} = *(Adjoint(A), B)
At_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = mul!(α, Transpose(A), B, β, C)
At_mul_B(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx} = *(Transpose(A), x)
At_mul_B(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} = *(Transpose(A), B)
A_mul_Bt(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(A, Transpose(B))
A_mul_Bc(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(A, Adjoint(B))
At_mul_B(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Transpose(A), B)
Ac_mul_B(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Adjoint(A),B)
At_mul_Bt(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Transpose(A), Transpose(B))
Ac_mul_Bc(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} = *(Adjoint(A), Adjoint(B))
end

# issue #24822
@deprecate_binding Display AbstractDisplay

Expand Down
223 changes: 140 additions & 83 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using Base.LinAlg: Adjoint, Transpose
import Base.LinAlg: checksquare

## sparse matrix multiplication

function (*)(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB}
(*)(sppromote(A, B)...)
end
for f in (:A_mul_Bt, :A_mul_Bc,
:At_mul_B, :Ac_mul_B,
:At_mul_Bt, :Ac_mul_Bc)
@eval begin
function ($f)(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB}
($f)(sppromote(A, B)...)
end
end
end
*(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
*(sppromote(A, B)...)
*(A::SparseMatrixCSC{TvA,TiA}, transB::Transpose{<:Any,<:SparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(B = transB.parent; A_mul_Bt(sppromote(A, B)...))
*(A::SparseMatrixCSC{TvA,TiA}, adjB::Adjoint{<:Any,<:SparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(B = adjB.parent; A_mul_Bc(sppromote(A, B)...))
*(transA::Transpose{<:Any,<:SparseMatrixCSC{TvA,TiA}}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
(A = transA.parent; At_mul_B(sppromote(A, B)...))
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{TvA,TiA}}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB} =
(A = adjA.parent; Ac_mul_B(sppromote(A, B)...))
*(transA::Transpose{<:Any,<:SparseMatrixCSC{TvA,TiA}}, transB::Transpose{<:Any,<:SparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(A = transA.parent; B = transB.parent; At_mul_Bt(sppromote(A, B)...))
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{TvA,TiA}}, adjB::Adjoint{<:Any,<:SparseMatrixCSC{TvB,TiB}}) where {TvA,TiA,TvB,TiB} =
(A = adjA.parent; B = adjB.parent; Ac_mul_Bc(sppromote(A, B)...))

function sppromote(A::SparseMatrixCSC{TvA,TiA}, B::SparseMatrixCSC{TvB,TiB}) where {TvA,TiA,TvB,TiB}
Tv = promote_type(TvA, TvB)
Expand All @@ -27,60 +30,90 @@ end

# In matrix-vector multiplication, the correct orientation of the vector is assumed.

for (f, op, transp) in ((:A_mul_B, :identity, false),
(:Ac_mul_B, :adjoint, true),
(:At_mul_B, :transpose, true))
@eval begin
function $(Symbol(f,:!))(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
if $transp
A.n == size(C, 1) || throw(DimensionMismatch())
A.m == size(B, 1) || throw(DimensionMismatch())
else
A.n == size(B, 1) || throw(DimensionMismatch())
A.m == size(C, 1) || throw(DimensionMismatch())
end
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = A.nzval
rv = A.rowval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
for col = 1:A.n
if $transp
tmp = zero(eltype(C))
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
tmp += $(op)(nzv[j])*B[rv[j],k]
end
C[col,k] += α*tmp
else
αxj = α*B[col,k]
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
C[rv[j], k] += nzv[j]*αxj
end
end
end
function mul!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
A.n == size(B, 1) || throw(DimensionMismatch())
A.m == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = A.nzval
rv = A.rowval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
for col = 1:A.n
αxj = α*B[col,k]
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
C[rv[j], k] += nzv[j]*αxj
end
C
end

function $(f)(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx}
T = promote_type(TA, Tx)
$(Symbol(f,:!))(one(T), A, x, zero(T), similar(x, T, A.n))
end
C
end
*(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) where {TA,S,Tx} =
(T = promote_type(TA, Tx); A_mul_B!(one(T), A, x, zero(T), similar(x, T, A.m)))
*(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx} =
(T = promote_type(TA, Tx); A_mul_B!(one(T), A, B, zero(T), similar(B, T, (A.m, size(B, 2)))))

function mul!::Number, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
A = adjA.parent
A.n == size(C, 1) || throw(DimensionMismatch())
A.m == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = A.nzval
rv = A.rowval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
for col = 1:A.n
tmp = zero(eltype(C))
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
tmp += adjoint(nzv[j])*B[rv[j],k]
end
C[col,k] += α*tmp
end
function $(f)(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) where {TA,S,Tx}
T = promote_type(TA, Tx)
$(Symbol(f,:!))(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2))))
end
C
end
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} =
(A = adjA.parent; T = promote_type(TA, Tx); Ac_mul_B!(one(T), A, x, zero(T), similar(x, T, A.n)))
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{TA,S}}, B::StridedMatrix{Tx}) where {TA,S,Tx} =
(A = adjA.parent; T = promote_type(TA, Tx); Ac_mul_B!(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2)))))

function mul!::Number, transA::Transpose{<:Any,<:SparseMatrixCSC}, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
A = transA.parent
A.n == size(C, 1) || throw(DimensionMismatch())
A.m == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = A.nzval
rv = A.rowval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
for col = 1:A.n
tmp = zero(eltype(C))
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
tmp += transpose(nzv[j])*B[rv[j],k]
end
C[col,k] += α*tmp
end
end
C
end
*(transA::Transpose{<:Any,<:SparseMatrixCSC{TA,S}}, x::StridedVector{Tx}) where {TA,S,Tx} =
(A = transA.parent; T = promote_type(TA, Tx); At_mul_B!(one(T), A, x, zero(T), similar(x, T, A.n)))
*(transA::Transpose{<:Any,<:SparseMatrixCSC{TA,S}}, B::StridedMatrix{Tx}) where {TA,S,Tx} =
(A = transA.parent; T = promote_type(TA, Tx); At_mul_B!(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2)))))

# For compatibility with dense multiplication API. Should be deleted when dense multiplication
# API is updated to follow BLAS API.
A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)
Ac_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)
At_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)

mul!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) =
A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:SparseMatrixCSC}, B::StridedVecOrMat) =
(A = adjA.parent; Ac_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C))
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:SparseMatrixCSC}, B::StridedVecOrMat) =
(A = transA.parent; At_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C))

function (*)(X::StridedMatrix{TX}, A::SparseMatrixCSC{TvA,TiA}) where {TX,TvA,TiA}
mX, nX = size(X)
Expand All @@ -107,18 +140,18 @@ end
# http://dl.acm.org/citation.cfm?id=355796

(*)(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = spmatmul(A,B)
for (f, opA, opB) in ((:A_mul_Bt, :identity, :transpose),
(:A_mul_Bc, :identity, :adjoint),
(:At_mul_B, :transpose, :identity),
(:Ac_mul_B, :adjoint, :identity),
(:At_mul_Bt, :transpose, :transpose),
(:Ac_mul_Bc, :adjoint, :adjoint))
@eval begin
function ($f)(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
spmatmul(($opA)(A), ($opB)(B))
end
end
end
*(A::SparseMatrixCSC{Tv,Ti}, transB::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} =
(B = transB.parent; spmatmul(A, transpose(B)))
*(A::SparseMatrixCSC{Tv,Ti}, adjB::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} =
(B = adjB.parent; spmatmul(A, adjoint(B)))
*(transA::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
(A = transA.parent; spmatmul(tranpsose(A), B))
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}, B::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} =
(A = adjA.parent; spmatmul(adjoint(A), B))
*(transA::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}, transB::Transpose{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} =
(A = transA.parent; B = transB.parent; spmatmul(transpose(A), transpose(B)))
*(adjA::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}, adjB::Adjoint{<:Any,<:SparseMatrixCSC{Tv,Ti}}) where {Tv,Ti} =
(A = adjA.parent; B = adjB.parent; spmatmul(adjoint(A), adjoint(B)))

function spmatmul(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti};
sortindices::Symbol = :sortcols) where {Tv,Ti}
Expand Down Expand Up @@ -268,13 +301,13 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat)
B
end

A_ldiv_B!(L::LowerTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = fwdTriSolve!(L.data, B)
A_ldiv_B!(U::UpperTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = bwdTriSolve!(U.data, B)
ldiv!(L::LowerTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = fwdTriSolve!(L.data, B)
ldiv!(U::UpperTriangular{T,<:SparseMatrixCSCUnion{T}}, B::StridedVecOrMat) where {T} = bwdTriSolve!(U.data, B)

(\)(L::LowerTriangular{T,<:SparseMatrixCSCUnion{T}}, B::SparseMatrixCSC) where {T} = A_ldiv_B!(L, Array(B))
(\)(U::UpperTriangular{T,<:SparseMatrixCSCUnion{T}}, B::SparseMatrixCSC) where {T} = A_ldiv_B!(U, Array(B))

function A_rdiv_B!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T
function rdiv!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T
dd = D.diag
if (k = length(dd)) A.n
throw(DimensionMismatch("size(A, 2)=$(A.n) should be size(D, 1)=$k"))
Expand All @@ -292,8 +325,10 @@ function A_rdiv_B!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where T
A
end

A_rdiv_Bc!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where {T} = A_rdiv_B!(A, conj(D))
A_rdiv_Bt!(A::SparseMatrixCSC{T}, D::Diagonal{T}) where {T} = A_rdiv_B!(A, D)
rdiv!(A::SparseMatrixCSC{T}, adjD::Adjoint{<:Any,<:Diagonal{T}}) where {T} =
(D = adjD.parent; A_rdiv_B!(A, conj(D)))
rdiv!(A::SparseMatrixCSC{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(D = transD.parent; A_rdiv_B!(A, D))

## triu, tril

Expand Down Expand Up @@ -883,29 +918,51 @@ end
scale!(A::SparseMatrixCSC, b::Number) = (scale!(A.nzval, b); A)
scale!(b::Number, A::SparseMatrixCSC) = (scale!(b, A.nzval); A)

for f in (:\, :Ac_ldiv_B, :At_ldiv_B)
function \(A::SparseMatrixCSC, B::AbstractVecOrMat)
m, n = size(A)
if m == n
if istril(A)
if istriu(A)
return \(Diagonal(Vector(diag(A))), B)
else
return \(LowerTriangular(A), B)
end
elseif istriu(A)
return \(UpperTriangular(A), B)
end
if ishermitian(A)
return \(Hermitian(A), B)
end
return \(lufact(A), B)
else
return \(qrfact(A), B)
end
end
\(::SparseMatrixCSC, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))
for xform in (:Adjoint, :Transpose)
@eval begin
function ($f)(A::SparseMatrixCSC, B::AbstractVecOrMat)
function \(xformA::($xform){<:Any,<:SparseMatrixCSC}, B::AbstractVecOrMat)
A = xformA.parent
m, n = size(A)
if m == n
if istril(A)
if istriu(A)
return ($f)(Diagonal(Vector(diag(A))), B)
return \(($xform)(Diagonal(Vector(diag(A)))), B)
else
return ($f)(LowerTriangular(A), B)
return \(($xform)(LowerTriangular(A)), B)
end
elseif istriu(A)
return ($f)(UpperTriangular(A), B)
return \(($xform)(UpperTriangular(A)), B)
end
if ishermitian(A)
return ($f)(Hermitian(A), B)
return \(($xform)(Hermitian(A)), B)
end
return ($f)(lufact(A), B)
return \(($xform)(lufact(A)), B)
else
return ($f)(qrfact(A), B)
return \(($xform)(qrfact(A)), B)
end
end
($f)(::SparseMatrixCSC, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))
\(::($xform){<:Any,<:SparseMatrixCSC}, ::RowVector) = throw(DimensionMismatch("Cannot left-divide matrix by transposed vector"))
end
end

Expand Down

0 comments on commit dc06d0d

Please sign in to comment.