Skip to content

Commit

Permalink
Add support for BigFloat
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 25, 2024
1 parent 05fb063 commit 2376418
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ CUDA = "5.2"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.149.0"
Enzyme = "0.12"
ExplicitImports = "1.4.4"
ExplicitImports = "1.5"
FastBroadcast = "0.2.8, 0.3"
FastClosures = "0.3.2"
FastLevenbergMarquardt = "0.1"
Expand All @@ -79,7 +79,7 @@ LineSearches = "7.2"
LinearAlgebra = "1.10"
LinearSolve = "2.30"
MINPACK = "1.2"
MaybeInplace = "0.1.1"
MaybeInplace = "0.1.3"
ModelingToolkit = "9.13.0"
NLSolvers = "0.5"
NLsolve = "4.5"
Expand Down
7 changes: 4 additions & 3 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,14 @@ function __show_cache(io::IO, cache::AbstractNonlinearSolveCache, indent = 0)
__show_algorithm(io, cache.alg,
(" "^(indent + 4)) * "alg = " * string(get_name(cache.alg)), indent + 4)

ustr = sprint(show, get_u(cache); context=(:compact=>true, :limit=>true))
ustr = sprint(show, get_u(cache); context = (:compact => true, :limit => true))
println(io, ",\n" * (" "^(indent + 4)) * "u = $(ustr),")

residstr = sprint(show, get_fu(cache); context=(:compact=>true, :limit=>true))
residstr = sprint(show, get_fu(cache); context = (:compact => true, :limit => true))
println(io, (" "^(indent + 4)) * "residual = $(residstr),")

normstr = sprint(show, norm(get_fu(cache), Inf); context=(:compact=>true, :limit=>true))
normstr = sprint(
show, norm(get_fu(cache), Inf); context = (:compact => true, :limit => true))
println(io, (" "^(indent + 4)) * "inf-norm(residual) = $(normstr),")

println(io, " "^(indent + 4) * "nsteps = ", cache.stats.nsteps, ",")
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ end

function BroydenLowRankJacobian(fu, u; threshold::Int = 10, alpha = true)
T = promote_type(eltype(u), eltype(fu))
U = similar(fu, T, length(fu), threshold)
Vᵀ = similar(u, T, length(u), threshold)
cache = similar(u, T, threshold)
U = __similar(fu, T, length(fu), threshold)
Vᵀ = __similar(u, T, length(u), threshold)
cache = __similar(u, T, threshold)
return BroydenLowRankJacobian{T}(U, Vᵀ, 0, cache, T(alpha))
end

