Skip to content

Commit

Permalink
Fixing load time for TrustRegion
Browse files Browse the repository at this point in the history
  • Loading branch information
Deltadahl committed Jan 22, 2023
1 parent 6f92187 commit 518747a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
prob = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))

precompile_algs = if VERSION >= v"1.7"
(NewtonRaphson(),)
(NewtonRaphson(), TrustRegion(), LevenbergMarquardt())
else
(NewtonRaphson(),)
end
Expand Down
18 changes: 13 additions & 5 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,19 +232,27 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
max_trust_radius = alg.max_trust_radius
initial_trust_radius = alg.initial_trust_radius
if max_trust_radius == 0.0
max_trust_radius = max(norm(fu), maximum(u) - minimum(u))
max_trust_radius = convert(typeof(max_trust_radius),
max(norm(fu), maximum(u) - minimum(u)))
end
if initial_trust_radius == 0.0
initial_trust_radius = max_trust_radius / 11
end

loss_new = loss
H = ArrayInterfaceCore.undefmatrix(u)
g = zero(fu)
shrink_counter = 0
step_size = zero(u)
fu_new = zero(fu)
make_new_J = true
r = loss

return TrustRegionCache{iip}(f, alg, u, fu, p, uf, linsolve, J, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob, initial_trust_radius,
max_trust_radius, loss, loss, H, zero(fu), 0, zero(u),
u_tmp, zero(fu), true,
loss)
max_trust_radius, loss, loss_new, H, g, shrink_counter,
step_size, u_tmp, fu_new, make_new_J, r)
end

function perform_step!(cache::TrustRegionCache{true})
Expand Down Expand Up @@ -293,7 +301,7 @@ function perform_step!(cache::TrustRegionCache{false})
end

function trust_region_step!(cache::TrustRegionCache)
@unpack fu_new, u_tmp, step_size, g, H, loss, alg, max_trust_r = cache
@unpack fu_new, step_size, g, H, loss, alg, max_trust_r = cache
cache.loss_new = get_loss(fu_new)

# Compute the ratio of the actual reduction to the predicted reduction.
Expand Down

0 comments on commit 518747a

Please sign in to comment.