diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index e83de5e39..b4e641832 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -75,7 +75,8 @@ function CommonSolve.solve( end function CommonSolve.solve( - prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, + prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, + alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] @@ -86,7 +87,8 @@ function CommonSolve.solve( p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function simplenonlinearsolve_solve_up(prob::ImmutableNonlinearProblem, sensealg, u0, +function simplenonlinearsolve_solve_up( + prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) (u0_changed || p_changed) && (prob = remake(prob; u0, p)) return SciMLBase.__solve(prob, alg, args...; kwargs...) diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index a18a1b6be..ebbb5f9f9 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -43,7 +43,7 @@ function SciMLBase.__solve( @bb xo = similar(x) fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? - safe_similar(fx) : nothing + safe_similar(fx) : fx jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) diff --git a/lib/SimpleNonlinearSolve/src/trust_region.jl b/lib/SimpleNonlinearSolve/src/trust_region.jl index 47acc5437..32e7a6219 100644 --- a/lib/SimpleNonlinearSolve/src/trust_region.jl +++ b/lib/SimpleNonlinearSolve/src/trust_region.jl @@ -94,7 +94,7 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi @bb xo = copy(x) fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? - safe_similar(fx) : nothing + safe_similar(fx) : fx jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 946c10529..fbc3d3c23 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -183,10 +183,10 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras) end if extras isa AnalyticJacobian if SciMLBase.isinplace(prob) - prob.jac(J, x, prob.p) + prob.f.jac(J, x, prob.p) return J else - return prob.jac(x, prob.p) + return prob.f.jac(x, prob.p) end end if SciMLBase.isinplace(prob) diff --git a/lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl b/lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl index 8b1378917..0005796f9 100644 --- a/lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl @@ -1 +1,115 @@ +@testitem "ForwardDiff.jl Integration NonlinearLeastSquaresProblem" tags=[:core] begin + using ForwardDiff, FiniteDiff, SimpleNonlinearSolve, StaticArrays, LinearAlgebra, + Zygote, ReverseDiff + using DifferentiationInterface + const DI = DifferentiationInterface + + true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]) + + θ_true = [1.0, 0.1, 2.0, 0.5] + x = [-1.0, -0.5, 0.0, 0.5, 1.0] + y_target = true_function(x, θ_true) + + loss_function(θ, p) = true_function(p, θ) .- y_target + + loss_function_jac(θ, p) = ForwardDiff.jacobian(Base.Fix2(loss_function, p), θ) + + loss_function_vjp(v, θ, p) = reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) + + function loss_function!(resid, θ, p) + ŷ = true_function(p, θ) + @. resid = ŷ - y_target + return + end + + function loss_function_jac!(J, θ, p) + J .= ForwardDiff.jacobian(θ -> loss_function(θ, p), θ) + return + end + + function loss_function_vjp!(vJ, v, θ, p) + vec(vJ) .= reshape(vec(v)' * loss_function_jac(θ, p), size(θ)) + return + end + + θ_init = θ_true .+ 0.1 + + @testset for alg in ( + SimpleGaussNewton(), + SimpleGaussNewton(; autodiff = AutoForwardDiff()), + SimpleGaussNewton(; autodiff = AutoFiniteDiff()), + SimpleGaussNewton(; autodiff = AutoReverseDiff()) + ) + function obj_1(p) + prob_oop = NonlinearLeastSquaresProblem{false}(loss_function, θ_init, p) + sol = solve(prob_oop, alg) + return sum(abs2, sol.u) + end + + function obj_2(p) + ff = NonlinearFunction{false}( + loss_function; resid_prototype = zeros(length(y_target))) + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) + sol = solve(prob_oop, alg) + return sum(abs2, sol.u) + end + + function obj_3(p) + ff = NonlinearFunction{false}(loss_function; vjp = loss_function_vjp) + prob_oop = NonlinearLeastSquaresProblem{false}(ff, θ_init, p) + sol = solve(prob_oop, alg) + return sum(abs2, sol.u) + end + + finitediff = DI.gradient(obj_1, AutoFiniteDiff(), x) + + fdiff1 = DI.gradient(obj_1, AutoForwardDiff(), x) + fdiff2 = DI.gradient(obj_2, AutoForwardDiff(), x) + fdiff3 = DI.gradient(obj_3, AutoForwardDiff(), x) + + @test finitediff≈fdiff1 atol=1e-5 + @test finitediff≈fdiff2 atol=1e-5 + @test finitediff≈fdiff3 atol=1e-5 + @test fdiff1 ≈ fdiff2 ≈ fdiff3 + + function obj_4(p) + prob_iip = NonlinearLeastSquaresProblem( + NonlinearFunction{true}( + loss_function!; resid_prototype = zeros(length(y_target))), + θ_init, + p) + sol = solve(prob_iip, alg) + return sum(abs2, sol.u) + end + + function obj_5(p) + ff = NonlinearFunction{true}( + loss_function!; resid_prototype = zeros(length(y_target)), + jac = loss_function_jac!) + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) + sol = solve(prob_iip, alg) + return sum(abs2, sol.u) + end + + function obj_6(p) + ff = NonlinearFunction{true}( + loss_function!; resid_prototype = zeros(length(y_target)), + vjp = loss_function_vjp!) + prob_iip = NonlinearLeastSquaresProblem(ff, θ_init, p) + sol = solve(prob_iip, alg) + return sum(abs2, sol.u) + end + + finitediff = DI.gradient(obj_4, AutoFiniteDiff(), x) + + fdiff4 = DI.gradient(obj_4, AutoForwardDiff(), x) + fdiff5 = DI.gradient(obj_5, AutoForwardDiff(), x) + fdiff6 = DI.gradient(obj_6, AutoForwardDiff(), x) + + @test finitediff≈fdiff4 atol=1e-5 + @test finitediff≈fdiff5 atol=1e-5 + @test finitediff≈fdiff6 atol=1e-5 + @test fdiff4 ≈ fdiff5 ≈ fdiff6 + end +end