Skip to content

Commit

Permalink
Allow FiniteDiff propagation for scalar problems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 16, 2024
1 parent e295922 commit e96a4a9
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.10.0"
version = "3.10.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
doctest = false,
linkcheck = true,
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615",
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
checkdocs = :exports,
warnonly = [:missing_docs],
plugins = [bib],
Expand Down
23 changes: 19 additions & 4 deletions src/internal/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,20 @@ function JacobianCache(
J, f, uf, fu, u, p, jac_cache, alg, 0, autodiff, vjp_autodiff, jvp_autodiff)
end

function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; kwargs...) where {F}
function JacobianCache(
prob, alg, f::F, ::Number, u::Number, p; autodiff = nothing, kwargs...) where {F}
uf = JacobianWrapper{false}(f, p)
autodiff = get_concrete_forward_ad(autodiff, prob; check_reverse_mode = false)
if !(autodiff isa AutoForwardDiff ||
autodiff isa AutoPolyesterForwardDiff ||
autodiff isa AutoFiniteDiff)
autodiff = AutoFiniteDiff()

Check warning on line 107 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L107

Added line #L107 was not covered by tests
# Other cases are not properly supported so we fallback to finite differencing
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \

Check warning on line 109 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L109

Added line #L109 was not covered by tests
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
return JacobianCache{false}(
u, f, uf, u, u, p, nothing, alg, 0, nothing, nothing, nothing)
u, f, uf, u, u, p, nothing, alg, 0, autodiff, nothing, nothing)
end

@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
Expand All @@ -115,7 +125,7 @@ function (cache::JacobianCache)(J::JacobianOperator, u, p = cache.p)
end
function (cache::JacobianCache)(::Number, u, p = cache.p) # Scalar
cache.njacs += 1
J = last(__value_derivative(cache.uf, u))
J = last(__value_derivative(cache.autodiff, cache.uf, u))
return J
end
# Compute the Jacobian
Expand Down Expand Up @@ -181,12 +191,17 @@ end
end
end

@inline function __value_derivative(f::F, x::R) where {F, R}
@inline function __value_derivative(
::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, x::R) where {F, R}
T = typeof(ForwardDiff.Tag(f, R))
out = f(ForwardDiff.Dual{T}(x, one(x)))
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

@inline function __value_derivative(ad::AutoFiniteDiff, f::F, x::R) where {F, R}
return f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype)

Check warning on line 202 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L201-L202

Added lines #L201 - L202 were not covered by tests
end

@inline function __scalar_jacvec(f::F, x::R, v::V) where {F, R, V}
T = typeof(ForwardDiff.Tag(f, R))
out = f(ForwardDiff.Dual{T}(x, v))
Expand Down
14 changes: 10 additions & 4 deletions src/internal/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
elseif SciMLBase.has_vjp(f)
f.vjp
elseif u isa Number # Ignore vjp directives
if ForwardDiff.can_dual(typeof(u))
@closure (v, u, p) -> last(__value_derivative(uf, u)) * v
if ForwardDiff.can_dual(typeof(u)) && (vjp_autodiff === nothing ||

Check warning on line 65 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L65

Added line #L65 was not covered by tests
vjp_autodiff isa AutoForwardDiff ||
vjp_autodiff isa AutoPolyesterForwardDiff)
# VJP is same as VJP for scalars
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))

Check warning on line 69 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L69

Added line #L69 was not covered by tests
else
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
end
Expand Down Expand Up @@ -92,8 +95,11 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
elseif SciMLBase.has_jvp(f)
f.jvp
elseif u isa Number # Ignore jvp directives
if ForwardDiff.can_dual(typeof(u))
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) * v
# Only ForwardDiff if user didn't override
if ForwardDiff.can_dual(typeof(u)) && (jvp_autodiff === nothing ||

Check warning on line 99 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L99

Added line #L99 was not covered by tests
jvp_autodiff isa AutoForwardDiff ||
jvp_autodiff isa AutoPolyesterForwardDiff)
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))

Check warning on line 102 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L102

Added line #L102 was not covered by tests
else
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
end
Expand Down
21 changes: 15 additions & 6 deletions test/misc/polyalg_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,22 @@ end
# Uses the `__solve` function
@test_throws MethodError solve(probN; abstol = 1e-9)
@test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9)
solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(
sol = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(sol)
sol = solve(
probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(probN, custom_polyalg; abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
@test SciMLBase.successful_retcode(sol)
sol = solve(probN, custom_polyalg; abstol = 1e-9)
@test SciMLBase.successful_retcode(sol)

quadratic_f(u::Float64, p) = u^2 - p

prob = NonlinearProblem(quadratic_f, 2.0, 4.0)

@test_throws MethodError solve(prob)
@test_throws MethodError solve(prob, RobustMultiNewton())
sol = solve(prob, RobustMultiNewton(; autodiff = AutoFiniteDiff()))
@test SciMLBase.successful_retcode(sol)
end

@testitem "Simple Scalar Problem #187" begin
Expand Down

0 comments on commit e96a4a9

Please sign in to comment.