diff --git a/src/problems/nonlinear_problems.jl b/src/problems/nonlinear_problems.jl index 58241f289..43c6e5f9c 100644 --- a/src/problems/nonlinear_problems.jl +++ b/src/problems/nonlinear_problems.jl @@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr * `probs`: the collection of problems to solve * `explictfuns!`: the explicit functions for mutating the parameter set """ -mutable struct SCCNonlinearProblem{P, E} +mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <: + AbstractNonlinearProblem{uType, iip} probs::P - explictfuns!::E + explicitfuns!::E + full_index_provider::I + parameter_object::Par + parameters_alias::Bool + + function SCCNonlinearProblem{P, E, I, Par}( + probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par} + u0 = mapreduce(state_values, vcat, probs) + uType = typeof(u0) + new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias) + end +end + +function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing, + parameter_object = nothing, parameters_alias = false) + return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!), + typeof(full_index_provider), typeof(parameter_object)}( + probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias) +end + +function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol) + if name == :explictfuns! + return getfield(prob, :explicitfuns!) + elseif name == :ps + return ParameterIndexingProxy(prob) + end + return getfield(prob, name) +end + +function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem) + prob.full_index_provider +end +function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem) + prob.parameter_object +end +function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem) + mapreduce(state_values, vcat, prob.probs) +end + +function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx) + for scc in prob.probs + svals = state_values(scc) + checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx) + idx -= length(svals) + end + throw(BoundsError(state_values(prob), idx)) +end + +function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx) + if prob.parameter_object !== nothing + set_parameter!(prob.parameter_object, val, idx) + prob.parameters_alias && return + end + for scc in prob.probs + is_parameter(scc, idx) || continue + set_parameter!(scc, val, idx) + end end diff --git a/src/problems/problem_utils.jl b/src/problems/problem_utils.jl index 15462e8b1..3c10f5fdf 100644 --- a/src/problems/problem_utils.jl +++ b/src/problems/problem_utils.jl @@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem) summary(io, A) println(io) print(io, "u0: ") - show(io, mime, A.u0) + show(io, mime, state_values(A)) end function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem) diff --git a/test/downstream/problem_interface.jl b/test/downstream/problem_interface.jl index c5dee2938..7e68b127b 100644 --- a/test/downstream/problem_interface.jl +++ b/test/downstream/problem_interface.jl @@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps) getsym(prob, [:X, :X2])(prob) == [0.1, 0.2] @test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) == getsym(prob, (:X, :X2))(prob) == (0.1, 0.2) + +@testset "SCCNonlinearProblem" begin + # TODO: Rewrite this example when the MTK codegen is merged + + function fullf!(du, u, p) + du[1] = cos(u[2]) - u[1] + du[2] = sin(u[1] + u[2]) + u[2] + du[3] = 2u[4] + u[3] + p[1] + du[4] = u[5]^2 + u[4] + du[5] = u[3]^2 + u[5] + du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8] + du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8] + du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8] + end + @variables u[1:8]=zeros(8) [irreducible = true] + u2 = collect(u) + @parameters p = 1.0 + eqs = Any[0 for _ in 1:8] + fullf!(eqs, u, [p]) + @named model = NonlinearSystem(0 .~ eqs, [u...], [p]) + model = complete(model; split = false) + + cache = zeros(4) + cache[1] = 1.0 + + function f1!(du, u, p) + du[1] = cos(u[2]) - u[1] + du[2] = sin(u[1] + u[2]) + u[2] + end + explicitfun1(cache, sols) = nothing + + f1!(eqs, u2[1:2], [p]) + @named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p]) + subsys1 = complete(subsys1; split = false) + prob1 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1), + zeros(2), copy(cache)) + + function f2!(du, u, p) + du[1] = 2u[2] + u[1] + p[1] + du[2] = u[3]^2 + u[2] + du[3] = u[1]^2 + u[3] + end + explicitfun2(cache, sols) = nothing + + f2!(eqs, u2[3:5], [p]) + @named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p]) + subsys2 = complete(subsys2; split = false) + prob2 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2), + zeros(3), copy(cache)) + + function f3!(du, u, p) + du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3] + du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3] + du[3] = p[4] + +u[1] - u[2] - u[3] + end + function explicitfun3(cache, sols) + cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3] + cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3] + cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] + + 6.0sols[2][3] + end + + @parameters tmpvar[1:3] + f3!(eqs, u2[6:8], [p, tmpvar...]) + @named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...]) + subsys3 = complete(subsys3; split = false) + prob3 = NonlinearProblem( + NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3), + zeros(3), copy(cache)) + + prob = NonlinearProblem(model, []) + sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3], + SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]), + model, copy(cache)) + + for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]] + @test prob[sym] ≈ sccprob[sym] + end + + for sym in [p, 2p + 1] + @test prob.ps[sym] ≈ sccprob.ps[sym] + end + + for (i, sym) in enumerate([u[1], u[3], u[6]]) + sccprob[sym] = 0.5i + @test sccprob[sym] ≈ 0.5i + @test sccprob.probs[i].u0[1] ≈ 0.5i + end + sccprob.ps[p] = 2.5 + @test sccprob.ps[p] ≈ 2.5 + @test sccprob.parameter_object[1] ≈ 2.5 + for scc in sccprob.probs + @test parameter_values(scc)[1] ≈ 2.5 + end +end