Skip to content

Commit

Permalink
feat: add lazy initialization to remake
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 3, 2024
1 parent a4fd6d8 commit 651899b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,15 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,

return u0, p, success
end

function is_trivial_initialization(initdata::OverrideInitData)
state_values(initdata.initializeprob) === nothing
end

function is_trivial_initialization(f::AbstractSciMLFunction)
has_initialization_data(f) && is_trivial_initialization(f.initialization_data)
end

function is_trivial_initialization(prob::AbstractSciMLProblem)
is_trivial_initialization(prob.f)
end
17 changes: 16 additions & 1 deletion src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing,
interpret_symbolicmap = true,
build_initializeprob = true,
use_defaults = false,
lazy_initialization = nothing,
_kwargs...)
if tspan === missing
tspan = prob.tspan
Expand All @@ -123,6 +124,8 @@ function remake(prob::ODEProblem; f = missing,

iip = isinplace(prob)

initialization_data = prob.f.initialization_data

if f === missing
if build_initializeprob
initialization_data = remake_initialization_data_compat_wrapper(
Expand Down Expand Up @@ -170,13 +173,25 @@ function remake(prob::ODEProblem; f = missing,
_f = ODEFunction{isinplace(prob), specialization(prob.f)}(f)
end

if kwargs === missing
prob = if kwargs === missing
ODEProblem{isinplace(prob)}(
_f, newu0, tspan, newp, prob.problem_type; prob.kwargs...,
_kwargs...)
else
ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...)
end

if lazy_initialization === nothing
lazy_initialization = !is_trivial_initialization(initialization_data)
end
if !lazy_initialization
u0, p, _ = get_initial_values(
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
@reset prob.u0 = u0
@reset prob.p = p
end

return prob
end

"""
Expand Down
9 changes: 9 additions & 0 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,12 @@ end
@test sccprob4.parameter_object !== sccprob4.probs[1].p
@test sccprob4.parameter_object !== sccprob4.probs[2].p
end

@testset "Lazy initialization" begin
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
@parameters p=missing [guess = 1.0]
@mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t)
prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
prob2 = remake(prob; u0 = [x => 2.0])
@test prob2.ps[p] 3.0
end

0 comments on commit 651899b

Please sign in to comment.