Skip to content

Commit

Permalink
Use two-arg similar in _setindex_zero
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Feb 8, 2024
1 parent 1d779d2 commit 348fb86
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ version = "1.61.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand Down Expand Up @@ -43,6 +42,7 @@ SuiteSparse = "1"
julia = "1.6"

[extras]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -52,4 +52,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
test = ["AxisArrays", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
1 change: 0 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module ChainRules

using Adapt: adapt
using AxisArrays: AxisArray, AxisArrays
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using ChainRulesCore
using Compat
Expand Down
12 changes: 4 additions & 8 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,18 @@ Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)
This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
`∇getindex` does next.
It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
allow `eltype(dy)`, nor does it work for many structured matrices.
"""
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
_setindex_zero(x::AxisArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), AxisArrays.axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false)
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
T = Union{typeof(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
return fill!(similar(x, T), ZeroTangent())
end
function _setindex_zero(x::AbstractArray, dy, inds...)
T = Union{eltype(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
return fill!(similar(x, T), ZeroTangent())
end
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)

Expand Down

0 comments on commit 348fb86

Please sign in to comment.