Skip to content

Commit

Permalink
fix: fix several adjoints, copy and zero methods for VoA
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 17, 2024
1 parent 030923c commit f1e9526
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 16 deletions.
26 changes: 21 additions & 5 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ end
Colon, BitArray, AbstractArray{Bool}}...)
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = VectorOfArray([zero(x) for (x, j) in zip(VA.u, 1:length(VA))])
Δ′[i, j...] = Δ
if isempty(j)
Δ′.u[i] = Δ

Check warning on line 54 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L54

Added line #L54 was not covered by tests
else
Δ′[i, j...] = Δ
end
(Δ′, nothing, map(_ -> nothing, j)...)
end
VA[i, j...], AbstractVectorOfArray_getindex_adjoint
Expand Down Expand Up @@ -104,13 +108,25 @@ end
end

@adjoint function Base.Array(VA::AbstractVectorOfArray)
Array(VA),
y -> (Array(y),)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
VA .= y
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
view(A, I...),
y -> (view(y, I...), ntuple(_ -> nothing, length(I))...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
end
view(A, I...), adj
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down
38 changes: 27 additions & 11 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ function DiffEqArray(vec::AbstractVector{T},
p,
sys)
end
function DiffEqArray(vec::AbstractVector{VT},

Check warning on line 163 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L163

Added line #L163 was not covered by tests
ts::AbstractVector,
::NTuple{N, Int},
p = nothing,
sys = nothing) where {T, N, VT <: AbstractArray{T, N}}
DiffEqArray{eltype(T), N, typeof(vec), typeof(ts), typeof(p), typeof(sys)}(vec,

Check warning on line 168 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L168

Added line #L168 was not covered by tests
ts,
p,
sys)
end
# Assume that the first element is representative of all other elements

function DiffEqArray(vec::AbstractVector,
Expand Down Expand Up @@ -466,19 +476,25 @@ end
tuples(VA::DiffEqArray) = tuple.(VA.t, VA.u)

# Growing the array simply adds to the container vector
function Base.copy(VA::AbstractDiffEqArray)
typeof(VA)(copy(VA.u),
copy(VA.t),
(VA.p === nothing) ? nothing : copy(VA.p),
(VA.sys === nothing) ? nothing : copy(VA.sys))
function _copyfield(VA, fname)
if fname == :u
copy(VA.u)
elseif fname == :t
copy(VA.t)
else
getfield(VA, fname)
end
end
function Base.copy(VA::AbstractVectorOfArray)
typeof(VA)((_copyfield(VA, fname) for fname in fieldnames(typeof(VA)))...)
end
Base.copy(VA::AbstractVectorOfArray) = typeof(VA)(copy(VA.u))

Base.zero(VA::AbstractVectorOfArray) = VectorOfArray(Base.zero.(VA.u))

function Base.zero(VA::AbstractDiffEqArray)
u = Base.zero.(VA.u)
DiffEqArray(u, VA.t, parameter_values(VA), symbolic_container(VA))
function Base.zero(VA::AbstractVectorOfArray)
val = copy(VA)
for i in eachindex(VA.u)
val.u[i] = zero(VA[i])
end
return val
end

Base.sizehint!(VA::AbstractVectorOfArray{T, N}, i) where {T, N} = sizehint!(VA.u, i)
Expand Down

0 comments on commit f1e9526

Please sign in to comment.