Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow FiniteDiff propagation for scalar problems #409

Merged
merged 1 commit into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.10.0"
version = "3.10.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ makedocs(; sitename = "NonlinearSolve.jl",
doctest = false,
linkcheck = true,
linkcheck_ignore = ["https://twitter.com/ChrisRackauckas/status/1544743542094020615",
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
"https://link.springer.com/article/10.1007/s40096-020-00339-4"],
checkdocs = :exports,
warnonly = [:missing_docs],
plugins = [bib],
Expand Down
12 changes: 11 additions & 1 deletion src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,17 @@
args...; internalnorm::IN = DEFAULT_NORM, kwargs...) where {F, IN}
T = promote_type(eltype(fu), eltype(u))
if u isa Number
grad_op = @closure (u, fu, p) -> last(__value_derivative(Base.Fix2(f, p), u)) * fu
autodiff = get_concrete_forward_ad(alg.autodiff, prob; check_forward_mode = true)
if !(autodiff isa AutoForwardDiff ||

Check warning on line 105 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L104-L105

Added lines #L104 - L105 were not covered by tests
autodiff isa AutoPolyesterForwardDiff ||
autodiff isa AutoFiniteDiff)
autodiff = AutoFiniteDiff()

Check warning on line 108 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L108

Added line #L108 was not covered by tests
# Other cases are not properly supported so we fallback to finite differencing
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \

Check warning on line 110 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L110

Added line #L110 was not covered by tests
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
grad_op = @closure (u, fu, p) -> last(__value_derivative(

Check warning on line 113 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L113

Added line #L113 was not covered by tests
autodiff, Base.Fix2(f, p), u)) * fu
else
if SciMLBase.has_jvp(f)
if isinplace(prob)
Expand Down
23 changes: 19 additions & 4 deletions src/internal/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,20 @@
J, f, uf, fu, u, p, jac_cache, alg, 0, autodiff, vjp_autodiff, jvp_autodiff)
end

function JacobianCache(prob, alg, f::F, ::Number, u::Number, p; kwargs...) where {F}
function JacobianCache(
prob, alg, f::F, ::Number, u::Number, p; autodiff = nothing, kwargs...) where {F}
uf = JacobianWrapper{false}(f, p)
autodiff = get_concrete_forward_ad(autodiff, prob; check_reverse_mode = false)
if !(autodiff isa AutoForwardDiff ||
autodiff isa AutoPolyesterForwardDiff ||
autodiff isa AutoFiniteDiff)
autodiff = AutoFiniteDiff()

Check warning on line 107 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L107

Added line #L107 was not covered by tests
# Other cases are not properly supported so we fallback to finite differencing
@warn "Scalar AD is supported only for AutoForwardDiff and AutoFiniteDiff. \

Check warning on line 109 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L109

Added line #L109 was not covered by tests
Detected $(autodiff). Falling back to AutoFiniteDiff."
end
return JacobianCache{false}(
u, f, uf, u, u, p, nothing, alg, 0, nothing, nothing, nothing)
u, f, uf, u, u, p, nothing, alg, 0, autodiff, nothing, nothing)
end

@inline (cache::JacobianCache)(u = cache.u) = cache(cache.J, u, cache.p)
Expand All @@ -115,7 +125,7 @@
end
function (cache::JacobianCache)(::Number, u, p = cache.p) # Scalar
cache.njacs += 1
J = last(__value_derivative(cache.uf, u))
J = last(__value_derivative(cache.autodiff, cache.uf, u))
return J
end
# Compute the Jacobian
Expand Down Expand Up @@ -181,12 +191,17 @@
end
end

@inline function __value_derivative(f::F, x::R) where {F, R}
@inline function __value_derivative(
::Union{AutoForwardDiff, AutoPolyesterForwardDiff}, f::F, x::R) where {F, R}
T = typeof(ForwardDiff.Tag(f, R))
out = f(ForwardDiff.Dual{T}(x, one(x)))
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
end

@inline function __value_derivative(ad::AutoFiniteDiff, f::F, x::R) where {F, R}
return f(x), FiniteDiff.finite_difference_derivative(f, x, ad.fdtype)

Check warning on line 202 in src/internal/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/jacobian.jl#L201-L202

Added lines #L201 - L202 were not covered by tests
end

@inline function __scalar_jacvec(f::F, x::R, v::V) where {F, R, V}
T = typeof(ForwardDiff.Tag(f, R))
out = f(ForwardDiff.Dual{T}(x, v))
Expand Down
14 changes: 10 additions & 4 deletions src/internal/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,11 @@
elseif SciMLBase.has_vjp(f)
f.vjp
elseif u isa Number # Ignore vjp directives
if ForwardDiff.can_dual(typeof(u))
@closure (v, u, p) -> last(__value_derivative(uf, u)) * v
if ForwardDiff.can_dual(typeof(u)) && (vjp_autodiff === nothing ||

Check warning on line 65 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L65

Added line #L65 was not covered by tests
vjp_autodiff isa AutoForwardDiff ||
vjp_autodiff isa AutoPolyesterForwardDiff)
# VJP is same as VJP for scalars
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))

Check warning on line 69 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L69

Added line #L69 was not covered by tests
else
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
end
Expand Down Expand Up @@ -92,8 +95,11 @@
elseif SciMLBase.has_jvp(f)
f.jvp
elseif u isa Number # Ignore jvp directives
if ForwardDiff.can_dual(typeof(u))
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v)) * v
# Only ForwardDiff if user didn't override
if ForwardDiff.can_dual(typeof(u)) && (jvp_autodiff === nothing ||

Check warning on line 99 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L99

Added line #L99 was not covered by tests
jvp_autodiff isa AutoForwardDiff ||
jvp_autodiff isa AutoPolyesterForwardDiff)
@closure (v, u, p) -> last(__scalar_jacvec(uf, u, v))

Check warning on line 102 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L102

Added line #L102 was not covered by tests
else
@closure (v, u, p) -> FiniteDiff.finite_difference_derivative(uf, u) * v
end
Expand Down
21 changes: 15 additions & 6 deletions test/misc/polyalg_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,22 @@ end
# Uses the `__solve` function
@test_throws MethodError solve(probN; abstol = 1e-9)
@test_throws MethodError solve(probN, RobustMultiNewton(); abstol = 1e-9)
solver = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(
sol = solve(probN, RobustMultiNewton(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(sol)
sol = solve(
probN, FastShortcutNonlinearPolyalg(; autodiff = AutoFiniteDiff()); abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
solver = solve(probN, custom_polyalg; abstol = 1e-9)
@test SciMLBase.successful_retcode(solver)
@test SciMLBase.successful_retcode(sol)
sol = solve(probN, custom_polyalg; abstol = 1e-9)
@test SciMLBase.successful_retcode(sol)

quadratic_f(u::Float64, p) = u^2 - p

prob = NonlinearProblem(quadratic_f, 2.0, 4.0)

@test_throws MethodError solve(prob)
@test_throws MethodError solve(prob, RobustMultiNewton())
sol = solve(prob, RobustMultiNewton(; autodiff = AutoFiniteDiff()))
@test SciMLBase.successful_retcode(sol)
end

@testitem "Simple Scalar Problem #187" begin
Expand Down
Loading