Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 24, 2022
1 parent 4bd17e1 commit 651d56e
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
1 change: 0 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
for alg in (NewtonRaphson,)
solve(prob, alg(), abstol = T(1e-2))
end

end end

export NewtonRaphson
Expand Down
6 changes: 4 additions & 2 deletions src/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return sol, partials
end

function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:Dual{T, V, P}}, alg::NewtonRaphson,
args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
return SciMLBase.build_solution(prob, alg, Dual{T, V, P}(sol.u, partials), sol.resid;
retcode = sol.retcode)
end
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector}, iip,
function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, StaticArraysCore.SVector},
iip,
<:AbstractArray{<:Dual{T, V, P}}},
alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ function perform_step!(cache::NewtonRaphsonCache{true})

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve, A = J, b = fu, linu = du1,
p = p, reltol = cache.abstol)
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du1
f(fu, u, p)
Expand Down
9 changes: 7 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
@inline function DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {T <: Union{AbstractFloat, Complex}}
@inline function DEFAULT_NORM(u::StaticArraysCore.StaticArray{T}) where {
T <: Union{
AbstractFloat,
Complex}}
sqrt(real(sum(abs2, u)) / length(u))
end
@inline function DEFAULT_NORM(u::RecursiveArrayTools.AbstractVectorOfArray)
Expand All @@ -28,7 +31,9 @@ function value_derivative(f::F, x::R) where {F, R}
end

# Todo: improve this dispatch
value_derivative(f::F, x::StaticArraysCore.SVector) where {F} = f(x), ForwardDiff.jacobian(f, x)
function value_derivative(f::F, x::StaticArraysCore.SVector) where {F}
f(x), ForwardDiff.jacobian(f, x)
end

value(x) = x
value(x::Dual) = ForwardDiff.value(x)
Expand Down

0 comments on commit 651d56e

Please sign in to comment.