diff --git a/Project.toml b/Project.toml index fedd2a600..c8e05e63b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.61.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" +AxisArrays = "0.4.7" ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..6eb6128c9 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,6 +1,7 @@ module ChainRules using Adapt: adapt +using AxisArrays: AxisArray, AxisArrays using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index d19619c39..681c25203 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -130,7 +130,8 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..423a7afe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -128,6 +128,13 @@ end @test dx23[3] == dxfix[3] end + @testset "getindex(::AxisArray{<:Number})" begin + X = randn((2, 3)) + A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z]) + dA, back = rrule(getindex, A, [:a], [:x, :z]) + unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0] + end + @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex # Forward, scalar result diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..81bb4ee22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize using Adapt +using AxisArrays using Base.Broadcast: broadcastable using ChainRules using ChainRules: stack