Skip to content

Commit

Permalink
add and implement augmentation with vorticity transport
Browse files Browse the repository at this point in the history
  • Loading branch information
tommbendall committed Nov 18, 2024
1 parent cdd2982 commit 50ef06b
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 36 deletions.
9 changes: 6 additions & 3 deletions gusto/core/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from gusto.core.coordinates import Coordinates
from gusto.core.function_spaces import Spaces, check_degree_args
from firedrake import (Constant, SpatialCoordinate, sqrt, CellNormal, cross,
inner, grad, VectorFunctionSpace, Function, FunctionSpace,
perp)
from firedrake import (
Constant, SpatialCoordinate, sqrt, CellNormal, cross, inner, grad,
VectorFunctionSpace, Function, FunctionSpace, perp, curl
)
import numpy as np


Expand Down Expand Up @@ -113,12 +114,14 @@ def __init__(self, mesh, dt, family, degree=None,
V = VectorFunctionSpace(mesh, "DG", sphere_degree)
self.outward_normals = Function(V).interpolate(CellNormal(mesh))
self.perp = lambda u: cross(self.outward_normals, u)
self.divperp = lambda u: inner(self.outward_normals, curl(u))
else:
kvec = [0.0]*dim
kvec[dim-1] = 1.0
self.k = Constant(kvec)
if dim == 2:
self.perp = perp
self.divperp = lambda u: -u[0].dx(1) + u[1].dx(0)

# -------------------------------------------------------------------- #
# Construct information relating to height/radius
Expand Down
1 change: 1 addition & 0 deletions gusto/spatial_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from gusto.spatial_methods.diffusion_methods import * # noqa
from gusto.spatial_methods.transport_methods import * # noqa
from gusto.spatial_methods.limiters import * # noqa
from gusto.spatial_methods.augmentation import * # noqa
159 changes: 159 additions & 0 deletions gusto/spatial_methods/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from abc import ABCMeta
from firedrake import (
MixedFunctionSpace, Function, TestFunctions, split, inner, dx, grad,
LinearVariationalProblem, LinearVariationalSolver, lhs, rhs, dot,
ds_b, ds_v, ds_t, ds, FacetNormal, TestFunction, TrialFunction,
transpose, nabla_grad, outer, dS, dS_h, dS_v, sign, jump, div,
Constant, sqrt, cross, curl
)
from firedrake.fml import subject
from gusto import (
time_derivative, transport, transporting_velocity, TransportEquationType,
logger
)


class Augmentation(object, metaclass=ABCMeta):
"""
Augments an equation with another equation to be solved simultaneously.
"""


class VorticityTransport(Augmentation):
"""
Solves the transport of a velocity field, simultaneously with the vorticity.
"""

### An argument to time discretisation or spatial method??

Check failure on line 27 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

E266

gusto/spatial_methods/augmentation.py:27:5: E266 too many leading '#' for block comment
# TODO: this all needs to be generalised

def __init__(self, domain, V_vel, V_vort, transpose_commutator=True,
supg=False, min_dx=None):

self.fs = MixedFunctionSpace((V_vel, V_vort))
self.X = Function(self.fs)
self.tests = TestFunctions(self.fs)

u = Function(V_vel)
F, Z = split(self.X)
test_F, test_Z = self.tests

if hasattr(domain.mesh, "_base_mesh"):
self.ds = ds_b + ds_t + ds_v
self.dS = dS_v + dS_h
else:
self.ds = ds
self.dS = dS

n = FacetNormal(domain.mesh)
sign_u = 0.5*(sign(dot(u, n)) + 1)
upw = lambda f: (sign_u('+')*f('+') + sign_u('-')*f('-'))

if domain.mesh.topological_dimension() == 2:
mix_test = test_F - domain.perp(grad(test_Z))
F_cross_u = Z*domain.perp(u)
elif domain.mesh.topological_dimension == 3:
mix_test = test_F - curl(test_Z)
F_cross_u = cross(Z, u)

time_deriv_form = inner(F, test_F)*dx + inner(Z, test_Z)*dx

# Standard vector invariant transport form -----------------------------
transport_form = (
# vorticity term
inner(mix_test, F_cross_u)*dx
+ inner(n, test_Z*Z*u)*self.ds
# 0.5*grad(v . F)
- 0.5 * div(mix_test) * inner(u, F)*dx
+ 0.5 * inner(mix_test, n) * inner(u, F)*self.ds
)

# Communtator of tranpose gradient terms -------------------------------
# This is needed for general vector transport
if transpose_commutator:
u_dot_nabla_F = dot(u, transpose(nabla_grad(F)))
transport_form += (
- inner(n, test_Z*domain.perp(u_dot_nabla_F))*self.ds
# + 0.5*grad(F).v
- 0.5 * dot(F, div(outer(u, mix_test)))*dx
+ 0.5 * inner(mix_test('+'), n('+'))*dot(jump(u), upw(F))*self.dS
# - 0.5*grad(v).F
+ 0.5 * dot(u, div(outer(F, mix_test)))*dx
- 0.5 * inner(mix_test('+'), n('+'))*dot(jump(F), upw(u))*self.dS
)

# SUPG terms -----------------------------------------------------------
# Add the vorticity residual to the transported vorticity,
# which damps enstrophy
if supg:
if min_dx is not None:
lamda = Constant(0.5)

Check failure on line 90 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

F841

gusto/spatial_methods/augmentation.py:90:17: F841 local variable 'lamda' is assigned to but never used
#TODO: decide on expression here

Check failure on line 91 in gusto/spatial_methods/augmentation.py

View workflow job for this annotation

GitHub Actions / Run linter

E265

gusto/spatial_methods/augmentation.py:91:17: E265 block comment should start with '# '
# tau = 0.5 / (lamda/domain.dt + sqrt(dot(u, u))/Constant(min_dx))
tau = 0.5*domain.dt*(1.0 + sqrt(dot(u, u))*domain.dt/Constant(min_dx))
else:
tau = 0.5*domain.dt

dxqp = dx(degree=3)

if domain.mesh.topological_dimension() == 2:
time_deriv_form -= inner(mix_test, tau*Z*domain.perp(u)/domain.dt)*dxqp
transport_form -= inner(
mix_test, tau*domain.perp(u)*domain.divperp(Z*domain.perp(u))
)*dxqp
if transpose_commutator:
transport_form -= inner(
mix_test,
tau*domain.perp(u)*domain.divperp(u_dot_nabla_F)
)*dxqp
elif domain.mesh.topological_dimension() == 3:
time_deriv_form -= inner(mix_test, tau*cross(Z, u)/domain.dt)*dxqp
transport_form -= inner(
mix_test, tau*cross(curl(Z*u), u)
)*dxqp
if transpose_commutator:
transport_form -= inner(
mix_test,
tau*cross(curl(u_dot_nabla_F), u)
)*dxqp

residual = (
time_derivative(time_deriv_form)
+ transport(
transport_form, TransportEquationType.vector_invariant
)
)
residual = transporting_velocity(residual, u)

self.residual = subject(residual, self.X)

self.x_in = Function(self.fs)
self.Z_in = Function(V_vort)
self.x_out = Function(self.fs)

vort_test = TestFunction(V_vort)
vort_trial = TrialFunction(V_vort)

F_in, _ = split(self.x_in)

eqn = (
inner(vort_trial, vort_test)*dx
+ inner(domain.perp(grad(vort_test)), F_in)*dx
+ vort_test*inner(n, domain.perp(F_in))*self.ds
)
problem = LinearVariationalProblem(
lhs(eqn), rhs(eqn), self.Z_in, constant_jacobian=True
)
self.solver = LinearVariationalSolver(problem)

def pre_apply(self, x_in):
self.x_in.subfunctions[0].assign(x_in)

def post_apply(self, x_out):
x_out.assign(self.x_out.subfunctions[0])

def update(self, x_in_mixed):
self.x_in.subfunctions[0].assign(x_in_mixed.subfunctions[0])
logger.info('Vorticity solve')
self.solver.solve()
self.x_in.subfunctions[1].assign(self.Z_in)
35 changes: 21 additions & 14 deletions gusto/time_discretisation/explicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class ExplicitRungeKutta(ExplicitTimeDiscretisation):
def __init__(self, domain, butcher_matrix, field_name=None,
fixed_subcycles=None, subcycle_by_courant=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, limiter=None, options=None):
solver_parameters=None, limiter=None, options=None,
augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand Down Expand Up @@ -123,7 +124,8 @@ def __init__(self, domain, butcher_matrix, field_name=None,
fixed_subcycles=fixed_subcycles,
subcycle_by_courant=subcycle_by_courant,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options,
augmentation=augmentation)
self.butcher_matrix = butcher_matrix
self.nbutcher = int(np.shape(self.butcher_matrix)[0])
self.rk_formulation = rk_formulation
Expand Down Expand Up @@ -210,7 +212,7 @@ def lhs(self):
if self.rk_formulation == RungeKuttaFormulation.increment:
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x_out, self.idx),
map_if_true=replace_subject(self.x_out, old_idx=self.idx, new_idx=self.new_idx),
map_if_false=drop)