Expand Down
6 changes: 1 addition & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
alias_u0 = false # If immutable don't care about aliasing
end
u0 = prob.u0
if alias_u0
u0_aliased = similar(u0)
else
u0_aliased = u0 # Irrelevant
end
u0_aliased = alias_u0 ? __similar(u0) : u0
end]
for i in 1:N
cur_sol = sol_syms[i]
Expand Down
4 changes: 2 additions & 2 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function __internal_init(
else
if SciMLBase.has_jvp(f)
if isinplace(prob)
g_cache = similar(u)
g_cache = __similar(u)

Check warning on line 118 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L118

Added line #L118 was not covered by tests
grad_op = @closure (u, fu, p) -> f.vjp(g_cache, fu, u, p)
else
grad_op = @closure (u, fu, p) -> f.vjp(fu, u, p)
Expand All @@ -125,7 +125,7 @@ function __internal_init(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
g_cache = similar(u)
g_cache = __similar(u)

Check warning on line 128 in src/globalization/line_search.jl

View check run for this annotation

Codecov / codecov/patch

src/globalization/line_search.jl#L128

Added line #L128 was not covered by tests
grad_op = @closure (u, fu, p) -> vjp_op(g_cache, fu, u, p)
else
grad_op = @closure (u, fu, p) -> vjp_op(fu, u, p)
Expand Down
2 changes: 1 addition & 1 deletion src/internal/approximate_initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function __internal_init(
@assert length(u)==length(fu) "Diagonal Jacobian Structure must be square!"
J = one.(_vec(fu)) .* α
else
J_ = similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
J_ = __similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u))
J = alg.structure(__make_identity!!(J_, α); alias = true)
end
return InitializedApproximateJacobianCache(
Expand Down
4 changes: 2 additions & 2 deletions src/internal/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
function evaluate_f(prob::AbstractNonlinearProblem{uType, iip}, u) where {uType, iip}
(; f, u0, p) = prob
if iip
fu = f.resid_prototype === nothing ? similar(u) :
fu = f.resid_prototype === nothing ? __similar(u) :
promote_type(eltype(u), eltype(f.resid_prototype)).(f.resid_prototype)
f(fu, u, p)
else
Expand Down Expand Up @@ -154,7 +154,7 @@ function __construct_extension_f(prob::AbstractNonlinearProblem; alias_u0::Bool

𝐅 = if force_oop === True && applicable(𝐟, u0, u0)
_resid = resid isa Number ? [resid] : _vec(resid)
du = _vec(similar(_resid))
du = _vec(__similar(_resid))

Check warning on line 157 in src/internal/helpers.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/helpers.jl#L157

Added line #L157 was not covered by tests
@closure u -> begin
𝐟(du, u)
return du
Expand Down
5 changes: 3 additions & 2 deletions src/internal/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,11 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
else
if has_analytic_jac
f.jac_prototype === nothing ?
similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
copy(f.jac_prototype)
elseif f.jac_prototype === nothing
init_jacobian(jac_cache; preserve_immutable = Val(true))
__init_bigfloat_array!!(init_jacobian(
jac_cache; preserve_immutable = Val(true)))
else
f.jac_prototype
end
Expand Down
16 changes: 8 additions & 8 deletions src/internal/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
@closure (v, u, p) -> auto_vecjac(uf, u, v)
elseif vjp_autodiff isa AutoFiniteDiff
if iip
cache1 = similar(fu)
cache2 = similar(fu)
cache1 = __similar(fu)
cache2 = __similar(fu)
@closure (Jv, v, u, p) -> num_vecjac!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> num_vecjac(uf, __mutable(u), v)
Expand Down Expand Up @@ -106,17 +106,17 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
if iip
# FIXME: Technically we should propagate the tag but ignoring that for now
cache1 = Dual{typeof(ForwardDiff.Tag(uf, eltype(u))), eltype(u),
1}.(similar(u), ForwardDiff.Partials.(tuple.(u)))
1}.(__similar(u), ForwardDiff.Partials.(tuple.(u)))
cache2 = Dual{typeof(ForwardDiff.Tag(uf, eltype(fu))), eltype(fu),
1}.(similar(fu), ForwardDiff.Partials.(tuple.(fu)))
1}.(__similar(fu), ForwardDiff.Partials.(tuple.(fu)))
@closure (Jv, v, u, p) -> auto_jacvec!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> auto_jacvec(uf, u, v)
end
elseif jvp_autodiff isa AutoFiniteDiff
if iip
cache1 = similar(fu)
cache2 = similar(u)
cache1 = __similar(fu)
cache2 = __similar(u)

Check warning on line 119 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
@closure (Jv, v, u, p) -> num_jacvec!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> num_jacvec(uf, u, v)
Expand Down Expand Up @@ -162,15 +162,15 @@ end
function (op::JacobianOperator{vjp, iip})(v, u, p) where {vjp, iip}
if vjp
if iip
res = similar(op.output_cache)
res = __similar(op.output_cache)

Check warning on line 165 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L165

Added line #L165 was not covered by tests
op.vjp_op(res, v, u, p)
return res
else
return op.vjp_op(v, u, p)
end
else
if iip
res = similar(op.output_cache)
res = __similar(op.output_cache)

Check warning on line 173 in src/internal/operators.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/operators.jl#L173

Added line #L173 was not covered by tests
op.jvp_op(res, v, u, p)
return res
else
Expand Down
13 changes: 13 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,16 @@ function __reinit_internal!(stats::NLStats)
stats.njacs = 0
stats.nsolve = 0

Check warning on line 161 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L156-L161

Added lines #L156 - L161 were not covered by tests
end

function __similar(x, args...; kwargs...)
y = similar(x, args...; kwargs...)
return __init_bigfloat_array!!(y)
end

function __init_bigfloat_array!!(x)
if ArrayInterface.can_setindex(x)
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
return x
end
return x
end

0 comments on commit 2376418

Please sign in to comment.