Skip to content

Commit

Permalink
changes to run for implicit flat fields
Browse files Browse the repository at this point in the history
  • Loading branch information
kmdeck committed Jul 26, 2024
1 parent 95fd5b2 commit cec368e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
} where {
V <: AbstractData{<:BandMatrixRow},
S <: Union{
Spaces.SpectralElementSpace2D,
Spaces.PointSpace,
Spaces.FiniteDifferenceSpace,
Spaces.ExtrudedFiniteDifferenceSpace,
Operators.PlaceholderSpace, # so that this can exist inside cuda kernels
Expand Down
14 changes: 13 additions & 1 deletion src/MatrixFields/matrix_shape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ struct Square <: AbstractMatrixShape end
struct FaceToCenter <: AbstractMatrixShape end
struct CenterToFace <: AbstractMatrixShape end

matrix_shape(matrix_field) = matrix_shape(
matrix_field,
axes(matrix_field),
)


"""
matrix_shape(matrix_field, [matrix_space])
Expand All @@ -11,11 +17,17 @@ whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and
whether `matrix_space` is on cell centers or cell faces. By default,
`matrix_space` is set to `axes(matrix_field)`.
"""
matrix_shape(matrix_field, matrix_space = axes(matrix_field)) = _matrix_shape(
matrix_shape(matrix_field, matrix_space) = _matrix_shape(
eltype(outer_diagonals(eltype(matrix_field))),
matrix_space.staggering,
)

function matrix_shape(matrix_field, matrix_space::Spaces.AbstractSpectralElementSpace)
@assert eltype(matrix_field) <: DiagonalMatrixRow
Square()
end


_matrix_shape(::Type{Int}, _) = Square()
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellCenter) = FaceToCenter()
_matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellFace) = CenterToFace()
Expand Down
6 changes: 3 additions & 3 deletions src/MatrixFields/single_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ end

function _single_field_solve_col!(
::ClimaComms.AbstractCPUDevice,
cache::Fields.ColumnField,
x::Fields.ColumnField,
cache,
x,
A,
b::Fields.ColumnField,
b,
)
if A isa Fields.ColumnField
band_matrix_solve!(
Expand Down

0 comments on commit cec368e

Please sign in to comment.