You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This appears to be caused by the frule definition here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 which I believe may be a junk method, not only because of it assuming you can do dims=:, but also because it assumes that sum(ẋ; dims=dims) makes sense, but that appears to be incorrect since ẋ is going to be a Tangent so this sum isn't actually summing the contents of the Tangent, but rather giving the outer thing (could be misunderstanding though).
I tried adding methods like
function frule((_, ẋ), ::typeof(sum), x::Generator;)
return sum(x), sum(ẋ)
end
but that gave incorrect answers I think because of the sum(ẋ).
Instead, I found that if I just did Base.delete_method on the frule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function) method, I got the right results.
The text was updated successfully, but these errors were encountered:
I think this is a ChainRules.jl problem, not a Diffractor.jl problem but I'm not 100% sure. Here's a MWE:
This appears to be caused by the
frule
definition here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/mapreduce.jl#L9 which I believe may be a junk method, not only because of it assuming you can dodims=:
, but also because it assumes thatsum(ẋ; dims=dims)
makes sense, but that appears to be incorrect sinceẋ
is going to be aTangent
so thissum
isn't actually summing the contents of theTangent
, but rather giving the outer thing (could be misunderstanding though).I tried adding methods like
but that gave incorrect answers I think because of the
sum(ẋ)
.Instead, I found that if I just did
Base.delete_method
on thefrule(::Tuple{…}, ::typeof(sum), x::Base.Generator{…}; dims::Function)
method, I got the right results.The text was updated successfully, but these errors were encountered: