Skip to content

Commit

Permalink
Update SciMLBaseZygoteExt.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas authored Oct 28, 2023
1 parent a44c420 commit 9cb8c14
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module SciMLBaseZygoteExt

using Zygote: pullback
using ZygoteRules: @adjoint
using SciMLBase: ODESolution, issymbollike, sym_to_index, remake, getobserved
import ZygoteRules
using SciMLBase: EnsembleSolution, ODESolution, issymbollike, sym_to_index, remake, getobserved

# This method resolves the ambiguity with the pullback defined in
# RecursiveArrayToolsZygoteExt
Expand Down Expand Up @@ -55,4 +56,25 @@ end
VA[sym, j], ODESolution_getindex_pullback
end

ZygoteRules.@adjoint function DiffEqBase.EnsembleSolution(sim, time, converged)
out = EnsembleSolution(sim, time, converged)
function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
arrarr = [[p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
for j in 1:size(p̄)[end - 1]] for i in 1:size(p̄)[end]]
(EnsembleSolution(arrarr, 0.0, true), nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
(EnsembleSolution(p̄, 0.0, true), nothing, nothing)
end
function EnsembleSolution_adjoint(p̄::EnsembleSolution)
(p̄, nothing, nothing)
end
out, EnsembleSolution_adjoint
end

ZygoteRules.@adjoint function ZygoteRules.literal_getproperty(sim::EnsembleSolution,
::Val{:u})
sim.u, p̄ -> (EnsembleSolution(p̄, 0.0, true),)
end

end

0 comments on commit 9cb8c14

Please sign in to comment.