From 04c7ac6d197c618161a459f5af8a51d91c29edc7 Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Thu, 21 Mar 2024 15:40:52 -0400 Subject: [PATCH 1/2] Parallelize single_field_solve -> multiple_field_ --- src/Fields/Fields.jl | 5 + src/MatrixFields/MatrixFields.jl | 3 + src/MatrixFields/field_matrix_solver.jl | 10 +- src/MatrixFields/field_name_dict.jl | 16 +++ src/MatrixFields/multiple_field_solver.jl | 127 +++++++++++++++++++ src/MatrixFields/single_field_solver.jl | 108 ++++++++++++++-- test/MatrixFields/field_matrix_solvers.jl | 2 +- test/MatrixFields/matrix_field_test_utils.jl | 8 +- 8 files changed, 262 insertions(+), 17 deletions(-) create mode 100644 src/MatrixFields/multiple_field_solver.jl diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 9206145f6a..6ced97883a 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -55,6 +55,11 @@ Adapt.adapt_structure(to, field::Field) = Field( const PointField{V, S} = Field{V, S} where {V <: AbstractData, S <: Spaces.PointSpace} +# TODO: do we need to make this distinction? what about inside cuda kernels +# when we replace with a PlaceHolerSpace? +const PointDataField{V, S} = + Field{V, S} where {V <: DataLayouts.DataF, S <: Spaces.AbstractSpace} + # Spectral Element Field const SpectralElementField{V, S} = Field{ V, diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 304bbc3899..653920546b 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -50,6 +50,7 @@ import BandedMatrices: BandedMatrix, band, _BandedMatrix import RecursiveArrayTools: recursive_bottom_eltype import KrylovKit import ClimaComms +import Adapt import ..Utilities: PlusHalf, half import ..RecursiveApply: @@ -86,6 +87,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{ S <: Union{ Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace, + Operators.PlaceholderSpace, # so that this can exist inside cuda kernels }, } @@ -99,6 +101,7 @@ include("field_name.jl") include("field_name_set.jl") include("field_name_dict.jl") include("single_field_solver.jl") +include("multiple_field_solver.jl") include("field_matrix_solver.jl") include("field_matrix_iterative_solver.jl") diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 86b585820c..39132352c5 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -248,9 +248,13 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b) end run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) = - foreach(matrix_row_keys(keys(A))) do name - single_field_solve!(cache[name], x[name], A[name, name], b[name]) - end + multiple_field_solve!(cache, x, A, b) + +# This may be helpful for debugging: +# run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) = +# foreach(matrix_row_keys(keys(A))) do name +# single_field_solve!(cache[name], x[name], A[name, name], b[name]) +# end """ BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂]) diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 516a7d9f2c..8e71ba7604 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -88,6 +88,22 @@ function Base.show(io::IO, dict::FieldNameDict) end end +function Operators.strip_space(dict::FieldNameDict) + vals = unrolled_map(values(dict)) do val + if val isa Fields.Field + Fields.Field(Fields.field_values(val), Operators.PlaceholderSpace()) + else + val + end + end + FieldNameDict(keys(dict), vals) +end + +function Adapt.adapt_structure(to, dict::FieldNameDict) + vals = unrolled_map(v -> Adapt.adapt_structure(to, v), values(dict)) + FieldNameDict(keys(dict), vals) +end + Base.keys(dict::FieldNameDict) = dict.keys Base.values(dict::FieldNameDict) = dict.entries diff --git a/src/MatrixFields/multiple_field_solver.jl b/src/MatrixFields/multiple_field_solver.jl new file mode 100644 index 0000000000..0a4da8bd8e --- /dev/null +++ b/src/MatrixFields/multiple_field_solver.jl @@ -0,0 +1,127 @@ +# TODO: Can different A's be different matrix styles? +# if so, how can we handle fuse/parallelize? + +# First, dispatch based on the first x and the device: +function multiple_field_solve!(cache, x, A, b) + name1 = first(matrix_row_keys(keys(A))) + x1 = x[name1] + multiple_field_solve!(ClimaComms.device(axes(x1)), cache, x, A, b, x1) +end + +# TODO: fuse/parallelize +function multiple_field_solve!( + ::ClimaComms.AbstractCPUDevice, + cache, + x, + A, + b, + x1, +) + foreach(matrix_row_keys(keys(A))) do name + single_field_solve!(cache[name], x[name], A[name, name], b[name]) + end +end + +import TuplesOfNTuples as ToNTs + +function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) + Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) + names = matrix_row_keys(keys(A)) + Nnames = length(names) + nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh * Nnames) + sscache = Operators.strip_space(cache) + ssx = Operators.strip_space(x) + ssA = Operators.strip_space(A) + ssb = Operators.strip_space(b) + cache_tup = map(name -> sscache[name], names) + x_tup = map(name -> ssx[name], names) + A_tup = map(name -> ssA[name, name], names) + b_tup = map(name -> ssb[name], names) + x1 = first(x_tup) + + # These are non-uniform tuples, so let's use TuplesOfNTuples.jl + # to unroll these. + cache_tonts = ToNTs.TupleOfNTuples(cache_tup) + x_tonts = ToNTs.TupleOfNTuples(x_tup) + A_tonts = ToNTs.TupleOfNTuples(A_tup) + b_tonts = ToNTs.TupleOfNTuples(b_tup) + + device = ClimaComms.device(x[first(names)]) + CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!( + device, + cache_tonts, + x_tonts, + A_tonts, + b_tonts, + x1, + Val(Nnames), + ) +end + +function get_ijhn(Ni, Nj, Nh, Nnames, blockIdx, threadIdx, blockDim, gridDim) + tidx = (blockIdx.x - 1) * blockDim.x + threadIdx.x + (i, j, h, n) = if 1 ≤ tidx ≤ prod((Ni, Nj, Nh, Nnames)) + CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nnames))[tidx].I + else + (-1, -1, -1, -1) + end + return (i, j, h, n) +end + +column_A(A::UniformScaling, i, j, h) = A +column_A(A, i, j, h) = Spaces.column(A, i, j, h) + +function multiple_field_solve_kernel!( + device::ClimaComms.CUDADevice, + caches::ToNTs.TupleOfNTuples, + xs::ToNTs.TupleOfNTuples, + As::ToNTs.TupleOfNTuples, + bs::ToNTs.TupleOfNTuples, + x1, + ::Val{Nnames}, +) where {Nnames} + @inbounds begin + Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) + (i, j, h, iname) = get_ijhn( + Ni, + Nj, + Nh, + Nnames, + CUDA.blockIdx(), + CUDA.threadIdx(), + CUDA.blockDim(), + CUDA.gridDim(), + ) + if 1 ≤ i <= Ni && 1 ≤ j ≤ Nj && 1 ≤ h ≤ Nh && 1 ≤ iname ≤ Nnames + c1 = ToNTs.inner_dispatch( + _single_field_solve!, + caches, + iname, + ξ -> Spaces.column(ξ, i, j, h), + ) + c2 = ToNTs.outer_dispatch( + c1, + xs, + iname, + ξ -> Spaces.column(ξ, i, j, h), + ) + c3 = ToNTs.outer_dispatch(c2, As, iname, ξ -> column_A(ξ, i, j, h)) + closure = ToNTs.outer_dispatch( + c3, + bs, + iname, + ξ -> Spaces.column(ξ, i, j, h), + ) + closure(device) + # closure(device) calls + # _single_field_solve!( + # Spaces.column(caches[iname], i, j, h), + # Spaces.column(xs[iname], i, j, h), + # column_A(As[iname], i, j, h), + # Spaces.column(bs[iname], i, j, h), + # device, + # ) + end + end + return nothing +end diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index 330a742c07..b3af8ddf39 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -47,11 +47,17 @@ single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) = single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b) single_field_solve!(::ClimaComms.AbstractCPUDevice, cache, x, A, b) = - _single_field_solve!(cache, x, A, b) + _single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b) + +# single_field_solve!(::ClimaComms.CUDADevice, ...) is no longer exercised, +# but it may be helpful for debugging, due to its simplicity. So, let's leave +# it here for now. function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b) Ni, Nj, _, _, Nh = size(Fields.field_values(A)) nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh) + device = ClimaComms.device(A) CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks single_field_solve_kernel!( + device, cache, x, A, @@ -59,12 +65,13 @@ function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b) ) end -function single_field_solve_kernel!(cache, x, A, b) +function single_field_solve_kernel!(device, cache, x, A, b) idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x Ni, Nj, _, _, Nh = size(Fields.field_values(A)) if idx <= Ni * Nj * Nh i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx) _single_field_solve!( + device, Spaces.column(cache, i, j, h), Spaces.column(x, i, j, h), Spaces.column(A, i, j, h), @@ -80,22 +87,103 @@ single_field_solve_kernel!( b::Fields.ColumnField, ) = _single_field_solve!(cache, x, A, b) -_single_field_solve!(cache, x, A, b) = +# CPU (GPU has already called Spaces.column on arg) +_single_field_solve!(device::ClimaComms.AbstractCPUDevice, cache, x, A, b) = Fields.bycolumn(axes(A)) do colidx - _single_field_solve!(cache[colidx], x[colidx], A[colidx], b[colidx]) + _single_field_solve_col!( + ClimaComms.device(axes(A)), + cache[colidx], + x[colidx], + A[colidx], + b[colidx], + ) end + +function _single_field_solve_col!( + ::ClimaComms.AbstractCPUDevice, + cache::Fields.ColumnField, + x::Fields.ColumnField, + A, + b::Fields.ColumnField, +) + if A isa Fields.ColumnField + band_matrix_solve!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), + ) + elseif A isa UniformScaling + x .= inv(A.λ) .* b + else + error("uncaught case") + end +end + +# called by TuplesOfNTuples.jl's `inner_dispatch`: +# which requires a particular argument order: _single_field_solve!( + cache::Fields.Field, + x::Fields.Field, + A::Union{Fields.Field, UniformScaling}, + b::Fields.Field, + dev::ClimaComms.CUDADevice, +) = _single_field_solve!(dev, cache, x, A, b) + +_single_field_solve!( + cache::Fields.Field, + x::Fields.Field, + A::Union{Fields.Field, UniformScaling}, + b::Fields.Field, + dev::ClimaComms.AbstractCPUDevice, +) = _single_field_solve_col!(dev, cache, x, A, b) + +function _single_field_solve!( + ::ClimaComms.CUDADevice, cache::Fields.ColumnField, x::Fields.ColumnField, A::Fields.ColumnField, b::Fields.ColumnField, -) = band_matrix_solve!( - eltype(A), - unzip_tuple_field_values(Fields.field_values(cache)), - Fields.field_values(x), - unzip_tuple_field_values(Fields.field_values(A.entries)), - Fields.field_values(b), ) + band_matrix_solve!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), + ) +end + +function _single_field_solve!( + ::ClimaComms.CUDADevice, + cache::Fields.ColumnField, + x::Fields.ColumnField, + A::UniformScaling, + b::Fields.ColumnField, +) + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + n = length(x_data) + @inbounds for i in 1:n + x_data[i] = inv(A.λ) ⊠ b_data[i] + end +end + +function _single_field_solve!( + ::ClimaComms.CUDADevice, + cache::Fields.PointDataField, + x::Fields.PointDataField, + A::UniformScaling, + b::Fields.PointDataField, +) + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + n = length(x_data) + @inbounds begin + x_data[] = inv(A.λ) ⊠ b_data[] + end +end unzip_tuple_field_values(data) = ntuple(i -> data.:($i), Val(length(propertynames(data)))) diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index e5740e0d6a..c25e1fed05 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -55,7 +55,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false) AnyFrameModule(Base.CoreLogging), ) @test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b) - @test_opt ignored_modules = ignored field_matrix_solve!(args...) + # @test_opt ignored_modules = ignored field_matrix_solve!(args...) @test_opt ignored_modules = ignored field_matrix_mul!(b, A, x) using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0 diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index 3512133d52..d42764fd86 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -41,9 +41,11 @@ macro benchmark(expression) end end -const comms_device = ClimaComms.device() -const using_cuda = comms_device isa ClimaComms.CUDADevice -const ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () +comms_device = ClimaComms.device() +# comms_device = ClimaComms.CPUSingleThreaded() +@show comms_device +using_cuda = comms_device isa ClimaComms.CUDADevice +ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () # Test the allocating and non-allocating versions of a field broadcast against # a reference non-allocating implementation. Ensure that they are performant, From 31e69efe212afcff92356278d87c671db29a790a Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 5 Apr 2024 09:06:40 -0400 Subject: [PATCH 2/2] Use simpler recurse pattern in mult_field_solve --- src/MatrixFields/multiple_field_solver.jl | 68 +++++++++++------------ test/MatrixFields/field_matrix_solvers.jl | 2 +- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/MatrixFields/multiple_field_solver.jl b/src/MatrixFields/multiple_field_solver.jl index 0a4da8bd8e..e99a01cfc1 100644 --- a/src/MatrixFields/multiple_field_solver.jl +++ b/src/MatrixFields/multiple_field_solver.jl @@ -22,8 +22,6 @@ function multiple_field_solve!( end end -import TuplesOfNTuples as ToNTs - function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) names = matrix_row_keys(keys(A)) @@ -39,20 +37,12 @@ function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) b_tup = map(name -> ssb[name], names) x1 = first(x_tup) - # These are non-uniform tuples, so let's use TuplesOfNTuples.jl - # to unroll these. - cache_tonts = ToNTs.TupleOfNTuples(cache_tup) - x_tonts = ToNTs.TupleOfNTuples(x_tup) - A_tonts = ToNTs.TupleOfNTuples(A_tup) - b_tonts = ToNTs.TupleOfNTuples(b_tup) + tups = (cache_tup, x_tup, A_tup, b_tup) device = ClimaComms.device(x[first(names)]) CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!( device, - cache_tonts, - x_tonts, - A_tonts, - b_tonts, + tups, x1, Val(Nnames), ) @@ -71,12 +61,33 @@ end column_A(A::UniformScaling, i, j, h) = A column_A(A, i, j, h) = Spaces.column(A, i, j, h) +@inline function _recurse(js::Tuple, tups::Tuple, transform, device, i::Int) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + _recurse(Base.tail(js), map(x -> Base.tail(x), tups), transform, device, i) +end + +@inline _recurse(js::Tuple{}, tups::Tuple, transform, device, i::Int) = nothing + +@inline function _recurse( + js::Tuple{Int}, + tups::Tuple, + transform, + device, + i::Int, +) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + return nothing +end + function multiple_field_solve_kernel!( device::ClimaComms.CUDADevice, - caches::ToNTs.TupleOfNTuples, - xs::ToNTs.TupleOfNTuples, - As::ToNTs.TupleOfNTuples, - bs::ToNTs.TupleOfNTuples, + tups, x1, ::Val{Nnames}, ) where {Nnames} @@ -93,27 +104,10 @@ function multiple_field_solve_kernel!( CUDA.gridDim(), ) if 1 ≤ i <= Ni && 1 ≤ j ≤ Nj && 1 ≤ h ≤ Nh && 1 ≤ iname ≤ Nnames - c1 = ToNTs.inner_dispatch( - _single_field_solve!, - caches, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - c2 = ToNTs.outer_dispatch( - c1, - xs, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - c3 = ToNTs.outer_dispatch(c2, As, iname, ξ -> column_A(ξ, i, j, h)) - closure = ToNTs.outer_dispatch( - c3, - bs, - iname, - ξ -> Spaces.column(ξ, i, j, h), - ) - closure(device) - # closure(device) calls + + nt = ntuple(ξ -> ξ, Val(Nnames)) + _recurse(nt, tups, ξ -> column_A(ξ, i, j, h), device, iname) + # _recurse effectively calls # _single_field_solve!( # Spaces.column(caches[iname], i, j, h), # Spaces.column(xs[iname], i, j, h), diff --git a/test/MatrixFields/field_matrix_solvers.jl b/test/MatrixFields/field_matrix_solvers.jl index c25e1fed05..e5740e0d6a 100644 --- a/test/MatrixFields/field_matrix_solvers.jl +++ b/test/MatrixFields/field_matrix_solvers.jl @@ -55,7 +55,7 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false) AnyFrameModule(Base.CoreLogging), ) @test_opt ignored_modules = ignored FieldMatrixSolver(alg, A, b) - # @test_opt ignored_modules = ignored field_matrix_solve!(args...) + @test_opt ignored_modules = ignored field_matrix_solve!(args...) @test_opt ignored_modules = ignored field_matrix_mul!(b, A, x) using_cuda || @test @allocated(field_matrix_solve!(args...)) == 0