diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 2e30d27a5..30a6c0b57 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -22,7 +22,7 @@ import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType -import Accessors: @set, @reset, @delete +import Accessors: @set, @reset, @delete, @insert using Expronicon.ADT: @match using Reexport diff --git a/src/initialization.jl b/src/initialization.jl index ddf9a96d0..50dd35f44 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -124,11 +124,13 @@ function evaluate_f( return _evaluate_f(integrator, f, isinplace, integrator.du, u, p, t) end -function evaluate_f(integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t) +function evaluate_f( + integrator::AbstractDDEIntegrator, prob::AbstractDDEProblem, f, isinplace, u, p, t) return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) end -function evaluate_f(integrator::AbstractSDDEIntegrator, prob::AbstractSDDEProblem, f, isinplace, u, p, t) +function evaluate_f(integrator::AbstractSDDEIntegrator, + prob::AbstractSDDEProblem, f, isinplace, u, p, t) return _evaluate_f(integrator, f, isinplace, u, get_history_function(integrator), p, t) end diff --git a/src/remake.jl b/src/remake.jl index 53e1e2911..31d013964 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -99,6 +99,105 @@ function remake( _remake_internal(prob; kwargs..., p) end +""" + $(TYPEDSIGNATURES) + +A utility function which merges two `NamedTuple`s `a` and `b`, assuming that the +keys of `a` are a subset of those of `b`. Values in `b` take priority over those +in `a`, except if they are `nothing`. Keys not present in `a` are assumed to have +a value of `nothing`. +""" +function _similar_namedtuple_merge_ignore_nothing(a::NamedTuple, b::NamedTuple) + ks = fieldnames(typeof(b)) + return NamedTuple{ks}(ntuple(Val(length(ks))) do i + something(get(b, ks[i], nothing), get(a, ks[i], nothing), Some(nothing)) + end) +end + +""" + remake(func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...) + +`remake` the given `func`. Return an `AbstractSciMLFunction` of the same kind, `isinplace` and +`specialization` as `func`. Retain the properties of `func`, except those that are overridden +by keyword arguments. For stochastic functions (e.g. `SDEFunction`) the `g` keyword argument +is used to override `func.g`. For split functions (e.g. `SplitFunction`) the `f2` keyword +argument is used to override `func.f2`, and `f` is used for `func.f1`. If +`f isa AbstractSciMLFunction` and `func` is not a split function, properties of `f` will +override those of `func` (but not ones provided via keyword arguments). Properties of `f` that +are `nothing` will fall back to those in `func` (unless provided via keyword arguments). If +`f` is a different type of `AbstractSciMLFunction` from `func`, the returned function will be +of the kind of `f` unless `func` is a split function. If `func` is a split function, `f` and +`f2` will be wrapped in the appropriate `AbstractSciMLFunction` type with the same `isinplace` +and `specialization` as `func`. +""" +function remake( + func::AbstractSciMLFunction; f = missing, g = missing, f2 = missing, kwargs...) + # retain iip and spec of original function + iip = isinplace(func) + spec = specialization(func) + # retain properties of original function + props = getproperties(func) + + if f === missing || is_split_function(func) + # if no `f` is provided, create the same type of SciMLFunction + T = parameterless_type(func) + f = isdefined(func, :f) ? func.f : func.f1 + elseif f isa AbstractSciMLFunction + # if `f` is a SciMLFunction, create that type + T = parameterless_type(f) + # properties of `f` take priority over those in the existing `func` + # ignore properties of `f` which are `nothing` but present in `func` + props = _similar_namedtuple_merge_ignore_nothing(props, getproperties(f)) + f = isdefined(f, :f) ? f.f : f.f1 + else + # if `f` is provided but not a SciMLFunction, create the same type + T = parameterless_type(func) + end + + # minor hack to avoid breaking MTK, since prior to ~9.57 in `remake_initialization_data` + # it creates a `NonlinearFunction` inside a `NonlinearFunction`. Just recursively unwrap + # in this case and forget about properties. + while !is_split_function(T) && f isa AbstractSciMLFunction + f = isdefined(f, :f) ? f.f : f.f1 + end + + props = @delete props.f + props = @delete props.f1 + + args = (f,) + if is_split_function(T) + # for DynamicalSDEFunction and SplitFunction + if isdefined(props, :cache) + props = @insert props._func_cache = props.cache + props = @delete props.cache + end + + # `f1` and `f2` are wrapped in another SciMLFunction, unless they're + # already wrapped in the appropriate type or are an `AbstractSciMLOperator` + if !(f isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)}) + f = split_function_f_wrapper(T){iip, spec}(f) + end + # For SplitFunction + # we don't do the same thing as `g`, because for SDEs `g` is + # stored in the problem as well, whereas for Split ODEs etc + # f2 is a part of the function. Thus, if the user provides + # a SciMLFunction for `f` which contains `f2` we use that. + f2 = coalesce(f2, get(props, :f2, missing), func.f2) + if !(f2 isa Union{AbstractSciMLOperator, split_function_f_wrapper(T)}) + f2 = split_function_f_wrapper(T){iip, spec}(f2) + end + props = @delete props.f2 + args = (args..., f2) + end + if isdefined(func, :g) + # For SDEs/SDDEs where `g` is not a keyword + g = coalesce(g, func.g) + props = @delete props.g + args = (args..., g) + end + T{iip, spec}(args...; props..., kwargs...) +end + """ remake(prob::ODEProblem; f = missing, u0 = missing, tspan = missing, p = missing, kwargs = missing, _kwargs...) @@ -135,53 +234,26 @@ function remake(prob::ODEProblem; f = missing, initialization_data = nothing end - if f === missing - if specialization(prob.f) === FunctionWrapperSpecialize - ptspan = promote_tspan(tspan) - if iip - _f = ODEFunction{iip, FunctionWrapperSpecialize}( - wrapfun_iip( - unwrapped_f(prob.f.f), - (newu0, newu0, newp, - ptspan[1])); initialization_data) - else - _f = ODEFunction{iip, FunctionWrapperSpecialize}( - wrapfun_oop( - unwrapped_f(prob.f.f), - (newu0, newp, - ptspan[1])); initialization_data) - end - else - _f = prob.f - if __has_initialization_data(_f) - props = getproperties(_f) - @reset props.initialization_data = initialization_data - props = values(props) - _f = parameterless_type(_f){iip, specialization(_f), map(typeof, props)...}(props...) - end - end - elseif f isa AbstractODEFunction - _f = f - elseif specialization(prob.f) === FunctionWrapperSpecialize + f = coalesce(f, prob.f) + f = remake(prob.f; f, initialization_data) + + if specialization(f) === FunctionWrapperSpecialize ptspan = promote_tspan(tspan) if iip - _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f, - (newu0, newu0, newp, - ptspan[1]))) + f = remake( + f; f = wrapfun_iip(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1]))) else - _f = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f, - (newu0, newp, ptspan[1]))) + f = remake( + f; f = wrapfun_oop(unwrapped_f(f.f), (newu0, newu0, newp, ptspan[1]))) end - else - _f = ODEFunction{isinplace(prob), specialization(prob.f)}(f) end prob = if kwargs === missing - ODEProblem{isinplace(prob)}( - _f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., + ODEProblem{iip}( + f, newu0, tspan, newp, prob.problem_type; prob.kwargs..., _kwargs...) else - ODEProblem{isinplace(prob)}(_f, newu0, tspan, newp, prob.problem_type; kwargs...) + ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...) end if lazy_initialization === nothing @@ -395,42 +467,6 @@ function remake(prob::SDEProblem; return prob end -""" - remake(func::SDEFunction; f = missing, g = missing, - mass_matrix = missing, analytic = missing, kwargs...) - -Remake the given `SDEFunction`. -""" -function remake(func::Union{SDEFunction, SDDEFunction}; - f = missing, - g = missing, - mass_matrix = missing, - analytic = missing, - sys = missing, - kwargs...) - props = getproperties(func) - props = @delete props.f - props = @delete props.g - @reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix) - @reset props.analytic = coalesce(analytic, func.analytic) - @reset props.sys = coalesce(sys, func.sys) - - if f === missing - f = func.f - end - - if g === missing - g = func.g - end - - if f isa AbstractSciMLFunction - f = f.f - end - - T = func isa SDEFunction ? SDEFunction : SDDEFunction - return T{isinplace(func)}(f, g; props..., kwargs...) -end - function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, tspan = missing, p = missing, constant_lags = missing, dependent_lags = missing, order_discontinuity_t0 = missing, @@ -497,28 +533,6 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing, return prob end -function remake(func::DDEFunction; - f = missing, - mass_matrix = missing, - analytic = missing, - sys = missing, - kwargs...) - props = getproperties(func) - props = @delete props.f - @reset props.mass_matrix = coalesce(mass_matrix, func.mass_matrix) - @reset props.analytic = coalesce(analytic, func.analytic) - @reset props.sys = coalesce(sys, func.sys) - - if f === missing - f = func.f - end - if f isa AbstractSciMLFunction - f = f.f - end - - return DDEFunction{isinplace(func)}(f; props..., kwargs...) -end - function remake(prob::SDDEProblem; f = missing, g = missing, @@ -706,6 +720,7 @@ function remake(prob::NonlinearProblem; initialization_data = nothing end + f = coalesce(f, prob.f) f = remake(prob.f; f, initialization_data) if problem_type === missing @@ -737,22 +752,6 @@ function remake(prob::NonlinearProblem; return prob end -function remake(func::NonlinearFunction; - f = missing, - kwargs...) - props = getproperties(func) - props = @delete props.f - - if f === missing - f = func.f - end - if f isa AbstractSciMLFunction - f = f.f - end - - return NonlinearFunction{isinplace(func)}(f; props..., kwargs...) -end - """ remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p = missing, kwargs = missing, _kwargs...) @@ -775,6 +774,7 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p initialization_data = nothing end + f = coalesce(f, prob.f) f = remake(prob.f; f, initialization_data) prob = if kwargs === missing diff --git a/src/scimlfunctions.jl b/src/scimlfunctions.jl index c18868c8f..c1fde4650 100644 --- a/src/scimlfunctions.jl +++ b/src/scimlfunctions.jl @@ -4601,6 +4601,20 @@ has_Wfact_t(f::JacobianWrapper) = has_Wfact_t(f.f) has_paramjac(f::JacobianWrapper) = has_paramjac(f.f) has_colorvec(f::JacobianWrapper) = has_colorvec(f.f) +is_split_function(x) = is_split_function(typeof(x)) +is_split_function(::Type) = false +function is_split_function(::Type{T}) where {T <: Union{ + SplitFunction, SplitSDEFunction, DynamicalODEFunction, + DynamicalDDEFunction, DynamicalSDEFunction}} + true +end + +split_function_f_wrapper(::Type{<:SplitFunction}) = ODEFunction +split_function_f_wrapper(::Type{<:SplitSDEFunction}) = SDEFunction +split_function_f_wrapper(::Type{<:DynamicalODEFunction}) = ODEFunction +split_function_f_wrapper(::Type{<:DynamicalDDEFunction}) = DDEFunction +split_function_f_wrapper(::Type{<:DynamicalSDEFunction}) = DDEFunction + ######### Additional traits islinear(::AbstractDiffEqFunction) = false diff --git a/test/remake_tests.jl b/test/remake_tests.jl index 0a9b156be..3fdf797ff 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -372,3 +372,14 @@ end prob = ODEProblem(ODEFunction(foo; sys), [1.5, 2.5], (0.0, 1.0), [3.5, 4.5]) @test_nowarn remake(prob; u0 = [:x => nothing], p = [:a => nothing]) end + +@testset "retain properties of `SciMLFunction` passed to `remake`" begin + u0 = [1.0; 2.0; 3.0] + p = [10.0, 20.0, 30.0] + sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) + fn = NonlinearFunction(nllorenz!; sys, resid_prototype = zeros(Float64, 3)) + prob = NonlinearProblem(fn, u0, p) + fn2 = NonlinearFunction(nllorenz!; resid_prototype = zeros(Float32, 3)) + prob2 = remake(prob; f = fn2) + @test prob2.f.resid_prototype isa Vector{Float32} +end