Skip to content

Commit

Permalink
Add test, fix existing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
simsurace committed Feb 8, 2024
1 parent 0cab7a5 commit 2aa8d26
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "1.61.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -20,6 +21,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0, 4"
AxisArrays = "0.4.7"
ChainRulesCore = "1.20"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Expand Down
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ChainRules

using Adapt: adapt
using AxisArrays: AxisArray, AxisArrays
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
Expand Down
3 changes: 2 additions & 1 deletion src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
allow `eltype(dy)`, nor does it work for many structured matrices.
"""
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false)
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
Expand Down
7 changes: 7 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ end
@test dx23[3] == dxfix[3]
end

@testset "getindex(::AxisArray{<:Number})" begin
X = randn((2, 3))
A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z])
dA, back = rrule(getindex, A, [:a], [:x, :z])
unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0]
end

@testset "second derivatives: ∇getindex" begin
@eval using ChainRules: ∇getindex
# Forward, scalar result
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils
@nospecialize

using Adapt
using AxisArrays
using Base.Broadcast: broadcastable
using ChainRules
using ChainRules: stack
Expand Down

0 comments on commit 2aa8d26

Please sign in to comment.