diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl index 7d1f0d33f6578..45b859e259ade 100644 --- a/base/sparse/higherorderfns.jl +++ b/base/sparse/higherorderfns.jl @@ -863,26 +863,26 @@ end return broadcast!(parevalf, dest, passedsrcargstup...) end # capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and -# broadcast scalar arguments (mixedargs), and returns a function (parevalf) and a reduced -# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs -# in their orginal order, and such that the result of broadcast(g, passedargstup...) is -# broadcast(f, mixedargs...) -capturescalars(f, mixedargs) = +# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially +# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse +# vectors/matrices in mixedargs in their orginal order, and such that the result of +# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...) +@inline capturescalars(f, mixedargs) = capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...) # Recursion cases for capturescalars -capturescalars(f, passedargstup, scalararg, mixedargs...) = +@inline capturescalars(f, passedargstup, scalararg, mixedargs...) = capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...) -capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) = +@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) = capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...) -passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...)) -capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...)) +@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...)) +@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...)) # Base cases for capturescalars -capturescalars(f, passedargstup, scalararg) = +@inline capturescalars(f, passedargstup, scalararg) = (capturelastscalar(f, scalararg), passedargstup) -capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) = +@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) = (passlastnonscalar(f), (passedargstup..., nonscalararg)) -passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),)) -capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,)) +@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),)) +@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,)) # NOTE: The following two method definitions work around #19096. broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A) diff --git a/test/sparse/higherorderfns.jl b/test/sparse/higherorderfns.jl index 3ef68a73690cb..7afbdb41008a0 100644 --- a/test/sparse/higherorderfns.jl +++ b/test/sparse/higherorderfns.jl @@ -193,6 +193,7 @@ end end end + @testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x) stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x) @@ -202,10 +203,10 @@ end @test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4))) end -@testset "broadcast over combinations of scalars and sparse vectors/matrices" begin - N, M, p = 10, 12, 0.3 +@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin + N, M, p = 10, 12, 0.5 elT = Float64 - s = elT(2.0) + s = Float32(2.0) V = sprand(elT, N, p) A = sprand(elT, N, M, p) fV, fA = Array(V), Array(A) @@ -235,8 +236,29 @@ end ((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)), ((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)), ((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), ) + # test broadcast entry point @test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...)) @test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT}) + # test broadcast! entry point + fX = broadcast(*, sparseargs...); X = sparse(fX) + @test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...)) + @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) + X = sparse(fX) # reset / warmup for @allocated test + @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 + # This test (and the analog below) fails for three reasons: + # (1) In all cases, generating the closures that capture the scalar arguments + # results in allocation, not sure why. + # (2) In some cases, though _broadcast_eltype (which wraps _return_type) + # consistently provides the correct result eltype when passed the closure + # that incorporates the scalar arguments to broadcast (and, with #19667, + # is inferable, so the overall return type from broadcast is inferred), + # in some cases inference seems unable to determine the return type of + # direct calls to that closure. This issue causes variables in both the + # broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and + # the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have + # inferred type Any, resulting in allocation and lackluster performance. + # (3) The sparseargs... splat in the call above allocates a bit, but of course + # that issue is negligible and perhaps could be accounted for in the test. end end # test combinations at the limit of inference (eight arguments net) @@ -248,8 +270,16 @@ end ((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices ((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices ((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices + # test broadcast entry point @test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...)) @test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT}) + # test broadcast! entry point + fX = broadcast(*, sparseargs...); X = sparse(fX) + @test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...)) + @test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT}) + X = sparse(fX) # reset / warmup for @allocated test + @test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0 + # please see the note a few lines above re. this @test_broken end end