Skip to content

Commit

Permalink
Merge pull request #125 from CCsimon123/master
Browse files Browse the repository at this point in the history
Bug fix for TrustRegion when iip=true.
  • Loading branch information
ChrisRackauckas authored Jan 19, 2023
2 parents 9049a3c + 69350aa commit c737bd2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
20 changes: 14 additions & 6 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss, H, fu, 0, u, u_tmp, fu, true,
max_trust_radius, loss, loss, H, zero(fu), 0, zero(u),
u_tmp, zero(fu), true,
loss)
end

Expand Down Expand Up @@ -307,10 +308,7 @@ function trust_region_step!(cache::TrustRegionCache)
cache.shrink_counter = 0
end
if r > alg.step_threshold

# Take the step.
cache.u = u_tmp
cache.fu = fu_new
take_step!(cache)
cache.loss = cache.loss_new

# Update the trust region radius.
Expand All @@ -324,7 +322,7 @@ function trust_region_step!(cache::TrustRegionCache)
cache.make_new_J = false
end

if iszero(cache.fu) || cache.internalnorm(cache.step_size) < cache.abstol
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
end
Expand Down Expand Up @@ -356,6 +354,16 @@ function dogleg!(cache::TrustRegionCache)
cache.step_size = δsd + τ * N_sd
end

function take_step!(cache::TrustRegionCache{true})
cache.u .= cache.u_tmp
cache.fu .= cache.fu_new
end

function take_step!(cache::TrustRegionCache{false})
cache.u = cache.u_tmp
cache.fu = cache.fu_new
end

function SciMLBase.solve!(cache::TrustRegionCache)
while !cache.force_stop && cache.iter < cache.maxiters &&
cache.shrink_counter < cache.alg.max_shrink_times
Expand Down
24 changes: 12 additions & 12 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ u0 = [1.0, 1.0]

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_mutable(ff, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
@test abs(sol.u * sol.u - 2) < 1e-9

# @test (@ballocated benchmark_immutable(ff, cu0)) < 200
# @test (@ballocated benchmark_mutable(ff, cu0)) < 200
Expand All @@ -59,7 +59,7 @@ u0 = [1.0, 1.0]

sol = benchmark_inplace(ffiip, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
Expand Down Expand Up @@ -160,13 +160,13 @@ u0 = [1.0, 1.0]

sol = benchmark_immutable(ff, cu0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_mutable(ff, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
sol = benchmark_scalar(sf, csu0)
@test sol.retcode === ReturnCode.Success
@test sol.u * sol.u - 2 < 1e-9
@test abs(sol.u * sol.u - 2) < 1e-9

function benchmark_inplace(f, u0)
probN = NonlinearProblem{true}(f, u0)
Expand All @@ -181,7 +181,7 @@ u0 = [1.0, 1.0]

sol = benchmark_inplace(ffiip, u0)
@test sol.retcode === ReturnCode.Success
@test all(sol.u .* sol.u .- 2 .< 1e-9)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

u0 = [1.0, 1.0]
probN = NonlinearProblem{true}(ffiip, u0)
Expand Down Expand Up @@ -263,7 +263,7 @@ f = (u, p) -> 0.010000000000000002 .+
0.0011552453009332421u .- p
g = function (p)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, TrustRegion())
sol = solve(probN, TrustRegion(), abstol = 1e-10)
return sol.u
end
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
Expand Down Expand Up @@ -295,7 +295,7 @@ for options in list_of_options
expand_factor = options[7],
max_shrink_times = options[8])

probN = NonlinearProblem(f, u0, p)
sol = solve(probN, alg)
@test all(f(u, p) .< 1e-10)
probN = NonlinearProblem{false}(f, u0, p)
sol = solve(probN, alg, abstol = 1e-10)
@test all(abs.(f(u, p)) .< 1e-10)
end

0 comments on commit c737bd2

Please sign in to comment.