Skip to content

Commit

Permalink
Merge branch 'courant_improvements' into physics_package
Browse files Browse the repository at this point in the history
  • Loading branch information
tommbendall committed Aug 29, 2023
2 parents 5a87322 + ea8c497 commit 6dfb8d1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 13 deletions.
57 changes: 49 additions & 8 deletions gusto/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from firedrake import op2, assemble, dot, dx, Function, sqrt, \
TestFunction, TrialFunction, Constant, grad, inner, curl, \
LinearVariationalProblem, LinearVariationalSolver, FacetNormal, \
ds_b, ds_v, ds_t, dS_v, div, avg, jump, pi, \
ds_b, ds_v, ds_t, dS_h, dS_v, ds, dS, div, avg, jump, pi, \
TensorFunctionSpace, SpatialCoordinate, as_vector, \
Projector, Interpolator
from firedrake.assign import Assigner
Expand Down Expand Up @@ -235,15 +235,17 @@ def __call__(self):
class CourantNumber(DiagnosticField):
"""Dimensionless Courant number diagnostic field."""
name = "CourantNumber"

def __init__(self, velocity='u', name=None, to_dump=True, space=None,
method='interpolate', required_fields=()):
def __init__(self, velocity='u', component='whole', name=None, to_dump=True,

Check failure on line 238 in gusto/diagnostics.py

View workflow job for this annotation

GitHub Actions / Run linter

E301

gusto/diagnostics.py:238:5: E301 expected 1 blank line, found 0
space=None, method='interpolate', required_fields=()):
"""
Args:
velocity (str or :class:`ufl.Expr`, optional): the velocity field to
take the Courant number of. Can be a string referring to an
existing field, or an expression. If it is an expression, the
name argument is required. Defaults to 'u'.
component (str, optional): the component of the velocity to use for
calculating the Courant number. Valid values are "whole",
"horizontal" or "vertical". Defaults to "whole".
name (str, optional): the name to append to "CourantNumber" to form
the name of this diagnostic. This argument must be provided if
the velocity is an expression (rather than a string). Defaults
Expand All @@ -260,6 +262,11 @@ def __init__(self, velocity='u', name=None, to_dump=True, space=None,
are required for the computation of this diagnostic field.
Defaults to ().
"""
if component not in ["whole", "horizontal", "vertical"]:
raise ValueError(f'component arg {component} not valid. Allowed '
+ 'values are "whole", "horizontal" and "vertical"')
self.component = component

