Skip to content

Commit

Permalink
implement simultaneous transport with SIQN
Browse files Browse the repository at this point in the history
  • Loading branch information
ta440 committed Nov 28, 2024
1 parent e138b5e commit bcde555
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 27 deletions.
7 changes: 4 additions & 3 deletions gusto/equations/prognostic_equations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,13 @@ def add_tracers_to_prognostics(self, domain, active_tracers):
name of the active tracer.
"""

# Check if there are any conservatively transported tracers.
# If so, ensure that the reference density is indexed before this tracer.
# If there are any conservatively transported tracers, ensure
# that the reference density, if it is also an active tracer,
# is indexed earlier.
for i in range(len(active_tracers) - 1):
tracer = active_tracers[i]
if tracer.transport_eqn == TransportEquationType.tracer_conservative:
ref_density = next(x for x in active_tracers if x.name == tracer.density_name)
ref_density = next((x for x in active_tracers if x.name == tracer.density_name), tracer)
j = active_tracers.index(ref_density)
if j > i:
# Swap the indices of the tracer and the reference density
Expand Down
68 changes: 57 additions & 11 deletions gusto/time_discretisation/time_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,22 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.residual = equation.residual

if self.field_name is not None and hasattr(equation, "field_names"):
self.idx = equation.field_names.index(self.field_name)
self.fs = equation.spaces[self.idx]
self.residual = self.residual.label_map(
lambda t: t.get(prognostic) == self.field_name,
lambda t: Term(
split_form(t.form)[self.idx].form,
t.labels),
drop)
if isinstance(self.field_name, list):
# Multiple fields are being solved for simultaneously.
# This enables conservative transport to be implemented with SIQN.
# Use the full mixed space for self.fs, with the
# field_name, residual, and BCs being set up later.
self.fs = equation.function_space
self.idx = None
else:
self.idx = equation.field_names.index(self.field_name)
self.fs = equation.spaces[self.idx]
self.residual = self.residual.label_map(
lambda t: t.get(prognostic) == self.field_name,
lambda t: Term(
split_form(t.form)[self.idx].form,
t.labels),
drop)

else:
self.field_name = equation.field_name
Expand All @@ -152,6 +160,34 @@ def setup(self, equation, apply_bcs=True, *active_labels):
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, *active_labels)),
map_if_false=drop)
if isinstance(self.field_name, list):
# Multiple fields are being solved for simultaneously.
# Keep all time derivative terms:
residual = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_false=drop)

# Only keep active labels for prognostics in the list:
for subname in self.field_name:
field_residual = self.residual.label_map(
lambda t: t.get(prognostic) == subname,
map_if_false=drop)

residual += field_residual.label_map(
lambda t: t.has_label(*active_labels),
map_if_false=drop)

self.residual = residual
else:
self.residual = self.residual.label_map(
lambda t: any(t.has_label(time_derivative, *active_labels)),
map_if_false=drop)

# Set the field name if using simultaneous transport.
if isinstance(self.field_name, list):
self.field_name = equation.field_name

bcs = equation.bcs[self.field_name]

self.evaluate_source = []
self.physics_names = []
Expand Down Expand Up @@ -244,9 +280,19 @@ def setup(self, equation, apply_bcs=True, *active_labels):
if not apply_bcs:
self.bcs = None
elif self.wrapper is not None:
# Transfer boundary conditions onto test function space
self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain)
for bc in bcs]
if self.wrapper_name == 'mixed_options':
# Define new Dirichlet BCs on the wrapper-modified
# mixed function space.
self.bcs = []
for idx, field_name in enumerate(self.equation.field_names):
for bc in equation.bcs[field_name]:
self.bcs.append(DirichletBC(self.fs.sub(idx),
bc.function_arg,
bc.sub_domain))
else:
# Transfer boundary conditions onto test function space
self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain)
for bc in bcs]
else:
self.bcs = bcs

Expand Down
62 changes: 49 additions & 13 deletions gusto/timestepping/semi_implicit_quasi_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.reference_update_freq = reference_update_freq
self.to_update_ref_profile = False

# Flag for if we have simultaneous transport
self.simult = False

# default is to not offcentre transporting velocity but if it
# is offcentred then use the same value as alpha
self.alpha_u = Constant(alpha) if off_centred_u else Constant(0.5)
Expand Down Expand Up @@ -148,15 +151,30 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods,
self.transported_fields = []
for scheme in transport_schemes:
assert scheme.nlevels == 1, "multilevel schemes not supported as part of this timestepping loop"
assert scheme.field_name in equation_set.field_names
self.active_transport.append((scheme.field_name, scheme))
self.transported_fields.append(scheme.field_name)
# Check that there is a corresponding transport method
method_found = False
for method in spatial_methods:
if scheme.field_name == method.variable and method.term_label == transport:
method_found = True
assert method_found, f'No transport method found for variable {scheme.field_name}'
if isinstance(scheme.field_name, list):
# This means that multiple fields are being transported simultaneously
self.simult = True
for subfield in scheme.field_name:
assert subfield in equation_set.field_names

# Check that there is a corresponding transport method for
# each field in the list
method_found = False
for method in spatial_methods:
if subfield == method.variable and method.term_label == transport:
method_found = True
assert method_found, f'No transport method found for variable {scheme.field_name}'
self.active_transport.append((scheme.field_name, scheme))
else:
assert scheme.field_name in equation_set.field_names

# Check that there is a corresponding transport method
method_found = False
for method in spatial_methods:
if scheme.field_name == method.variable and method.term_label == transport:
method_found = True
self.active_transport.append((scheme.field_name, scheme))
assert method_found, f'No transport method found for variable {scheme.field_name}'

self.diffusion_schemes = []
if diffusion_schemes is not None:
Expand Down Expand Up @@ -240,7 +258,11 @@ def transporting_velocity(self):
def setup_fields(self):
"""Sets up time levels n, star, p and np1"""
self.x = TimeLevelFields(self.equation, 1)
self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast"))
if self.simult is True:
# If there is any simultaneous transport, add a temporary field:
self.x.add_fields(self.equation, levels=("star", "p", "simult", "after_slow", "after_fast"))
else:
self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast"))
for aux_eqn, _ in self.auxiliary_equations_and_schemes:
self.x.add_fields(aux_eqn)
# Prescribed fields for auxiliary eqns should come from prognostics of
Expand Down Expand Up @@ -339,6 +361,9 @@ def timestep(self):
xrhs_phys = self.xrhs_phys
dy = self.dy

if self.simult:
xsimult = self.x.simult

# Update reference profiles --------------------------------------------
self.update_reference_profiles()

Expand Down Expand Up @@ -368,9 +393,20 @@ def timestep(self):
self.io.log_courant(self.fields, 'transporting_velocity',
message=f'transporting velocity, outer iteration {outer}')
for name, scheme in self.active_transport:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a field from xstar and puts result in xp
self.transport_field(name, scheme, xstar, xp)
if isinstance(name, list):
# Transport multiple fields from xstar simultaneously.
# We transport the mixed function space from xstar to xsimult, then
# extract the updated fields and pass them to xp; this avoids overwriting
# any previously transported fields.
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: '
+ f'Simultaneous transport of {name}')
self.transport_field(self.field_name, scheme, xstar, xsimult)
for field_name in name:
xp(field_name).assign(xsimult(field_name))
else:
logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}')
# transports a field from xstar and puts result in xp
self.transport_field(name, scheme, xstar, xp)

# Fast physics -----------------------------------------------------
x_after_fast(self.field_name).assign(xp(self.field_name))
Expand Down

0 comments on commit bcde555

Please sign in to comment.