Skip to content

Commit

Permalink
fix the matrices in the linear solvers between reference state updates
Browse files Browse the repository at this point in the history
  • Loading branch information
JHopeCollins committed Nov 6, 2024
1 parent fed1bdc commit 50dd3ba
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions gusto/solvers/linear_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ def L_tr(f):
rhobar_avg = Function(Vtrace)
exnerbar_avg = Function(Vtrace)

rho_avg_prb = LinearVariationalProblem(a_tr, L_tr(rhobar), rhobar_avg)
exner_avg_prb = LinearVariationalProblem(a_tr, L_tr(exnerbar), exnerbar_avg)
rho_avg_prb = LinearVariationalProblem(a_tr, L_tr(rhobar), rhobar_avg,
constant_jacobian=True)
exner_avg_prb = LinearVariationalProblem(a_tr, L_tr(exnerbar), exnerbar_avg,
constant_jacobian=True)

self.rho_avg_solver = LinearVariationalSolver(rho_avg_prb,
solver_parameters=cg_ilu_parameters,
Expand Down Expand Up @@ -334,7 +336,8 @@ def L_tr(f):
# Function for the hybridized solutions
self.urhol0 = Function(M)

hybridized_prb = LinearVariationalProblem(aeqn, Leqn, self.urhol0)
hybridized_prb = LinearVariationalProblem(aeqn, Leqn, self.urhol0,
constant_jacobian=True)
hybridized_solver = LinearVariationalSolver(hybridized_prb,
solver_parameters=self.solver_parameters,
options_prefix='ImplicitSolver')
Expand All @@ -360,7 +363,8 @@ def L_tr(f):
theta_eqn = gamma*(theta - theta_in
+ dot(k, self.u_hdiv)*dot(k, grad(thetabar))*beta_t)*dx

theta_problem = LinearVariationalProblem(lhs(theta_eqn), rhs(theta_eqn), self.theta)
theta_problem = LinearVariationalProblem(lhs(theta_eqn), rhs(theta_eqn), self.theta,
constant_jacobian=True)
self.theta_solver = LinearVariationalSolver(theta_problem,
solver_parameters=cg_ilu_parameters,
options_prefix='thetabacksubstitution')
Expand Down Expand Up @@ -389,6 +393,9 @@ def update_reference_profiles(self):
logger.info('Compressible linear solver: Exner average solve')
self.exner_avg_solver.solve()

self.hybridized_solver.invalidate_jacobian()
self.theta_solver.invalidate_jacobian()

@timed_function("Gusto:LinearSolve")
def solve(self, xrhs, dy):
"""
Expand Down Expand Up @@ -532,7 +539,8 @@ def V(u):
bcs = [DirichletBC(M.sub(0), bc.function_arg, bc.sub_domain) for bc in self.equations.bcs['u']]

# Solver for u, p
up_problem = LinearVariationalProblem(aeqn, Leqn, self.up, bcs=bcs)
up_problem = LinearVariationalProblem(aeqn, Leqn, self.up, bcs=bcs,
constant_jacobian=True)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
Expand All @@ -555,7 +563,8 @@ def trace_nullsp(T):

b_problem = LinearVariationalProblem(lhs(b_eqn),
rhs(b_eqn),
self.b)
self.b,
constant_jacobian=True)
self.b_solver = LinearVariationalSolver(b_problem)

# Log residuals on hybridized solver
Expand Down Expand Up @@ -683,7 +692,8 @@ def _setup_solver(self):
bcs = [DirichletBC(M.sub(0), bc.function_arg, bc.sub_domain) for bc in self.equations.bcs['u']]

# Solver for u, D
uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs)
uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs,
constant_jacobian=True)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
Expand All @@ -704,7 +714,8 @@ def trace_nullsp(T):

b_problem = LinearVariationalProblem(lhs(b_eqn),
rhs(b_eqn),
self.b)
self.b,
constant_jacobian=True)
self.b_solver = LinearVariationalSolver(b_problem)

# Log residuals on hybridized solver
Expand Down Expand Up @@ -805,7 +816,8 @@ def __init__(self, equation, alpha):
bcs = [DirichletBC(W.sub(0), bc.function_arg, bc.sub_domain) for bc in equation.bcs['u']]
problem = LinearVariationalProblem(aeqn.form,
action(Leqn.form, self.xrhs),
self.dy, bcs=bcs)
self.dy, bcs=bcs,
constant_jacobian=True)

self.solver = LinearVariationalSolver(problem,
solver_parameters=self.solver_parameters,
Expand Down Expand Up @@ -899,7 +911,8 @@ def _setup_solver(self):
bcs = [DirichletBC(M.sub(0), bc.function_arg, bc.sub_domain) for bc in self.equations.bcs['u']]

# Solver for u, D
uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs)
uD_problem = LinearVariationalProblem(aeqn, Leqn, self.uD, bcs=bcs,
constant_jacobian=True)

# Provide callback for the nullspace of the trace system
def trace_nullsp(T):
Expand Down

0 comments on commit 50dd3ba

Please sign in to comment.