diff --git a/Project.toml b/Project.toml index fedd2a600..11af953aa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.61.0" +version = "1.61.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" +AxisArrays = "0.4.7" ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..6eb6128c9 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,6 +1,7 @@ module ChainRules using Adapt: adapt +using AxisArrays: AxisArray, AxisArrays using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..681c25203 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -131,6 +131,7 @@ allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) _setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..423a7afe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -128,6 +128,13 @@ end @test dx23[3] == dxfix[3] end + @testset "getindex(::AxisArray{<:Number})" begin + X = randn((2, 3)) + A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z]) + dA, back = rrule(getindex, A, [:a], [:x, :z]) + unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0] + end + @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex # Forward, scalar result diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..81bb4ee22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize using Adapt +using AxisArrays using Base.Broadcast: broadcastable using ChainRules using ChainRules: stack