Skip to content

Commit

Permalink
Merge #1436
Browse files Browse the repository at this point in the history
1436: Add FieldMatrix and linear solvers r=dennisYatunin a=dennisYatunin

## Purpose

Third PR of #1230. Adds an interface for specifying block matrices of matrix fields and solving linear systems with these matrices. This will replace what is currently in `ClimaAtmos.jl/src/prognostic_equations/implicit/schur_complement_W.jl`, generalizing it for implicit diffusion and implicit EDMF.

## Content

### Main Changes

- Add the `FieldName` struct, which is a singleton type that represents a chain of `getproperty` calls
  - Add the ``@name`` macro for constructing `FieldName`s, which checks whether its input expression is a syntactically valid chain of `getproperty` calls before calling the default constructor.
  - A `name` can be used to access a property or sub-property of an object `x` by calling `get_field(x, name)`.
  - An `internal_name` can be appended onto another `name` in order to access a property or sub-property of `get_field(x, name)`.
- Add the `FieldNameDict` struct, which maps each key in a set of `FieldVectorKeys` or `FieldMatrixKeys` (see below) to a `Field` or some other object.
  - There are currently four subtypes of `FieldNameDict`:
    - `FieldMatrix` (the only user-facing subtype), which maps `NTuple{2, FieldName}`s to `ColumnwiseBandMatrixField`s or multiples of `LinearAlgebra.I`
    - `FieldVectorView`, which maps `FieldName`s to `Field`s; this is used to wrap a `FieldVector` so that it can be used in conjunction with a `FieldMatrix`
    - `FieldVectorViewBroadcasted` and `FieldMatrixBroadcasted`, each of which can store unevaluated `Base.AbstractBroadcasted` objects, in addition to what `FieldVectorView` and `FieldMatrix` can already store
  - Supports standard `AbstractDict` functions like `keys` and `pairs`.
  - An individual block of a `FieldNameDict` can be accessed by calling `dict[key]`, and a range of blocks can be accessed by calling `dict[set]`, where `set` is a `FieldNameSet`.
  - Given a `FieldMatrix` `A`, a similar matrix that only contains identity matrix blocks can be constructed with `one(A)`.
  - `FieldNameDict`s can be used in broadcast expressions, which support the following operations:
    - `+`, `-`, or `*`, where each input is either a `FieldNameDict` or a `FieldVector`
    - `inv`, where the input is a diagonal `FieldMatrix`
  - The new methods for `Base.Broadcast.broadcasted` construct chains of `Field` broadcast expressions from `FieldNameDict` broadcast expressions on the fly, somewhat similarly to how broadcasting works for ClimaCore operators.
- Add the `FieldMatrixSolver` struct, which solves an equation of the form `A * x = b`, where `A` is a `FieldMatrix` and where `x` and `b` are `FieldVector`s.
  - Add the `field_matrix_solve!` function, which works just like `ldiv!(x, A, b)`, except that it also takes a `FieldMatrixSolver` as its first argument.
  - Add four `FieldMatrixSolverAlgorithm`s, which can be nested inside of each other to build up more specialized algorithms:
    - `BlockDiagonalSolve`, which runs a "single field solver" for each block of the block diagonal matrix `A`; the single field solver can handle the four types of blocks:
      - Multiples of `LinearAlgebra.I`
      - Diagonal `ColumnwiseBandMatrixField`s
      - Tri-diagonal `ColumnwiseBandMatrixField`s (implementation of the Thomas algorithm)
      - Penta-diagonal `ColumnwiseBandMatrixField`s (implementation of the PTRANS-I algorithm)
    - `BlockLowerTriangularSolve`, which uses forward substitution to solve the equation for a block lower triangular matrix `A`
    - `SchurComplementSolve`, which generalizes what is currently in ClimaAtmos's `schur_complement_W.jl` file to any block matrix `A` with a diagonal block in the top-left corner
    - `ApproximateFactorizationSolve`, which lets us use "operator splitting" to approximately solve the equation for a diagonally dominant block matrix `A`
- Add documentation for how to specify a `FieldMatrix` and use it in a linear solver, along with internal documentation for the new `FieldName`-based infrastructure.
- Add unit tests for correctness, type stability, and allocations, and run them on both CPUs and GPUs through CI.
  - Test each single field solver on both a cell-center and a cell-face field.
  - Test each `FieldMatrixSolverAlgorithm` on block diagonal, block lower triangular, and block dense matrices.
  - Test solvers with identical structures to what we will use in ClimaAtmos for the following examples:
    - Dry dycore with implicit acoustic waves
    - Dry dycore with implicit acoustic waves and diffusion
    - Dry dycore + prognostic EDMF with implicit acoustic waves and SGS fluxes
    - Moist dycore + prognostic EDMF + tracers with implicit acoustic waves and SGS fluxes

