From 81c86f32fcacf0433782e6334aa8cf1b32cbe945 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 30 Oct 2023 16:08:34 -0400 Subject: [PATCH] Should make tests pass --- Manifest.toml | 2 +- src/NonlinearSolve.jl | 1 + src/dfsane.jl | 1 + src/utils.jl | 14 ++++++-------- test/basictests.jl | 15 ++------------- 5 files changed, 11 insertions(+), 22 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 7a9729bd6..4fb69b6ae 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -164,7 +164,7 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" [[deps.DiffEqBase]] deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"] -git-tree-sha1 = "4e661d0beddac31da05e71b79afd769232622de8" +git-tree-sha1 = "0ab52aef95c5cc71e9a8c9d26919ce1f7fb472fa" repo-rev = "ap/tstable_termination" repo-url = "https://github.com/SciML/DiffEqBase.jl" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 676afab7d..ef7eca963 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -57,6 +57,7 @@ end get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1 set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu) get_u(cache::AbstractNonlinearSolveCache) = cache.u +set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u) function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) while not_terminated(cache) diff --git a/src/dfsane.jl b/src/dfsane.jl index b0de39c49..f78d40411 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -98,6 +98,7 @@ end get_fu(cache::DFSaneCache) = cache.fuₙ set_fu!(cache::DFSaneCache, fu) = (cache.fuₙ = fu) get_u(cache::DFSaneCache) = cache.uₙ +set_u!(cache::DFSaneCache, u) = (cache.uₙ = u) function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, diff --git a/src/utils.jl b/src/utils.jl index 1977735f4..2777c93e4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -230,8 +230,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev, if isinplace(cache) cache.prob.f(get_fu(cache), u, cache.prob.p) else - cache.u = u - set_fu!(cache, cache.prob.f(cache.u, cache.prob.p)) + set_fu!(cache, cache.prob.f(u, cache.prob.p)) end cache.force_stop = true end @@ -252,8 +251,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev, if isinplace(cache) cache.prob.f(get_fu(cache), u, cache.prob.p) else - cache.u = u - set_fu!(cache, cache.prob.f(cache.u, cache.prob.p)) + set_fu!(cache, cache.prob.f(u, cache.prob.p)) end cache.force_stop = true end @@ -271,11 +269,11 @@ function check_and_update!(tc_cache, cache, fu, u, uprev, cache.retcode = ReturnCode.Unstable end if isinplace(cache) - copyto!(u, tc_cache.u) - cache.prob.f(get_fu(cache), u, cache.prob.p) + copyto!(get_u(cache), tc_cache.u) + cache.prob.f(get_fu(cache), get_u(cache), cache.prob.p) else - cache.u = tc_cache.u - set_fu!(cache, cache.prob.f(cache.u, cache.prob.p)) + set_u!(cache, tc_cache.u) + set_fu!(cache, cache.prob.f(get_u(cache), cache.prob.p)) end cache.force_stop = true end diff --git a/test/basictests.jl b/test/basictests.jl index 6602edfb6..24ee5c831 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -453,7 +453,6 @@ end end @testset "[OOP] [Immutable AD]" begin - broken_forwarddiff = [3.0, 4.0, 81.0] for p in 1.1:0.1:100.0 res = abs.(benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u) @@ -461,9 +460,6 @@ end @test_broken all(res .≈ sqrt(p)) @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) - elseif p in broken_forwarddiff - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) else @test all(res .≈ sqrt(p)) @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, @@ -473,7 +469,6 @@ end end @testset "[OOP] [Scalar AD]" begin - broken_forwarddiff = [3.0, 4.0, 81.0] for p in 1.1:0.1:100.0 res = abs(benchmark_nlsolve_oop(quadratic_f, 1.0, p).u) @@ -481,9 +476,6 @@ end @test_broken res ≈ sqrt(p) @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) - elseif p in broken_forwarddiff - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) else @test res ≈ sqrt(p) @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, @@ -549,7 +541,6 @@ end probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0) sol = solve(probN, alg, abstol = 1e-11) - println(abs.(quadratic_f(sol.u, 2.0))) @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end end @@ -644,13 +635,11 @@ end function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip} probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin]) - cache = init(probN, - PseudoTransient(alpha_initial = 10.0); - maxiters = 100, + cache = init(probN, PseudoTransient(alpha_initial = 10.0); maxiters = 100, abstol = 1e-10) sols = zeros(length(p_range)) for (i, p) in enumerate(p_range) - reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0) + reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha = 10.0) sol = solve!(cache) sols[i] = iip ? sol.u[1] : sol.u end