Skip to content

Commit

Permalink
Make sparse map/broadcast work where the output eltype is not a concr…
Browse files Browse the repository at this point in the history
…ete subtype of Number.
  • Loading branch information
Sacha0 committed Dec 14, 2016
1 parent 1f8006d commit 7456d88
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
30 changes: 16 additions & 14 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1404,14 +1404,14 @@ sparse(S::UniformScaling, m::Integer, n::Integer=m) = speye_scaled(S.λ, m, n)
function map!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_checksameshape(C, A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = !_broadcast_isnonzero(fofzeros)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_checksameshape(A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = !_broadcast_isnonzero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
entrytypeC = _broadcast_type(f, A, Bs...)
indextypeC = _promote_indtype(A, Bs...)
Expand All @@ -1429,14 +1429,14 @@ function broadcast!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Var
_aresameshape(C, A, Bs...) && return map!(f, C, A, Bs...) # could avoid a second dims check in map
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = !_broadcast_isnonzero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
_aresameshape(A, Bs...) && return map(f, A, Bs...) # could avoid a second dims check in map
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = !_broadcast_isnonzero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
entrytypeC = _broadcast_type(f, A, Bs...)
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
Expand Down Expand Up @@ -1465,6 +1465,8 @@ _maxnnzfrom(Cm, Cn, A) = nnz(A) * div(Cm, A.m) * div(Cn, A.n)
@inline _unchecked_maxnnzbcres(Cm, Cn, As) = min(Cm * Cn, sum(_maxnnzfrom_each(Cm, Cn, As)))
@inline _checked_maxnnzbcres(Cm, Cn, As...) = Cm != 0 && Cn != 0 ? _unchecked_maxnnzbcres(Cm, Cn, As) : 0
_broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...))
@inline _broadcast_isnonzero{T<:Number}(x::T) = x != zero(T)
@inline _broadcast_isnonzero(x) = x != 0

# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
Expand All @@ -1475,7 +1477,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
C.colptr[j] = Ck
for Ak in nzrange(A, j)
Cx = f(A.nzval[Ak])
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + nnz(A) - (Ak - 1)))
C.rowval[Ck] = A.rowval[Ak]
C.nzval[Ck] = Cx
Expand Down Expand Up @@ -1559,7 +1561,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::Sp
# cases are equally or more likely than the Ai < Bi and Bi < Ai cases. Hence
# the ordering of the conditional chain above differs from that in the
# corresponding broadcast code (below).
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))))
C.rowval[Ck] = Ci
C.nzval[Ck] = Cx
Expand Down Expand Up @@ -1660,7 +1662,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
# pattern) the Ai < Bi and Bi < Ai cases are equally or more likely than the
# Ai == Bi and termination cases. Hence the ordering of the conditional
# chain above differs from that in the corresponding map code.
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
C.rowval[Ck] = Ci
C.nzval[Ck] = Cx
Expand All @@ -1682,7 +1684,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
# B's jth column without storing every entry in C's jth column
while Bk < stopBk
Cx = f(Ax, B.nzval[Bk])
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
C.rowval[Ck] = B.rowval[Bk]
C.nzval[Ck] = Cx
Expand All @@ -1701,7 +1703,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
else
Cx = fvAzB
end
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
C.rowval[Ck] = Ci
C.nzval[Ck] = Cx
Expand All @@ -1723,7 +1725,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
# A's jth column without storing every entry in C's jth column
while Ak < stopAk
Cx = f(A.nzval[Ak], Bx)
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
C.rowval[Ck] = A.rowval[Ak]
C.nzval[Ck] = Cx
Expand All @@ -1742,7 +1744,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC,
else
Cx = fzAvB
end
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, A, B)))
C.rowval[Ck] = Ci
C.nzval[Ck] = Cx
Expand Down Expand Up @@ -1862,7 +1864,7 @@ function _map_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{SparseMatrix
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
Cx = f(vals...)
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + min(length(C), _sumnnzs(As...)) - (sum(ks) - N)))
C.rowval[Ck] = activerow
C.nzval[Ck] = Cx
Expand Down Expand Up @@ -1982,7 +1984,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{Sparse
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
args, ks, rows = _fusedupdatebc_all(rowsentinel, activerow, rows, defargs, ks, stopks, As)
Cx = f(args...)
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, As)))
C.rowval[Ck] = activerow
C.nzval[Ck] = Cx
Expand All @@ -2003,7 +2005,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{Sparse
else
Cx = defaultCx
end
if Cx != zero(eltype(C))
if _broadcast_isnonzero(Cx)
Ck > spaceC && (spaceC = _expandstorage!(C, _unchecked_maxnnzbcres(C.m, C.n, As)))
C.rowval[Ck] = Ci
C.nzval[Ck] = Cx
Expand Down
11 changes: 11 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1825,3 +1825,14 @@ let
@test_throws DimensionMismatch broadcast(+, A, B, speye(N))
@test_throws DimensionMismatch broadcast!(+, X, A, B, speye(N))
end

# Test that sparse map and broadcast succeed where the output eltype isn't a
# concrete subtype of Number. (Issue #19561.)
let
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
@test map(intoneorfloatzero, speye(4)) == sparse(map(intoneorfloatzero, eye(4)))
@test map(stringorfloatzero, speye(4)) == sparse(map(stringorfloatzero, eye(4)))
@test broadcast(intoneorfloatzero, speye(4)) == sparse(broadcast(intoneorfloatzero, eye(4)))
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
end

0 comments on commit 7456d88

Please sign in to comment.