# Work out whether to take Courant number from field or expression
if type(velocity) is str:
# Default name should just be CourantNumber
Expand All @@ -269,6 +276,8 @@ def __init__(self, velocity='u', name=None, to_dump=True, space=None,
self.name = 'CourantNumber_'+velocity
else:
self.name = 'CourantNumber_'+name
if component != 'whole':
self.name += '_'+component
else:
if name is None:
raise ValueError('CourantNumber diagnostic: if provided '
Expand All @@ -291,21 +300,53 @@ def setup(self, domain, state_fields):
state_fields (:class:`StateFields`): the model's field container.
"""

# set up area computation
V = domain.spaces("DG0", "DG", 0)
test = TestFunction(V)
self.area = Function(V)
assemble(test*dx, tensor=self.area)
cell_volume = Function(V)
self.cell_flux = Function(V)

# Calculate cell volumes
One = Function(V).assign(1)
assemble(One*test*dx, tensor=cell_volume)

# Get the velocity that is being used
if type(self.velocity) is str:
u = state_fields(self.velocity)
else:
u = self.velocity

self.expr = sqrt(dot(u, u))/sqrt(self.area)*domain.dt
# Determine the component of the velocity
if self.component == "whole":
u_expr = u
elif self.component == "vertical":
u_expr = dot(u, domain.k)*domain.k
elif self.component == "horizontal":
u_expr = u - dot(u, domain.k)*domain.k

# Work out which facet integrals to use
if domain.mesh.extruded:
dS_calc = dS_v + dS_h
ds_calc = ds_v + ds_t + ds_b
else:
dS_calc = dS
ds_calc = ds

# Set up form for DG flux
n = FacetNormal(domain.mesh)
un = 0.5*(inner(-u_expr, n) + abs(inner(-u_expr, n)))
self.cell_flux_form = 2*avg(un*test)*dS_calc + un*test*ds_calc

# Final Courant number expression
self.expr = self.cell_flux *domain.dt / cell_volume

Check failure on line 340 in gusto/diagnostics.py

View workflow job for this annotation

GitHub Actions / Run linter

E225

gusto/diagnostics.py:340:37: E225 missing whitespace around operator

super().setup(domain, state_fields)

def compute(self):
"""Compute the diagnostic field from the current state."""

assemble(self.cell_flux_form, tensor=self.cell_flux)
super().compute()


class Gradient(DiagnosticField):
"""Diagnostic for computing the gradient of fields."""
Expand Down
18 changes: 14 additions & 4 deletions gusto/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,18 @@ def log_parameters(self, equation):
logger.info("Physical parameters that take non-default values:")
logger.info(", ".join("%s: %s" % (k, float(v)) for (k, v) in vars(equation.parameters).items()))

def setup_log_courant(self, state_fields, name='u', expression=None):
def setup_log_courant(self, state_fields, name='u', component="whole",
expression=None):
"""
Sets up Courant number diagnostics to be logged.
Args:
state_fields (:class:`StateFields`): the model's field container.
name (str, optional): the name of the field to log the Courant
number of. Defaults to 'u'.
component (str, optional): the component of the velocity to use for
calculating the Courant number. Valid values are "whole",
"horizontal" or "vertical". Defaults to "whole".
expression (:class:`ufl.Expr`, optional): expression of velocity
field to take Courant number of. Defaults to None, in which case
the "name" argument must correspond to an existing field.
Expand All @@ -258,29 +262,35 @@ def setup_log_courant(self, state_fields, name='u', expression=None):
# Set up diagnostic if it hasn't already been
if courant_name not in diagnostic_names and 'u' in state_fields._field_names:
if expression is None:
diagnostic = CourantNumber(to_dump=False)
diagnostic = CourantNumber(to_dump=False, component=component)
elif expression is not None:
diagnostic = CourantNumber(velocity=expression, name=courant_name, to_dump=False)
diagnostic = CourantNumber(velocity=expression, component=component,
name=courant_name, to_dump=False)

self.diagnostic_fields.append(diagnostic)
diagnostic.setup(self.domain, state_fields)
self.diagnostics.register(diagnostic.name)

def log_courant(self, state_fields, name='u', message=None):
def log_courant(self, state_fields, name='u', component="whole", message=None):
"""
Logs the maximum Courant number value.
Args:
state_fields (:class:`StateFields`): the model's field container.
name (str, optional): the name of the field to log the Courant
number of. Defaults to 'u'.
component (str, optional): the component of the velocity to use for
calculating the Courant number. Valid values are "whole",
"horizontal" or "vertical". Defaults to "whole".
message (str, optional): an extra message to be logged. Defaults to
None.
"""

if self.output.log_courant and 'u' in state_fields._field_names:
diagnostic_names = [diagnostic.name for diagnostic in self.diagnostic_fields]
courant_name = 'CourantNumber' if name == 'u' else 'CourantNumber_'+name
if component != 'whole':
courant_name += '_'+component
courant_idx = diagnostic_names.index(courant_name)
courant_diagnostic = self.diagnostic_fields[courant_idx]
courant_diagnostic.compute()
Expand Down
8 changes: 7 additions & 1 deletion gusto/timeloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def run(self, t, tmax, pick_up=False):
# Set up diagnostics, which may set up some fields necessary to pick up
self.io.setup_diagnostics(self.fields)
self.io.setup_log_courant(self.fields)
if self.equation.domain.mesh.extruded:
self.io.setup_log_courant(self.fields, component='horizontal')
self.io.setup_log_courant(self.fields, component='vertical')
if self.transporting_velocity != "prognostic":
self.io.setup_log_courant(self.fields, name='transporting_velocity',
expression=self.transporting_velocity)
Expand All @@ -176,6 +179,9 @@ def run(self, t, tmax, pick_up=False):
self.x.update()

self.io.log_courant(self.fields)
if self.equation.domain.mesh.extruded:
self.io.log_courant(self.fields, component='horizontal', message='horizontal')
self.io.log_courant(self.fields, component='vertical', message='vertical')

self.timestep()

Expand Down Expand Up @@ -558,7 +564,7 @@ def timestep(self):

with timed_stage("Transport"):
self.io.log_courant(self.fields, 'transporting_velocity',
f'transporting velocity, outer iteration {k}')
message=f'transporting velocity, outer iteration {k}')
for name, scheme in self.active_transport:
# transports a field from xstar and puts result in xp
scheme.apply(xp(name), xstar(name))
Expand Down

0 comments on commit 6dfb8d1

Please sign in to comment.