From 0cab7a5f28f559fdba3ecc5032189af172ef203f Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:19:09 +0100 Subject: [PATCH 1/3] Remove third argument to `similar` --- src/rulesets/Base/indexing.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..d19619c39 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -130,7 +130,7 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors From 2aa8d26d91d7a50e7706ca4e4b5e34dd37b09c79 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:52:05 +0100 Subject: [PATCH 2/3] Add test, fix existing tests --- Project.toml | 2 ++ src/ChainRules.jl | 1 + src/rulesets/Base/indexing.jl | 3 ++- test/rulesets/Base/indexing.jl | 7 +++++++ test/runtests.jl | 1 + 5 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fedd2a600..c8e05e63b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "1.61.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" +AxisArrays = "0.4.7" ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 6d33a22e7..6eb6128c9 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,6 +1,7 @@ module ChainRules using Adapt: adapt +using AxisArrays: AxisArray, AxisArrays using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index d19619c39..681c25203 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -130,7 +130,8 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ _setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..423a7afe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -128,6 +128,13 @@ end @test dx23[3] == dxfix[3] end + @testset "getindex(::AxisArray{<:Number})" begin + X = randn((2, 3)) + A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z]) + dA, back = rrule(getindex, A, [:a], [:x, :z]) + unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0] + end + @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex # Forward, scalar result diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..81bb4ee22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize using Adapt +using AxisArrays using Base.Broadcast: broadcastable using ChainRules using ChainRules: stack From 1d779d27e409658da883d6731be55809a85e0a31 Mon Sep 17 00:00:00 2001 From: Simone Carlo Surace Date: Thu, 8 Feb 2024 22:55:02 +0100 Subject: [PATCH 3/3] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c8e05e63b..11af953aa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.61.0" +version = "1.61.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"