Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make construction of type stable TP BVProblem easier #518

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SciMLBase"
uuid = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "2.0.7"
version = "2.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
20 changes: 16 additions & 4 deletions src/problems/bvp_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""
$(TYPEDEF)
"""
struct TwoPointBVProblem end
struct TwoPointBVProblem{iip} end # The iip is needed to make type stable construction easier

@doc doc"""

Expand Down Expand Up @@ -112,7 +112,7 @@
p = NullParameters(); problem_type=nothing, kwargs...) where {iip, TP}
_tspan = promote_tspan(tspan)
warn_paramtype(p)
prob_type = TP ? TwoPointBVProblem() : StandardBVProblem()
prob_type = TP ? TwoPointBVProblem{iip}() : StandardBVProblem()

Check warning on line 115 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L115

Added line #L115 was not covered by tests
# Needed to ensure that `problem_type` doesn't get passed in kwargs
if problem_type === nothing
problem_type = prob_type
Expand Down Expand Up @@ -144,16 +144,28 @@
# But we need it for function calls like TwoPointBVProblem{iip}(...) = ...
struct TwoPointBVPFunction{iip} end

@inline TwoPointBVPFunction(args...; kwargs...) = BVPFunction(args...; kwargs..., twopoint=true)
@inline function TwoPointBVPFunction(args...; kwargs...)
return BVPFunction(args...; kwargs..., twopoint = Val(true))
end
@inline function TwoPointBVPFunction{iip}(args...; kwargs...) where {iip}
return BVPFunction{iip}(args...; kwargs..., twopoint=true)
return BVPFunction{iip}(args...; kwargs..., twopoint = Val(true))

Check warning on line 151 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L151

Added line #L151 was not covered by tests
end

function TwoPointBVProblem{iip}(f, bc, u0, tspan, p = NullParameters();

Check warning on line 154 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L154

Added line #L154 was not covered by tests
bcresid_prototype=nothing, kwargs...) where {iip}
return TwoPointBVProblem(TwoPointBVPFunction{iip}(f, bc; bcresid_prototype), u0, tspan,

Check warning on line 156 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L156

Added line #L156 was not covered by tests
p; kwargs...)
end
function TwoPointBVProblem(f, bc, u0, tspan, p = NullParameters();
bcresid_prototype=nothing, kwargs...)
return TwoPointBVProblem(TwoPointBVPFunction(f, bc; bcresid_prototype), u0, tspan, p;
kwargs...)
end
function TwoPointBVProblem{iip}(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,

Check warning on line 164 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L164

Added line #L164 was not covered by tests
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
return BVProblem{iip}(f, f.bc, u0, tspan, p; kwargs...)

Check warning on line 167 in src/problems/bvp_problems.jl

View check run for this annotation

Codecov / codecov/patch

src/problems/bvp_problems.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
end
function TwoPointBVProblem(f::AbstractBVPFunction{iip, twopoint}, u0, tspan,
p = NullParameters(); kwargs...) where {iip, twopoint}
@assert twopoint "`TwoPointBVProblem` can only be used with a `TwoPointBVPFunction`. Instead of using `BVPFunction`, use `TwoPointBVPFunction` or pass a kwarg `twopoint=true` during the construction of the `BVPFunction`."
Expand Down
61 changes: 61 additions & 0 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,67 @@
end
end

"""
remake(prob::BVProblem; f = missing, u0 = missing, tspan = missing,
p = missing, kwargs = missing, problem_type = missing, _kwargs...)

Remake the given `BVProblem`.
"""
function remake(prob::BVProblem; f = missing, bc = missing, u0 = missing, tspan = missing,

Check warning on line 129 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L129

Added line #L129 was not covered by tests
p = missing, kwargs = missing, problem_type = missing, _kwargs...)
if tspan === missing
tspan = prob.tspan

Check warning on line 132 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L131-L132

Added lines #L131 - L132 were not covered by tests
end

if p === missing && u0 === missing
p, u0 = prob.p, prob.u0

Check warning on line 136 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
else # at least one of them has a value
if p === missing
p = prob.p

Check warning on line 139 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
end
if u0 === missing
u0 = prob.u0

Check warning on line 142 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L141-L142

Added lines #L141 - L142 were not covered by tests
end
end

iip = isinplace(prob)

Check warning on line 146 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L146

Added line #L146 was not covered by tests

if problem_type === missing
problem_type = prob.problem_type

Check warning on line 149 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L148-L149

Added lines #L148 - L149 were not covered by tests
end

twopoint = problem_type isa TwoPointBVProblem

Check warning on line 152 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L152

Added line #L152 was not covered by tests

if bc === missing
bc = prob.bc

Check warning on line 155 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L154-L155

Added lines #L154 - L155 were not covered by tests
end

if f === missing
_f = prob.f
elseif f isa BVPFunction
_f = f
bc = f.bc
elseif specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_iip(f,

Check warning on line 166 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L158-L166

Added lines #L158 - L166 were not covered by tests
(u0, u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
else
_f = BVPFunction{iip, FunctionWrapperSpecialize, twopoint}(wrapfun_oop(f,

Check warning on line 169 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L169

Added line #L169 was not covered by tests
(u0, p, ptspan[1])), bc; prob.f.bcresid_prototype)
end
else
_f = BVPFunction{isinplace(prob), specialization(prob.f), twopoint}(f, bc;

Check warning on line 173 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L173

Added line #L173 was not covered by tests
prob.f.bcresid_prototype)
end

if kwargs === missing
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, prob.kwargs..., _kwargs...)

Check warning on line 178 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L177-L178

Added lines #L177 - L178 were not covered by tests
else
BVProblem{iip}(_f, bc, u0, tspan, p; problem_type, kwargs...)

Check warning on line 180 in src/remake.jl

View check run for this annotation

Codecov / codecov/patch

src/remake.jl#L180

Added line #L180 was not covered by tests
end
end

"""
remake(prob::SDEProblem; f = missing, u0 = missing, tspan = missing,
p = missing, noise = missing, noise_rate_prototype = missing,
Expand Down
9 changes: 5 additions & 4 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4064,12 +4064,13 @@
end
end

function BVPFunction{iip}(f, bc; twopoint::Bool=false, kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction{iip}(f, bc; twopoint::Union{Val, Bool}=Val(false),

Check warning on line 4067 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4067

Added line #L4067 was not covered by tests
kwargs...) where {iip}
BVPFunction{iip, FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)

Check warning on line 4069 in src/scimlfunctions.jl

View check run for this annotation

Codecov / codecov/patch

src/scimlfunctions.jl#L4069

Added line #L4069 was not covered by tests
end
BVPFunction{iip}(f::BVPFunction, bc; kwargs...) where {iip} = f
function BVPFunction(f, bc; twopoint::Bool=false, kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, twopoint}(f, bc; kwargs...)
function BVPFunction(f, bc; twopoint::Union{Val, Bool}=Val(false), kwargs...)
BVPFunction{isinplace(f, 4), FullSpecialize, _unwrap_val(twopoint)}(f, bc; kwargs...)
end
BVPFunction(f::BVPFunction; kwargs...) = f

Expand Down
Loading