Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change SCCNonlinearProblem fields #884

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions src/problems/nonlinear_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,28 +462,30 @@ Note that this example aliases the parameters together for a memory-reduced repr
* `probs`: the collection of problems to solve
* `explictfuns!`: the explicit functions for mutating the parameter set
"""
mutable struct SCCNonlinearProblem{uType, iip, P, E, I, Par} <:
mutable struct SCCNonlinearProblem{uType, iip, P, E, F <: NonlinearFunction{iip}, Par} <:
AbstractNonlinearProblem{uType, iip}
probs::P
explicitfuns!::E
full_index_provider::I
parameter_object::Par
# NonlinearFunction with `f = Returns(nothing)`
f::F
p::Par
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what is this p? How does it relate to the actual p of the probs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This p is the parameter object for SII. Since the parameter object is not guaranteed to be identical for all problems, the user can provide one that SII will use and that prob.f.observed will codegen for. There's a flag if all problems share the same parameter object, which aliases prob.p and functions like set_parameter! take advantage of this.

parameters_alias::Bool

function SCCNonlinearProblem{P, E, I, Par}(
probs::P, funs::E, indp::I, pobj::Par, alias::Bool) where {P, E, I, Par}
function SCCNonlinearProblem{P, E, F, Par}(probs::P, funs::E, f::F, pobj::Par,
alias::Bool) where {P, E, F <: NonlinearFunction, Par}
u0 = mapreduce(
state_values, vcat, probs; init = similar(state_values(first(probs)), 0))
uType = typeof(u0)
new{uType, false, P, E, I, Par}(probs, funs, indp, pobj, alias)
new{uType, false, P, E, F, Par}(probs, funs, f, pobj, alias)
end
end

function SCCNonlinearProblem(probs, explicitfuns!, full_index_provider = nothing,
parameter_object = nothing, parameters_alias = false)
function SCCNonlinearProblem(probs, explicitfuns!, parameter_object = nothing,
parameters_alias = false; kwargs...)
f = NonlinearFunction{false}(Returns(nothing); kwargs...)
return SCCNonlinearProblem{typeof(probs), typeof(explicitfuns!),
typeof(full_index_provider), typeof(parameter_object)}(
probs, explicitfuns!, full_index_provider, parameter_object, parameters_alias)
typeof(f), typeof(parameter_object)}(
probs, explicitfuns!, f, parameter_object, parameters_alias)
end

function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
Expand All @@ -496,10 +498,10 @@ function Base.getproperty(prob::SCCNonlinearProblem, name::Symbol)
end

function SymbolicIndexingInterface.symbolic_container(prob::SCCNonlinearProblem)
prob.full_index_provider
prob.f
end
function SymbolicIndexingInterface.parameter_values(prob::SCCNonlinearProblem)
prob.parameter_object
prob.p
end
function SymbolicIndexingInterface.state_values(prob::SCCNonlinearProblem)
mapreduce(
Expand All @@ -516,8 +518,8 @@ function SymbolicIndexingInterface.set_state!(prob::SCCNonlinearProblem, val, id
end

function SymbolicIndexingInterface.set_parameter!(prob::SCCNonlinearProblem, val, idx)
if prob.parameter_object !== nothing
set_parameter!(prob.parameter_object, val, idx)
if prob.p !== nothing
set_parameter!(prob.p, val, idx)
prob.parameters_alias && return
end
for scc in prob.probs
Expand Down
14 changes: 6 additions & 8 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
if p !== missing && !parameters_alias && probs === missing
throw(ArgumentError("`parameters_alias` is `false` for the given `SCCNonlinearProblem`. Please provide the subproblems using the keyword `probs` with the parameters updated appropriately in each."))
end
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults,
indp = sys === missing ? prob.full_index_provider : sys)
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
if probs === missing
probs = prob.probs
end
Expand All @@ -547,11 +546,10 @@ function remake(prob::SCCNonlinearProblem; u0 = missing, p = missing, probs = mi
end
end
if sys === missing
sys = prob.full_index_provider
sys = prob.f.sys
end
return SCCNonlinearProblem{
typeof(probs), typeof(explicitfuns!), typeof(sys), typeof(newp)}(
probs, explicitfuns!, sys, newp, parameters_alias)
return SCCNonlinearProblem(
probs, explicitfuns!, newp, parameters_alias; sys)
end

function varmap_has_var(varmap, var)
Expand Down Expand Up @@ -784,11 +782,11 @@ end

function updated_u0_p(
prob, u0, p, t0 = nothing; interpret_symbolicmap = true,
use_defaults = false, indp = has_sys(prob.f) ? prob.f.sys : nothing)
use_defaults = false)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
if indp === nothing
if prob.f.sys === nothing

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this line causes SciML/Integrals.jl#259

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if interpret_symbolicmap && eltype(p) !== Union{} && eltype(p) <: Pair
throw(ArgumentError("This problem does not support symbolic maps with " *
"`remake`, i.e. it does not have a symbolic origin. Please use `remake`" *
Expand Down
12 changes: 6 additions & 6 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fullsys = complete(fullsys)
prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
push!(syss, fullsys)
push!(probs, sccprob)

Expand Down Expand Up @@ -315,16 +315,16 @@ end
prob1 = NonlinearProblem(sys1, u0, p)
prob2 = NonlinearProblem(sys2, u0, prob1.p)
sccprob = SCCNonlinearProblem(
[prob1, prob2], [Returns(nothing), Returns(nothing)], fullsys, prob1.p, true)
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)

sccprob2 = remake(sccprob; u0 = 2ones(3))
@test state_values(sccprob2) ≈ 2ones(3)
@test sccprob2.probs[1].u0 ≈ 2ones(2)
@test sccprob2.probs[2].u0 ≈ 2ones(1)

sccprob3 = remake(sccprob; p = [σ => 2.0])
@test sccprob3.parameter_object === sccprob3.probs[1].p
@test sccprob3.parameter_object === sccprob3.probs[2].p
@test sccprob3.p === sccprob3.probs[1].p
@test sccprob3.p === sccprob3.probs[2].p

@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
sccprob; parameters_alias = false, p = [σ => 2.0])
Expand All @@ -333,6 +333,6 @@ end
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
probs = [remake(prob1; p = [σ => 3.0]), prob2])
@test !sccprob4.parameters_alias
@test sccprob4.parameter_object !== sccprob4.probs[1].p
@test sccprob4.parameter_object !== sccprob4.probs[2].p
@test sccprob4.p !== sccprob4.probs[1].p
@test sccprob4.p !== sccprob4.probs[2].p
end
4 changes: 2 additions & 2 deletions test/downstream/problem_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ prob = SteadyStateProblem(osys, u0, ps)
prob = NonlinearProblem(model, [])
sccprob = SciMLBase.SCCNonlinearProblem([prob1, prob2, prob3],
SciMLBase.Void{Any}.([explicitfun1, explicitfun2, explicitfun3]),
model, copy(cache))
copy(cache); sys = model)

for sym in [u, u..., u[2] + u[3], p * u[1] + u[2]]
@test prob[sym] ≈ sccprob[sym]
Expand All @@ -384,7 +384,7 @@ prob = SteadyStateProblem(osys, u0, ps)
end
sccprob.ps[p] = 2.5
@test sccprob.ps[p] ≈ 2.5
@test sccprob.parameter_object[1] ≈ 2.5
@test sccprob.p[1] ≈ 2.5
for scc in sccprob.probs
@test parameter_values(scc)[1] ≈ 2.5
end
Expand Down
Loading