-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
merge cumprod rules to fix ambig on julia 1.6 #784
Conversation
end | ||
, | ||
@thunk project_x(if dims <= ndims(x) | ||
vald = Val(Int(dims)) | ||
∇cumprod_dim(vald, x, dy, y) | ||
if ndims(x) == 1 | ||
∇cumprod(x, dy, y) | ||
else | ||
vald = Val(Int(dims)) | ||
∇cumprod_dim(vald, x, dy, y) | ||
end | ||
else | ||
dy | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[JuliaFormatter] reported by reviewdog 🐶
end | |
, | |
@thunk project_x(if dims <= ndims(x) | |
vald = Val(Int(dims)) | |
∇cumprod_dim(vald, x, dy, y) | |
if ndims(x) == 1 | |
∇cumprod(x, dy, y) | |
else | |
vald = Val(Int(dims)) | |
∇cumprod_dim(vald, x, dy, y) | |
end | |
else | |
dy | |
end) | |
end, | |
@thunk project_x( | |
if dims <= ndims(x) | |
if ndims(x) == 1 | |
∇cumprod(x, dy, y) | |
else | |
vald = Val(Int(dims)) | |
∇cumprod_dim(vald, x, dy, y) | |
end | |
else | |
dy | |
end, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'm slightly in favour of this PR, since it reduces the number of methods and cumprod(::AbstractVector)
is forwarded to the keyword argument version: https://github.com/JuliaLang/julia/blob/63e95d47e692352fac5541b6ee835b573f4c4898/base/accumulate.jl#L227
Unfortunately, the diff is significantly larger than in #781 - and it requires a few additional changes to the tests since there's no rrule(cumprod, ::AbstractVector)
anymore (so one has to set dims=1
explicitly in these tests).
src/rulesets/Base/mapreduce.jl
Outdated
end | ||
|
||
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer) | ||
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it OK to put a default value for dims
even though there's no default value in base? https://github.com/JuliaLang/julia/blob/63e95d47e692352fac5541b6ee835b573f4c4898/base/accumulate.jl#L193
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it is see:
#784 (comment)
its actually fine for the
|
What does this mean? The primal computation is done inside the rrule, and does use the default. I think that will mean That said, I have no memory of why there are two rrules. Besides this issue there might have been performance concerns? |
That is a point. We will have to do somehting more complex like set the rrule as |
Can't we just omit the default value and adjust the tests? |
My vote is that #781 is much simpler. Besides worrying about error cases, I think this kind of change would want to be checked very carefully for type stability or performance regressions. |
Suggested by @devmotion in #781
should resolve same issue.
I am not sure which is cleaner