Skip to content

Commit

Permalink
Make sparse map[!]/broadcast[!] work where the output eltype is not a…
Browse files Browse the repository at this point in the history
… concrete subtype of number.
  • Loading branch information
Sacha0 committed Dec 31, 2016
1 parent 6484f4c commit 72b38af
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
31 changes: 17 additions & 14 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ map!{Tf,N}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMa
(_checksameshape(C, A, Bs...); _noshapecheck_map!(f, C, A, Bs...))
function _noshapecheck_map!{Tf,N}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N})
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
function _noshapecheck_map{Tf,N}(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N})
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = _iszero(fofzeros)
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
entrytypeC = Base.Broadcast._broadcast_type(Any, f, A, Bs...)
indextypeC = _promote_indtype(A, Bs...)
Expand All @@ -86,7 +86,7 @@ function broadcast!{Tf,N}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Varar
_aresameshape(C, A, Bs...) && return _noshapecheck_map!(f, C, A, Bs...)
Base.Broadcast.check_broadcast_indices(indices(C), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = _iszero(fofzeros)
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
Expand All @@ -99,7 +99,7 @@ broadcast{Tf,N}(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) =
_diffshape_broadcast(f, A, Bs...)
function _diffshape_broadcast{Tf,N}(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N})
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = fofzeros == zero(fofzeros)
fpreszeros = _iszero(fofzeros)
indextypeC = _promote_indtype(A, Bs...)
entrytypeC = Base.Broadcast._broadcast_type(Any, f, A, Bs...)
shapeC = to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
Expand All @@ -111,6 +111,9 @@ end
# helper functions for map[!]/broadcast[!] entry points (and related methods below)
@inline _sumnnzs(A) = nnz(A)
@inline _sumnnzs(A, Bs...) = nnz(A) + _sumnnzs(Bs...)
@inline _iszero(x) = x == 0
@inline _iszero(x::Number) = Base.iszero(x)
@inline _iszero(x::AbstractArray) = Base.iszero(x)
@inline _zeros_eltypes(A) = (zero(eltype(A)),)
@inline _zeros_eltypes(A, Bs...) = (zero(eltype(A)), _zeros_eltypes(Bs...)...)
@inline _promote_indtype(A) = indtype(A)
Expand Down Expand Up @@ -159,7 +162,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat)
setcolptr!(C, j, Ck)
for Ak in colrange(A, j)
Cx = f(storedvals(A)[Ak])
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Ck + nnz(A) - (Ak - 1)))
storedinds(C)[Ck] = storedinds(A)[Ak]
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -240,7 +243,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B::Spar
# cases are equally or more likely than the Ai < Bi and Bi < Ai cases. Hence
# the ordering of the conditional chain above differs from that in the
# corresponding broadcast code (below).
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -305,7 +308,7 @@ function _map_zeropres!{Tf,N}(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMa
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
Cx = f(vals...)
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, Ck + min(length(C), _sumnnzs(As...)) - (sum(ks) - N)))
storedinds(C)[Ck] = activerow
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -461,7 +464,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
# pattern) the Ai < Bi and Bi < Ai cases are equally or more likely than the
# Ai == Bi and termination cases. Hence the ordering of the conditional
# chain above differs from that in the corresponding map code.
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand All @@ -483,7 +486,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
# B's jth column without storing every entry in C's jth column
while Bk < stopBk
Cx = f(Ax, storedvals(B)[Bk])
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = storedinds(B)[Bk]
storedvals(C)[Ck] = Cx
Expand All @@ -502,7 +505,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
else
Cx = fvAzB
end
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand All @@ -524,7 +527,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
# A's jth column without storing every entry in C's jth column
while Ak < stopAk
Cx = f(storedvals(A)[Ak], Bx)
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = storedinds(A)[Ak]
storedvals(C)[Ck] = Cx
Expand All @@ -543,7 +546,7 @@ function _broadcast_zeropres!{Tf}(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, B
else
Cx = fzAvB
end
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), A, B)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down Expand Up @@ -674,7 +677,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseVecOrMat, As::Vararg{SparseV
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
args, ks, rows = _fusedupdatebc_all(rowsentinel, activerow, rows, defargs, ks, stopks, As)
Cx = f(args...)
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), As)))
storedinds(C)[Ck] = activerow
storedvals(C)[Ck] = Cx
Expand All @@ -695,7 +698,7 @@ function _broadcast_zeropres!{Tf,N}(f::Tf, C::SparseVecOrMat, As::Vararg{SparseV
else
Cx = defaultCx
end
if Cx != zero(eltype(C))
if !_iszero(Cx)
Ck > spaceC && (spaceC = expandstorage!(C, _unchecked_maxnnzbcres(size(C), As)))
storedinds(C)[Ck] = Ci
storedvals(C)[Ck] = Cx
Expand Down
9 changes: 9 additions & 0 deletions test/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ end
end
end

@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
@test map(intoneorfloatzero, speye(4)) == sparse(map(intoneorfloatzero, eye(4)))
@test map(stringorfloatzero, speye(4)) == sparse(map(stringorfloatzero, eye(4)))
@test broadcast(intoneorfloatzero, speye(4)) == sparse(broadcast(intoneorfloatzero, eye(4)))
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
end

# Older tests of sparse broadcast, now largely covered by the tests above
@testset "assorted tests of sparse broadcast over two input arguments" begin
N, p = 10, 0.3
Expand Down

0 comments on commit 72b38af

Please sign in to comment.