### Internal Chagnes

- Add a collection of "unrolled functions", whose return values can be inferred during compilation if their input values are all singleton types.
  - These are all implemented as combinations of `unrolled_zip`, `unrolled_map`, and `unrolled_foldl`.
  - Several of these need to have their recursion limits disabled for the unit tests to be type stable.
- Add the `FieldNameTree` struct, which stores every `FieldName` that can be used to access `x` with `get_field(x, name)`.
  - A `name` can be checked for validity by calling `has_subtree_at_name(tree, name)`.
  - The children of `name` (the `FieldName`s that can be used to access the properties of `get_field(x, name)`) can be obtained by calling `child_names(name, tree)`.
- Add the `FieldNameSet` struct, which stores a set of `FieldVectorKeys` (each of which is a `FieldName`) or a set of `FieldMatrixKeys` (each of which is an `NTuple{2, FieldName}`).
  - Roughly equivalent to the built-in `KeySet` for `AbstractDict`s, but specialized for `FieldNameDict`s.
  - Supports standard `AbstractSet` functions like `union` and `setdiff`, as well as custom functions like `set_complement` and `matrix_product_keys`.
  - Handles overlaps between `FieldName`s (that is, situations where one property of `x` lies inside another property of `x`) by storing a `FieldNameTree` that contains all available `FieldName`s.
- Disable the recursion limits for several functions used to manipulate `FieldName`s, `FieldNameTree`s, and `FieldNameSet`s, as this is necessary in order for the unit tests to be type stable.
- Remove the methods for `RecursiveApply.rmul` that specialize on `Number`, which is also necessary in order for the unit tests to be type stable.
  - These methods are no longer required, now that #1454 has been merged in.
- Add support for calling `inv` on `BandMatrixRow`s.
- Qualify the use of `CUDA.`@allowscalar`,` per Charlie's suggestion.
- Fix some type instabilities in `matrix_field_test_utils.jl`.
- Remove an unused variable name in `operator_matrices.jl`.


Co-authored-by: Dennis Yatunin <[email protected]>
  • Loading branch information
bors[bot] and dennisYatunin authored Oct 11, 2023
2 parents 88c3f78 + e96451e commit 10f2ae3
Show file tree
Hide file tree
Showing 17 changed files with 2,803 additions and 29 deletions.
17 changes: 17 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,23 @@ steps:
slurm_gpus: 1
slurm_mem: 40GB

- label: "Unit: field names"
key: unit_field_names
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_names.jl"

- label: "Unit: field matrix solvers (CPU)"
key: unit_field_matrix_solvers_cpu
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/field_matrix_solvers.jl"
agents:
slurm_mem: 40GB

- label: "Unit: field matrix solvers (GPU)"
key: unit_field_matrix_solvers_gpu
command: "julia --color=yes --project=test test/MatrixFields/field_matrix_solvers.jl"
agents:
slurm_gpus: 1
slurm_mem: 40GB

- group: "Unit: Hypsography"
steps:

