Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify __wrapprecs #465

Merged
merged 6 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ using LinearAlgebra: LinearAlgebra, ColumnNorm, Diagonal, I, LowerTriangular, Sy
UpperTriangular, axpy!, cond, diag, diagind, dot, issuccess, istril,
istriu, lu, mul!, norm, pinv, tril!, triu!
using LineSearches: LineSearches
using LinearSolve: LinearSolve, LUFactorization, QRFactorization, ComposePreconditioner,
InvPreconditioner, needs_concrete_A, AbstractFactorization,
using LinearSolve: LinearSolve, LUFactorization, QRFactorization,
needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!, recursivefill!
using RecursiveArrayTools: recursivecopy!,
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
using SciMLBase: AbstractNonlinearAlgorithm, JacobianWrapper, AbstractNonlinearProblem,
AbstractSciMLOperator, _unwrap_val, isinplace, NLStats
using SciMLJacobianOperators: AbstractJacobianOperator, JacobianOperator, VecJacOperator,
Expand Down
25 changes: 7 additions & 18 deletions src/internal/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ function LinearSolverCache(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_ = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_, kwargs...)

weight = __init_ones(u_fixed)
oscardssmith marked this conversation as resolved.
Show resolved Hide resolved
if __hasfield(alg, Val(:precs))
precs = alg.precs
Pl_, Pr_ = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
else
precs, Pl_, Pr_ = nothing, nothing, nothing
end
Pl, Pr = __wrapprecs(Pl_, Pr_, weight)
Pl, Pr = __wrapprecs(Pl_, Pr_, u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is creating an operator using only the length of u

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specifically, NonlinearSolve never actually used the weight variable.


# Unalias here, we will later use these as caches
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
Expand Down Expand Up @@ -128,10 +127,8 @@ function (cache::LinearSolverCache)(;
b !== nothing && (cache.lincache.b = b)
linu !== nothing && __set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl isa ComposePreconditioner ? cache.lincache.Pl.outer :
cache.lincache.Pl
Prprev = cache.lincache.Pr isa ComposePreconditioner ? cache.lincache.Pr.outer :
cache.lincache.Pr
Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
_Pl, _Pr = nothing, nothing
Expand All @@ -141,10 +138,7 @@ function (cache::LinearSolverCache)(;
end

if (_Pl !== nothing || _Pr !== nothing)
_weight = weight === nothing ?
(cache.lincache.Pr isa Diagonal ? cache.lincache.Pr.diag :
cache.lincache.Pr.inner.diag) : weight
Pl, Pr = __wrapprecs(_Pl, _Pr, _weight)
Pl, Pr = __wrapprecs(_Pl, _Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end
Expand Down Expand Up @@ -242,14 +236,9 @@ function __set_lincache_A(lincache, new_A)
end
end

@inline function __wrapprecs(_Pl, _Pr, weight)
Pl = _Pl !== nothing ?
ComposePreconditioner(InvPreconditioner(Diagonal(_vec(weight))), _Pl) :
InvPreconditioner(Diagonal(_vec(weight)))

Pr = _Pr !== nothing ? ComposePreconditioner(Diagonal(_vec(weight)), _Pr) :
Diagonal(_vec(weight))

function __wrapprecs(_Pl, _Pr, u)
Pl = _Pl !== nothing ? _Pl : IdentityOperator(length(u))
Pr = _Pr !== nothing ? _Pr : IdentityOperator(length(u))
return Pl, Pr
end

Expand Down
7 changes: 0 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,6 @@ end
@inline _restructure(y, x) = restructure(y, x)
@inline _restructure(y::Number, x::Number) = x

@inline function __init_ones(x)
w = similar(x)
recursivefill!(w, true)
return w
end
@inline __init_ones(x::StaticArray) = ones(typeof(x))

@inline __maybe_unaliased(x::Union{Number, SArray}, ::Bool) = x
@inline function __maybe_unaliased(x::AbstractArray, alias::Bool)
# Spend time coping iff we will mutate the array
Expand Down
Loading