From d6f380a882c736ed7f82936c37e657fd36bffdc6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Nov 2024 20:39:14 -0500 Subject: [PATCH 1/2] fix: aos_to_soa for all singleton dims --- ext/ArrayInterfaceReverseDiffExt.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ext/ArrayInterfaceReverseDiffExt.jl b/ext/ArrayInterfaceReverseDiffExt.jl index 37176824..3a000def 100644 --- a/ext/ArrayInterfaceReverseDiffExt.jl +++ b/ext/ArrayInterfaceReverseDiffExt.jl @@ -8,11 +8,8 @@ ArrayInterface.ismutable(T::Type{<:ReverseDiff.TrackedReal}) = false ArrayInterface.can_setindex(::Type{<:ReverseDiff.TrackedArray}) = false ArrayInterface.fast_scalar_indexing(::Type{<:ReverseDiff.TrackedArray}) = false function ArrayInterface.aos_to_soa(x::AbstractArray{<:ReverseDiff.TrackedReal, N}) where {N} - if length(x) > 1 - return reshape(reduce(vcat, x), size(x)) - else - return reduce(vcat,[x[1], x[1]])[1:1] - end + y = length(x) > 1 ? reduce(vcat, x) : reduce(vcat, [x[1], x[1]])[1:1] + return reshape(y, size(x)) end function ArrayInterface.restructure(x::Array, y::ReverseDiff.TrackedArray) From 1d931145ffb5765b82cb3021edb4b408866db8c2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 17 Nov 2024 20:41:59 -0500 Subject: [PATCH 2/2] test: singleton dims AoS --- test/ad.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/ad.jl b/test/ad.jl index bc3c3dd8..3c61873e 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,6 +1,9 @@ using ArrayInterface, ReverseDiff, Tracker, Test x = ReverseDiff.track([4.0]) @test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +x = reshape([ReverseDiff.track(rand(1, 1, 1))[1]], 1, 1, 1) +@test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray +@test ndims(ArrayInterface.aos_to_soa(x)) == 3 x = reduce(vcat, ReverseDiff.track([4.0,4.0])) @test ArrayInterface.aos_to_soa(x) isa ReverseDiff.TrackedArray x = [ReverseDiff.track([4.0])[1]]