diff --git a/src/algorithms/trust_region.jl b/src/algorithms/trust_region.jl index 89c4d8f5d..95e04cef3 100644 --- a/src/algorithms/trust_region.jl +++ b/src/algorithms/trust_region.jl @@ -24,9 +24,16 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1, - max_shrink_times::Int = 32, vjp_autodiff = nothing, autodiff = nothing) + max_shrink_times::Int = 32, autodiff = nothing, vjp_autodiff = nothing) descent = Dogleg(; linsolve, precs) - forward_ad = autodiff isa ADTypes.AbstractForwardMode ? autodiff : nothing + if autodiff isa Union{ADTypes.AbstractForwardMode, ADTypes.AbstractFiniteDifferencesMode} + forward_ad = autodiff + else + forward_ad = nothing + end + if isnothing(vjp_autodiff) && autodiff isa ADTypes.AbstractFiniteDifferencesMode + vjp_autodiff = autodiff + end trustregion = GenericTrustRegionScheme(; method = radius_update_scheme, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, reverse_ad = vjp_autodiff, forward_ad) diff --git a/test/misc/polyalgs.jl b/test/misc/polyalgs.jl index e36c066fc..047ba0933 100644 --- a/test/misc/polyalgs.jl +++ b/test/misc/polyalgs.jl @@ -93,11 +93,14 @@ end maxiters = 10) end +no_ad_fast = FastShortcutNonlinearPolyalg(autodiff=AutoFiniteDiff()) +no_ad_robust = RobustMultiNewton(autodiff=AutoFiniteDiff()) +no_ad_algs = Set([no_ad_fast, no_ad_robust, no_ad_fast.algs..., no_ad_robust.algs...]) @testset "[IIP] no AD" begin f_iip = Base.Experimental.@opaque (du, u, p) -> du .= u .* u .- p - u0 = [0.0] + u0 = [0.5] prob = NonlinearProblem(f_iip, u0, 1.0) - for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())] + for alg in no_ad_algs sol = solve(prob, alg) @test isapprox(only(sol.u), 1.0) @test SciMLBase.successful_retcode(sol.retcode) @@ -106,9 +109,9 @@ end @testset "[OOP] no AD" begin f_oop = Base.Experimental.@opaque (u, p) -> u .* u .- p - u0 = [0.0] + u0 = [0.5] prob = NonlinearProblem{false}(f_oop, u0, 1.0) - for alg in [RobustMultiNewton(autodiff = AutoFiniteDiff())] + for alg in no_ad_algs sol = solve(prob, alg) @test isapprox(only(sol.u), 1.0) @test SciMLBase.successful_retcode(sol.retcode)