Skip to content

Commit

Permalink
Make sparse operations less dependent on inference
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Dec 15, 2016
1 parent 1362268 commit 8dab405
Showing 1 changed file with 51 additions and 39 deletions.
90 changes: 51 additions & 39 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,52 +1401,65 @@ sparse(S::UniformScaling, m::Integer, n::Integer=m) = speye_scaled(S.λ, m, n)
## map/map! and broadcast/broadcast! over sparse matrices

# map/map! entry points
function map!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
function map!{F,N}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_checksameshape(C, A, Bs...)
return _map!(f, C, A, Bs...)
end
@inline function _map!(f, C, A, Bs::Vararg)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
function map{F,N}(f::F, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_checksameshape(A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
return _map(f, A, Bs...)
end
@inline function _map(f, A, Bs::Vararg)
entrytypeC = _broadcast_type(f, A, Bs...)
indextypeC = _promote_indtype(A, Bs...)
Ccolptr = Vector{indextypeC}(A.n + 1)
Crowval = Vector{indextypeC}(maxnnzC)
Cnzval = Vector{entrytypeC}(maxnnzC)
C = SparseMatrixCSC(A.m, A.n, Ccolptr, Crowval, Cnzval)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
if isleaftype(entrytypeC)
indextypeC = _promote_indtype(A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
Ccolptr = Vector{indextypeC}(A.n + 1)
Crowval = Vector{indextypeC}(maxnnzC)
Cnzval = Vector{entrytypeC}(maxnnzC)
C = SparseMatrixCSC(A.m, A.n, Ccolptr, Crowval, Cnzval)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
return sparse(collect(Base.Generator(f, A, Bs...)))
end
# broadcast/broadcast! entry points
broadcast{Tf}(f::Tf, A::SparseMatrixCSC) = map(f, A)
broadcast!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC) = map!(f, C, A)
function broadcast!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_aresameshape(C, A, Bs...) && return map!(f, C, A, Bs...) # could avoid a second dims check in map
broadcast{F}(f::F, A::SparseMatrixCSC) = map(f, A)
broadcast!{F}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC) = map!(f, C, A)
function broadcast!{F,N}(f::F, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_aresameshape(C, A, Bs...) && return _map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_aresameshape(A, Bs...) && return map(f, A, Bs...) # could avoid a second dims check in map
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
function broadcast{F,N}(f::F, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_aresameshape(A, Bs...) && return _map(f, A, Bs...)
entrytypeC = _broadcast_type(f, A, Bs...)
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
Ccolptr = Vector{indextypeC}(Cn + 1)
Crowval = Vector{indextypeC}(maxnnzC)
Cnzval = Vector{entrytypeC}(maxnnzC)
C = SparseMatrixCSC(Cm, Cn, Ccolptr, Crowval, Cnzval)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
shape = Base.Broadcast.broadcast_indices(A, Bs...)
if isleaftype(entrytypeC)
indextypeC = _promote_indtype(A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
Cm, Cn = Base.to_shape(shape)
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
Ccolptr = Vector{indextypeC}(Cn + 1)
Crowval = Vector{indextypeC}(maxnnzC)
Cnzval = Vector{entrytypeC}(maxnnzC)
C = SparseMatrixCSC(Cm, Cn, Ccolptr, Crowval, Cnzval)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
return sparse(Base.Broadcast.broadcast_t(f, Any, shape, CartesianRange(shape), A, Bs...))
end
# map/map! and broadcast/broadcast! entry point helper functions
@inline _sumnnzs(A) = nnz(A)
Expand All @@ -1468,7 +1481,7 @@ _broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...)

# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
function _map_zeropres!(f, C::SparseMatrixCSC, A::SparseMatrixCSC)
spaceC = min(length(C.rowval), length(C.nzval))
Ck = 1
@inbounds for j in 1:C.n
Expand All @@ -1491,7 +1504,7 @@ end
Densifies `C`, storing `fillvalue` in place of each unstored entry in `A` and
`f(A[i,j])` in place of each stored entry `A[i,j]` in `A`.
"""
function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC)
function _map_notzeropres!(f, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC)
# Build dense matrix structure in C, expanding storage if necessary
_densestructure!(C)
# Populate values
Expand Down Expand Up @@ -2281,14 +2294,13 @@ round{To}(::Type{To}, A::SparseMatrixCSC) = round.(To, A)
# TODO: These seven functions should probably be reimplemented in terms of sparse map
# when a better sparse map exists. (And vectorized min, max, &, |, and xor should be
# deprecated in favor of compact-broadcast syntax.)
_checksameshape(A, B) = size(A) == size(B) || throw(DimensionMismatch("size(A) must match size(B)"))
(+)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(+, A, B))
(-)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(-, A, B))
min(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(min, A, B))
max(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(max, A, B))
(&)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(&, A, B))
(|)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(|, A, B))
xor(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(xor, A, B))
(+)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(+, A, B)
(-)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(-, A, B)
min(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(min, A, B)
max(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(max, A, B)
(&)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(&, A, B)
(|)(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(|, A, B)
xor(A::SparseMatrixCSC, B::SparseMatrixCSC) = broadcast(xor, A, B)

(.+)(A::SparseMatrixCSC, B::Number) = Array(A) .+ B
( +)(A::SparseMatrixCSC, B::Array ) = Array(A) + B
Expand Down

0 comments on commit 8dab405

Please sign in to comment.