return l.form
Expand All @@ -220,7 +222,7 @@ def lhs(self):
for stage in range(self.nStages):
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.field_i[stage+1], self.idx),
map_if_true=replace_subject(self.field_i[stage+1], old_idx=self.idx, new_idx=self.new_idx),
map_if_false=drop)
lhs_list.append(l)

Expand All @@ -229,7 +231,7 @@ def lhs(self):
if self.rk_formulation == RungeKuttaFormulation.linear:
l = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x1, self.idx),
map_if_true=replace_subject(self.x1, old_idx=self.idx, new_idx=self.new_idx),
map_if_false=drop)

return l.form
Expand All @@ -246,7 +248,7 @@ def rhs(self):
if self.rk_formulation == RungeKuttaFormulation.increment:
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.x1, old_idx=self.idx))
map_if_true=replace_subject(self.x1, old_idx=self.idx, new_idx=self.new_idx))

r = r.label_map(
lambda t: t.has_label(time_derivative),
Expand All @@ -273,7 +275,7 @@ def rhs(self):
for stage in range(self.nStages):
r = self.residual.label_map(
all_terms,
map_if_true=replace_subject(self.field_i[0], old_idx=self.idx))
map_if_true=replace_subject(self.field_i[0], old_idx=self.idx, new_idx=self.new_idx))

