diff --git a/gusto/equations/prognostic_equations.py b/gusto/equations/prognostic_equations.py index 7369f5ee..b2df68e2 100644 --- a/gusto/equations/prognostic_equations.py +++ b/gusto/equations/prognostic_equations.py @@ -305,12 +305,13 @@ def add_tracers_to_prognostics(self, domain, active_tracers): name of the active tracer. """ - # Check if there are any conservatively transported tracers. - # If so, ensure that the reference density is indexed before this tracer. + # If there are any conservatively transported tracers, ensure + # that the reference density, if it is also an active tracer, + # is indexed earlier. for i in range(len(active_tracers) - 1): tracer = active_tracers[i] if tracer.transport_eqn == TransportEquationType.tracer_conservative: - ref_density = next(x for x in active_tracers if x.name == tracer.density_name) + ref_density = next((x for x in active_tracers if x.name == tracer.density_name), tracer) j = active_tracers.index(ref_density) if j > i: # Swap the indices of the tracer and the reference density diff --git a/gusto/time_discretisation/time_discretisation.py b/gusto/time_discretisation/time_discretisation.py index df108a61..35226e33 100644 --- a/gusto/time_discretisation/time_discretisation.py +++ b/gusto/time_discretisation/time_discretisation.py @@ -132,14 +132,22 @@ def setup(self, equation, apply_bcs=True, *active_labels): self.residual = equation.residual if self.field_name is not None and hasattr(equation, "field_names"): - self.idx = equation.field_names.index(self.field_name) - self.fs = equation.spaces[self.idx] - self.residual = self.residual.label_map( - lambda t: t.get(prognostic) == self.field_name, - lambda t: Term( - split_form(t.form)[self.idx].form, - t.labels), - drop) + if isinstance(self.field_name, list): + # Multiple fields are being solved for simultaneously. + # This enables conservative transport to be implemented with SIQN. + # Use the full mixed space for self.fs, with the + # field_name, residual, and BCs being set up later. + self.fs = equation.function_space + self.idx = None + else: + self.idx = equation.field_names.index(self.field_name) + self.fs = equation.spaces[self.idx] + self.residual = self.residual.label_map( + lambda t: t.get(prognostic) == self.field_name, + lambda t: Term( + split_form(t.form)[self.idx].form, + t.labels), + drop) else: self.field_name = equation.field_name @@ -152,6 +160,34 @@ def setup(self, equation, apply_bcs=True, *active_labels): self.residual = self.residual.label_map( lambda t: any(t.has_label(time_derivative, *active_labels)), map_if_false=drop) + if isinstance(self.field_name, list): + # Multiple fields are being solved for simultaneously. + # Keep all time derivative terms: + residual = self.residual.label_map( + lambda t: t.has_label(time_derivative), + map_if_false=drop) + + # Only keep active labels for prognostics in the list: + for subname in self.field_name: + field_residual = self.residual.label_map( + lambda t: t.get(prognostic) == subname, + map_if_false=drop) + + residual += field_residual.label_map( + lambda t: t.has_label(*active_labels), + map_if_false=drop) + + self.residual = residual + else: + self.residual = self.residual.label_map( + lambda t: any(t.has_label(time_derivative, *active_labels)), + map_if_false=drop) + + # Set the field name if using simultaneous transport. + if isinstance(self.field_name, list): + self.field_name = equation.field_name + + bcs = equation.bcs[self.field_name] self.evaluate_source = [] self.physics_names = [] @@ -244,9 +280,19 @@ def setup(self, equation, apply_bcs=True, *active_labels): if not apply_bcs: self.bcs = None elif self.wrapper is not None: - # Transfer boundary conditions onto test function space - self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain) - for bc in bcs] + if self.wrapper_name == 'mixed_options': + # Define new Dirichlet BCs on the wrapper-modified + # mixed function space. + self.bcs = [] + for idx, field_name in enumerate(self.equation.field_names): + for bc in equation.bcs[field_name]: + self.bcs.append(DirichletBC(self.fs.sub(idx), + bc.function_arg, + bc.sub_domain)) + else: + # Transfer boundary conditions onto test function space + self.bcs = [DirichletBC(self.fs, bc.function_arg, bc.sub_domain) + for bc in bcs] else: self.bcs = bcs diff --git a/gusto/timestepping/semi_implicit_quasi_newton.py b/gusto/timestepping/semi_implicit_quasi_newton.py index 76c517a7..46a575b9 100644 --- a/gusto/timestepping/semi_implicit_quasi_newton.py +++ b/gusto/timestepping/semi_implicit_quasi_newton.py @@ -115,6 +115,9 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.reference_update_freq = reference_update_freq self.to_update_ref_profile = False + # Flag for if we have simultaneous transport + self.simult = False + # default is to not offcentre transporting velocity but if it # is offcentred then use the same value as alpha self.alpha_u = Constant(alpha) if off_centred_u else Constant(0.5) @@ -148,15 +151,30 @@ def __init__(self, equation_set, io, transport_schemes, spatial_methods, self.transported_fields = [] for scheme in transport_schemes: assert scheme.nlevels == 1, "multilevel schemes not supported as part of this timestepping loop" - assert scheme.field_name in equation_set.field_names - self.active_transport.append((scheme.field_name, scheme)) - self.transported_fields.append(scheme.field_name) - # Check that there is a corresponding transport method - method_found = False - for method in spatial_methods: - if scheme.field_name == method.variable and method.term_label == transport: - method_found = True - assert method_found, f'No transport method found for variable {scheme.field_name}' + if isinstance(scheme.field_name, list): + # This means that multiple fields are being transported simultaneously + self.simult = True + for subfield in scheme.field_name: + assert subfield in equation_set.field_names + + # Check that there is a corresponding transport method for + # each field in the list + method_found = False + for method in spatial_methods: + if subfield == method.variable and method.term_label == transport: + method_found = True + assert method_found, f'No transport method found for variable {scheme.field_name}' + self.active_transport.append((scheme.field_name, scheme)) + else: + assert scheme.field_name in equation_set.field_names + + # Check that there is a corresponding transport method + method_found = False + for method in spatial_methods: + if scheme.field_name == method.variable and method.term_label == transport: + method_found = True + self.active_transport.append((scheme.field_name, scheme)) + assert method_found, f'No transport method found for variable {scheme.field_name}' self.diffusion_schemes = [] if diffusion_schemes is not None: @@ -240,7 +258,11 @@ def transporting_velocity(self): def setup_fields(self): """Sets up time levels n, star, p and np1""" self.x = TimeLevelFields(self.equation, 1) - self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast")) + if self.simult is True: + # If there is any simultaneous transport, add a temporary field: + self.x.add_fields(self.equation, levels=("star", "p", "simult", "after_slow", "after_fast")) + else: + self.x.add_fields(self.equation, levels=("star", "p", "after_slow", "after_fast")) for aux_eqn, _ in self.auxiliary_equations_and_schemes: self.x.add_fields(aux_eqn) # Prescribed fields for auxiliary eqns should come from prognostics of @@ -339,6 +361,9 @@ def timestep(self): xrhs_phys = self.xrhs_phys dy = self.dy + if self.simult: + xsimult = self.x.simult + # Update reference profiles -------------------------------------------- self.update_reference_profiles() @@ -368,9 +393,20 @@ def timestep(self): self.io.log_courant(self.fields, 'transporting_velocity', message=f'transporting velocity, outer iteration {outer}') for name, scheme in self.active_transport: - logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}') - # transports a field from xstar and puts result in xp - self.transport_field(name, scheme, xstar, xp) + if isinstance(name, list): + # Transport multiple fields from xstar simultaneously. + # We transport the mixed function space from xstar to xsimult, then + # extract the updated fields and pass them to xp; this avoids overwriting + # any previously transported fields. + logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: ' + + f'Simultaneous transport of {name}') + self.transport_field(self.field_name, scheme, xstar, xsimult) + for field_name in name: + xp(field_name).assign(xsimult(field_name)) + else: + logger.info(f'Semi-implicit Quasi Newton: Transport {outer}: {name}') + # transports a field from xstar and puts result in xp + self.transport_field(name, scheme, xstar, xp) # Fast physics ----------------------------------------------------- x_after_fast(self.field_name).assign(xp(self.field_name))