From 77cade8edfba21ac69d81b108115703d7fb0748c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 26 Dec 2023 18:07:56 -0500 Subject: [PATCH 1/2] Add ForwardDiff Inplace Overloads --- lib/SimpleNonlinearSolve/Project.toml | 2 +- lib/SimpleNonlinearSolve/src/ad.jl | 95 ++++++++++-------- .../src/nlsolve/halley.jl | 9 +- lib/SimpleNonlinearSolve/src/utils.jl | 3 + lib/SimpleNonlinearSolve/test/basictests.jl | 98 ------------------- lib/SimpleNonlinearSolve/test/forward_ad.jl | 93 ++++++++++++++++++ lib/SimpleNonlinearSolve/test/runtests.jl | 3 +- 7 files changed, 160 insertions(+), 143 deletions(-) create mode 100644 lib/SimpleNonlinearSolve/test/forward_ad.jl diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 124e21010..7a7ac60d2 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -1,7 +1,7 @@ name = "SimpleNonlinearSolve" uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7" authors = ["SciML"] -version = "1.1.0" +version = "1.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/lib/SimpleNonlinearSolve/src/ad.jl b/lib/SimpleNonlinearSolve/src/ad.jl index d4cbcf744..f6f5f5895 100644 --- a/lib/SimpleNonlinearSolve/src/ad.jl +++ b/lib/SimpleNonlinearSolve/src/ad.jl @@ -1,21 +1,24 @@ -function scalar_nlsolve_ad(prob, alg, args...; kwargs...) - f = prob.f +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, + iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P, iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, + sol.original) +end + +function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...; + kwargs...) where {uType, iip} p = value(prob.p) - if prob isa IntervalNonlinearProblem - tspan = value.(prob.tspan) - newprob = IntervalNonlinearProblem(f, tspan, p; prob.kwargs...) - else - u0 = value(prob.u0) - newprob = NonlinearProblem(f, u0, p; prob.kwargs...) - end + newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...) sol = solve(newprob, alg, args...; kwargs...) uu = sol.u - f_p = scalar_nlsolve_∂f_∂p(f, uu, p) - f_x = scalar_nlsolve_∂f_∂u(f, uu, p) + f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, p) + f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, p) - z_arr = -inv(f_x) * f_p + z_arr = -f_x \ f_p pp = prob.p sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) @@ -30,49 +33,57 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...) return sol, partials end -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray}, - false, <:Dual{T, V, P}}, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; - kwargs...) where {T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode) -end - -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector, <:AbstractArray}, - false, <:AbstractArray{<:Dual{T, V, P}}}, - alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) where {T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode) -end - -function scalar_nlsolve_∂f_∂p(f, u, p) - ff = p isa Number ? ForwardDiff.derivative : - (u isa Number ? ForwardDiff.gradient : ForwardDiff.jacobian) - return ff(Base.Fix1(f, u), p) +@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F} + if isinplace(prob) + __f = p -> begin + du = similar(u, promote_type(eltype(u), eltype(p))) + f(du, u, p) + return du + end + else + __f = Base.Fix1(f, u) + end + if p isa Number + return __reshape(ForwardDiff.derivative(__f, p), :, 1) + elseif u isa Number + return __reshape(ForwardDiff.gradient(__f, p), 1, :) + else + return ForwardDiff.jacobian(__f, p) + end end -function scalar_nlsolve_∂f_∂u(f, u, p) - ff = u isa Number ? ForwardDiff.derivative : ForwardDiff.jacobian - return ff(Base.Fix2(f, p), u) +@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F} + if isinplace(prob) + du = similar(u) + __f = (du, u) -> f(du, u, p) + ForwardDiff.jacobian(__f, du, u) + else + __f = Base.Fix2(f, p) + if u isa Number + return ForwardDiff.derivative(__f, u) + else + return ForwardDiff.jacobian(__f, u) + end + end end -function scalar_nlsolve_dual_soln(u::Number, partials, +@inline function __nlsolve_dual_soln(u::Number, partials, ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} return Dual{T, V, P}(u, partials) end -function scalar_nlsolve_dual_soln(u::AbstractArray, partials, +@inline function __nlsolve_dual_soln(u::AbstractArray, partials, ::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P} - return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, partials)) + _partials = _restructure(u, partials) + return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) end # avoid ambiguities for Alg in [Bisection] @eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, <:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, left = Dual{T, V, P}(sol.left, partials), right = Dual{T, V, P}(sol.right, partials)) @@ -80,8 +91,8 @@ for Alg in [Bisection] @eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} - sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = scalar_nlsolve_dual_soln(sol.u, partials, prob.p) + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, left = Dual{T, V, P}(sol.left, partials), right = Dual{T, V, P}(sol.right, partials)) diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl index 50f7d38d0..491a340e3 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl @@ -55,7 +55,14 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; setindex_trait(x) === CannotSetindex() && (A = dfx) # Factorize Once and Reuse - dfx_fact = factorize(dfx) + dfx_fact = if dfx isa Number + dfx + else + fact = lu(dfx; check = false) + !issuccess(fact) && return build_solution(prob, alg, x, fx; + retcode = ReturnCode.Unstable) + fact + end aᵢ = dfx_fact \ _vec(fx) A_ = _vec(A) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 4fb620a53..b3018f5ac 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -381,3 +381,6 @@ end return AutoFiniteDiff() end end + +@inline __reshape(x::Number, args...) = x +@inline __reshape(x::AbstractArray, args...) = reshape(x, args...) diff --git a/lib/SimpleNonlinearSolve/test/basictests.jl b/lib/SimpleNonlinearSolve/test/basictests.jl index 7b4e7bbc8..e43a90761 100644 --- a/lib/SimpleNonlinearSolve/test/basictests.jl +++ b/lib/SimpleNonlinearSolve/test/basictests.jl @@ -64,36 +64,6 @@ const TERMINATION_CONDITIONS = [ autodiff = AutoForwardDiff())) == 0 end - @testset "[OOP] Immutable AD" begin - for p in [1.0, 100.0] - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) - end - end - - @testset "[OOP] Scalar AD" begin - for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) - res_true = sqrt(p) - res.u ≈ res_true - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, - p) ≈ 1 / (2 * sqrt(p)) - end - end - - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], - p) ≈ ForwardDiff.jacobian(t, p) - @testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) @@ -124,36 +94,6 @@ end autodiff = AutoForwardDiff())) == 0 end - @testset "[OOP] Immutable AD" begin - for p in [1.0, 100.0] - @test begin - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - res_true = sqrt(p) - all(res.u .≈ res_true) - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p) ≈ 1 / (2 * sqrt(p)) - end - end - - @testset "[OOP] Scalar AD" begin - for p in 1.0:0.1:100.0 - @test begin - res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) - res_true = sqrt(p) - res.u ≈ res_true - end - @test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, - p) ≈ 1 / (2 * sqrt(p)) - end - end - - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], - p) ≈ ForwardDiff.jacobian(t, p) - @testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) @@ -195,44 +135,6 @@ end @test (@ballocated $(benchmark_nlsolve_oop)($quadratic_f, 1.0, 2.0)) == allocs end - @testset "[OOP] Immutable AD" begin - for p in [1.0, 100.0] - res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p) - - if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res) - @test_broken all(abs.(res) .≈ sqrt(p)) - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)) ≈ 1 / (2 * sqrt(p)) - else - @test all(abs.(res) .≈ sqrt(p)) - @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - @SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p))) - end - end - end - - @testset "[OOP] Scalar AD" begin - for p in 1.0:0.1:100.0 - res = benchmark_nlsolve_oop(quadratic_f, 1.0, p) - - if any(x -> isnan(x), res) - @test_broken abs(res.u) ≈ sqrt(p) - @test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - 1.0, p).u, p)) ≈ 1 / (2 * sqrt(p)) - else - @test abs(res.u) ≈ sqrt(p) - @test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, - 1.0, p).u, p)), 1 / (2 * sqrt(p))) - end - end - end - - t = (p) -> [sqrt(p[2] / p[1])] - p = [0.9, 50.0] - @test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u ≈ sqrt(p[2] / p[1]) - @test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], - p) ≈ ForwardDiff.jacobian(t, p) - @testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) diff --git a/lib/SimpleNonlinearSolve/test/forward_ad.jl b/lib/SimpleNonlinearSolve/test/forward_ad.jl new file mode 100644 index 000000000..717c222df --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/forward_ad.jl @@ -0,0 +1,93 @@ +using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra + +test_f!(du, u, p) = (@. du = u^2 - p) +test_f(u, p) = (@. u^2 - p) + +jacobian_f(::Number, p) = 1 / (2 * √p) +jacobian_f(::Number, p::Number) = 1 / (2 * √p) +jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * √p)) +jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * √p))) + +function solve_with(::Val{mode}, u, alg) where {mode} + f = if mode === :iip + solve_iip(p) = solve(NonlinearProblem(test_f!, u, p), alg).u + elseif mode === :oop + solve_oop(p) = solve(NonlinearProblem(test_f, u, p), alg).u + end + return f +end + +__compatible(::Any, ::Val{:oop}) = true +__compatible(::Number, ::Val{:iip}) = false +__compatible(::AbstractArray, ::Val{:iip}) = true +__compatible(::StaticArray, ::Val{:iip}) = false + +__compatible(::Any, ::Number) = true +__compatible(::Number, ::AbstractArray) = false +__compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p) + +__compatible(u::Number, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) = true +function __compatible(u::AbstractArray, + ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) + true +end +function __compatible(u::StaticArray, + ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) + true +end + +function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, + ::Val{:iip}) + true +end +function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, + ::Val{:oop}) + true +end +__compatible(::SimpleHalley, ::Val{:iip}) = false + +@testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(), + SimpleTrustRegion(), SimpleHalley(), SimpleBroyden(), SimpleKlement(), SimpleDFSane()) + us = (2.0, @SVector[1.0, 1.0], [1.0, 1.0], ones(2, 2), @SArray ones(2, 2)) + + @testset "Scalar AD" begin + for p in 1.0:0.1:100.0, u0 in us, mode in (:iip, :oop) + __compatible(u0, alg) || continue + __compatible(u0, Val(mode)) || continue + __compatible(alg, Val(mode)) || continue + + sol = solve(NonlinearProblem(test_f, u0, p), alg) + if SciMLBase.successful_retcode(sol) + gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p)) + gs_true = abs.(jacobian_f(u0, p)) + if !(isapprox(gs, gs_true, atol = 1e-5)) + @show sol.retcode, sol.u + @error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true + else + @test abs.(gs)≈abs.(gs_true) atol=1e-5 + end + end + end + end + + @testset "Jacobian" begin + for u0 in us, p in ([2.0, 1.0], [2.0 1.0; 3.0 4.0]), mode in (:iip, :oop) + __compatible(u0, p) || continue + __compatible(u0, alg) || continue + __compatible(u0, Val(mode)) || continue + __compatible(alg, Val(mode)) || continue + + sol = solve(NonlinearProblem(test_f, u0, p), alg) + if SciMLBase.successful_retcode(sol) + gs = abs.(ForwardDiff.jacobian(solve_with(Val{mode}(), u0, alg), p)) + gs_true = abs.(jacobian_f(u0, p)) + if !(isapprox(gs, gs_true, atol = 1e-5)) + @show sol.retcode, sol.u + @error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_jacobian=gs true_jacobian=gs_true + else + @test abs.(gs)≈abs.(gs_true) atol=1e-5 + end + end + end + end +end diff --git a/lib/SimpleNonlinearSolve/test/runtests.jl b/lib/SimpleNonlinearSolve/test/runtests.jl index cc4cd70b3..6cb730bc7 100644 --- a/lib/SimpleNonlinearSolve/test/runtests.jl +++ b/lib/SimpleNonlinearSolve/test/runtests.jl @@ -4,7 +4,8 @@ const GROUP = get(ENV, "GROUP", "All") @time @testset "SimpleNonlinearSolve.jl" begin if GROUP == "All" || GROUP == "Core" - @time @safetestset "Basic Tests + Some AD" include("basictests.jl") + @time @safetestset "Basic Tests" include("basictests.jl") + @time @safetestset "Forward AD" include("forward_ad.jl") @time @safetestset "Matrix Resizing Tests" include("matrix_resizing_tests.jl") @time @safetestset "Least Squares Tests" include("least_squares.jl") @time @safetestset "23 Test Problems" include("23_test_problems.jl") From f66e91385dd2b8008b2d493c47d01673fc7b61fa Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 26 Dec 2023 19:01:44 -0500 Subject: [PATCH 2/2] Fix tests --- lib/SimpleNonlinearSolve/src/ad.jl | 47 +++++++++---------- .../src/nlsolve/halley.jl | 2 +- lib/SimpleNonlinearSolve/test/forward_ad.jl | 23 +++------ 3 files changed, 30 insertions(+), 42 deletions(-) diff --git a/lib/SimpleNonlinearSolve/src/ad.jl b/lib/SimpleNonlinearSolve/src/ad.jl index f6f5f5895..574904bcc 100644 --- a/lib/SimpleNonlinearSolve/src/ad.jl +++ b/lib/SimpleNonlinearSolve/src/ad.jl @@ -7,10 +7,30 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray} sol.original) end -function __nlsolve_ad(prob::NonlinearProblem{uType, iip}, alg, args...; - kwargs...) where {uType, iip} +# Handle Ambiguities +for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder) + @eval begin + function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, + alg::$(algType), args...; kwargs...) where {uType, T, V, P, iip} + sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) + dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, + sol.stats, sol.original, left = Dual{T, V, P}(sol.left, partials), + right = Dual{T, V, P}(sol.right, partials)) + end + end +end + +function __nlsolve_ad(prob, alg, args...; kwargs...) p = value(prob.p) - newprob = NonlinearProblem(prob.f, value(prob.u0), p; prob.kwargs...) + if prob isa IntervalNonlinearProblem + tspan = value.(prob.tspan) + newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...) + else + u0 = value(prob.u0) + newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...) + end sol = solve(newprob, alg, args...; kwargs...) @@ -77,24 +97,3 @@ end _partials = _restructure(u, partials) return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, _partials)) end - -# avoid ambiguities -for Alg in [Bisection] - @eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:Dual{T, V, P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P} - sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, - left = Dual{T, V, P}(sol.left, partials), - right = Dual{T, V, P}(sol.right, partials)) - end - @eval function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:AbstractArray{<:Dual{T, V, P}}}, alg::$Alg, args...; - kwargs...) where {uType, iip, T, V, P} - sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) - return SciMLBase.build_solution(prob, alg, dual_soln, sol.resid; sol.retcode, - left = Dual{T, V, P}(sol.left, partials), - right = Dual{T, V, P}(sol.right, partials)) - end -end diff --git a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl index 491a340e3..44877e097 100644 --- a/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl +++ b/lib/SimpleNonlinearSolve/src/nlsolve/halley.jl @@ -71,7 +71,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...; @bb Aaᵢ = A × aᵢ @bb A .*= -1 - bᵢ = dfx_fact \ Aaᵢ + bᵢ = dfx_fact \ _vec(Aaᵢ) cᵢ_ = _vec(cᵢ) @bb @. cᵢ_ = (aᵢ * aᵢ) / (-aᵢ + (T(0.5) * bᵢ)) diff --git a/lib/SimpleNonlinearSolve/test/forward_ad.jl b/lib/SimpleNonlinearSolve/test/forward_ad.jl index 717c222df..f545ccb0c 100644 --- a/lib/SimpleNonlinearSolve/test/forward_ad.jl +++ b/lib/SimpleNonlinearSolve/test/forward_ad.jl @@ -1,4 +1,5 @@ using ForwardDiff, SimpleNonlinearSolve, StaticArrays, Test, LinearAlgebra +import SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm test_f!(du, u, p) = (@. du = u^2 - p) test_f(u, p) = (@. u^2 - p) @@ -26,24 +27,12 @@ __compatible(::Any, ::Number) = true __compatible(::Number, ::AbstractArray) = false __compatible(u::AbstractArray, p::AbstractArray) = size(u) == size(p) -__compatible(u::Number, ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) = true -function __compatible(u::AbstractArray, - ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) - true -end -function __compatible(u::StaticArray, - ::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm) - true -end +__compatible(u::Number, ::AbstractSimpleNonlinearSolveAlgorithm) = true +__compatible(u::AbstractArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true +__compatible(u::StaticArray, ::AbstractSimpleNonlinearSolveAlgorithm) = true -function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, - ::Val{:iip}) - true -end -function __compatible(::SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, - ::Val{:oop}) - true -end +__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:iip}) = true +__compatible(::AbstractSimpleNonlinearSolveAlgorithm, ::Val{:oop}) = true __compatible(::SimpleHalley, ::Val{:iip}) = false @testset "ForwardDiff.jl Integration: $(alg)" for alg in (SimpleNewtonRaphson(),