r = r.label_map(
lambda t: t.has_label(time_derivative),
Expand All @@ -284,7 +286,7 @@ def rhs(self):
r_i = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=drop,
map_if_false=replace_subject(self.field_i[i], old_idx=self.idx)
map_if_false=replace_subject(self.field_i[i], old_idx=self.idx, new_idx=self.new_idx)
)

r -= self.butcher_matrix[stage, i]*self.dt*r_i
Expand All @@ -297,8 +299,8 @@ def rhs(self):

r = self.residual.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x0, old_idx=self.idx),
map_if_false=replace_subject(self.field_rhs, old_idx=self.idx)
map_if_true=replace_subject(self.x0, old_idx=self.idx, new_idx=self.new_idx),
map_if_false=replace_subject(self.field_rhs, old_idx=self.idx, new_idx=self.new_idx)
)
r = r.label_map(
lambda t: t.has_label(time_derivative),
Expand All @@ -325,8 +327,8 @@ def rhs(self):
)
r_all_but_last = r_all_but_last.label_map(
lambda t: t.has_label(time_derivative),
map_if_true=replace_subject(self.x0, old_idx=self.idx),
map_if_false=replace_subject(self.field_rhs, old_idx=self.idx)
map_if_true=replace_subject(self.x0, old_idx=self.idx, new_idx=self.new_idx),
map_if_false=replace_subject(self.field_rhs, old_idx=self.idx, new_idx=self.new_idx)
)
r_all_but_last = r_all_but_last.label_map(
lambda t: t.has_label(time_derivative),
Expand Down Expand Up @@ -468,6 +470,9 @@ def apply_cycle(self, x_out, x_in):
x_out (:class:`Function`): the output field to be computed.
"""

if self.augmentation is not None:
self.augmentation.update(x_in)

# TODO: is this limiter application necessary?
if self.limiter is not None:
self.limiter.apply(x_in)
Expand Down Expand Up @@ -546,7 +551,8 @@ def __init__(
self, domain, field_name=None,
fixed_subcycles=None, subcycle_by_courant=None,
rk_formulation=RungeKuttaFormulation.increment,
solver_parameters=None, limiter=None, options=None
solver_parameters=None, limiter=None, options=None,
augmentation=None
):
"""
Args:
Expand Down Expand Up @@ -586,7 +592,8 @@ def __init__(
subcycle_by_courant=subcycle_by_courant,
rk_formulation=rk_formulation,
solver_parameters=solver_parameters,
limiter=limiter, options=options)
limiter=limiter, options=options,
augmentation=augmentation)


class RK4(ExplicitRungeKutta):
Expand Down
8 changes: 4 additions & 4 deletions gusto/time_discretisation/implicit_runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ImplicitRungeKutta(TimeDiscretisation):
# ---------------------------------------------------------------------------

def __init__(self, domain, butcher_matrix, field_name=None,
solver_parameters=None, options=None,):
solver_parameters=None, options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -75,7 +75,7 @@ def __init__(self, domain, butcher_matrix, field_name=None,
"""
super().__init__(domain, field_name=field_name,
solver_parameters=solver_parameters,
options=options)
options=options, augmentation=augmentation)
self.butcher_matrix = butcher_matrix
self.nStages = int(np.shape(self.butcher_matrix)[1])

Expand Down Expand Up @@ -165,7 +165,7 @@ class ImplicitMidpoint(ImplicitRungeKutta):
y^(n+1) = y^n + dt*k0 \n
"""
def __init__(self, domain, field_name=None, solver_parameters=None,
options=None):
options=None, augmentation=None):
"""
Args:
domain (:class:`Domain`): the model's domain object, containing the
Expand All @@ -182,7 +182,7 @@ def __init__(self, domain, field_name=None, solver_parameters=None,
butcher_matrix = np.array([[0.5], [1.]])
super().__init__(domain, butcher_matrix, field_name,
solver_parameters=solver_parameters,
options=options)
options=options, augmentation=augmentation)


class QinZhang(ImplicitRungeKutta):
Expand Down
Loading

0 comments on commit 50ef06b

Please sign in to comment.