Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make __findmin type stable #393

Merged
merged 1 commit into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.8.1"
version = "3.8.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
27 changes: 20 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,29 @@
@inline __is_complex(::Type{Complex}) = true
@inline __is_complex(::Type{T}) where {T} = false

@inline __findmin_caches(f, caches) = __findmin(f ∘ get_fu, caches)
@inline __findmin_caches(f::F, caches) where {F} = __findmin(f ∘ get_fu, caches)

Check warning on line 97 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L97

Added line #L97 was not covered by tests
# FIXME: DEFAULT_NORM makes an Array of NaNs not a NaN (atleast according to `isnan`)
@inline __findmin(::typeof(DEFAULT_NORM), x) = __findmin(Base.Fix1(maximum, abs), x)
@inline function __findmin(f, x)
@generated function __findmin(f::F, x) where {F}
# JET shows dynamic dispatch if this is not written as a generated function
if F === typeof(DEFAULT_NORM)
return :(return __findmin_impl(Base.Fix1(maximum, abs), x))
end
return :(return __findmin_impl(f, x))

Check warning on line 104 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L104

Added line #L104 was not covered by tests
end
@inline @views function __findmin_impl(f::F, x) where {F}
idx = findfirst(Base.Fix2(!==, nothing), x)
# This is an internal function so we assume that inputs are consistent and there is
# atleast one non-`nothing` value
fx_idx = f(x[idx])
idx == length(x) && return fx_idx, idx
fmin = @closure xᵢ -> begin
xᵢ === nothing && return Inf
xᵢ === nothing && return oftype(fx_idx, Inf)
fx = f(xᵢ)
return ifelse(isnan(fx), Inf, fx)
return ifelse(isnan(fx), oftype(fx, Inf), fx)
end
return findmin(fmin, x)
x_min, x_min_idx = findmin(fmin, x[(idx + 1):length(x)])
x_min < fx_idx && return x_min, x_min_idx + idx
return fx_idx, idx

Check warning on line 119 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L119

Added line #L119 was not covered by tests
end

@inline __can_setindex(x) = can_setindex(x)
Expand All @@ -130,7 +143,7 @@
- nf: Number of function evaluations.
- njacs: Number of Jacobians created during the solve.
- nfactors: Number of factorzations of the jacobian required for the solve.
- nsolve: Number of linear solves `W\b` required for the solve.
- nsolve: Number of linear solves `W \\ b` required for the solve.
- nsteps: Total number of iterations for the nonlinear solver.
"""
struct ImmutableNLStats
Expand Down
Loading