Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith authored Sep 25, 2024
1 parent 0e8c353 commit 80b7b2b
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions src/internal/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ function LinearSolverCache(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_ = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_, kwargs...)

weight = __init_ones(u_fixed)
if __hasfield(alg, Val(:precs))
precs = alg.precs
Pl_, Pr_ = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
else
precs, Pl_, Pr_ = nothing, nothing, nothing
end
Pl, Pr = __wrapprecs(Pl_, Pr_, weight)
Pl, Pr = __wrapprecs(Pl_, Pr_, u)

# Unalias here, we will later use these as caches
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
Expand Down Expand Up @@ -128,10 +127,8 @@ function (cache::LinearSolverCache)(;
b !== nothing && (cache.lincache.b = b)
linu !== nothing && __set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl isa ComposePreconditioner ? cache.lincache.Pl.outer :
cache.lincache.Pl
Prprev = cache.lincache.Pr isa ComposePreconditioner ? cache.lincache.Pr.outer :
cache.lincache.Pr
Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
_Pl, _Pr = nothing, nothing
Expand All @@ -141,10 +138,7 @@ function (cache::LinearSolverCache)(;
end

if (_Pl !== nothing || _Pr !== nothing)
_weight = weight === nothing ?
(cache.lincache.Pr isa Diagonal ? cache.lincache.Pr.diag :
cache.lincache.Pr.inner.diag) : weight
Pl, Pr = __wrapprecs(_Pl, _Pr, _weight)
Pl, Pr = __wrapprecs(_Pl, _Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end
Expand Down Expand Up @@ -242,9 +236,9 @@ function __set_lincache_A(lincache, new_A)
end
end

function __wrapprecs(_Pl, _Pr, weight)
Pl = _Pl !== nothing ?= _Pl : IdentityOperator(length(weight))
Pr = _Pr !== nothing ? _Pr : IdentityOperator(length(weight))
function __wrapprecs(_Pl, _Pr, u)
Pl = _Pl !== nothing ?= _Pl : IdentityOperator(length(u))
Pr = _Pr !== nothing ? _Pr : IdentityOperator(length(u))
return Pl, Pr
end

Expand Down

0 comments on commit 80b7b2b

Please sign in to comment.