Skip to content

Commit

Permalink
Use simpler recurse pattern in mult_field_solve
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Apr 5, 2024
1 parent 04c7ac6 commit 31e69ef
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 38 deletions.
68 changes: 31 additions & 37 deletions src/MatrixFields/multiple_field_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ function multiple_field_solve!(
end
end

import TuplesOfNTuples as ToNTs

function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1)
Ni, Nj, _, _, Nh = size(Fields.field_values(x1))
names = matrix_row_keys(keys(A))
Expand All @@ -39,20 +37,12 @@ function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1)
b_tup = map(name -> ssb[name], names)
x1 = first(x_tup)

# These are non-uniform tuples, so let's use TuplesOfNTuples.jl
# to unroll these.
cache_tonts = ToNTs.TupleOfNTuples(cache_tup)
x_tonts = ToNTs.TupleOfNTuples(x_tup)
A_tonts = ToNTs.TupleOfNTuples(A_tup)
b_tonts = ToNTs.TupleOfNTuples(b_tup)
tups = (cache_tup, x_tup, A_tup, b_tup)

device = ClimaComms.device(x[first(names)])
CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!(
device,
cache_tonts,
x_tonts,
A_tonts,
b_tonts,
tups,
x1,
Val(Nnames),
)
Expand All @@ -71,12 +61,33 @@ end
column_A(A::UniformScaling, i, j, h) = A
column_A(A, i, j, h) = Spaces.column(A, i, j, h)

@inline function _recurse(js::Tuple, tups::Tuple, transform, device, i::Int)
if first(js) == i
tup_args = map(x -> transform(first(x)), tups)
_single_field_solve!(tup_args..., device)
end
_recurse(Base.tail(js), map(x -> Base.tail(x), tups), transform, device, i)
end

@inline _recurse(js::Tuple{}, tups::Tuple, transform, device, i::Int) = nothing

@inline function _recurse(
js::Tuple{Int},
tups::Tuple,
transform,
device,
i::Int,
)
if first(js) == i
tup_args = map(x -> transform(first(x)), tups)
_single_field_solve!(tup_args..., device)
end
return nothing
end

function multiple_field_solve_kernel!(
device::ClimaComms.CUDADevice,
caches::ToNTs.TupleOfNTuples,
xs::ToNTs.TupleOfNTuples,
As::ToNTs.TupleOfNTuples,
bs::ToNTs.TupleOfNTuples,
tups,
x1,
::Val{Nnames},
) where {Nnames}
Expand All @@ -93,27 +104,10 @@ function multiple_field_solve_kernel!(
CUDA.gridDim(),
)
if 1 i <= Ni && 1 j Nj && 1 h Nh && 1 iname Nnames
c1 = ToNTs.inner_dispatch(
_single_field_solve!,
caches,
iname,
ξ -> Spaces.column(ξ, i, j, h),
)
c2 = ToNTs.outer_dispatch(
c1,
xs,
iname,
ξ -> Spaces.column(ξ, i, j, h),
)
c3 = ToNTs.outer_dispatch(c2, As, iname, ξ -> column_A(ξ, i, j, h))
closure = ToNTs.outer_dispatch(
c3,
bs,
iname,
ξ -> Spaces.column(ξ, i, j, h),
)
closure(device)
# closure(device) calls

nt = ntuple-> ξ, Val(Nnames))
_recurse(nt, tups, ξ -> column_A(ξ, i, j, h), device, iname)
# _recurse effectively calls
# _single_field_solve!(
# Spaces.column(caches[iname], i, j, h),
# Spaces.column(xs[iname], i, j, h),
Expand Down
2 changes: 1 addition & 1 deletion test/MatrixFields/field_matrix_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
AnyFrameModule(Base.CoreLogging),
)
@test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b)
# @test_opt ignored_modules = ignored field_matrix_solve!(args...)
@test_opt ignored_modules = ignored field_matrix_solve!(args...)
@test_opt ignored_modules = ignored field_matrix_mul!(b, A, x)

using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0
Expand Down

0 comments on commit 31e69ef

Please sign in to comment.