Skip to content

Commit

Permalink
extend adapt_structure for bcs and ops
Browse files Browse the repository at this point in the history
  • Loading branch information
juliasloan25 committed May 1, 2024
1 parent 6eb0f0d commit e557b21
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 20 deletions.
5 changes: 3 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -629,11 +629,11 @@ steps:

- label: "Unit: matrix multiplication recursion example (CPU)"
key: cpu_matrix_multiplication_recursion
command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_multiplication_recursion.jl"
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/matrix_multiplication_recursion.jl"

- label: "Unit: matrix multiplication recursion example (GPU)"
key: gpu_matrix_multiplication_recursion
command: "julia --color=yes --project=test test/MatrixFields/matrix_multiplication_recursion.jl"
command: "julia --color=yes --project=.buildkite test/MatrixFields/matrix_multiplication_recursion.jl"
soft_fail: true
agents:
slurm_gpus: 1
Expand Down Expand Up @@ -862,6 +862,7 @@ steps:
- label: "Unit: matrix field broadcasting (GPU)"
key: unit_matrix_field_broadcasting_gpu_non_scalar_3
command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl"
soft_fail: true
agents:
slurm_gpus: 1
slurm_mem: 10GB
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ ClimaCore.jl Release Notes
main
-------

- ![][badge-🐛bugfix] Extend adapt_structure for all operator and boundary
condition types. Also use `unrolled_map` in `multiply_matrix_at_index` to
avoid the recursive inference limit when compiling nested matrix operations.
PR [#1684](https://github.com/CliMA/ClimaCore.jl/pull/1684)
- ![][badge-🤖precisionΔ] ![][badge-💥breaking] `Remapper`s can now process
multiple `Field`s at the same time if created with some `buffer_lenght > 1`.
PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669))
Expand Down
6 changes: 3 additions & 3 deletions src/MatrixFields/matrix_multiplication.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ for `MultiplyColumnwiseBandMatrixField()`.
What follows is a derivation of the algorithm used by this operator with
single-column `Field`s. For `Field`s on multiple columns, the same computation
is done for each column.
In this derivation, we will use ``M_1`` and ``M_2`` to denote two
`ColumnwiseBandMatrixField`s, and we will use ``V`` to denote a regular
(vector-like) `Field`. For both ``M_1`` and ``M_2``, we will use the array-like
Expand Down Expand Up @@ -169,7 +169,7 @@ The values of ``i`` in this range are considered to be in the "interior" of the
operator, while those not in this range (for which we cannot make these
simplifications) are considered to be on the "boundary".
## 2.2 ``ld_{prod}`` and ``ud_{prod}``
## 2.2 ``ld_{prod}`` and ``ud_{prod}``
We only need to compute ``(M_1 ⋅ M_2)[i][d_{prod}]`` for values of ``d_{prod}``
that correspond to a nonempty sum in the interior, i.e, those for which
Expand Down Expand Up @@ -375,7 +375,7 @@ function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc)
# of as a map from boundary_modified_ld1 to boundary_modified_ud1. For
# simplicity, use zero padding for rows that are outside the matrix.
# Wrap the rows in a BandMatrixRow so that they can be easily indexed.
matrix2_rows = map((ld1:ud1...,)) do d
matrix2_rows = unrolled_map((ld1:ud1...,)) do d
# TODO: Use @propagate_inbounds_meta instead of @inline_meta.
Base.@_inline_meta
if isnothing(bc) ||
Expand Down
3 changes: 3 additions & 0 deletions src/MatrixFields/operator_matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ struct LazyOneArgFDOperatorMatrix{O <: OneArgFDOperator} <: AbstractLazyOperator
op::O
end

Adapt.adapt_structure(to, op::FDOperatorMatrix) =
FDOperatorMatrix(Adapt.adapt_structure(to, op.op))

# Since the operator matrix of a one-argument operator does not have any
# arguments, we need to use a lazy operator to add an argument.
replace_lazy_operator(space, lazy_op::LazyOneArgFDOperatorMatrix) =
Expand Down
42 changes: 27 additions & 15 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import ..Utilities: PlusHalf, half
import ..Utilities: PlusHalf, half, UnrolledFunctions

const AllFiniteDifferenceSpace =
Union{Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace}
Expand Down Expand Up @@ -240,12 +240,6 @@ Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} =
Adapt.adapt(to, sbc.axes),
)

function Adapt.adapt_structure(to, op::FiniteDifferenceOperator)
op
end



function Base.Broadcast.instantiate(sbc::StencilBroadcasted)
op = sbc.op
# recursively instantiate the arguments to allocate intermediate work arrays
Expand Down Expand Up @@ -2610,18 +2604,36 @@ Base.@propagate_inbounds function stencil_right_boundary(
stencil_interior(op, loc, space, idx - 1, hidx, arg)
end

function Adapt.adapt_structure(to, op::DivergenceF2C)
DivergenceF2C(map(bc -> Adapt.adapt_structure(to, bc), op.bcs))
end
"""
unionall_type(::Type{T})
function Adapt.adapt_structure(to, bc::SetValue)
SetValue(Adapt.adapt_structure(to, bc.val))
end
Extract the type of the input, and strip it of any type parameters.
"""
unionall_type(::Type{T}) where {T} = T.name.wrapper

function Adapt.adapt_structure(to, bc::SetDivergence)
SetDivergence(Adapt.adapt_structure(to, bc.val))
# Extend `adapt_structure` for all boundary conditions containing a `val` field.
function Adapt.adapt_structure(to, bc::AbstractBoundaryCondition)
if hasfield(typeof(bc), :val)
return unionall_type(typeof(bc))(Adapt.adapt_structure(to, bc.val))
else
return bc
end
end

# Extend `adapt_structure` for all operator types with boundary conditions.
function Adapt.adapt_structure(to, op::FiniteDifferenceOperator)
if hasfield(typeof(op), :bcs)
bcs_adapted = NamedTuple{keys(op.bcs)}(
UnrolledFunctions.unrolled_map(
bc -> Adapt.adapt_structure(to, bc),
values(op.bcs),
),
)
return unionall_type(typeof(op))(bcs_adapted)
else
return op
end
end

"""
D = DivergenceC2F(;boundaryname=boundarycondition...)
Expand Down

0 comments on commit e557b21

Please sign in to comment.