-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add DFSane method #214
Conversation
Started implementation of OOP solver but this doesn't work (error: using NonlinearSolve
using Random
Random.seed!(123)
function f!(du, u, p)
@. du .= u .* u .- p
return nothing
end
f = (u, p) -> u .* u .- p
n_test = 10
u0 = rand(n_test)
p = rand(n_test) .* 5
prob_iip = NonlinearProblem{true}(f!, u0, p);
prob_oop = NonlinearProblem{false}(f, u0, p);
alg = NonlinearSolve.DFSane()
sol = solve(prob_iip, alg) # works
sol = solve(prob_oop, alg) # doesn't work |
Stacktrace: ERROR: UndefVarError: `f` not defined
Stacktrace:
[1] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; alias_u0::Bool, maxiters::Int64, abstol::Float64, internalnorm::Function, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:118
[2] __init(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/dfsane.jl:88
[3] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:455
[4] init_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:433
[5] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:505
[6] init_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:475
[7] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:468
[8] init(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:459
[9] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:32
[10] __solve(::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, ::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ NonlinearSolve ~/Desktop/PrincetonCourses/MIT/NonlinearSolve.jl/src/NonlinearSolve.jl:29
[11] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; merge_callbacks::Bool, kwargshandle::DiffEqBase.KeywordArgError, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:539
[12] solve_call(_prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:509
[13] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:1008
[14] solve_up(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, sensealg::Nothing, u0::Vector{Float64}, p::Vector{Float64}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:973
[15] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{true}, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:967
[16] solve(prob::NonlinearProblem{Vector{Float64}, false, Vector{Float64}, NonlinearFunction{false, SciMLBase.FullSpecialize, var"#7#8", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardNonlinearProblem}, args::DFSane{Float32, NonlinearSolve.var"#28#30"})
@ DiffEqBase ~/.julia/packages/DiffEqBase/JehiA/src/solve.jl:957
[17] top-level scope
@ ~/Desktop/PrincetonCourses/MIT/dfsane_test/mwe_oop.jl:21 |
src/dfsane.jl
Outdated
f(dx, x) = prob.f(dx, x, p) | ||
f(fuₙ₋₁, uₙ₋₁) | ||
|
||
else | ||
f(x) = prob.f(x, p) | ||
fuₙ₋₁ = f(uₙ₋₁) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error is because f
is being overwritten due to branching. Changing the name might fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the name does sort of fix it. But I'm sorry, I really don't understand why this happens. I never had this problem before. Do you have some quick reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You just cannot define standard functions in branches, make them anonymous.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Codecov Report
@@ Coverage Diff @@
## master #214 +/- ##
===========================================
+ Coverage 0.00% 86.11% +86.11%
===========================================
Files 13 14 +1
Lines 1054 1203 +149
===========================================
+ Hits 0 1036 +1036
+ Misses 1054 167 -887
... and 12 files with indirect coverage changes 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
How close are we from getting this in? |
Tests are added! Everything works except for that ForwardDiff fails in some cases, see this MWE: using NonlinearSolve
using FiniteDiff, ForwardDiff
quadratic_f(u, p) = u .* u .- p
function benchmark_nlsolve_oop(f, u0, p=2.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
end
broken_forwarddiff = [3.0, 4.0, 81.0]
for p in broken_forwarddiff
analytical_derivative = 1 / (2 * sqrt(p))
forward_diff = abs(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
finite_diff = abs(FiniteDiff.finite_difference_derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p))
println("p = $p, Analytical: $analytical_derivative, ForwardDiff: $forward_diff, FiniteDiff: $finite_diff")
end Which prints out:
|
That's fine. We shouldn't ForwardDiff the solver anyways. Someone should handle that separately. |
Specifically #245 |
This PR adds a DFSane solver, similar to the ones in SimpleNonlinearSolve, here and here.
The implementation in this PR improves on the SimpleNonlinearSolve version by adding a cached solver with non allocating iterations.
Checklist: