Skip to content

Commit

Permalink
Merge pull request #877 from AayushSabharwal/as/scc-nlprob-sii
Browse files Browse the repository at this point in the history
feat: implement SII for `SCCNonlinearProblem`
  • Loading branch information
ChrisRackauckas authored Nov 26, 2024
2 parents e3846c8 + 069ffd2 commit bbc413b
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 3 deletions.
61 changes: 59 additions & 2 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,64 @@ Note that this example aliases the parameters together for a memory-reduced repr
* `probs`: the collection of problems to solve
* `explictfuns!`: the explicit functions for mutating the parameter set
"""
mutable struct SCCNonlinearProblem{P, E}
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
AbstractNonlinearProblem{uType, iip}
probs::P
explictfuns!::E
explicitfuns!::E
full_index_provider::I
parameter_object::Par
parameters_alias::Bool

function SCCNonlinearProblem{P, E, I, Par}(
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
u0 = mapreduce(state_values, vcat, probs)
uType = typeof(u0)
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
end
end

function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
parameter_object = nothing, parameters_alias = false)
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
typeof(full_index_provider), typeof(parameter_object)}(
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
end

function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
if name == :explictfuns!
return getfield(prob, :explicitfuns!)
elseif name == :ps
return ParameterIndexingProxy(prob)
end
return getfield(prob, name)
end

function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
prob.full_index_provider
end
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
prob.parameter_object
end
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
mapreduce(state_values, vcat, prob.probs)
end

function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, idx)
for scc in prob.probs
svals = state_values(scc)
checkbounds(Bool, svals, idx) && return set_state!(scc, val, idx)
idx -= length(svals)
end
throw(BoundsError(state_values(prob), idx))
end

function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
if prob.parameter_object !== nothing
set_parameter!(prob.parameter_object, val, idx)
prob.parameters_alias && return
end
for scc in prob.probs
is_parameter(scc, idx) || continue
set_parameter!(scc, val, idx)
end
end
2 changes: 1 addition & 1 deletion src/problems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function Base.show(io::IO, mime::MIME"text/plain", A::AbstractNonlinearProblem)
summary(io, A)
println(io)
print(io, "u0: ")
show(io, mime, A.u0)
show(io, mime, state_values(A))
end

function Base.show(io::IO, mime::MIME"text/plain", A::IntervalNonlinearProblem)
Expand Down
97 changes: 97 additions & 0 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,100 @@ prob = SteadyStateProblem(osys, u0, ps)
getsym(prob, [:X, :X2])(prob) == [0.1, 0.2]
@test getsym(prob, (X, X2))(prob) == getsym(prob, (osys.X, osys.X2))(prob) ==
getsym(prob, (:X, :X2))(prob) == (0.1, 0.2)

@testset "SCCNonlinearProblem" begin
# TODO: Rewrite this example when the MTK codegen is merged

function fullf!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
du[3] = 2u[4] + u[3] + p[1]
du[4] = u[5]^2 + u[4]
du[5] = u[3]^2 + u[5]
du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8]
du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8]
du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8]
end
@variables u[1:8]=zeros(8) [irreducible = true]
u2 = collect(u)
@parameters p = 1.0
eqs = Any[0 for _ in 1:8]
fullf!(eqs, u, [p])
@named model = NonlinearSystem(0 .~ eqs, [u...], [p])
model = complete(model; split = false)

cache = zeros(4)
cache[1] = 1.0

function f1!(du, u, p)
du[1] = cos(u[2]) - u[1]
du[2] = sin(u[1] + u[2]) + u[2]
end
explicitfun1(cache, sols) = nothing

f1!(eqs, u2[1:2], [p])
@named subsys1 = NonlinearSystem(0 .~ eqs[1:2], [u2[1:2]...], [p])
subsys1 = complete(subsys1; split = false)
prob1 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f1!; sys = subsys1),
zeros(2), copy(cache))

function f2!(du, u, p)
du[1] = 2u[2] + u[1] + p[1]
du[2] = u[3]^2 + u[2]
du[3] = u[1]^2 + u[3]
end
explicitfun2(cache, sols) = nothing

f2!(eqs, u2[3:5], [p])
@named subsys2 = NonlinearSystem(0 .~ eqs[1:3], [u2[3:5]...], [p])
subsys2 = complete(subsys2; split = false)
prob2 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f2!; sys = subsys2),
zeros(3), copy(cache))

function f3!(du, u, p)
du[1] = p[2] + 2.0u[1] + 2.5u[2] + 1.5u[3]
du[2] = p[3] + 4.0u[1] - 1.5u[2] + 1.5u[3]
du[3] = p[4] + +u[1] - u[2] - u[3]
end
function explicitfun3(cache, sols)
cache[2] = sols[1][1] + sols[1][2] + sols[2][1] + sols[2][2] + sols[2][3]
cache[3] = sols[1][1] + sols[1][2] + sols[2][1] + 2.0sols[2][2] + sols[2][3]
cache[4] = sols[1][1] + 2.0sols[1][2] + 3.0sols[2][1] + 5.0sols[2][2] +
6.0sols[2][3]
end

@parameters tmpvar[1:3]
f3!(eqs, u2[6:8], [p, tmpvar...])
@named subsys3 = NonlinearSystem(0 .~ eqs[1:3], [u2[6:8]...], [p, tmpvar...])
subsys3 = complete(subsys3; split = false)
prob3 = NonlinearProblem(
NonlinearFunction{true, SciMLBase.NoSpecialize}(f3!; sys = subsys3),
zeros(3), copy(cache))

prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
model, copy(cache))

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] sccprob[sym]
end

for sym in [p, 2p + 1]
@test prob.ps[sym] sccprob.ps[sym]
end

for (i, sym) in enumerate([u[1], u[3], u[6]])
sccprob[sym] = 0.5i
@test sccprob[sym] 0.5i
@test sccprob.probs[i].u0[1] 0.5i
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] 2.5
@test sccprob.parameter_object[1] 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] 2.5
end
end

0 comments on commit bbc413b

Please sign in to comment.