From 0acd84746f909ba7cbecc779d905c7e5b3886a47 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 20:58:55 -0400 Subject: [PATCH 1/3] Fix Tracker with restructure --- ext/ArrayInterfaceTrackerExt.jl | 7 +++++++ test/ad.jl | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/ext/ArrayInterfaceTrackerExt.jl b/ext/ArrayInterfaceTrackerExt.jl index d2d4e2ce..6d26f410 100644 --- a/ext/ArrayInterfaceTrackerExt.jl +++ b/ext/ArrayInterfaceTrackerExt.jl @@ -9,4 +9,11 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) +function ArrayInterface.restructure(x::Array, y::TrackedArray) + reshape(y, Base.size(x)...) +end +function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal}) + reshape(y, Base.size(x)...) +end + end # module diff --git a/test/ad.jl b/test/ad.jl index 7c29c8dd..c1a207b7 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -18,3 +18,12 @@ x = Tracker.TrackedArray([4.0,4.0]) x = reduce(vcat, Tracker.TrackedArray([4.0,4.0])) x = [x[1],x[2]] @test ArrayInterface.aos_to_soa(x) isa Tracker.TrackedArray + +x = rand(4) +y = Tracker.TrackedReal.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: Tracker.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = Tracker.TrackedArray(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) \ No newline at end of file From c3acb74ea96540abec7644a7d8bd9c26d0dae644 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 21:00:10 -0400 Subject: [PATCH 2/3] Namespace --- ext/ArrayInterfaceTrackerExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ArrayInterfaceTrackerExt.jl b/ext/ArrayInterfaceTrackerExt.jl index 6d26f410..5723d9f1 100644 --- a/ext/ArrayInterfaceTrackerExt.jl +++ b/ext/ArrayInterfaceTrackerExt.jl @@ -9,7 +9,7 @@ ArrayInterface.can_setindex(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:Tracker.TrackedArray}) = false ArrayInterface.aos_to_soa(x::AbstractArray{<:Tracker.TrackedReal,N}) where {N} = Tracker.collect(x) -function ArrayInterface.restructure(x::Array, y::TrackedArray) +function ArrayInterface.restructure(x::Array, y::Tracker.TrackedArray) reshape(y, Base.size(x)...) end function ArrayInterface.restructure(x::Array, y::Array{<:Tracker.TrackedReal}) From 00cc8c6e747c0ec2bf7fa53147e4fff241bee0ee Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sat, 31 Aug 2024 21:28:13 -0400 Subject: [PATCH 3/3] Add ReverseDiff --- ext/ArrayInterfaceReverseDiffExt.jl | 4 ++++ test/ad.jl | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/ext/ArrayInterfaceReverseDiffExt.jl b/ext/ArrayInterfaceReverseDiffExt.jl index ed544e48..37176824 100644 --- a/ext/ArrayInterfaceReverseDiffExt.jl +++ b/ext/ArrayInterfaceReverseDiffExt.jl @@ -15,4 +15,8 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N end end +function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray) + reshape(y, Base.size(x)...) +end + end # module diff --git a/test/ad.jl b/test/ad.jl index c1a207b7..bc3c3dd8 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -26,4 +26,13 @@ y = Tracker.TrackedReal.(rand(2,2)) @test size(ArrayInterface.restructure(x, y)) == (4,) y = Tracker.TrackedArray(rand(2,2)) @test ArrayInterface.restructure(x, y) isa Tracker.TrackedArray -@test size(ArrayInterface.restructure(x, y)) == (4,) \ No newline at end of file +@test size(ArrayInterface.restructure(x, y)) == (4,) + +x = rand(4) +y = ReverseDiff.track(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa ReverseDiff.TrackedArray +@test size(ArrayInterface.restructure(x, y)) == (4,) +y = ReverseDiff.track.(rand(2,2)) +@test ArrayInterface.restructure(x, y) isa Array +@test eltype(ArrayInterface.restructure(x, y)) <: ReverseDiff.TrackedReal +@test size(ArrayInterface.restructure(x, y)) == (4,)