-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can we get rid of auto-broadcasting of 0D arrays for activations? #608
Comments
I don't like the auto-broadcast either but here we are. The built-in opt-out is this function -- which perhaps Reactant needs to know how to handle anyway? julia> Base.Broadcast.broadcast_preserving_zero_d(sin, fill(pi/2))
0-dimensional Array{Float64, 0}:
1.0
julia> Base.Broadcast.broadcast_preserving_zero_d(sin, [0, pi/2])
2-element Vector{Float64}:
0.0
1.0 |
This will still cause issues, right? I want the OD case to be forwarded to the original call without any broadcasting. For example, for relu I want |
But what this function is fed is some special fake 0D array which Reactant invents? My hope is that it can also be made to understand that |
I found a solution that would do it: for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, :σ))
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
end
end
correct. Reactant doesn't have a "Number" type, so we treat 0D arrays as a scalar |
With a rework of the Reactant scalar handling this is now fixed without using |
(Ideally I don't think it should be auto-broadcasting in the first place). But if we just get rid of O-D array broadcasting that solves our problem over at EnzymeAD/Reactant.jl#54
NNlib.jl/src/activations.jl
Lines 752 to 755 in ba29c90
Essentially, Reactant needs to treat scalars as OD trackedrarrays but that causes a recursion loop and expectedly the IR has an unreachable (EnzymeAD/Reactant.jl#54 (comment)). This means the only way we can support NNlib activations is to manually copy over all the code for activation functions.
Now I know there isn't a general way to "opt-out" of the broadcasting for 0-D arrays but we can just define the broadcasting for N=1..10 and hope no one is using an 11+D tensor.
The text was updated successfully, but these errors were encountered: