-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make getindex
rule work for AxisArrays
#779
Conversation
|
||
[deps] | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO ChainRules should not depend on AxisArrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have any insight on the merits for or against this. But what is your suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT it has been a general policy to not accept such dependencies, see e.g. JuliaArrays/FillArrays.jl#153 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An extension to FillArrays is also out of the question?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR was before extensions existed, so I have been thinking for a while one should try again with an extension. I managed to get in an extension on PDMats recently, so I think it seems likely that it would be approved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making that FillArrays PR into an extension there would be great.
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) | ||
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason for why these are not just
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) | |
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) | |
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false) | |
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) |
AFAICT this would also fix the AxisArrays problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This breaks existing tests. The problem is if you don't pass the axes, then you don't get a dense array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which tests are broken? The two-arg method is even advised in the Julia docs: https://docs.julialang.org/en/v1/manual/methods/#Building-a-similar-type-with-a-different-type-parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 3-arg one removes structured matrices like Symmetric, iirc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should just add special cases for these? At first glance, it doesn't seem very desirable to remove structure (as the AxisArrays case shows).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO the ideal situation is for axes
to return more information in AxisArrays, alla mcabbott/AxisKeys.jl#6 , since the relevant properties do belong to individual axes, not to the whole (like Symmetric). But we ran out of energy to fix things.
What I thought should work is If the gradient is another AxisArray (as proposed here) then one question is what meaning its axis vector have. E.g. if I differentiate |
Is this something that people do (label axes with floats)? |
Certainly some people want AxisArrays to be a sort of DataFrame, with labels just identifying columns. I think the structure is more generally useful for storing anything which varies along an axis, and passing this along, never using it to replace indexing. For instance keeping a probability vector associated with eachcol(matrix). No idea really which is more common. The current behaviour with Zygote is this, which I think means you won't silently get wrong answers. But may get errors if you wish to combine the two functions: julia> gradient(x -> x[1], AxisArray([1,2,3.0], aux=[4,5,6.0])) # natural
([1.0, 0.0, 0.0],)
julia> gradient(x -> AxisArrays.axes(x)[1][1], AxisArray([1,2,3.0], aux=[4,5,6.0])) # structural
((data = nothing, axes = ((val = [1.0, 0.0, 0.0],),)),) AxisKeys acquired some projection rules which seem to be designed only for the first case:
|
It would seem to me that the first use case, where the index is not differentiable (strings, symbols) is the more pressing one to tackle, as it is clear what the right behavior should be (i.e. identical to indexing with integers). Could one of the maintainers please make a final decision as to whether this PR is going to be rejected based on the additional dependency? In that case, and if we can't make #780 work, I will open a PR for an extension over at AxisArrays. |
CR doesn't define rules for any packages, only standard lib. This isn't a whole If there's an easy way to tweak a rule to play nicely that's OK, and there are some tests involving packages. But most things should be handled by packages depending on ChainRulesCore, or pkg extensions. Making AxisArrays opt out of the rule might work, as it must ultimately index the parent array, and probably that's the right time to call this rule? I'm still a little surprised that |
Due to the distinction between
Base.axes
andAxisArrays.axes
, the existing rules did not work forAxisArrays
.