Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: hessian #489

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
NonlinearSolveBase = "1.2"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
100 changes: 9 additions & 91 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using CommonSolve: solve
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: mul!
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

Expand All @@ -20,11 +19,14 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
end

Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
Utils.value(x::Dual) = ForwardDiff.value(x)
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
prob::Union{
IntervalNonlinearProblem, NonlinearProblem,
ImmutableNonlinearProblem, NonlinearLeastSquaresProblem
},
alg, args...; kwargs...
)
p = Utils.value(prob.p)
Expand All @@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

if uu isa Number
partials = sum(sumfun, zip(z, pp))
elseif p isa Number
partials = sumfun((z, pp))
else
partials = sum(sumfun, zip(eachcol(z), pp))
end

return sol, partials
end

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
)
p = Utils.value(prob.p)
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)
uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) ≥ 50 ?
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
AutoForwardDiff()

if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
ff = Base.Fix2(prob.f, p)
res = only(DI.pullback(ff, autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
Expand Down
4 changes: 2 additions & 2 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Preferences: @load_preference, @set_preferences!

using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector,
KnownJacobianSparsityDetector
using Adapt: WrappedArray
using ArrayInterface: ArrayInterface
Expand All @@ -25,7 +25,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using SymbolicIndexingInterface: SymbolicIndexingInterface

using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul!
using Markdown: @doc_str
using Printf: @printf

Expand Down
62 changes: 62 additions & 0 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,65 @@ end
is_finite_differences_backend(ad::AbstractADType) = false
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true

function nlls_generate_vjp_function(prob::NonlinearLeastSquaresProblem, sol, uu)
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f.vjp(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
return @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
return @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) ≥ 50 ?
select_reverse_mode_autodiff(prob, nothing) : AutoForwardDiff()

if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
return @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
res = only(DI.pullback(Base.Fix2(prob.f, p), autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function nonlinearsolve_forwarddiff_solve end
function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
12 changes: 7 additions & 5 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ function InternalAPI.reinit!(
end

for algType in ALL_SOLVER_TYPES
# XXX: Extend to DualNonlinearLeastSquaresProblem
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
Expand All @@ -64,10 +63,13 @@ end
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

Expand Down
94 changes: 94 additions & 0 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,97 @@ end
end
end
end

@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin
using ForwardDiff, FiniteDiff

function objfn(F, init, params)
th1, th2 = init
px, py, l1, l2 = params
F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px
F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py
return F
end

function solve_nlprob(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
resu = solve(
prob,
reltol = 1e-12, abstol = 1e-12
)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])

@test grad1≈grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])

@test hess1≈hess2 atol=1e-3

function solve_nlprob_with_cache(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
cache = init(prob; reltol = 1e-12, abstol = 1e-12)
resu = solve!(cache)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob_with_cache, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob_with_cache, [34.0, 87.0])

@test grad1≈grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob_with_cache, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob_with_cache, [34.0, 87.0])

@test hess1≈hess2 atol=1e-3
end
Loading