From e96a4a9b7df4e09e48e5b077dc249c5a7e39276e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Apr 2024 11:27:39 -0400 Subject: [PATCH] Allow FiniteDiff propagation for scalar problems --- Project.toml | 2 +- docs/make.jl | 2 +- src/internal/jacobian.jl | 23 +++++++++++++++++++---- src/internal/operators.jl | 14 ++++++++++---- test/misc/polyalg_tests.jl | 21 +++++++++++++++------ 5 files changed, 46 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index f6c2785d3..99a50b257 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NonlinearSolve" uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" authors = ["SciML"] -version = "3.10.0" +version = "3.10.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index 54b7a6424..622cebbfe 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -19,7 +19,7 @@ makedocs(; sitename = "NonlinearSolve.jl", doctest = false, linkcheck = true, linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615", - "https://link.springer.com/article/10.1007/s40096-020-00339-4"], + "https://link.springer.com/article/10.1007/s40096-020-00339-4"], checkdocs = :exports, warnonly = [:missing_docs], plugins = [bib], diff --git a/src/internal/jacobian.jl b/src/internal/jacobian.jl index 3dfb904ee..221bc5d62 100644 --- a/src/internal/jacobian.jl +++ b/src/internal/jacobian.jl @@ -97,10 +97,20 @@ function JacobianCache( J, f, uf, fu, u, p, jac_cache, alg, 0, autodiff, vjp_autodiff, jvp_autodiff) end -function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; kwargs...) where {F} +function JacobianCache( + prob, alg, f::F, ::Number, u::Number, p; autodiff = nothing, kwargs...) where {F} uf = JacobianWrapper{false}(f, p) + autodiff = get_concrete_forward_ad(autodiff, prob; check_reverse_mode = false) + if !(autodiff isa AutoForwardDiff || + autodiff isa AutoPolyesterForwardDiff || + autodiff isa AutoFiniteDiff) + autodiff = AutoFiniteDiff() + # Other cases are not properly supported so we fallback to finite differencing + @warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \ + Detected $(autodiff). Falling back to AutoFiniteDiff." + end return JacobianCache{false}( - u, f, uf, u, u, p, nothing, alg, 0, nothing, nothing, nothing) + u, f, uf, u, u, p, nothing, alg, 0, autodiff, nothing, nothing) end @inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p) @@ -115,7 +125,7 @@ function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p) end function (cache::JacobianCache)(::Number, u, p = cache.p) # Scalar cache.njacs += 1 - J = last(__value_derivative(cache.uf, u)) + J = last(__value_derivative(cache.autodiff, cache.uf, u)) return J end # Compute the Jacobian @@ -181,12 +191,17 @@ end end end -@inline function __value_derivative(f::F, x::R) where {F, R} +@inline function __value_derivative( + ::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, x::R) where {F, R} T = typeof(ForwardDiff.Tag(f, R)) out = f(ForwardDiff.Dual{T}(x, one(x))) return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) end +@inline function __value_derivative(ad::AutoFiniteDiff, f::F, x::R) where {F, R} + return f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype) +end + @inline function __scalar_jacvec(f::F, x::R, v::V) where {F, R, V} T = typeof(ForwardDiff.Tag(f, R)) out = f(ForwardDiff.Dual{T}(x, v)) diff --git a/src/internal/operators.jl b/src/internal/operators.jl index a34565a9d..0c8737040 100644 --- a/src/internal/operators.jl +++ b/src/internal/operators.jl @@ -62,8 +62,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = elseif SciMLBase.has_vjp(f) f.vjp elseif u isa Number # Ignore vjp directives - if ForwardDiff.can_dual(typeof(u)) - @closure (v, u, p) -> last(__value_derivative(uf, u)) * v + if ForwardDiff.can_dual(typeof(u)) && (vjp_autodiff === nothing || + vjp_autodiff isa AutoForwardDiff || + vjp_autodiff isa AutoPolyesterForwardDiff) + # VJP is same as VJP for scalars + @closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) else @closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v end @@ -92,8 +95,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff = elseif SciMLBase.has_jvp(f) f.jvp elseif u isa Number # Ignore jvp directives - if ForwardDiff.can_dual(typeof(u)) - @closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) * v + # Only ForwardDiff if user didn't override + if ForwardDiff.can_dual(typeof(u)) && (jvp_autodiff === nothing || + jvp_autodiff isa AutoForwardDiff || + jvp_autodiff isa AutoPolyesterForwardDiff) + @closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) else @closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v end diff --git a/test/misc/polyalg_tests.jl b/test/misc/polyalg_tests.jl index 74016d75f..666db02ab 100644 --- a/test/misc/polyalg_tests.jl +++ b/test/misc/polyalg_tests.jl @@ -86,13 +86,22 @@ end # Uses the `__solve` function @test_throws MethodError solve(probN; abstol = 1e-9) @test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9) - solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9) - @test SciMLBase.successful_retcode(solver) - solver = solve( + sol = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9) + @test SciMLBase.successful_retcode(sol) + sol = solve( probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9) - @test SciMLBase.successful_retcode(solver) - solver = solve(probN, custom_polyalg; abstol = 1e-9) - @test SciMLBase.successful_retcode(solver) + @test SciMLBase.successful_retcode(sol) + sol = solve(probN, custom_polyalg; abstol = 1e-9) + @test SciMLBase.successful_retcode(sol) + + quadratic_f(u::Float64, p) = u^2 - p + + prob = NonlinearProblem(quadratic_f, 2.0, 4.0) + + @test_throws MethodError solve(prob) + @test_throws MethodError solve(prob, RobustMultiNewton()) + sol = solve(prob, RobustMultiNewton(; autodiff = AutoFiniteDiff())) + @test SciMLBase.successful_retcode(sol) end @testitem "Simple Scalar Problem #187" begin