-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Make it easier to extend broadcast! #24992
Conversation
base/broadcast.jl
Outdated
|
||
# special cases for "X .= ..." (broadcast!) assignments | ||
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x) | ||
broadcast!(f, X::AbstractArray, x::Number...) = (@inbounds for I in eachindex(X); X[I] = f(x...); end; X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't quite sure how to handle this case, as it specializes on both the destination and the source args. What is the goal of this special case? Reducing compilation time or runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Runtime
base/sparse/higherorderfns.jl
Outdated
@@ -93,7 +93,8 @@ end | |||
# (3) broadcast[!] entry points | |||
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A) | |||
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A) | |||
function broadcast!(f::Tf, C::SparseVecOrMat) where Tf | |||
|
|||
function broadcast!(f::Tf, C::SparseVecOrMat, ::Void) where Tf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ::Void
for the broadcast style here and in the next broadcast!
method since we're specializing on the destination only, as per the rules at the end of #24914 (comment).
@@ -106,14 +107,13 @@ function broadcast!(f::Tf, C::SparseVecOrMat) where Tf | |||
end | |||
return C | |||
end | |||
function broadcast!(f::Tf, C::SparseVecOrMat, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turned this into a spbroadcast_args!
method.
base/sparse/higherorderfns.jl
Outdated
function broadcast!(f, dest::SparseVecOrMat, ::Void, A, Bs::Vararg{Any,N}) where N | ||
if isa(f, typeof(identity)) && N == 0 && isa(A, Number) | ||
return fill!(dest, A) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure whether to also implement the copy!
optimization when the indices
match here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that method was mostly about ambiguity resolution, so prob not necessary.
I tried to mimic the previous behavior of sparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. Will need to update the docs too.
base/sparse/higherorderfns.jl
Outdated
function broadcast!(f, dest::SparseVecOrMat, ::Void, A, Bs::Vararg{Any,N}) where N | ||
if isa(f, typeof(identity)) && N == 0 && isa(A, Number) | ||
return fill!(dest, A) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that method was mostly about ambiguity resolution, so prob not necessary.
base/broadcast.jl
Outdated
_broadcast!(f, C, A, Bs...) | ||
broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...) | ||
broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...) | ||
@inline function broadcast!(f, C, ::Void, A, Bs::Vararg{Any,N}) where N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't think about this before, but I wonder if we should consider the implications of specializing A, Bs::Vararg{Any,N}
vs just As::Vararg{Any,N}
. I believe those two have separate precedence for dispatch. Whatever we choose should be documented that way.
Might need some |
OK, thanks for your comments. I'll have more time to work on this after Wednesday. |
Not sure I follow --- sparse broadcast should only capture argument combinations where both the destination is |
Tests don't run due to a merge conflict. |
@Sacha0, that may be the case for julia/base/sparse/higherorderfns.jl Line 96 in 193f763
julia/base/sparse/higherorderfns.jl Line 109 in 193f763
julia/base/sparse/higherorderfns.jl Line 1017 in 193f763
Updates so far (sorry, should have commented as I pushed my latest commit):
I'll rebase and keep trying things today to get the allocation tests to pass. Any additional pointers on staying within the current limits of inference would be greatly appreciated. |
Prior to #23939, IIRC the entry points to sparse |
IIRC, this signature structure might've been necessary to correctly handle first arguments that are types. |
Again thanks for your efforts on this front @tkoolen! :) |
I think those implementations that retain the |
If you want to test for ambiguities, you can do this: using Test
detect_ambiguities(Base, Core) |
Might want |
@tkoolen also don't hesitate to ask for help. We definitely want this done for 1.0, and today is the deadline. |
Last I heard broadcast internals will be exempt from the stability guarantee in 1.0? Additionally, stretching this by a weekend should be alright (what with all the other work in progress that is yet to merge) :). |
84bb5d9
to
35638c7
Compare
I was aware of the deadline. I could certainly use help understanding what inference/compiler limits I'm triggering in the last cases where there are allocations, i.e.
I've just pushed my rebase (was hacking away at my local branch earlier). Feel free to open a PR against my branch if anybody has found a solution to the allocation problems. Or just give me some ideas for things to try. I had also come to the conclusion that passing along the |
Here's a smaller test (reduced from using Test
@testset begin
N, M, p = 10, 12, 0.3
mats = (sprand(N, M, p), sprand(N, 1, p), sprand(1, M, p), sprand(1, 1, 1.0), spzeros(1, 1))
vecs = (sprand(N, p), sprand(1, 1.0), spzeros(1))
tens = (mats..., vecs...)
Xo = tens[1]
X = ndims(Xo) == 1 ? SparseVector{Float32,Int32}(Xo) : SparseMatrixCSC{Float32,Int32}(Xo)
shapeX, fX = size(X), Array(X)
Y = tens[1]
Z = tens[1]
fY, fZ = Array(Y), Array(Z)
fQ = broadcast(+, fX, fY, fZ); Q = sparse(fQ)
broadcast!(+, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated
@test (@allocated broadcast!(+, Q, X, Y, Z)) == 0
end Do note that if I use a let block, i.e. let Q = Q, X = X, Y = Y, Z = Z
broadcast!(+, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated
@test (@allocated broadcast!(+, Q, X, Y, Z)) == 0
end then the test passes. I'm honestly kind of surprised that this passed before by the way, given these lines: julia/test/sparse/higherorderfns.jl Lines 216 to 220 in 9315ca0
|
Ran out of time for today. I'll take another look tomorrow. |
It seems to be a failure to specialize. I can get that particular test to pass with the following changes (I'm not sure they all are necessary, but they are sufficient): diff --git a/base/broadcast.jl b/base/broadcast.jl
index 8576824..f40f526 100644
--- a/base/broadcast.jl
+++ b/base/broadcast.jl
@@ -258,7 +258,7 @@ longest(::Tuple{}, ::Tuple{}) = ()
# combine_styles operates on values (arbitrarily many)
combine_styles(c) = result_style(BroadcastStyle(typeof(c)))
combine_styles(c1, c2) = result_style(combine_styles(c1), combine_styles(c2))
-combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
+@inline combine_styles(c1, c2, cs...) = result_style(combine_styles(c1), combine_styles(c2, cs...))
# result_style works on types (singletons and pairs), and leverages `BroadcastStyle`
result_style(s::BroadcastStyle) = s
@@ -442,8 +442,8 @@ Note that `dest` is only used to store the result, and does not supply
arguments to `f` unless it is also listed in the `As`,
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
"""
-@inline broadcast!(f, dest, As...) = broadcast!(f, dest, combine_styles(As...), As...)
-@inline broadcast!(f, dest, ::BroadcastStyle, As...) = broadcast!(f, dest, nothing, As...)
+@inline broadcast!(f::Tf, dest, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, combine_styles(As...), As...)
+@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...)
# Default behavior (separated out so that it can be called by users who want to extend broadcast!).
@inline function broadcast!(f, dest, ::Void, As::Vararg{Any, N}) where N
diff --git a/base/sparse/higherorderfns.jl b/base/sparse/higherorderfns.jl
index ba89f04..423ce10 100644
--- a/base/sparse/higherorderfns.jl
+++ b/base/sparse/higherorderfns.jl
@@ -1021,15 +1021,15 @@ function spbroadcast_args!(f::Tf, C, ::SPVM, A::SparseVecOrMat, Bs::Vararg{Spars
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, C, A, Bs...)
end
-function spbroadcast_args!(f, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where N
+function spbroadcast_args!(f::Tf, dest, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
return broadcast!(parevalf, dest, passedsrcargstup...)
end
-function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N
+function spbroadcast_args!(f::Tf, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
end
-function spbroadcast_args!(f, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where N
+function spbroadcast_args!(f::Tf, dest, ::Any, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# Fallback. From a performance perspective would it be best to densify?
Broadcast._broadcast!(f, dest, mixedsrcargs...)
end |
There's still a really strange inference failure though. Let's get a branch pushed and then see if we can get Jameson to take a look. |
Thanks, I pushed commit with those changes. |
Great. Any chance anyone not swamped with other things can take a look at the inference failure? (I'm moving today...bad timing.) Here's a test: julia> A = sprand(5, 5, 0.4);
julia> a = sprand(5, 0.4);
julia> X = similar(A);
julia> @code_warntype broadcast!(*, X, A, a, A, 1.0f0, 1.0f0, 1.0f0, A, a)
Variables:
f<optimized out>
dest::SparseMatrixCSC{Float64,Int64}
As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}
Body:
begin
SSAValue(14) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 1)::SparseMatrixCSC{Float64,Int64}
SSAValue(15) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 2)::SparseVector{Float64,Int64}
SSAValue(16) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 3)::SparseMatrixCSC{Float64,Int64}
SSAValue(17) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 4)::Float32
SSAValue(18) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 5)::Float32
SSAValue(19) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 6)::Float32
SSAValue(20) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 7)::SparseMatrixCSC{Float64,Int64}
SSAValue(21) = (Core.getfield)(As::Tuple{SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64},SparseMatrixCSC{Float64,Int64},Float32,Float32,Float32,SparseMatrixCSC{Float64,Int64},SparseVector{Float64,Int64}}, 8)::SparseVector{Float64,Int64}
# meta: location broadcast.jl broadcast! 443
SSAValue(24) = $(Expr(:invoke, MethodInstance for broadcast!(::typeof(*), ::SparseMatrixCSC{Float64,Int64}, ::Void, ::SparseMatrixCSC{Float64,Int64}, ::SparseVector{Float64,Int64}, ::SparseMatrixCSC{Float64,Int64}, ::Float32, ::Float32, ::Float32, ::SparseMatrixCSC{Float64,Int64}, ::SparseVector{Float64,Int64}), :(Base.Broadcast.broadcast!), *, :(dest), nothing, SSAValue(14), SSAValue(15), SSAValue(16), SSAValue(17), SSAValue(18), SSAValue(19), SSAValue(20), SSAValue(21)))::Any
# meta: pop location
return SSAValue(24)
end::Any The funny thing about it is that it knows which method it will call, and that method is inferrable: @code_warntype broadcast!(*, X, nothing, A, a, A, 1.0f0, 1.0f0, 1.0f0, A, a) # just adds the `nothing` Worst-case scenario I think we should mark those tests as broken and merge anyway, then fix. @tkoolen, can you add docs? See the |
Just reviewed more extensively. Overall these changes look great; thanks Re. the inference failure, my best guess is that the additional argument (the One broad comment: I imagined this mechanism enabling flattening of the dispatch layers in sparse |
780d5b5
to
26769f7
Compare
bd93724
to
761ca80
Compare
As long as you're dispatching on a specific |
Yep, agree. FYI: I'll be AFK for the next 48 hours or so. |
On this end, enabling graceful dispatch on both the destination and source arguments (edit: simultaneously, that is) was the perhaps primary impetus for this direction :). |
This passes tests locally for me. Let's get some performance data: @nanosoldier |
Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan |
LGTM. Do people want to see this go through CI? If so I'll squash on a separate branch and push as a new PR. The API will change further because of #23692 and its successor, but I think this is making progress in the right direction. So I'd say it's worth merging. |
I can squash as well if that's easier. |
If you can (and get rid of the ci skips), that would be great. |
140de09
to
94dcd82
Compare
Done. |
@inline broadcast!(f::Tf, dest, ::BroadcastStyle, As::Vararg{Any,N}) where {Tf,N} = broadcast!(f, dest, nothing, As...) | ||
|
||
# Default behavior (separated out so that it can be called by users who want to extend broadcast!). | ||
@inline function broadcast!(f, dest, ::Nothing, As::Vararg{Any, N}) where N |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@StefanKarpinski, thanks again for seeing the Void
-> Nothing
change through! So much more intuitive :).
base/sparse/higherorderfns.jl
Outdated
# Fallback. From a performance perspective would it be best to densify? | ||
Broadcast._broadcast!(f, dest, mixedsrcargs...) | ||
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N} | ||
broadcast!(f, dest, nothing, map(_sparsifystructured, mixedsrcargs)...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nothing
for BroadcastStyle
causes dispatch to the generic AbstractArray
broadcast implementation, correct? If so, this nothing
should disappear, as PromoteToSparse
argument combinations should go to sparse broadcast after promotion IIRC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, when I originally suggested this there was a Nothing
method at the "top" of the sparse dispatch hierarchy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
base/sparse/higherorderfns.jl
Outdated
end | ||
function spbroadcast_args!(f, dest, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where N | ||
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...) | ||
broadcast!(parevalf, dest, nothing, passedsrcargstup...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise (https://github.com/JuliaLang/julia/pull/24992/files#r158561090) here, does this nothing
cause dispatch specifically to the generic AbstractArray
code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
94dcd82
to
fa3fe32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! Much thanks for seeing this through @tkoolen! :)
Who else's sign-off do we need in order to merge this? Anyone? @timholy perhaps? |
Thanks, @tkoolen! |
Thank you very much for your help, @timholy. |
Note that no bounds should be placed on the types of `f` and `As...`. | ||
|
||
Second, if specialized `broadcast!` behavior is desired depending on the input types, | ||
you should write [binary broadcasting rules](@ref writing-binary-broadcasting-rules) to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the writing-binary-broadcasting-rules
section?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://docs.julialang.org/en/latest/manual/interfaces/#writing-binary-broadcasting-rules-1 (link seems to be working).
Not sure if I'm commenting on the right PR, but I think I am. I note that now I can't use I think maybe we need a different sentinel for the "default" |
Yes, this was noted in #25377 (comment), but if that doesn't get merged soon it should be fixed in the old code. I opened #26097 as a reminder. |
Fix #24914.
This implements @timholy's proposal from #24914 (comment). This passes all tests locally except some
@inferred
/@allocated
tests insparse/higherorderfns.jl
, specifically:julia/test/sparse/higherorderfns.jl
Line 239 in 5f929a4
julia/test/sparse/higherorderfns.jl
Line 244 in 5f929a4
julia/test/sparse/higherorderfns.jl
Line 329 in 5f929a4
julia/test/sparse/higherorderfns.jl
Line 363 in 5f929a4
I'm not exactly sure why this is the case, and could use some help.
To do:
sparse/higherorderfns.jl
inference/allocation tests to passbroadcast!(f, X::AbstractArray, x::Number...)
I'll add some line comments to better explain what I did.