diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 73b78b00f..4582eb1f4 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -42,6 +42,8 @@ struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode end end +@inline set_termination_mode_internalnorm(mode, ::F) where {F} = mode + @inline __norm_type(::typeof(Base.Fix2(norm, Inf))) = :Inf @inline __norm_type(::typeof(Base.Fix1(maximum, abs))) = :Inf @inline __norm_type(::typeof(Base.Fix2(norm, 2))) = :L2 @@ -98,6 +100,11 @@ for name in (:Norm, :RelNorm, :AbsNorm) $(struct_name)(f::F = nothing) where {F} = new{__norm_type(f), F}(f) end + + @inline function set_termination_mode_internalnorm( + ::$(struct_name), internalnorm::F) where {F} + return $(struct_name)(internalnorm) + end end end @@ -143,6 +150,13 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest) patience_objective_multiplier, min_max_factor, max_stalled_steps) end end + + @inline function set_termination_mode_internalnorm( + mode::$(struct_name), internalnorm::F) where {F} + return $(struct_name)(internalnorm; mode.protective_threshold, + mode.patience_steps, mode.patience_objective_multiplier, + mode.min_max_factor, mode.max_stalled_steps) + end end end diff --git a/src/utils.jl b/src/utils.jl index f7e698cc3..fa4488e8b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -83,7 +83,10 @@ end @inline function __norm_op(::typeof(Base.Fix2(norm, 2)), op::F, x, y) where {F} if __fast_scalar_indexing(x, y) - return sqrt(sum(@closure((xᵢyᵢ)->(op(xᵢ, yᵢ)^2)), zip(x, y))) + return sqrt(sum(@closure((xᵢyᵢ)->begin + xᵢ, yᵢ = xᵢyᵢ + return op(xᵢ, yᵢ)^2 + end), zip(x, y))) else return sqrt(mapreduce(@closure((xᵢ, yᵢ)->(op(xᵢ, yᵢ)^2)), +, x, y)) end @@ -104,7 +107,8 @@ end @inline function __add_and_norm(::Nothing, x, y) Base.depwarn("Not specifying the internal norm of termination conditions has been \ - deprecated. Using inf-norm currently.", :__add_and_norm) + deprecated. Using inf-norm currently.", + :__add_and_norm) return __maximum_abs(+, x, y) end @inline __add_and_norm(::typeof(Base.Fix1(maximum, abs)), x, y) = __maximum_abs(+, x, y) @@ -113,7 +117,8 @@ end @inline function __apply_termination_internalnorm(::Nothing, u) Base.depwarn("Not specifying the internal norm of termination conditions has been \ - deprecated. Using inf-norm currently.", :__apply_termination_internalnorm) + deprecated. Using inf-norm currently.", + :__apply_termination_internalnorm) return __apply_termination_internalnorm(Base.Fix1(maximum, abs), u) end @inline __apply_termination_internalnorm(f::F, u) where {F} = f(u) diff --git a/test/termination_conditions.jl b/test/termination_conditions.jl index 3403262f0..70a4711c9 100644 --- a/test/termination_conditions.jl +++ b/test/termination_conditions.jl @@ -1,4 +1,4 @@ -using BenchmarkTools, DiffEqBase, Test +using BenchmarkTools, DiffEqBase, LinearAlgebra, Test du = rand(4) u = rand(4) @@ -6,14 +6,17 @@ uprev = rand(4) const TERMINATION_CONDITIONS = [ SteadyStateDiffEqTerminationMode(), SimpleNonlinearSolveTerminationMode(), - NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(), + RelTerminationMode(), NormTerminationMode(), RelNormTerminationMode(), AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(), AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode() ] @testset "Termination Conditions: Allocations" begin @testset "Mode: $(tcond)" for tcond in TERMINATION_CONDITIONS - @test (@ballocated DiffEqBase.check_convergence($tcond, $du, $u, $uprev, 1e-3, - 1e-3)) == 0 + for nfn in (Base.Fix1(maximum, abs), Base.Fix2(norm, 2), Base.Fix2(norm, Inf)) + tcond = DiffEqBase.set_termination_mode_internalnorm(tcond, nfn) + @test (@ballocated DiffEqBase.check_convergence($tcond, $du, $u, $uprev, 1e-3, + 1e-3)) == 0 + end end end