Skip to content

Commit

Permalink
Better BigFloat support
Browse files Browse the repository at this point in the history
Signed-off-by: ErikQQY <[email protected]>
  • Loading branch information
ErikQQY committed Aug 19, 2024
1 parent 8dec0c9 commit 86e6720
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ for (probType, pType) in ((:NonlinearProblem, :NLS), (:NonlinearLeastSquaresProb
alias_u0 = false # If immutable don't care about aliasing
end
u0 = prob.u0
u0_aliased = alias_u0 ? __similar(u0) : u0
u0_aliased = alias_u0 ? zero(u0) : u0
end]
for i in 1:N
cur_sol = sol_syms[i]
Expand Down
8 changes: 4 additions & 4 deletions src/globalization/line_search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function __internal_init(
deriv_op = nothing
elseif SciMLBase.has_jvp(f)
if isinplace(prob)
jvp_cache = __similar(fu)
jvp_cache = zero(fu)
deriv_op = @closure (du, u, fu, p) -> begin
f.jvp(jvp_cache, du, u, p)
dot(fu, jvp_cache)
Expand All @@ -135,7 +135,7 @@ function __internal_init(
end
elseif SciMLBase.has_vjp(f)
if isinplace(prob)
vjp_cache = __similar(u)
vjp_cache = zero(u)
deriv_op = @closure (du, u, fu, p) -> begin
f.vjp(vjp_cache, fu, u, p)
dot(du, vjp_cache)
Expand All @@ -149,7 +149,7 @@ function __internal_init(
alg.autodiff, prob; check_reverse_mode = true)
vjp_op = VecJacOperator(prob, fu, u; autodiff)
if isinplace(prob)
vjp_cache = __similar(u)
vjp_cache = zero(u)
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(vjp_cache, fu, u, p))
else
deriv_op = @closure (du, u, fu, p) -> dot(du, vjp_op(fu, u, p))
Expand All @@ -159,7 +159,7 @@ function __internal_init(
alg.autodiff, prob; check_forward_mode = true)
jvp_op = JacVecOperator(prob, fu, u; autodiff)
if isinplace(prob)
jvp_cache = __similar(fu)
jvp_cache = zero(fu)
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(jvp_cache, du, u, p))
else
deriv_op = @closure (du, u, fu, p) -> dot(fu, jvp_op(du, u, p))
Expand Down
4 changes: 2 additions & 2 deletions src/internal/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
function evaluate_f(prob::AbstractNonlinearProblem{uType, iip}, u) where {uType, iip}
(; f, u0, p) = prob
if iip
fu = f.resid_prototype === nothing ? __similar(u) :
fu = f.resid_prototype === nothing ? zero(u) :
promote_type(eltype(u), eltype(f.resid_prototype)).(f.resid_prototype)
f(fu, u, p)
else
Expand Down Expand Up @@ -156,7 +156,7 @@ function __construct_extension_f(prob::AbstractNonlinearProblem; alias_u0::Bool

𝐅 = if force_oop === True && applicable(𝐟, u0, u0)
_resid = resid isa Number ? [resid] : _vec(resid)
du = _vec(__similar(_resid))
du = _vec(zero(_resid))
@closure u -> begin
𝐟(du, u)
return du
Expand Down
2 changes: 1 addition & 1 deletion src/internal/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ function JacobianCache(prob, alg, f::F, fu_, u, p; stats, autodiff = nothing,
__similar(fu, promote_type(eltype(fu), eltype(u)), length(fu), length(u)) :
copy(f.jac_prototype)
elseif f.jac_prototype === nothing
__init_bigfloat_array!!(init_jacobian(
zero(init_jacobian(
jac_cache; preserve_immutable = Val(true)))
else
f.jac_prototype
Expand Down
16 changes: 8 additions & 8 deletions src/internal/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
@closure (v, u, p) -> auto_vecjac(uf, u, v)
elseif vjp_autodiff isa AutoFiniteDiff
if iip
cache1 = __similar(fu)
cache2 = __similar(fu)
cache1 = zero(fu)
cache2 = zero(fu)
@closure (Jv, v, u, p) -> num_vecjac!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> num_vecjac(uf, __mutable(u), v)
Expand Down Expand Up @@ -106,17 +106,17 @@ function JacobianOperator(prob::AbstractNonlinearProblem, fu, u; jvp_autodiff =
if iip
# FIXME: Technically we should propagate the tag but ignoring that for now
cache1 = Dual{typeof(ForwardDiff.Tag(uf, eltype(u))), eltype(u),
1}.(__similar(u), ForwardDiff.Partials.(tuple.(u)))
1}.(zero(u), ForwardDiff.Partials.(tuple.(u)))
cache2 = Dual{typeof(ForwardDiff.Tag(uf, eltype(fu))), eltype(fu),
1}.(__similar(fu), ForwardDiff.Partials.(tuple.(fu)))
1}.(zero(fu), ForwardDiff.Partials.(tuple.(fu)))
@closure (Jv, v, u, p) -> auto_jacvec!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> auto_jacvec(uf, u, v)
end
elseif jvp_autodiff isa AutoFiniteDiff
if iip
cache1 = __similar(fu)
cache2 = __similar(u)
cache1 = zero(fu)
cache2 = zero(u)
@closure (Jv, v, u, p) -> num_jacvec!(Jv, uf, u, v, cache1, cache2)
else
@closure (v, u, p) -> num_jacvec(uf, u, v)
Expand Down Expand Up @@ -162,15 +162,15 @@ end
function (op::JacobianOperator{vjp, iip})(v, u, p) where {vjp, iip}
if vjp
if iip
res = __similar(op.output_cache)
res = zero(op.output_cache)
op.vjp_op(res, v, u, p)
return res
else
return op.vjp_op(v, u, p)
end
else
if iip
res = __similar(op.output_cache)
res = zero(op.output_cache)
op.jvp_op(res, v, u, p)
return res
else
Expand Down
10 changes: 1 addition & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,5 @@ end

function __similar(x, args...; kwargs...)
y = similar(x, args...; kwargs...)
return __init_bigfloat_array!!(y)
end

function __init_bigfloat_array!!(x)
if ArrayInterface.can_setindex(x)
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
return x
end
return x
return zero(y)
end

0 comments on commit 86e6720

Please sign in to comment.