From 651d56e2593ec682220cfd74dcb784b2f9bc2fb8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Thu, 24 Nov 2022 11:20:19 +0100 Subject: [PATCH] format --- src/NonlinearSolve.jl | 1 - src/ad.jl | 6 ++++-- src/raphson.jl | 2 +- src/utils.jl | 9 +++++++-- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 307ec8144..d82680481 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -44,7 +44,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64) for alg in (NewtonRaphson,) solve(prob, alg(), abstol = T(1e-2)) end - end end export NewtonRaphson diff --git a/src/ad.jl b/src/ad.jl index ada5a2862..a2b09d569 100644 --- a/src/ad.jl +++ b/src/ad.jl @@ -23,14 +23,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...) return sol, partials end -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip, +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, + iip, <:Dual{T, V, P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid; retcode = sol.retcode) end -function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip, +function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, + iip, <:AbstractArray{<:Dual{T, V, P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) diff --git a/src/raphson.jl b/src/raphson.jl index dec5c0f4e..d61878683 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -120,7 +120,7 @@ function perform_step!(cache::NewtonRaphsonCache{true}) # u = u - J \ fu linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1, - p = p, reltol = cache.abstol) + p = p, reltol = cache.abstol) cache.linsolve = linres.cache @. u = u - du1 f(fu, u, p) diff --git a/src/utils.jl b/src/utils.jl index 2651817fd..2bcc6334d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,7 +4,10 @@ @inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}} sqrt(real(sum(abs2, u)) / length(u)) end -@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {T <: Union{AbstractFloat, Complex}} +@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where { + T <: Union{ + AbstractFloat, + Complex}} sqrt(real(sum(abs2, u)) / length(u)) end @inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray) @@ -28,7 +31,9 @@ function value_derivative(f::F, x::R) where {F, R} end # Todo: improve this dispatch -value_derivative(f::F, x::StaticArraysCore.SVector) where {F} = f(x), ForwardDiff.jacobian(f, x) +function value_derivative(f::F, x::StaticArraysCore.SVector) where {F} + f(x), ForwardDiff.jacobian(f, x) +end value(x) = x value(x::Dual) = ForwardDiff.value(x)