diff --git a/Project.toml b/Project.toml index 6753ed3ee..6488517f2 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ Reexport = "0.2" Setfield = "0.7" StaticArrays = "0.11, 0.12" UnPack = "0.1, 1.0" +julia = "1" [extras] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index dbb6e0cb4..d7a894d5c 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -3,6 +3,7 @@ module NonlinearSolve using Reexport using UnPack: @unpack using FiniteDiff, ForwardDiff + using ForwardDiff: Dual using Setfield using StaticArrays using RecursiveArrayTools diff --git a/src/scalar.jl b/src/scalar.jl index 343543afd..ccdc61ab1 100644 --- a/src/scalar.jl +++ b/src/scalar.jl @@ -19,6 +19,51 @@ function solve(prob::NonlinearProblem{<:Number}, ::NewtonRaphson, args...; xatol return NewtonSolution(x, MAXITERS_EXCEED) end +function scalar_nlsolve_ad(prob, alg, args...; kwargs...) + f = prob.f + p = value(prob.p) + u0 = value(prob.u0) + + newprob = NonlinearProblem(f, u0, p; prob.kwargs...) + sol = solve(newprob, alg, args...; kwargs...) + + uu = getsolution(sol) + if p isa Number + f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p) + else + f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p) + end + + f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu) + pp = prob.p + sumfun = let f_x′ = -f_x + ((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p) + end + partials = sum(sumfun, zip(f_p, pp)) + return sol, partials +end + +function solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) +end +function solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode) +end + +# avoid ambiguities +for Alg in [Bisection, Falsi] + @eval function solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + end + @eval function solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} + sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) + return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode) + end +end + function solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.u0 diff --git a/src/types.jl b/src/types.jl index c04012944..9f1b03ee9 100644 --- a/src/types.jl +++ b/src/types.jl @@ -78,3 +78,6 @@ function sync_residuals!(solver::BracketingImmutableSolver) @set! solver.fr = solver.f(solver.right, solver.p) solver end + +getsolution(sol::NewtonSolution) = sol.u +getsolution(sol::BracketingSolution) = sol.left diff --git a/src/utils.jl b/src/utils.jl index c13635fd5..061ee3838 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -216,7 +216,7 @@ Move `x` one floating point towards x0. function prevfloat_tdir(x::T, x0::T, x1::T)::T where {T} x1 > x0 ? prevfloat(x) : nextfloat(x) end - + function nextfloat_tdir(x::T, x0::T, x1::T)::T where {T} x1 > x0 ? nextfloat(x) : prevfloat(x) end @@ -234,3 +234,7 @@ function value_derivative(f::F, x::R) where {F,R} out = f(ForwardDiff.Dual{T}(x, one(x))) ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out) end + +value(x) = x +value(x::Dual) = ForwardDiff.value(x) +value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x) diff --git a/test/runtests.jl b/test/runtests.jl index fc53f0904..80699284d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,18 +57,41 @@ end f, u0 = (u, p) -> u * u - p, 1.0 g = function (p) - probN = NonlinearProblem{false}(f, u0, p) + probN = NonlinearProblem{false}(f, oftype(p, u0), p) sol = solve(probN, NewtonRaphson()) return sol.u end -@test_broken ForwardDiff.derivative(g, 1.0) ≈ 0.5 +@test ForwardDiff.derivative(g, 1.0) ≈ 0.5 for p in 1.1:0.1:100.0 @test g(p) ≈ sqrt(p) @test ForwardDiff.derivative(g, p) ≈ 1/(2*sqrt(p)) end +f, u0 = (u, p) -> p[1] * u * u - p[2], (1.0, 100.0) +t = (p) -> [sqrt(p[2] / p[1])] +p = [0.9, 50.0] +for alg in [Bisection(), Falsi()] + global g, p + g = function (p) + probN = NonlinearProblem{false}(f, u0, p) + sol = solve(probN, Bisection()) + return [sol.left] + end + + @test g(p) ≈ [sqrt(p[2] / p[1])] + @test ForwardDiff.jacobian(g, p) ≈ ForwardDiff.jacobian(t, p) +end + +gnewton = function (p) + probN = NonlinearProblem{false}(f, 0.5, p) + sol = solve(probN, NewtonRaphson()) + return [sol.u] +end +@test gnewton(p) ≈ [sqrt(p[2] / p[1])] +@test ForwardDiff.jacobian(gnewton, p) ≈ ForwardDiff.jacobian(t, p) + # Error Checks f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]