Skip to content

Commit

Permalink
implicit flat fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kmdeck committed Aug 5, 2024
1 parent 8fd60da commit c5aeb16
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
} where {
V <: AbstractData{<:BandMatrixRow},
S <: Union{
Spaces.FiniteDifferenceSpace,
Spaces.ExtrudedFiniteDifferenceSpace,
Spaces.AbstractSpace,
Operators.PlaceholderSpace, # so that this can exist inside cuda kernels
},
}
Expand Down
23 changes: 19 additions & 4 deletions src/MatrixFields/matrix_shape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,34 @@ struct Square <: AbstractMatrixShape end
struct FaceToCenter <: AbstractMatrixShape end
struct CenterToFace <: AbstractMatrixShape end

matrix_shape(matrix_field) = matrix_shape(matrix_field, axes(matrix_field))

"""
matrix_shape(matrix_field, [matrix_space])
Returns either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on
whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and
whether `matrix_space` is on cell centers or cell faces. By default,
Returns the matrix shape for a matrix field defined on the `matrix_space`. By default,
`matrix_space` is set to `axes(matrix_field)`.
When the matrix_space is a finite difference space (extruded or otherwise): the shape is
either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on
whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and
whether `matrix_space` is on cell centers or cell faces.
When the matrix_space is a spectral element or point space: only a Square() shape is supported.
"""
matrix_shape(matrix_field, matrix_space = axes(matrix_field)) = _matrix_shape(
matrix_shape(matrix_field, matrix_space) = _matrix_shape(
eltype(outer_diagonals(eltype(matrix_field))),
matrix_space.staggering,
)

function matrix_shape(
matrix_field,
matrix_space::Union{Spaces.AbstractSpectralElementSpace, Spaces.PointSpace},
)
@assert eltype(matrix_field) <: DiagonalMatrixRow
Square()
end

_matrix_shape(::Type{Int}, _) = Square()
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellCenter) = FaceToCenter()
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellFace) = CenterToFace()
Expand Down
6 changes: 3 additions & 3 deletions src/MatrixFields/single_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ end

function _single_field_solve_col!(
::ClimaComms.AbstractCPUDevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
cache,
x,
A,
b::Fields.ColumnField,
b,
)
if A isa Fields.ColumnField
band_matrix_solve!(
Expand Down
66 changes: 66 additions & 0 deletions test/MatrixFields/flat_spaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
include("matrix_field_test_utils.jl")
import ClimaCore.MatrixFields: @name,
function point_space(::Type{FT}) where {FT}
comms_ctx = ClimaComms.SingletonCommsContext(comms_device)
coord = Geometry.ZPoint(FT(0))
return Spaces.PointSpace(comms_ctx, coord)
end


function se2d_space(::Type{FT}) where {FT}
comms_ctx = ClimaComms.SingletonCommsContext(comms_device)
domain_x = Domains.IntervalDomain(
Geometry.XPoint(FT(0)),
Geometry.XPoint(FT(1));
periodic = true,
)
domain_y = Domains.IntervalDomain(
Geometry.YPoint(FT(0)),
Geometry.YPoint(FT(1));
periodic = true,
)
plane = Domains.RectangleDomain(domain_x, domain_y)

mesh = Meshes.RectilinearMesh(plane, 1, 1)
grid_topology = Topologies.Topology2D(comms_ctx, mesh)
quad = Spaces.Quadratures.GLL{1 + 1}()
space = Spaces.SpectralElementSpace2D(grid_topology, quad)
return space
end


get_jac_type(
space::Union{Spaces.PointSpace, Spaces.SpectralElementSpace2D},
FT,
) = MatrixFields.DiagonalMatrixRow{FT}

function get_j_field(space, FT)
jac_type = get_jac_type(space, FT)
field = Fields.Field(jac_type, space)
fill!(parent(field), 1)
return field
end

implicit_vars = (@name(tmp.v1), @name(tmp.v2))
for FT in (Float32, Float64)
ps = point_space(FT)
ses = se2d_space(FT)
v1 = Fields.zeros(ps)
v2 = Fields.zeros(ses)
Y = Fields.FieldVector(; :tmp => (; :v1 => v1, :v2 => v2))
implicit_blocks = MatrixFields.unrolled_map(
var ->
(var, var) =>
get_j_field(axes(MatrixFields.get_field(Y, var)), FT),
implicit_vars,
)
matrix = MatrixFields.FieldMatrix(implicit_blocks...)
alg = MatrixFields.BlockDiagonalSolve()
solver = MatrixFields.FieldMatrixSolver(alg, matrix, Y)
b1 = random_field(FT, ps)
b2 = random_field(FT, ses)
x = similar(Y)
b = Fields.FieldVector(; :tmp => (; :v1 => b1, :v2 => b2))
MatrixFields.field_matrix_solve!(solver, x, matrix, b)
@test x == b
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ UnitTest("MatrixFields - non-scalar broadcasting (1)" ,"MatrixFields/matrix_fiel
UnitTest("MatrixFields - non-scalar broadcasting (2)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"),
UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"),
UnitTest("MatrixFields - flat spaces" ,"MatrixFields/flat_spaces.jl"),

# UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long
# UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long
# UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long
Expand Down

0 comments on commit c5aeb16

Please sign in to comment.