From cec368ebb0f1666ec433391d316c2b99c4d4d292 Mon Sep 17 00:00:00 2001 From: kmdeck Date: Thu, 11 Jul 2024 17:41:19 -0700 Subject: [PATCH] changes to run for implicit flat fields --- src/MatrixFields/MatrixFields.jl | 2 ++ src/MatrixFields/matrix_shape.jl | 14 +++++++++++++- src/MatrixFields/single_field_solver.jl | 6 +++--- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 6d2630ec6e..7c0fc944e3 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -88,6 +88,8 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{ } where { V <: AbstractData{<:BandMatrixRow}, S <: Union{ + Spaces.SpectralElementSpace2D, + Spaces.PointSpace, Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace, Operators.PlaceholderSpace, # so that this can exist inside cuda kernels diff --git a/src/MatrixFields/matrix_shape.jl b/src/MatrixFields/matrix_shape.jl index 6ebe32a1c8..21434fc83c 100644 --- a/src/MatrixFields/matrix_shape.jl +++ b/src/MatrixFields/matrix_shape.jl @@ -3,6 +3,12 @@ struct Square <: AbstractMatrixShape end struct FaceToCenter <: AbstractMatrixShape end struct CenterToFace <: AbstractMatrixShape end +matrix_shape(matrix_field) = matrix_shape( + matrix_field, + axes(matrix_field), +) + + """ matrix_shape(matrix_field, [matrix_space]) @@ -11,11 +17,17 @@ whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and whether `matrix_space` is on cell centers or cell faces. By default, `matrix_space` is set to `axes(matrix_field)`. """ -matrix_shape(matrix_field, matrix_space = axes(matrix_field)) = _matrix_shape( +matrix_shape(matrix_field, matrix_space) = _matrix_shape( eltype(outer_diagonals(eltype(matrix_field))), matrix_space.staggering, ) +function matrix_shape(matrix_field, matrix_space::Spaces.AbstractSpectralElementSpace) + @assert eltype(matrix_field) <: DiagonalMatrixRow + Square() +end + + _matrix_shape(::Type{Int}, _) = Square() _matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellCenter) = FaceToCenter() _matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellFace) = CenterToFace() diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index f3712f28d5..e785c7cee0 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -98,10 +98,10 @@ end function _single_field_solve_col!( ::ClimaComms.AbstractCPUDevice, - cache::Fields.ColumnField, - x::Fields.ColumnField, + cache, + x, A, - b::Fields.ColumnField, + b, ) if A isa Fields.ColumnField band_matrix_solve!(