Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge cumprod rules to fix ambig on julia 1.6 #784

Closed
wants to merge 3 commits into from
Closed

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Feb 9, 2024

Suggested by @devmotion in #781
should resolve same issue.
I am not sure which is cleaner

@oxinabox oxinabox requested a review from devmotion February 9, 2024 13:49
Comment on lines 343 to 354
end
,
@thunk project_x(if dims <= ndims(x)
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
if ndims(x) == 1
∇cumprod(x, dy, y)
else
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
end
else
dy
end)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
,
@thunk project_x(if dims <= ndims(x)
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
if ndims(x) == 1
∇cumprod(x, dy, y)
else
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
end
else
dy
end)
end,
@thunk project_x(
if dims <= ndims(x)
if ndims(x) == 1
∇cumprod(x, dy, y)
else
vald = Val(Int(dims))
∇cumprod_dim(vald, x, dy, y)
end
else
dy
end,

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm slightly in favour of this PR, since it reduces the number of methods and cumprod(::AbstractVector) is forwarded to the keyword argument version: https://github.com/JuliaLang/julia/blob/63e95d47e692352fac5541b6ee835b573f4c4898/base/accumulate.jl#L227

Unfortunately, the diff is significantly larger than in #781 - and it requires a few additional changes to the tests since there's no rrule(cumprod, ::AbstractVector) anymore (so one has to set dims=1 explicitly in these tests).

end

function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to put a default value for dims even though there's no default value in base? https://github.com/JuliaLang/julia/blob/63e95d47e692352fac5541b6ee835b573f4c4898/base/accumulate.jl#L193

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is see:
#784 (comment)

@oxinabox
Copy link
Member Author

oxinabox commented Feb 9, 2024

its actually fine for the rrule to always default dims to 1. (So I have made that change)
Since:

  • in the vector case this is what is wanted and is the correct default
  • in the nonvector case then a keyword argument is required by the primal, so the default will never be hit.

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2024

in the nonvector case then a keyword argument is required by the primal, so the default will never be hit.

What does this mean? The primal computation is done inside the rrule, and does use the default.

I think that will mean gradient(x -> sum(abs, cumprod(x)), [1 2; 3 4.]) is not an error, whereas without gradient it would be.

That said, I have no memory of why there are two rrules. Besides this issue there might have been performance concerns?

@oxinabox
Copy link
Member Author

oxinabox commented Feb 9, 2024

I think that will mean gradient(x -> sum(abs, cumprod(x)), [1 2; 3 4.]) is not an error, whereas without gradient it would be.

That is a point.
I had been thinking we could rely on the primal computation being correct.
But I guess we can not.

We will have to do somehting more complex like set the rrule as dims=nothing
then if ndims(x)==1 set it to dims=1

@devmotion
Copy link
Member

We will have to do somehting more complex like set the rrule as dims=nothing

Can't we just omit the default value and adjust the tests?

@mcabbott
Copy link
Member

mcabbott commented Feb 9, 2024

My vote is that #781 is much simpler.

Besides worrying about error cases, I think this kind of change would want to be checked very carefully for type stability or performance regressions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants