diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 46ccfcf631..66f7bf281d 100755 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -629,11 +629,11 @@ steps: - label: "Unit: matrix multiplication recursion example (CPU)" key: cpu_matrix_multiplication_recursion - command: "julia --color=yes --check-bounds=yes --project=test test/MatrixFields/matrix_multiplication_recursion.jl" + command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/matrix_multiplication_recursion.jl" - label: "Unit: matrix multiplication recursion example (GPU)" key: gpu_matrix_multiplication_recursion - command: "julia --color=yes --project=test test/MatrixFields/matrix_multiplication_recursion.jl" + command: "julia --color=yes --project=.buildkite test/MatrixFields/matrix_multiplication_recursion.jl" soft_fail: true agents: slurm_gpus: 1 @@ -862,6 +862,7 @@ steps: - label: "Unit: matrix field broadcasting (GPU)" key: unit_matrix_field_broadcasting_gpu_non_scalar_3 command: "julia --color=yes --check-bounds=yes --project=.buildkite test/MatrixFields/matrix_fields_broadcasting/test_non_scalar_3.jl" + soft_fail: true agents: slurm_gpus: 1 slurm_mem: 10GB diff --git a/NEWS.md b/NEWS.md index 98007489e1..17f080f41b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,10 @@ ClimaCore.jl Release Notes main ------- +- ![][badge-🐛bugfix] Extend adapt_structure for all operator and boundary + condition types. Also use `unrolled_map` in `multiply_matrix_at_index` to + avoid the recursive inference limit when compiling nested matrix operations. + PR [#1684](https://github.com/CliMA/ClimaCore.jl/pull/1684) - ![][badge-🤖precisionΔ] ![][badge-💥breaking] `Remapper`s can now process multiple `Field`s at the same time if created with some `buffer_lenght > 1`. PR ([#1669](https://github.com/CliMA/ClimaCore.jl/pull/1669)) diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl index 5d46557dbe..80863a7555 100644 --- a/src/MatrixFields/matrix_multiplication.jl +++ b/src/MatrixFields/matrix_multiplication.jl @@ -8,7 +8,7 @@ for `MultiplyColumnwiseBandMatrixField()`. What follows is a derivation of the algorithm used by this operator with single-column `Field`s. For `Field`s on multiple columns, the same computation is done for each column. - + In this derivation, we will use ``M_1`` and ``M_2`` to denote two `ColumnwiseBandMatrixField`s, and we will use ``V`` to denote a regular (vector-like) `Field`. For both ``M_1`` and ``M_2``, we will use the array-like @@ -169,7 +169,7 @@ The values of ``i`` in this range are considered to be in the "interior" of the operator, while those not in this range (for which we cannot make these simplifications) are considered to be on the "boundary". -## 2.2 ``ld_{prod}`` and ``ud_{prod}`` +## 2.2 ``ld_{prod}`` and ``ud_{prod}`` We only need to compute ``(M_1 ⋅ M_2)[i][d_{prod}]`` for values of ``d_{prod}`` that correspond to a nonempty sum in the interior, i.e, those for which @@ -375,7 +375,7 @@ function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc) # of as a map from boundary_modified_ld1 to boundary_modified_ud1. For # simplicity, use zero padding for rows that are outside the matrix. # Wrap the rows in a BandMatrixRow so that they can be easily indexed. - matrix2_rows = map((ld1:ud1...,)) do d + matrix2_rows = unrolled_map((ld1:ud1...,)) do d # TODO: Use @propagate_inbounds_meta instead of @inline_meta. Base.@_inline_meta if isnothing(bc) || diff --git a/src/MatrixFields/operator_matrices.jl b/src/MatrixFields/operator_matrices.jl index 33303fbc6b..7eaa097968 100644 --- a/src/MatrixFields/operator_matrices.jl +++ b/src/MatrixFields/operator_matrices.jl @@ -94,6 +94,9 @@ struct LazyOneArgFDOperatorMatrix{O <: OneArgFDOperator} <: AbstractLazyOperator op::O end +Adapt.adapt_structure(to, op::FDOperatorMatrix) = + FDOperatorMatrix(Adapt.adapt_structure(to, op.op)) + # Since the operator matrix of a one-argument operator does not have any # arguments, we need to use a lazy operator to add an argument. replace_lazy_operator(space, lazy_op::LazyOneArgFDOperatorMatrix) = diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 9eadb98795..3b31dde3b2 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -1,4 +1,4 @@ -import ..Utilities: PlusHalf, half +import ..Utilities: PlusHalf, half, UnrolledFunctions const AllFiniteDifferenceSpace = Union{Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace} @@ -240,12 +240,6 @@ Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} = Adapt.adapt(to, sbc.axes), ) -function Adapt.adapt_structure(to, op::FiniteDifferenceOperator) - op -end - - - function Base.Broadcast.instantiate(sbc::StencilBroadcasted) op = sbc.op # recursively instantiate the arguments to allocate intermediate work arrays @@ -2610,18 +2604,36 @@ Base.@propagate_inbounds function stencil_right_boundary( stencil_interior(op, loc, space, idx - 1, hidx, arg) end -function Adapt.adapt_structure(to, op::DivergenceF2C) - DivergenceF2C(map(bc -> Adapt.adapt_structure(to, bc), op.bcs)) -end +""" + unionall_type(::Type{T}) -function Adapt.adapt_structure(to, bc::SetValue) - SetValue(Adapt.adapt_structure(to, bc.val)) -end +Extract the type of the input, and strip it of any type parameters. +""" +unionall_type(::Type{T}) where {T} = T.name.wrapper -function Adapt.adapt_structure(to, bc::SetDivergence) - SetDivergence(Adapt.adapt_structure(to, bc.val)) +# Extend `adapt_structure` for all boundary conditions containing a `val` field. +function Adapt.adapt_structure(to, bc::AbstractBoundaryCondition) + if hasfield(typeof(bc), :val) + return unionall_type(typeof(bc))(Adapt.adapt_structure(to, bc.val)) + else + return bc + end end +# Extend `adapt_structure` for all operator types with boundary conditions. +function Adapt.adapt_structure(to, op::FiniteDifferenceOperator) + if hasfield(typeof(op), :bcs) + bcs_adapted = NamedTuple{keys(op.bcs)}( + UnrolledFunctions.unrolled_map( + bc -> Adapt.adapt_structure(to, bc), + values(op.bcs), + ), + ) + return unionall_type(typeof(op))(bcs_adapted) + else + return op + end +end """ D = DivergenceC2F(;boundaryname=boundarycondition...)