From 31e69efe212afcff92356278d87c671db29a790a Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 5 Apr 2024 09:06:40 -0400 Subject: [PATCH] Use simpler recurse pattern in mult_field_solve --- src/MatrixFields/multiple_field_solver.jl | 68 +++++++++++------------ test/MatrixFields/field_matrix_solvers.jl | 2 +- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/MatrixFields/multiple_field_solver.jl b/src/MatrixFields/multiple_field_solver.jl index 0a4da8bd8e..e99a01cfc1 100644 --- a/src/MatrixFields/multiple_field_solver.jl +++ b/src/MatrixFields/multiple_field_solver.jl @@ -22,8 +22,6 @@ function multiple_field_solve!( end end -import TuplesOfNTuples as ToNTs - function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) names = matrix_row_keys(keys(A)) @@ -39,20 +37,12 @@ function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) b_tup = map(name -> ssb[name], names) x1 = first(x_tup) - # These are non-uniform tuples, so let's use TuplesOfNTuples.jl - # to unroll these. - cache_tonts = ToNTs.TupleOfNTuples(cache_tup) - x_tonts = ToNTs.TupleOfNTuples(x_tup) - A_tonts = ToNTs.TupleOfNTuples(A_tup) - b_tonts = ToNTs.TupleOfNTuples(b_tup) + tups = (cache_tup, x_tup, A_tup, b_tup) device = ClimaComms.device(x[first(names)]) CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!( device, - cache_tonts, - x_tonts, - A_tonts, - b_tonts, + tups, x1, Val(Nnames), ) @@ -71,12 +61,33 @@ end column_A(A::UniformScaling, i, j, h) = A column_A(A, i, j, h) = Spaces.column(A, i, j, h) +@inline function _recurse(js::Tuple, tups::Tuple, transform, device, i::Int) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + _recurse(Base.tail(js), map(x -> Base.tail(x), tups), transform, device, i) +end + +@inline _recurse(js::Tuple{}, tups::Tuple, transform, device, i::Int) = nothing + +@inline function _recurse( + js::Tuple{Int}, + tups::Tuple, + transform, + device, + i::Int, +) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + return nothing +end + function multiple_field_solve_kernel!( device::ClimaComms.CUDADevice, - caches::ToNTs.TupleOfNTuples, - xs::ToNTs.TupleOfNTuples, - As::ToNTs.TupleOfNTuples, - bs::ToNTs.TupleOfNTuples, + tups, x1, ::Val{Nnames}, ) where {Nnames} @@ -93,27 +104,10 @@ function multiple_field_solve_kernel!( CUDA.gridDim(), ) if 1 ≤ i <= Ni && 1 ≤ j ≤ Nj && 1 ≤ h ≤ Nh && 1 ≤ iname ≤ Nnames - c1 = ToNTs.inner_dispatch( - _single_field_solve!, - caches, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - c2 = ToNTs.outer_dispatch( - c1, - xs, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - c3 = ToNTs.outer_dispatch(c2, As, iname, ξ -> column_A(ξ, i, j, h)) - closure = ToNTs.outer_dispatch( - c3, - bs, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - closure(device) - # closure(device) calls + + nt = ntuple(ξ -> ξ, Val(Nnames)) + _recurse(nt, tups, ξ -> column_A(ξ, i, j, h), device, iname) + # _recurse effectively calls # _single_field_solve!( # Spaces.column(caches[iname], i, j, h), # Spaces.column(xs[iname], i, j, h), diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index c25e1fed05..e5740e0d6a 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -55,7 +55,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false) AnyFrameModule(Base.CoreLogging), ) @test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b) - # @test_opt ignored_modules = ignored field_matrix_solve!(args...) + @test_opt ignored_modules = ignored field_matrix_solve!(args...) @test_opt ignored_modules = ignored field_matrix_mul!(b, A, x) using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0