From c5aeb16f9810e81afbae48e4534dc553d9a6582b Mon Sep 17 00:00:00 2001 From: kmdeck Date: Thu, 11 Jul 2024 17:41:19 -0700 Subject: [PATCH] implicit flat fields --- src/MatrixFields/MatrixFields.jl | 3 +- src/MatrixFields/matrix_shape.jl | 23 +++++++-- src/MatrixFields/single_field_solver.jl | 6 +-- test/MatrixFields/flat_spaces.jl | 66 +++++++++++++++++++++++++ test/runtests.jl | 2 + 5 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 test/MatrixFields/flat_spaces.jl diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 6d2630ec6e..64f06402a2 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -88,8 +88,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{ } where { V <: AbstractData{<:BandMatrixRow}, S <: Union{ - Spaces.FiniteDifferenceSpace, - Spaces.ExtrudedFiniteDifferenceSpace, + Spaces.AbstractSpace, Operators.PlaceholderSpace, # so that this can exist inside cuda kernels }, } diff --git a/src/MatrixFields/matrix_shape.jl b/src/MatrixFields/matrix_shape.jl index 6ebe32a1c8..020d9424f1 100644 --- a/src/MatrixFields/matrix_shape.jl +++ b/src/MatrixFields/matrix_shape.jl @@ -3,19 +3,34 @@ struct Square <: AbstractMatrixShape end struct FaceToCenter <: AbstractMatrixShape end struct CenterToFace <: AbstractMatrixShape end +matrix_shape(matrix_field) = matrix_shape(matrix_field, axes(matrix_field)) + """ matrix_shape(matrix_field, [matrix_space]) -Returns either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on -whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and -whether `matrix_space` is on cell centers or cell faces. By default, +Returns the matrix shape for a matrix field defined on the `matrix_space`. By default, `matrix_space` is set to `axes(matrix_field)`. + +When the matrix_space is a finite difference space (extruded or otherwise): the shape is +either `Square()`, `FaceToCenter()`, or `CenterToFace()`, depending on +whether the diagonal indices of `matrix_field` are `Int`s or `PlusHalf`s and +whether `matrix_space` is on cell centers or cell faces. + +When the matrix_space is a spectral element or point space: only a Square() shape is supported. """ -matrix_shape(matrix_field, matrix_space = axes(matrix_field)) = _matrix_shape( +matrix_shape(matrix_field, matrix_space) = _matrix_shape( eltype(outer_diagonals(eltype(matrix_field))), matrix_space.staggering, ) +function matrix_shape( + matrix_field, + matrix_space::Union{Spaces.AbstractSpectralElementSpace, Spaces.PointSpace}, +) + @assert eltype(matrix_field) <: DiagonalMatrixRow + Square() +end + _matrix_shape(::Type{Int}, _) = Square() _matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellCenter) = FaceToCenter() _matrix_shape(::Type{PlusHalf{Int}}, ::Spaces.CellFace) = CenterToFace() diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index f3712f28d5..e785c7cee0 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -98,10 +98,10 @@ end function _single_field_solve_col!( ::ClimaComms.AbstractCPUDevice, - cache::Fields.ColumnField, - x::Fields.ColumnField, + cache, + x, A, - b::Fields.ColumnField, + b, ) if A isa Fields.ColumnField band_matrix_solve!( diff --git a/test/MatrixFields/flat_spaces.jl b/test/MatrixFields/flat_spaces.jl new file mode 100644 index 0000000000..c4ad1a730d --- /dev/null +++ b/test/MatrixFields/flat_spaces.jl @@ -0,0 +1,66 @@ +include("matrix_field_test_utils.jl") +import ClimaCore.MatrixFields: @name, ⋅ +function point_space(::Type{FT}) where {FT} + comms_ctx = ClimaComms.SingletonCommsContext(comms_device) + coord = Geometry.ZPoint(FT(0)) + return Spaces.PointSpace(comms_ctx, coord) +end + + +function se2d_space(::Type{FT}) where {FT} + comms_ctx = ClimaComms.SingletonCommsContext(comms_device) + domain_x = Domains.IntervalDomain( + Geometry.XPoint(FT(0)), + Geometry.XPoint(FT(1)); + periodic = true, + ) + domain_y = Domains.IntervalDomain( + Geometry.YPoint(FT(0)), + Geometry.YPoint(FT(1)); + periodic = true, + ) + plane = Domains.RectangleDomain(domain_x, domain_y) + + mesh = Meshes.RectilinearMesh(plane, 1, 1) + grid_topology = Topologies.Topology2D(comms_ctx, mesh) + quad = Spaces.Quadratures.GLL{1 + 1}() + space = Spaces.SpectralElementSpace2D(grid_topology, quad) + return space +end + + +get_jac_type( + space::Union{Spaces.PointSpace, Spaces.SpectralElementSpace2D}, + FT, +) = MatrixFields.DiagonalMatrixRow{FT} + +function get_j_field(space, FT) + jac_type = get_jac_type(space, FT) + field = Fields.Field(jac_type, space) + fill!(parent(field), 1) + return field +end + +implicit_vars = (@name(tmp.v1), @name(tmp.v2)) +for FT in (Float32, Float64) + ps = point_space(FT) + ses = se2d_space(FT) + v1 = Fields.zeros(ps) + v2 = Fields.zeros(ses) + Y = Fields.FieldVector(; :tmp => (; :v1 => v1, :v2 => v2)) + implicit_blocks = MatrixFields.unrolled_map( + var -> + (var, var) => + get_j_field(axes(MatrixFields.get_field(Y, var)), FT), + implicit_vars, + ) + matrix = MatrixFields.FieldMatrix(implicit_blocks...) + alg = MatrixFields.BlockDiagonalSolve() + solver = MatrixFields.FieldMatrixSolver(alg, matrix, Y) + b1 = random_field(FT, ps) + b2 = random_field(FT, ses) + x = similar(Y) + b = Fields.FieldVector(; :tmp => (; :v1 => b1, :v2 => b2)) + MatrixFields.field_matrix_solve!(solver, x, matrix, b) + @test x == b +end diff --git a/test/runtests.jl b/test/runtests.jl index f794aa12dd..851558f337 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,6 +81,8 @@ UnitTest("MatrixFields - non-scalar broadcasting (1)" ,"MatrixFields/matrix_fiel UnitTest("MatrixFields - non-scalar broadcasting (2)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_2.jl"), UnitTest("MatrixFields - non-scalar broadcasting (3)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"), UnitTest("MatrixFields - non-scalar broadcasting (4)" ,"MatrixFields/matrix_fields_broadcasting/test_non_scalar_4.jl"), +UnitTest("MatrixFields - flat spaces" ,"MatrixFields/flat_spaces.jl"), + # UnitTest("MatrixFields - matrix field broadcast" ,"MatrixFields/matrix_field_broadcasting.jl"), # too long # UnitTest("MatrixFields - operator matrices" ,"MatrixFields/operator_matrices.jl"), # too long # UnitTest("MatrixFields - field matrix solvers" ,"MatrixFields/field_matrix_solvers.jl"), # too long