Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make getindex rule work for AxisArrays #779

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.61.0"
version = "1.61.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO ChainRules should not depend on AxisArrays.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have any insight on the merits for or against this. But what is your suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT it has been a general policy to not accept such dependencies, see e.g. JuliaArrays/FillArrays.jl#153 (comment)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An extension to FillArrays is also out of the question?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR was before extensions existed, so I have been thinking for a while one should try again with an extension. I managed to get in an extension on PDMats recently, so I think it seems likely that it would be approved.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making that FillArrays PR into an extension there would be great.

ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0, 4"
AxisArrays = "0.4.7"
ChainRulesCore = "1.20"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ChainRules

using Adapt: adapt
using AxisArrays: AxisArray, AxisArrays
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
Expand Down
1 change: 1 addition & 0 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ allow `eltype(dy)`, nor does it work for many structured matrices.
"""
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
Comment on lines 132 to 133
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for why these are not just

Suggested change
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false)

AFAICT this would also fix the AxisArrays problem.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks existing tests. The problem is if you don't pass the axes, then you don't get a dense array.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which tests are broken? The two-arg method is even advised in the Julia docs: https://docs.julialang.org/en/v1/manual/methods/#Building-a-similar-type-with-a-different-type-parameter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 3-arg one removes structured matrices like Symmetric, iirc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just add special cases for these? At first glance, it doesn't seem very desirable to remove structure (as the AxisArrays case shows).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the ideal situation is for axes to return more information in AxisArrays, alla mcabbott/AxisKeys.jl#6 , since the relevant properties do belong to individual axes, not to the whole (like Symmetric). But we ran out of energy to fix things.

_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false)
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
Expand Down
7 changes: 7 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ end
@test dx23[3] == dxfix[3]
end

@testset "getindex(::AxisArray{<:Number})" begin
X = randn((2, 3))
A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z])
dA, back = rrule(getindex, A, [:a], [:x, :z])
unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0]
end

@testset "second derivatives: ∇getindex" begin
@eval using ChainRules: ∇getindex
# Forward, scalar result
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils
@nospecialize

using Adapt
using AxisArrays
using Base.Broadcast: broadcastable
using ChainRules
using ChainRules: stack
Expand Down
Loading