From f09aee529df84108aae5e4c77e99a53440574b61 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Nov 2024 16:40:36 +0530 Subject: [PATCH] feat: add lazy initialization to `remake` --- src/initialization.jl | 12 ++++++++++++ src/remake.jl | 11 ++++++++++- test/downstream/modelingtoolkit_remake.jl | 9 +++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/initialization.jl b/src/initialization.jl index 9f5274eae..eea553005 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -225,3 +225,15 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, return u0, p, success end + +function is_trivial_initialization(initdata::OverrideInitData) + state_values(initdata.initializeprob) === nothing +end + +function is_trivial_initialization(f::AbstractSciMLFunction) + has_initialization_data(f) && is_trivial_initialization(f.initialization_data) +end + +function is_trivial_initialization(prob::AbstractSciMLProblem) + is_trivial_initialization(prob.f) +end diff --git a/src/remake.jl b/src/remake.jl index 280957671..d008c2e3f 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -114,6 +114,7 @@ function remake(prob::ODEProblem; f = missing, interpret_symbolicmap = true, build_initializeprob = true, use_defaults = false, + lazy_initialization = !is_trivial_initialization(prob), _kwargs...) if tspan === missing tspan = prob.tspan @@ -170,13 +171,21 @@ function remake(prob::ODEProblem; f = missing, _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end - if kwargs === missing + prob = if kwargs === missing ODEProblem{isinplace(prob)}( _f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., _kwargs...) else ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...) end + + if !lazy_initialization + u0, p, _ = get_initial_values(prob, prob, prob.f, OverrideInit(), Val(isinplace(prob))) + @reset prob.u0 = u0 + @reset prob.p = p + end + + return prob end """ diff --git a/test/downstream/modelingtoolkit_remake.jl b/test/downstream/modelingtoolkit_remake.jl index 228a26dd6..6029a06ab 100644 --- a/test/downstream/modelingtoolkit_remake.jl +++ b/test/downstream/modelingtoolkit_remake.jl @@ -274,3 +274,12 @@ end @test_throws SciMLBase.CyclicDependencyError remake( prob; u0 = [x => 2y + p, y => q + 3], p = [p => x + y, q => p + 3]) end + +@testset "Lazy initialization" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p = missing [guess = 1.0] + @mtkbuild sys = ODESystem([D(x) ~ x, x + y ~ p], t) + prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + prob2 = remake(prob; u0 = [x => 2.0]) + @test prob2.ps[p] ≈ 3.0 +end