Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards a cleaner and more maintainable internals of NonlinearSolve.jl #203

Merged
merged 19 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions src/ad.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,61 @@
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
f = prob.f
p = value(prob.p)

u0 = value(prob.u0)
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
if p isa Number
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
else
f_p = ForwardDiff.gradient(Base.Fix1(f, uu), p)
end
f_p = scalar_nlsolve_∂f_∂p(f, uu, p)
f_x = scalar_nlsolve_∂f_∂u(f, uu, p)

Check warning on line 11 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L10-L11

Added lines #L10 - L11 were not covered by tests

z_arr = -inv(f_x) * f_p

Check warning on line 13 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L13

Added line #L13 was not covered by tests

f_x = ForwardDiff.derivative(Base.Fix2(f, p), uu)
pp = prob.p
sumfun = let f_x′ = -f_x
((fp, p),) -> (fp / f_x′) * ForwardDiff.partials(p)
sumfun = ((z, p),) -> [zᵢ * ForwardDiff.partials(p) for zᵢ in z]
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
if uu isa Number
partials = sum(sumfun, zip(z_arr, pp))

Check warning on line 18 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
else
partials = sum(sumfun, zip(eachcol(z_arr), pp))

Check warning on line 20 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L20

Added line #L20 was not covered by tests
end
partials = sum(sumfun, zip(f_p, pp))

return sol, partials
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
<:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},

Check warning on line 26 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L26

Added line #L26 was not covered by tests
iip, <:Dual{T, V, P}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)

Check warning on line 31 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L30-L31

Added lines #L30 - L31 were not covered by tests
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}, iip,
<:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray},

Check warning on line 34 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L34

Added line #L34 was not covered by tests
iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::AbstractNewtonAlgorithm, args...;
kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
sol.retcode)
dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p)
return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode)

Check warning on line 39 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end

function scalar_nlsolve_∂f_∂p(f, u, p)
ff = p isa Number ? ForwardDiff.derivative :

Check warning on line 43 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L42-L43

Added lines #L42 - L43 were not covered by tests
(u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian)
return ff(Base.Fix1(f, u), p)

Check warning on line 45 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L45

Added line #L45 was not covered by tests
end

function scalar_nlsolve_∂f_∂u(f, u, p)
ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian
return ff(Base.Fix2(f, p), u)

Check warning on line 50 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L48-L50

Added lines #L48 - L50 were not covered by tests
end

function scalar_nlsolve_dual_soln(u::Number, partials,

Check warning on line 53 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L53

Added line #L53 was not covered by tests
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return Dual{T, V, P}(u, partials[1])

Check warning on line 55 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L55

Added line #L55 was not covered by tests
end

function scalar_nlsolve_dual_soln(u::AbstractArray, partials,

Check warning on line 58 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L58

Added line #L58 was not covered by tests
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials))

Check warning on line 60 in src/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/ad.jl#L60

Added line #L60 was not covered by tests
end
30 changes: 9 additions & 21 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,13 @@ end
@test (@ballocated solve!($cache)) ≤ 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p))
end

Expand Down Expand Up @@ -101,11 +99,9 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) ≈ sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) ≈ sqrt.(p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(),
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0))
end
Expand Down Expand Up @@ -149,8 +145,6 @@ end
@test (@ballocated solve!($cache)) ≤ 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
p in 1.0:0.1:100.0

Expand All @@ -160,7 +154,7 @@ end
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) ≈ 1 / (2 * sqrt(p))
end

Expand Down Expand Up @@ -204,11 +198,9 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) ≈ sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) ≈ sqrt.(p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
@testset "ADType: $(autodiff) u0: $(_nameof(u0)) radius_update_scheme: $(radius_update_scheme)" for autodiff in (false,
true, AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(), AutoSparseEnzyme()),
u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]),
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0]),
radius_update_scheme in radius_update_schemes

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -302,15 +294,13 @@ end
@test (@ballocated solve!($cache)) ≤ 64
end

# FIXME: Even the previous tests were broken, but due to a typo in the tests they
# accidentally passed
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test_broken ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p))
end

Expand All @@ -330,11 +320,9 @@ end
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p) ≈
ForwardDiff.jacobian(t, p)

probN = NonlinearProblem(quadratic_f, @SVector[1.0, 1.0], 2.0)
@testset "ADType: $(autodiff) u0: $(u0)" for autodiff in (false, true,
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(),
AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0])
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, LevenbergMarquardt(; autodiff)).u .≈ sqrt(2.0))
end
Expand Down
Loading