diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index c6175a52e..00424ffbc 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -643,8 +643,8 @@ Same as `check_error` but also set solution's return code """ function check_error!(integrator::DEIntegrator) code = check_error(integrator) + integrator.sol = solution_new_retcode(integrator.sol, code) if code != ReturnCode.Success - integrator.sol = solution_new_retcode(integrator.sol, code) postamble!(integrator) end return code diff --git a/test/integrator_tests.jl b/test/integrator_tests.jl index 23fa9fb8f..dc46747ac 100644 --- a/test/integrator_tests.jl +++ b/test/integrator_tests.jl @@ -1,8 +1,11 @@ using SciMLBase -mutable struct DummySolution - retcode::Any + +struct DummySolution + retcode::SciMLBase.ReturnCode.T end +SciMLBase.solution_new_retcode(::DummySolution, code) = DummySolution(code) + mutable struct DummyIntegrator{Alg, IIP, U, T} <: SciMLBase.DEIntegrator{Alg, IIP, U, T} uprev::U tprev::T @@ -46,6 +49,9 @@ function SciMLBase.done(integrator::DummyIntegrator) integrator.t > 10 end +SciMLBase.check_error(::DummyIntegrator) = ReturnCode.Success +SciMLBase.postamble!(::DummyIntegrator) = nothing + integrator = DummyIntegrator() @test step_dt!(integrator, 1.5) == 2 @test step_dt!(integrator, 1.5, true) == 1.5 @@ -62,3 +68,14 @@ for (uprev, tprev, u, t) in intervals(DummyIntegrator()) end @test eltype(collect(intervals(DummyIntegrator()))) == Tuple{Vector{Float64}, Float64, Vector{Float64}, Float64} + +@test integrator.sol.retcode == ReturnCode.Default +@test check_error(integrator) == ReturnCode.Success +@test integrator.sol.retcode == ReturnCode.Default +@test SciMLBase.check_error!(integrator) == ReturnCode.Success +@test integrator.sol.retcode == ReturnCode.Success + +let + integrator = DummyIntegrator() + @test 0 == @allocated SciMLBase.check_error!(integrator) +end