Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow linsolve to be \ #358

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "3.5.0"
version = "3.5.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 2 additions & 1 deletion src/internal/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ use this functionality unless it can't be avoided (like in [`LevenbergMarquardt`

# Extension Algorithm Helpers
function __test_termination_condition(termination_condition, alg)
termination_condition !== AbsNormTerminationMode && termination_condition !== nothing &&
!(termination_condition isa AbsNormTerminationMode) &&
termination_condition !== nothing &&
error("`$(alg)` does not support termination conditions!")
end

Expand Down
16 changes: 6 additions & 10 deletions src/internal/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,11 @@
cache.nfactors = 0
end

@inline function LinearSolverCache(alg, linsolve, A::Number, b::Number, u; kwargs...)
return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0)
end
@inline function LinearSolverCache(alg, ::Nothing, A::SMatrix, b, u; kwargs...)
# Default handling for SArrays caching in LinearSolve is not the best. Override it here
return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0)
end
@inline function LinearSolverCache(alg, linsolve, A::Diagonal, b, u; kwargs...)
return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0)
end
function LinearSolverCache(alg, linsolve, A, b, u; kwargs...)
if (A isa Number && b isa Number) || (linsolve === nothing && A isa SMatrix) ||
(A isa Diagonal) || (linsolve isa typeof(\))
return LinearSolverCache(nothing, nothing, A, b, nothing, 0, 0)
end
@bb b_ = copy(b)
@bb u_ = copy(u)
linprob = LinearProblem(A, b_; u0 = u_, kwargs...)
Expand Down Expand Up @@ -193,3 +187,5 @@
@inline __needs_square_A(::Nothing, ::Number) = false
@inline __needs_square_A(::Nothing, _) = false
@inline __needs_square_A(linsolve, _) = LinearSolve.needs_square_A(linsolve)
@inline __needs_square_A(::typeof(\), _) = false
@inline __needs_square_A(::typeof(\), ::Number) = false # Ambiguity Fix

Check warning on line 191 in src/internal/linear_solve.jl

View check run for this annotation

Codecov / codecov/patch

src/internal/linear_solve.jl#L191

Added line #L191 was not covered by tests
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ end
end

@inline __needs_concrete_A(::Nothing) = false
@inline __needs_concrete_A(::typeof(\)) = true
@inline __needs_concrete_A(linsolve) = needs_concrete_A(linsolve)

@inline __maybe_mutable(x, ::AutoSparseEnzyme) = __mutable(x)
Expand Down
6 changes: 3 additions & 3 deletions test/core/rootfind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ const TERMINATION_CONDITIONS = [
]

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES(), \)
ad isa AutoZygote && continue
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
Expand Down Expand Up @@ -139,7 +139,7 @@ end
RadiusUpdateSchemes.NLsolve, RadiusUpdateSchemes.Hei, RadiusUpdateSchemes.Yuan,
RadiusUpdateSchemes.Fan, RadiusUpdateSchemes.Bastin]
u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES()]
linear_solvers = [nothing, LUFactorization(), KrylovJL_GMRES(), \]

@testset "[OOP] u0: $(typeof(u0)) radius_update_scheme: $(radius_update_scheme) linear_solver: $(linsolve)" for u0 in u0s,
radius_update_scheme in radius_update_schemes, linsolve in linear_solvers
Expand Down Expand Up @@ -471,7 +471,7 @@ end
precs = [NonlinearSolve.DEFAULT_PRECS, :Random]

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES(), \)
ad isa AutoZygote && continue
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
Expand Down
Loading