Skip to content

Commit

Permalink
refactor: simplify __wrapprecs (#465)
Browse files Browse the repository at this point in the history
* simplify `__wrapprecs`

previously the preconditioner being set was a very complicated `IdentityOperator`. Using a regular `IdentityOperator` means that solvers that don't support one of the preconditioners won't throw warnings since they mostly know that `IdentityOperator`s aren't real.

* fix

* fix: typo

* remove __init_ones

* remove imports

* Update src/NonlinearSolve.jl
  • Loading branch information
oscardssmith authored Sep 28, 2024
1 parent 736b4a3 commit d6d741b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 28 deletions.
6 changes: 3 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Sy
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A, AbstractFactorization,
using LinearSolve: LinearSolve, LUFactorization, QRFactorization,
needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!, recursivefill!
using RecursiveArrayTools: recursivecopy!
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
AbstractSciMLOperator, _unwrap_val, isinplace, NLStats
using SciMLJacobianOperators: AbstractJacobianOperator, JacobianOperator, VecJacOperator,
Expand Down
25 changes: 7 additions & 18 deletions src/internal/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ function LinearSolverCache(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_ = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_, kwargs...)

weight = __init_ones(u_fixed)
if __hasfield(alg, Val(:precs))
precs = alg.precs
Pl_, Pr_ = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
else
precs, Pl_, Pr_ = nothing, nothing, nothing
end
Pl, Pr = __wrapprecs(Pl_, Pr_, weight)
Pl, Pr = __wrapprecs(Pl_, Pr_, u)

# Unalias here, we will later use these as caches
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
Expand Down Expand Up @@ -128,10 +127,8 @@ function (cache::LinearSolverCache)(;
b !== nothing && (cache.lincache.b = b)
linu !== nothing && __set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl isa ComposePreconditioner ? cache.lincache.Pl.outer :
cache.lincache.Pl
Prprev = cache.lincache.Pr isa ComposePreconditioner ? cache.lincache.Pr.outer :
cache.lincache.Pr
Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
_Pl, _Pr = nothing, nothing
Expand All @@ -141,10 +138,7 @@ function (cache::LinearSolverCache)(;
end

if (_Pl !== nothing || _Pr !== nothing)
_weight = weight === nothing ?
(cache.lincache.Pr isa Diagonal ? cache.lincache.Pr.diag :
cache.lincache.Pr.inner.diag) : weight
Pl, Pr = __wrapprecs(_Pl, _Pr, _weight)
Pl, Pr = __wrapprecs(_Pl, _Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end
Expand Down Expand Up @@ -242,14 +236,9 @@ function __set_lincache_A(lincache, new_A)
end
end

@inline function __wrapprecs(_Pl, _Pr, weight)
Pl = _Pl !== nothing ?
ComposePreconditioner(InvPreconditioner(Diagonal(_vec(weight))), _Pl) :
InvPreconditioner(Diagonal(_vec(weight)))

Pr = _Pr !== nothing ? ComposePreconditioner(Diagonal(_vec(weight)), _Pr) :
Diagonal(_vec(weight))

function __wrapprecs(_Pl, _Pr, u)
Pl = _Pl !== nothing ? _Pl : IdentityOperator(length(u))
Pr = _Pr !== nothing ? _Pr : IdentityOperator(length(u))
return Pl, Pr
end

Expand Down
7 changes: 0 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ end
@inline _restructure(y, x) = restructure(y, x)
@inline _restructure(y::Number, x::Number) = x

@inline function __init_ones(x)
w = similar(x)
recursivefill!(w, true)
return w
end
@inline __init_ones(x::StaticArray) = ones(typeof(x))

@inline __maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x
@inline function __maybe_unaliased(x::AbstractArray, alias::Bool)
# Spend time coping iff we will mutate the array
Expand Down

0 comments on commit d6d741b

Please sign in to comment.