diff --git a/Project.toml b/Project.toml index e926c8d1f..c809c13c2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SciMLBase" uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462" authors = ["Chris Rackauckas and contributors"] -version = "2.63.1" +version = "2.64.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index 58241f289..43c6e5f9c 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr * `probs`: the collection of problems to solve * `explictfuns!`: the explicit functions for mutating the parameter set """ -mutable struct SCCNonlinearProblem{P, E} +mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <: + AbstractNonlinearProblem{uType, iip} probs::P - explictfuns!::E + explicitfuns!::E + full_index_provider::I + parameter_object::Par + parameters_alias::Bool + + function SCCNonlinearProblem{P, E, I, Par}( + probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par} + u0 = mapreduce(state_values, vcat, probs) + uType = typeof(u0) + new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias) + end +end + +function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing, + parameter_object = nothing, parameters_alias = false) + return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!), + typeof(full_index_provider), typeof(parameter_object)}( + probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias) +end + +function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) + if name == :explictfuns! + return getfield(prob, :explicitfuns!) + elseif name == :ps + return ParameterIndexingProxy(prob) + end + return getfield(prob, name) +end + +function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem) + prob.full_index_provider +end +function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem) + prob.parameter_object +end +function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem) + mapreduce(state_values, vcat, prob.probs) +end + +function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx) + for scc in prob.probs + svals = state_values(scc) + checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx) + idx -= length(svals) + end + throw(BoundsError(state_values(prob), idx)) +end + +function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx) + if prob.parameter_object !== nothing + set_parameter!(prob.parameter_object, val, idx) + prob.parameters_alias && return + end + for scc in prob.probs + is_parameter(scc, idx) || continue + set_parameter!(scc, val, idx) + end end diff --git a/src/problems/problem_utils.jl b/src/problems/problem_utils.jl index 15462e8b1..3c10f5fdf 100644 --- a/src/problems/problem_utils.jl +++ b/src/problems/problem_utils.jl @@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem) summary(io, A) println(io) print(io, "u0: ") - show(io, mime, A.u0) + show(io, mime, state_values(A)) end function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem) diff --git a/src/remake.jl b/src/remake.jl index 750413a1f..280957671 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -125,8 +125,8 @@ function remake(prob::ODEProblem; f = missing, if f === missing if build_initializeprob - initialization_data = remake_initialization_data( - prob.f.sys, prob.f, u0, tspan[1], p) + initialization_data = remake_initialization_data_compat_wrapper( + prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp) else initialization_data = nothing end @@ -203,16 +203,32 @@ function remake_initializeprob(sys, scimlfn, u0, t0, p) end """ - remake_initialization_data(sys, scimlfn, u0, t0, p) + $(TYPEDSIGNATURES) + +Wrapper around `remake_initialization_data` for backward compatibility when `newu0` and +`newp` were not arguments. +""" +function remake_initialization_data_compat_wrapper(sys, scimlfn, u0, t0, p, newu0, newp) + if hasmethod(remake_initialization_data, + Tuple{typeof(sys), typeof(scimlfn), typeof(u0), typeof(t0), typeof(p)}) + remake_initialization_data(sys, scimlfn, u0, t0, p) + else + remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) + end +end + +""" + remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) Re-create the initialization data present in the function `scimlfn`, using the -associated system `sys` and the user provided new values of `u0`, initial time `t0` and -`p`. By default, this calls `remake_initializeprob` for backward compatibility and -attempts to construct an `OverrideInitData` from the result. +associated system `sys`, the user provided new values of `u0`, initial time `t0`, +user-provided `p`, new u0 vector `newu0` and new parameter object `newp`. By default, +this calls `remake_initializeprob` for backward compatibility and attempts to construct +an `OverrideInitData` from the result. Note that `u0` or `p` may be `missing` if the user does not provide a value for them. """ -function remake_initialization_data(sys, scimlfn, u0, t0, p) +function remake_initialization_data(sys, scimlfn, u0, t0, p, newu0, newp) return reconstruct_initialization_data( nothing, remake_initializeprob(sys, scimlfn, u0, t0, p)...) end diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index ebf0dda85..b036b0b12 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -38,7 +38,9 @@ ModelingToolkitStandardLibrary = "2.7" NonlinearSolve = "2, 3, 4" Optimization = "4" OptimizationOptimJL = "0.4" +OptimizationMOI = "0.5" OrdinaryDiffEq = "6.33" +PartialFunctions = "1" Plots = "1.40" RecursiveArrayTools = "3" SciMLBase = "2" diff --git a/test/downstream/adjoints.jl b/test/downstream/adjoints.jl index 327172ef7..4e75e19b6 100644 --- a/test/downstream/adjoints.jl +++ b/test/downstream/adjoints.jl @@ -68,7 +68,7 @@ gs_ts, = Zygote.gradient(sol) do sol sum(sum.(sol[[lorenz1.x, lorenz2.x], :])) end -@test_broken all(map(x -> x == true_grad_vecsym, gs_ts)) +@test all(map(x -> x == true_grad_vecsym, gs_ts)) # BatchedInterface AD @variables x(t)=1.0 y(t)=1.0 z(t)=1.0 w(t)=1.0 diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index c5dee2938..7e68b127b 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps) getsym(prob, [:X, :X2])(prob) == [0.1, 0.2] @test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) == getsym(prob, (:X, :X2))(prob) == (0.1, 0.2) + +@testset "SCCNonlinearProblem" begin + # TODO: Rewrite this example when the MTK codegen is merged + + function fullf!(du, u, p) + du[1] = cos(u[2]) - u[1] + du[2] = sin(u[1] + u[2]) + u[2] + du[3] = 2u[4] + u[3] + p[1] + du[4] = u[5]^2 + u[4] + du[5] = u[3]^2 + u[5] + du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8] + du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8] + du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8] + end + @variables u[1:8]=zeros(8) [irreducible = true] + u2 = collect(u) + @parameters p = 1.0 + eqs = Any[0 for _ in 1:8] + fullf!(eqs, u, [p]) + @named model = NonlinearSystem(0 .~ eqs, [u...], [p]) + model = complete(model; split = false) + + cache = zeros(4) + cache[1] = 1.0 + + function f1!(du, u, p) + du[1] = cos(u[2]) - u[1] + du[2] = sin(u[1] + u[2]) + u[2] + end + explicitfun1(cache, sols) = nothing + + f1!(eqs, u2[1:2], [p]) + @named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p]) + subsys1 = complete(subsys1; split = false) + prob1 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1), + zeros(2), copy(cache)) + + function f2!(du, u, p) + du[1] = 2u[2] + u[1] + p[1] + du[2] = u[3]^2 + u[2] + du[3] = u[1]^2 + u[3] + end + explicitfun2(cache, sols) = nothing + + f2!(eqs, u2[3:5], [p]) + @named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p]) + subsys2 = complete(subsys2; split = false) + prob2 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2), + zeros(3), copy(cache)) + + function f3!(du, u, p) + du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3] + du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3] + du[3] = p[4] + +u[1] - u[2] - u[3] + end + function explicitfun3(cache, sols) + cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3] + cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3] + cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] + + 6.0sols[2][3] + end + + @parameters tmpvar[1:3] + f3!(eqs, u2[6:8], [p, tmpvar...]) + @named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...]) + subsys3 = complete(subsys3; split = false) + prob3 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3), + zeros(3), copy(cache)) + + prob = NonlinearProblem(model, []) + sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], + SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), + model, copy(cache)) + + for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] + @test prob[sym] ≈ sccprob[sym] + end + + for sym in [p, 2p + 1] + @test prob.ps[sym] ≈ sccprob.ps[sym] + end + + for (i, sym) in enumerate([u[1], u[3], u[6]]) + sccprob[sym] = 0.5i + @test sccprob[sym] ≈ 0.5i + @test sccprob.probs[i].u0[1] ≈ 0.5i + end + sccprob.ps[p] = 2.5 + @test sccprob.ps[p] ≈ 2.5 + @test sccprob.parameter_object[1] ≈ 2.5 + for scc in sccprob.probs + @test parameter_values(scc)[1] ≈ 2.5 + end +end