diff --git a/src/stats.jl b/src/stats.jl index f3ee705..537ac87 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -65,6 +65,7 @@ abstract type AbstractExecutionStats end """ GenericExecutionStats(nlp; ...) + GenericExecutionStats{T, S, V, Tsp}(;...) A GenericExecutionStats is a struct for storing the output information of solvers. It contains the following fields: @@ -100,7 +101,9 @@ the field value as reliable. The `reset!()` method marks all fields as unreliable. -`nlp` is mandatory to set default optional fields. +`nlp` is highly recommended to set default optional fields. +If it is not provided, the function `reset!(stats, nlp)` should be called before `solve!`. + All other variables can be input as keyword arguments. Notice that `GenericExecutionStats` does not compute anything, it simply stores. @@ -129,6 +132,44 @@ mutable struct GenericExecutionStats{T, S, V, Tsp} <: AbstractExecutionStats solver_specific::Dict{Symbol, Tsp} end +function GenericExecutionStats{T, S, V, Tsp}(; + status::Symbol = :unknown, + solution::S = S(), + objective::T = T(Inf), + dual_feas::T = T(Inf), + primal_feas::T = T(Inf), + multipliers::S = S(), + multipliers_L::V = V(), + multipliers_U::V = V(), + iter::Int = -1, + elapsed_time::Real = Inf, + solver_specific::Dict{Symbol, Tsp} = Dict{Symbol, Any}(), +) where {T, S, V, Tsp} + return GenericExecutionStats{T, S, V, Tsp}( + false, + status, + false, + solution, + false, + objective, + false, + dual_feas, + false, + primal_feas, + false, + multipliers, + false, + multipliers_L, + multipliers_U, + false, + iter, + false, + elapsed_time, + false, + solver_specific, + ) +end + function GenericExecutionStats( nlp::AbstractNLPModel{T, S}; status::Symbol = :unknown, @@ -171,9 +212,12 @@ end """ reset!(stats::GenericExecutionStats) + reset!(stats::GenericExecutionStats, nlp::AbstractNLPModel) Reset the internal flags of `stats` to `false` to Indicate that the contents should not be trusted. +If an `AbstractNLPModel` is also provided, +the pre-allocated vectors are adjusted to the problem size. """ function NLPModels.reset!(stats::GenericExecutionStats) stats.status_reliable = false @@ -189,6 +233,15 @@ function NLPModels.reset!(stats::GenericExecutionStats) stats end +function NLPModels.reset!(stats::GenericExecutionStats{T, S}, nlp::AbstractNLPModel{T, S}) where {T, S} + stats.solution = similar(nlp.meta.x0) + stats.multipliers = similar(nlp.meta.y0) + stats.multipliers_L = similar(nlp.meta.y0, has_bounds(nlp) ? nlp.meta.nvar : 0) + stats.multipliers_U = similar(nlp.meta.y0, has_bounds(nlp) ? nlp.meta.nvar : 0) + reset!(stats) + stats +end + """ set_status!(stats::GenericExecutionStats, status::Symbol) diff --git a/test/test_stats.jl b/test/test_stats.jl index 66078ee..fd7da40 100644 --- a/test/test_stats.jl +++ b/test/test_stats.jl @@ -55,6 +55,15 @@ function test_stats() @test typeof(stats.objective) == T @test typeof(stats.dual_feas) == T @test typeof(stats.primal_feas) == T + + S = Vector{T} + stats = GenericExecutionStats{T, S, S, Any}() + set_status!(stats, :first_order) + @test stats.status == :first_order + @test stats.status_reliable + @test typeof(stats.objective) == T + @test typeof(stats.dual_feas) == T + @test typeof(stats.primal_feas) == T end end @@ -62,6 +71,9 @@ function test_stats() stats = GenericExecutionStats(nlp) @test_throws Exception set_status!(stats, :bad) @test_throws Exception GenericExecutionStats(:unkwown, nlp, bad = true) + + stats = GenericExecutionStats{Float64, Vector{Float64}, Vector{Float64}, Any}() + @test_throws Exception set_status!(stats, :bad) end @testset "Testing Dummy Solver with multi-precision" begin @@ -94,6 +106,19 @@ function test_stats() @test eltype(stats.multipliers) == T @test eltype(stats.multipliers_L) == T @test eltype(stats.multipliers_U) == T + + stats = GenericExecutionStats{T, Vector{T}, Vector{T}, Any}() + reset!(stats, nlp) + with_logger(NullLogger()) do + solve!(solver, nlp, stats) + end + @test typeof(stats.objective) == T + @test typeof(stats.dual_feas) == T + @test typeof(stats.primal_feas) == T + @test eltype(stats.solution) == T + @test eltype(stats.multipliers) == T + @test eltype(stats.multipliers_L) == T + @test eltype(stats.multipliers_U) == T end end