Skip to content

Commit

Permalink
Fix the tracing for NLLS
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 31, 2024
1 parent 63bd4a9 commit 94016b5
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 52 deletions.
2 changes: 1 addition & 1 deletion src/core/approximate_jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ function SciMLBase.__init(
update_rule_cache = __internal_init(
prob, alg.update_rule, J, fu, u, du; internalnorm)

trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du;
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du;
uses_jacobian_inverse = Val(INV), kwargs...)

return ApproximateJacobianSolveCache{INV, GB, iip, maxtime !== nothing}(
Expand Down
2 changes: 1 addition & 1 deletion src/core/generalized_first_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ function SciMLBase.__init(
GB = :LineSearch
end

trace = init_nonlinearsolve_trace(alg, u, fu, ApplyArray(__zero, J), du; kwargs...)
trace = init_nonlinearsolve_trace(prob, alg, u, fu, ApplyArray(__zero, J), du; kwargs...)

return GeneralizedFirstOrderAlgorithmCache{iip, GB, maxtime !== nothing}(
fu, u, u_cache, p, du, J, alg, prob, jac_cache, descent_cache, linesearch_cache,
Expand Down
2 changes: 1 addition & 1 deletion src/core/spectral_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ function SciMLBase.__init(prob::AbstractNonlinearProblem, alg::GeneralizedDFSane

abstol, reltol, tc_cache = init_termination_cache(
prob, abstol, reltol, fu, u_cache, termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)
trace = init_nonlinearsolve_trace(prob, alg, u, fu, nothing, du; kwargs...)

if alg.σ_1 === nothing
σ_n = dot(u, u) / dot(u, fu)
Expand Down
81 changes: 49 additions & 32 deletions src/internal/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ for Tr in (:TraceMinimal, :TraceWithJacobianConditionNumber, :TraceAll)
end

# NonlinearSolve Tracing Utilities
@concrete struct NonlinearSolveTraceEntry
@concrete struct NonlinearSolveTraceEntry{nType}
iteration::Int
fnorm
stepnorm
Expand All @@ -63,19 +63,27 @@ end
δu
end

function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry)
function __show_top_level(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
if entry.condJ === nothing
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
if nType === :L2
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm"
else
@printf io "%-8s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm"
end
@printf io "%-8s %-20s %-20s\n" "----" "-------------" "-----------"
else
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
if nType === :L2
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) 2-norm" "Step 2-norm" "cond(J)"
else
@printf io "%-8s %-20s %-20s %-20s\n" "Iter" "f(u) inf-norm" "Step 2-norm" "cond(J)"
end
@printf io "%-8s %-20s %-20s %-20s\n" "----" "-------------" "-----------" "-------"
end
end

function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
function Base.show(io::IO, entry::NonlinearSolveTraceEntry{nType}) where {nType}
entry.iteration == 0 && __show_top_level(io, entry)
if entry.iteration < 0
# Special case for final entry
Expand All @@ -89,25 +97,32 @@ function Base.show(io::IO, entry::NonlinearSolveTraceEntry)
return nothing
end

function NonlinearSolveTraceEntry(iteration, fu, δu)
return NonlinearSolveTraceEntry(
iteration, norm(fu, Inf), norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu)
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
return NonlinearSolveTraceEntry{nType}(
iteration, fnorm, norm(δu, 2), nothing, nothing, nothing, nothing, nothing)
end

function NonlinearSolveTraceEntry(iteration, fu, δu, J)
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2),
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J)
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
return NonlinearSolveTraceEntry{nType}(iteration, fnorm, norm(δu, 2),
__cond(J), nothing, nothing, nothing, nothing)
end

function NonlinearSolveTraceEntry(iteration, fu, δu, J, u)
return NonlinearSolveTraceEntry(iteration, norm(fu, Inf), norm(δu, 2), __cond(J),
function NonlinearSolveTraceEntry(prob::AbstractNonlinearProblem, iteration, fu, δu, J, u)
nType = ifelse(prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
return NonlinearSolveTraceEntry{nType}(iteration, fnorm, norm(δu, 2), __cond(J),
__copy(J), __copy(u), __copy(fu), __copy(δu))
end

@concrete struct NonlinearSolveTrace{
show_trace, store_trace, Tr <: AbstractNonlinearSolveTraceLevel}
history
trace_level::Tr
prob
end

function reset!(trace::NonlinearSolveTrace)
Expand All @@ -123,61 +138,63 @@ function Base.show(io::IO, trace::NonlinearSolveTrace)
return nothing
end

function init_nonlinearsolve_trace(alg, u, fu, J, δu; show_trace::Val = Val(false),
function init_nonlinearsolve_trace(prob, alg, u, fu, J, δu; show_trace::Val = Val(false),
trace_level::AbstractNonlinearSolveTraceLevel = TraceMinimal(),
store_trace::Val = Val(false), uses_jac_inverse = Val(false), kwargs...)
return init_nonlinearsolve_trace(
alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
prob, alg, show_trace, trace_level, store_trace, u, fu, J, δu, uses_jac_inverse)
end

function init_nonlinearsolve_trace(
alg, ::Val{show_trace}, trace_level::AbstractNonlinearSolveTraceLevel,
::Val{store_trace}, u, fu, J, δu,
::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
function init_nonlinearsolve_trace(prob::AbstractNonlinearProblem, alg, ::Val{show_trace},
trace_level::AbstractNonlinearSolveTraceLevel, ::Val{store_trace}, u, fu, J,
δu, ::Val{uses_jac_inverse}) where {show_trace, store_trace, uses_jac_inverse}
if show_trace
print("\nAlgorithm: ")
Base.printstyled(alg, "\n\n"; color = :green, bold = true)
end
J_ = uses_jac_inverse ? (trace_level isa TraceMinimal ? J : __safe_inv(J)) : J
history = __init_trace_history(
Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level)
prob, Val{show_trace}(), trace_level, Val{store_trace}(), u, fu, J_, δu)
return NonlinearSolveTrace{show_trace, store_trace}(history, trace_level, prob)
end

function __init_trace_history(::Val{show_trace}, trace_level, ::Val{store_trace},
u, fu, J, δu) where {show_trace, store_trace}
function __init_trace_history(
prob::AbstractNonlinearProblem, ::Val{show_trace}, trace_level,
::Val{store_trace}, u, fu, J, δu) where {show_trace, store_trace}
!store_trace && !show_trace && return nothing
entry = __trace_entry(trace_level, 0, u, fu, J, δu)
entry = __trace_entry(prob, trace_level, 0, u, fu, J, δu)
show_trace && show(entry)
store_trace && return NonlinearSolveTraceEntry[entry]
return nothing
end

function __trace_entry(::TraceMinimal, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(iter, fu, δu .* α)
function __trace_entry(prob, ::TraceMinimal, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α)
end
function __trace_entry(::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J)
function __trace_entry(prob, ::TraceWithJacobianConditionNumber, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J)
end
function __trace_entry(::TraceAll, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(iter, fu, δu .* α, J, u)
function __trace_entry(prob, ::TraceAll, iter, u, fu, J, δu, α = 1)
return NonlinearSolveTraceEntry(prob, iter, fu, δu .* α, J, u)
end

function update_trace!(trace::NonlinearSolveTrace{ShT, StT}, iter, u, fu, J, δu,
α = 1; last::Val{L} = Val(false)) where {ShT, StT, L}
!StT && !ShT && return nothing

if L
entry = NonlinearSolveTraceEntry(
-1, norm(fu, Inf), NaN32, nothing, nothing, nothing, nothing, nothing)
nType = ifelse(trace.prob isa NonlinearLeastSquaresProblem, :L2, :Inf)
fnorm = trace.prob isa NonlinearLeastSquaresProblem ? norm(fu, 2) : norm(fu, Inf)
entry = NonlinearSolveTraceEntry{nType}(
-1, fnorm, NaN32, nothing, nothing, nothing, nothing, nothing)
ShT && show(entry)
return trace
end

show_now = ShT && (mod1(iter, trace.trace_level.print_frequency) == 1)
store_now = StT && (mod1(iter, trace.trace_level.store_frequency) == 1)
(show_now || store_now) &&
(entry = __trace_entry(trace.trace_level, iter, u, fu, J, δu, α))
(entry = __trace_entry(trace.prob, trace.trace_level, iter, u, fu, J, δu, α))
store_now && push!(trace.history, entry)
show_now && show(entry)
return trace
Expand Down
1 change: 0 additions & 1 deletion test/core/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ end
gs = abs.(ForwardDiff.derivative(solve_with(Val{mode}(), u0, alg), p))
gs_true = abs.(jacobian_f(u0, p))
if !(isapprox(gs, gs_true, atol = 1e-5))
@show sol.retcode, sol.u
@error "ForwardDiff Failed for u0=$(u0) and p=$(p) with $(alg)" forwardiff_gradient=gs true_gradient=gs_true
else
@test abs.(gs)abs.(gs_true) atol=1e-5
Expand Down
17 changes: 9 additions & 8 deletions test/core/nlls_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ using Reexport
true_function(x, θ) = @. θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4])
true_function(y, x, θ) = (@. y = θ[1] * exp(θ[2] * x) * cos(θ[3] * x + θ[4]))

θ_true = [1.0, 0.1, 2.0, 0.5]
const θ_true = [1.0, 0.1, 2.0, 0.5]

x = [-1.0, -0.5, 0.0, 0.5, 1.0]
const x = [-1.0, -0.5, 0.0, 0.5, 1.0]

y_target = true_function(x, θ_true)
const y_target = true_function(x, θ_true)

function loss_function(θ, p)
= true_function(p, θ)
Expand All @@ -23,7 +23,7 @@ function loss_function(resid, θ, p)
return resid
end

θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1
const θ_init = θ_true .+ randn!(StableRNG(0), similar(θ_true)) * 0.1

solvers = []
for linsolve in [nothing, LUFactorization(), KrylovJL_GMRES(), KrylovJL_LSMR()]
Expand Down Expand Up @@ -56,9 +56,9 @@ end
nlls_problems = [prob_oop, prob_iip]

for prob in nlls_problems, solver in solvers
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
@test SciMLBase.successful_retcode(sol)
@test maximum(abs, sol.resid) < 1e-6
@test norm(sol.resid, 2) < 1e-6
end
end

Expand Down Expand Up @@ -90,8 +90,9 @@ end
x)]

for prob in probs, solver in solvers
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
@test maximum(abs, sol.resid) < 1e-6
sol = solve(prob, solver; maxiters = 10000, abstol = 1e-6)
@test SciMLBase.successful_retcode(sol)
@test norm(abs, 2) < 1e-6
end
end

Expand Down
16 changes: 8 additions & 8 deletions test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ end
@test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0))
end

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -238,7 +238,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -324,7 +324,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -395,7 +395,7 @@ end
end
end

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -462,7 +462,7 @@ end
sqrt(2.0))
end

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -514,7 +514,7 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false), Broyden()) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true), Broyden()) sqrt.(p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -563,7 +563,7 @@ end
@test nlprob_iterator_interface(quadratic_f, p, Val(false), Klement()) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true), Klement()) sqrt.(p)

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down Expand Up @@ -613,7 +613,7 @@ end
@test nlprob_iterator_interface(
quadratic_f!, p, Val(true), LimitedMemoryBroyden())sqrt.(p) atol=1e-2

@testset "Termination condition: $(termination_condition) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
@testset "Termination condition: $(_nameof(termination_condition)) u0: $(_nameof(u0))" for termination_condition in TERMINATION_CONDITIONS,
u0 in (1.0, [1.0, 1.0])

probN = NonlinearProblem(quadratic_f, u0, 2.0)
Expand Down

0 comments on commit 94016b5

Please sign in to comment.