Expand Down
18 changes: 18 additions & 0 deletions docs/src/matrix_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ MultiplyColumnwiseBandMatrixField
operator_matrix
```

# Linear Solvers

```@docs
FieldMatrixSolverAlgorithm
FieldMatrixSolver
field_matrix_solve!
BlockDiagonalSolve
BlockLowerTriangularSolve
SchurComplementSolve
ApproximateFactorizationSolve
```

## Internals

```@docs
Expand All @@ -39,6 +51,12 @@ matrix_shape
column_axes
AbstractLazyOperator
replace_lazy_operator
FieldName
@name
FieldNameTree
FieldNameSet
FieldNameDict
field_vector_view
```

## Utilities
Expand Down
52 changes: 41 additions & 11 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
MatrixFields
This module adds support for defining and manipulating `Field`s that represent
matrices. Specifically, it specifies the `BandMatrixRow` type, which can be used
matrices. Specifically, it adds the `BandMatrixRow` type, which can be used
to store the entries of a band matrix. A `Field` of `BandMatrixRow`s on a
`FiniteDifferenceSpace` can be interpreted as a band matrix by vertically
concatenating the `BandMatrixRow`s. Similarly, a `Field` of `BandMatrixRow`s on
Expand All @@ -19,46 +19,60 @@ for them:
- Integration with `RecursiveApply`, e.g., the entries of `matrix_field` can be
`Tuple`s or `NamedTuple`s instead of single values, which allows
`matrix_field` to represent multiple band matrices at the same time
- Integration with `Operators`, e.g., the `matrix_field` that gets applied to
the argument of any `FiniteDifferenceOperator` `op` can be obtained using
the `FiniteDifferenceOperator` `operator_matrix(op)`
- Conversions to native array types, e.g., `field2arrays(matrix_field)` can
convert each column of `matrix_field` into a `BandedMatrix` from
`BandedMatrices.jl`
- Custom printing, e.g., `matrix_field` gets displayed as a `BandedMatrix`,
specifically, as the `BandedMatrix` that corresponds to its first column
This module also adds support for defining and manipulating sparse block
matrices of `Field`s. Specifically, it adds the `FieldMatrix` type, which is a
dictionary that maps pairs of `FieldName`s to `ColumnwiseBandMatrixField`s or
multiples of `LinearAlgebra.I`. This comes with the following functionality:
- Addition and subtraction, e.g., `@. field_matrix1 + field_matrix2`
- Matrix-vector multiplication, e.g., `@. field_matrix * field_vector`
- Matrix-matrix multiplication, e.g., `@. field_matrix1 * field_matrix2`
- Integration with `RecursiveApply`, e.g., the entries of `field_matrix` can be
specified either as matrix `Field`s of `Tuple`s or `NamedTuple`s, or as
separate matrix `Field`s of single values
- The ability to solve linear equations using `FieldMatrixSolver`, which is a
generalization of `ldiv!` that is designed to optimize solver performance
"""
module MatrixFields

import CUDA: @allowscalar
import LinearAlgebra: UniformScaling, Adjoint, AdjointAbsVec
import CUDA
import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv
import StaticArrays: SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import ClimaComms

import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
import ..RecursiveApply: , ,
import ..DataLayouts: AbstractData
import ..Geometry
import ..Spaces
import ..Fields
import ..Operators

export
export DiagonalMatrixRow,
BidiagonalMatrixRow,
TridiagonalMatrixRow,
QuaddiagonalMatrixRow,
PentadiagonalMatrixRow
export FieldVectorKeys, FieldVectorView, FieldVectorViewBroadcasted
export FieldMatrixKeys, FieldMatrix, FieldMatrixBroadcasted
export , FieldMatrixSolver, field_matrix_solve!

# Types that are teated as single values when using matrix fields.
const SingleValue =
Union{Number, Geometry.AxisTensor, Geometry.AdjointAxisTensor}

include("band_matrix_row.jl")
include("rmul_with_projection.jl")
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")

const ColumnwiseBandMatrixField{V, S} = Fields.Field{
V,
Expand All @@ -71,6 +85,19 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{
},
}

include("rmul_with_projection.jl")
include("matrix_shape.jl")
include("matrix_multiplication.jl")
include("lazy_operators.jl")
include("operator_matrices.jl")
include("field2arrays.jl")
include("unrolled_functions.jl")
include("field_name.jl")
include("field_name_set.jl")
include("field_name_dict.jl")
include("field_matrix_solver.jl")
include("single_field_solver.jl")

function Base.show(io::IO, field::ColumnwiseBandMatrixField)
print(io, eltype(field), "-valued Field")
if eltype(eltype(field)) <: Number
Expand All @@ -82,7 +109,10 @@ function Base.show(io::IO, field::ColumnwiseBandMatrixField)
end
column_field = Fields.column(field, 1, 1, 1)
io = IOContext(io, :compact => true, :limit => true)
@allowscalar Base.print_array(io, column_field2array_view(column_field))
CUDA.@allowscalar Base.print_array(
io,
column_field2array_view(column_field),
)
else
# When a BandedMatrix with non-number entries is printed, it currently
# either prints in an illegible format (e.g., if it has AxisTensor or
Expand Down
6 changes: 6 additions & 0 deletions src/MatrixFields/band_matrix_row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,9 @@ Base.:*(value::SingleValue, row::BandMatrixRow) =

Base.:/(row::BandMatrixRow, value::Number) =
map(entry -> rdiv(entry, value), row)

inv(row::DiagonalMatrixRow) = DiagonalMatrixRow(inv(row[0]))
inv(::BandMatrixRow{ld, bw}) where {ld, bw} = error(
"The inverse of a matrix with $bw diagonals is (usually) a dense matrix, \
so it cannot be represented using BandMatrixRows",
)
4 changes: 2 additions & 2 deletions src/MatrixFields/field2arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ function column_field2array(field::Fields.FiniteDifferenceField)
last_row = matrix_d < n_cols - n_rows ? n_rows : n_cols - matrix_d

diagonal_data_view = view(diagonal_data, first_row:last_row)
@allowscalar copyto!(matrix_diagonal, diagonal_data_view)
CUDA.@allowscalar copyto!(matrix_diagonal, diagonal_data_view)
end
return matrix
else # field represents a vector
return @allowscalar Array(column_field2array_view(field))
return CUDA.@allowscalar Array(column_field2array_view(field))
end
end

Expand Down
Loading

0 comments on commit 10f2ae3

Please sign in to comment.