From 65ff1c4567d3b485684bc84b3826fe1668298122 Mon Sep 17 00:00:00 2001 From: Julia Sloan Date: Thu, 18 Apr 2024 14:30:33 -0700 Subject: [PATCH] fixes for land imp solver --- src/MatrixFields/matrix_multiplication.jl | 6 +-- src/Operators/finitedifference.jl | 45 +++++++++++++++-------- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl index 5d46557dbe..80863a7555 100644 --- a/src/MatrixFields/matrix_multiplication.jl +++ b/src/MatrixFields/matrix_multiplication.jl @@ -8,7 +8,7 @@ for `MultiplyColumnwiseBandMatrixField()`. What follows is a derivation of the algorithm used by this operator with single-column `Field`s. For `Field`s on multiple columns, the same computation is done for each column. - + In this derivation, we will use ``M_1`` and ``M_2`` to denote two `ColumnwiseBandMatrixField`s, and we will use ``V`` to denote a regular (vector-like) `Field`. For both ``M_1`` and ``M_2``, we will use the array-like @@ -169,7 +169,7 @@ The values of ``i`` in this range are considered to be in the "interior" of the operator, while those not in this range (for which we cannot make these simplifications) are considered to be on the "boundary". -## 2.2 ``ld_{prod}`` and ``ud_{prod}`` +## 2.2 ``ld_{prod}`` and ``ud_{prod}`` We only need to compute ``(M_1 ⋅ M_2)[i][d_{prod}]`` for values of ``d_{prod}`` that correspond to a nonempty sum in the interior, i.e, those for which @@ -375,7 +375,7 @@ function multiply_matrix_at_index(loc, space, idx, hidx, matrix1, arg, bc) # of as a map from boundary_modified_ld1 to boundary_modified_ud1. For # simplicity, use zero padding for rows that are outside the matrix. # Wrap the rows in a BandMatrixRow so that they can be easily indexed. - matrix2_rows = map((ld1:ud1...,)) do d + matrix2_rows = unrolled_map((ld1:ud1...,)) do d # TODO: Use @propagate_inbounds_meta instead of @inline_meta. Base.@_inline_meta if isnothing(bc) || diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 2a2043dd27..52ac0ee327 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -1,4 +1,4 @@ -import ..Utilities: PlusHalf, half +import ..Utilities: PlusHalf, half, UnrolledFunctions const AllFiniteDifferenceSpace = Union{Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace} @@ -242,12 +242,6 @@ Adapt.adapt_structure(to, sbc::StencilBroadcasted{Style}) where {Style} = Adapt.adapt(to, sbc.axes), ) -function Adapt.adapt_structure(to, op::FiniteDifferenceOperator) - op -end - - - function Base.Broadcast.instantiate(sbc::StencilBroadcasted) op = sbc.op # recursively instantiate the arguments to allocate intermediate work arrays @@ -2612,18 +2606,39 @@ Base.@propagate_inbounds function stencil_right_boundary( stencil_interior(op, loc, space, idx - 1, hidx, arg) end -function Adapt.adapt_structure(to, op::DivergenceF2C) - DivergenceF2C(map(bc -> Adapt.adapt_structure(to, bc), op.bcs)) -end +""" + Adapt.adapt_structure(to, bc::AbstractBoundaryCondition) -function Adapt.adapt_structure(to, bc::SetValue) - SetValue(Adapt.adapt_structure(to, bc.val)) +Extend `adapt_structure` for all boundary conditions containing a `val` field. +By default, `adapt_structure` will do nothing for BCs without a `val` field. +""" +function Adapt.adapt_structure(to, bc::AbstractBoundaryCondition) + if hasfield(typeof(bc), :val) + return typeof(bc).name.wrapper(Adapt.adapt_structure(to, bc.val)) + else + return bc + end end -function Adapt.adapt_structure(to, bc::SetDivergence) - SetDivergence(Adapt.adapt_structure(to, bc.val)) -end +""" + Adapt.adapt_structure(to, op::FiniteDifferenceOperator) +Extend `adapt_structure` for all operator types. Recursively adapt the boundary +conditions of the operator. +""" +function Adapt.adapt_structure(to, op::FiniteDifferenceOperator) + if hasfield(typeof(op), :bcs) + bcs_adapted = NamedTuple{keys(op.bcs)}( + UnrolledFunctions.unrolled_map( + bc -> Adapt.adapt_structure(to, bc), + values(op.bcs), + ), + ) + return typeof(op).name.wrapper(bcs_adapted) + else + return op + end +end """ D = DivergenceC2F(;boundaryname=boundarycondition...)