Skip to content

Commit

Permalink
Make __findmin type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 22, 2024
1 parent 3925219 commit e6ff3aa
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.1"
version = "3.8.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
27 changes: 20 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,29 @@ LazyArrays.applied_axes(::typeof(__zero), x) = axes(x)
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

@inline __findmin_caches(f, caches) = __findmin(f get_fu, caches)
@inline __findmin_caches(f::F, caches) where {F} = __findmin(f get_fu, caches)

Check warning on line 97 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97

Added line #L97 was not covered by tests
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
@inline function __findmin(f, x)
@generated function __findmin(f::F, x) where {F}
# JET shows dynamic dispatch if this is not written as a generated function
if F === typeof(DEFAULT_NORM)
return :(return __findmin_impl(Base.Fix1(maximum, abs), x))
end
return :(return __findmin_impl(f, x))

Check warning on line 104 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L104

Added line #L104 was not covered by tests
end
@inline @views function __findmin_impl(f::F, x) where {F}
idx = findfirst(Base.Fix2(!==, nothing), x)
# This is an internal function so we assume that inputs are consistent and there is
# atleast one non-`nothing` value
fx_idx = f(x[idx])
idx == length(x) && return fx_idx, idx
fmin = @closure xᵢ -> begin
xᵢ === nothing && return Inf
xᵢ === nothing && return oftype(fx_idx, Inf)
fx = f(xᵢ)
return ifelse(isnan(fx), Inf, fx)
return ifelse(isnan(fx), oftype(fx, Inf), fx)
end
return findmin(fmin, x)
x_min, x_min_idx = findmin(fmin, x[(idx + 1):length(x)])
x_min < fx_idx && return x_min, x_min_idx + idx
return fx_idx, idx

Check warning on line 119 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L119

Added line #L119 was not covered by tests
end

@inline __can_setindex(x) = can_setindex(x)
Expand All @@ -130,7 +143,7 @@ Statistics from the nonlinear equation solver about the solution process.
- nf: Number of function evaluations.
- njacs: Number of Jacobians created during the solve.
- nfactors: Number of factorzations of the jacobian required for the solve.
- nsolve: Number of linear solves `W\b` required for the solve.
- nsolve: Number of linear solves `W \\ b` required for the solve.
- nsteps: Total number of iterations for the nonlinear solver.
"""
struct ImmutableNLStats
Expand Down

0 comments on commit e6ff3aa

Please sign in to comment.