diff --git a/src/Fields/Fields.jl b/src/Fields/Fields.jl index 9206145f6a..6ced97883a 100644 --- a/src/Fields/Fields.jl +++ b/src/Fields/Fields.jl @@ -55,6 +55,11 @@ Adapt.adapt_structure(to, field::Field) = Field( const PointField{V, S} = Field{V, S} where {V <: AbstractData, S <: Spaces.PointSpace} +# TODO: do we need to make this distinction? what about inside cuda kernels +# when we replace with a PlaceHolerSpace? +const PointDataField{V, S} = + Field{V, S} where {V <: DataLayouts.DataF, S <: Spaces.AbstractSpace} + # Spectral Element Field const SpectralElementField{V, S} = Field{ V, diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 304bbc3899..653920546b 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -50,6 +50,7 @@ import BandedMatrices: BandedMatrix, band, _BandedMatrix import RecursiveArrayTools: recursive_bottom_eltype import KrylovKit import ClimaComms +import Adapt import ..Utilities: PlusHalf, half import ..RecursiveApply: @@ -86,6 +87,7 @@ const ColumnwiseBandMatrixField{V, S} = Fields.Field{ S <: Union{ Spaces.FiniteDifferenceSpace, Spaces.ExtrudedFiniteDifferenceSpace, + Operators.PlaceholderSpace, # so that this can exist inside cuda kernels }, } @@ -99,6 +101,7 @@ include("field_name.jl") include("field_name_set.jl") include("field_name_dict.jl") include("single_field_solver.jl") +include("multiple_field_solver.jl") include("field_matrix_solver.jl") include("field_matrix_iterative_solver.jl") diff --git a/src/MatrixFields/field_matrix_solver.jl b/src/MatrixFields/field_matrix_solver.jl index 86b585820c..39132352c5 100644 --- a/src/MatrixFields/field_matrix_solver.jl +++ b/src/MatrixFields/field_matrix_solver.jl @@ -248,9 +248,13 @@ function check_field_matrix_solver(::BlockDiagonalSolve, _, A, b) end run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) = - foreach(matrix_row_keys(keys(A))) do name - single_field_solve!(cache[name], x[name], A[name, name], b[name]) - end + multiple_field_solve!(cache, x, A, b) + +# This may be helpful for debugging: +# run_field_matrix_solver!(::BlockDiagonalSolve, cache, x, A, b) = +# foreach(matrix_row_keys(keys(A))) do name +# single_field_solve!(cache[name], x[name], A[name, name], b[name]) +# end """ BlockLowerTriangularSolve(names₁...; [alg₁], [alg₂]) diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 516a7d9f2c..8e71ba7604 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -88,6 +88,22 @@ function Base.show(io::IO, dict::FieldNameDict) end end +function Operators.strip_space(dict::FieldNameDict) + vals = unrolled_map(values(dict)) do val + if val isa Fields.Field + Fields.Field(Fields.field_values(val), Operators.PlaceholderSpace()) + else + val + end + end + FieldNameDict(keys(dict), vals) +end + +function Adapt.adapt_structure(to, dict::FieldNameDict) + vals = unrolled_map(v -> Adapt.adapt_structure(to, v), values(dict)) + FieldNameDict(keys(dict), vals) +end + Base.keys(dict::FieldNameDict) = dict.keys Base.values(dict::FieldNameDict) = dict.entries diff --git a/src/MatrixFields/multiple_field_solver.jl b/src/MatrixFields/multiple_field_solver.jl new file mode 100644 index 0000000000..e99a01cfc1 --- /dev/null +++ b/src/MatrixFields/multiple_field_solver.jl @@ -0,0 +1,121 @@ +# TODO: Can different A's be different matrix styles? +# if so, how can we handle fuse/parallelize? + +# First, dispatch based on the first x and the device: +function multiple_field_solve!(cache, x, A, b) + name1 = first(matrix_row_keys(keys(A))) + x1 = x[name1] + multiple_field_solve!(ClimaComms.device(axes(x1)), cache, x, A, b, x1) +end + +# TODO: fuse/parallelize +function multiple_field_solve!( + ::ClimaComms.AbstractCPUDevice, + cache, + x, + A, + b, + x1, +) + foreach(matrix_row_keys(keys(A))) do name + single_field_solve!(cache[name], x[name], A[name, name], b[name]) + end +end + +function multiple_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b, x1) + Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) + names = matrix_row_keys(keys(A)) + Nnames = length(names) + nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh * Nnames) + sscache = Operators.strip_space(cache) + ssx = Operators.strip_space(x) + ssA = Operators.strip_space(A) + ssb = Operators.strip_space(b) + cache_tup = map(name -> sscache[name], names) + x_tup = map(name -> ssx[name], names) + A_tup = map(name -> ssA[name, name], names) + b_tup = map(name -> ssb[name], names) + x1 = first(x_tup) + + tups = (cache_tup, x_tup, A_tup, b_tup) + + device = ClimaComms.device(x[first(names)]) + CUDA.@cuda threads = nthreads blocks = nblocks multiple_field_solve_kernel!( + device, + tups, + x1, + Val(Nnames), + ) +end + +function get_ijhn(Ni, Nj, Nh, Nnames, blockIdx, threadIdx, blockDim, gridDim) + tidx = (blockIdx.x - 1) * blockDim.x + threadIdx.x + (i, j, h, n) = if 1 ≤ tidx ≤ prod((Ni, Nj, Nh, Nnames)) + CartesianIndices((1:Ni, 1:Nj, 1:Nh, 1:Nnames))[tidx].I + else + (-1, -1, -1, -1) + end + return (i, j, h, n) +end + +column_A(A::UniformScaling, i, j, h) = A +column_A(A, i, j, h) = Spaces.column(A, i, j, h) + +@inline function _recurse(js::Tuple, tups::Tuple, transform, device, i::Int) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + _recurse(Base.tail(js), map(x -> Base.tail(x), tups), transform, device, i) +end + +@inline _recurse(js::Tuple{}, tups::Tuple, transform, device, i::Int) = nothing + +@inline function _recurse( + js::Tuple{Int}, + tups::Tuple, + transform, + device, + i::Int, +) + if first(js) == i + tup_args = map(x -> transform(first(x)), tups) + _single_field_solve!(tup_args..., device) + end + return nothing +end + +function multiple_field_solve_kernel!( + device::ClimaComms.CUDADevice, + tups, + x1, + ::Val{Nnames}, +) where {Nnames} + @inbounds begin + Ni, Nj, _, _, Nh = size(Fields.field_values(x1)) + (i, j, h, iname) = get_ijhn( + Ni, + Nj, + Nh, + Nnames, + CUDA.blockIdx(), + CUDA.threadIdx(), + CUDA.blockDim(), + CUDA.gridDim(), + ) + if 1 ≤ i <= Ni && 1 ≤ j ≤ Nj && 1 ≤ h ≤ Nh && 1 ≤ iname ≤ Nnames + + nt = ntuple(ξ -> ξ, Val(Nnames)) + _recurse(nt, tups, ξ -> column_A(ξ, i, j, h), device, iname) + # _recurse effectively calls + # _single_field_solve!( + # Spaces.column(caches[iname], i, j, h), + # Spaces.column(xs[iname], i, j, h), + # column_A(As[iname], i, j, h), + # Spaces.column(bs[iname], i, j, h), + # device, + # ) + end + end + return nothing +end diff --git a/src/MatrixFields/single_field_solver.jl b/src/MatrixFields/single_field_solver.jl index 330a742c07..b3af8ddf39 100644 --- a/src/MatrixFields/single_field_solver.jl +++ b/src/MatrixFields/single_field_solver.jl @@ -47,11 +47,17 @@ single_field_solve!(cache, x, A::ColumnwiseBandMatrixField, b) = single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b) single_field_solve!(::ClimaComms.AbstractCPUDevice, cache, x, A, b) = - _single_field_solve!(cache, x, A, b) + _single_field_solve!(ClimaComms.device(axes(A)), cache, x, A, b) + +# single_field_solve!(::ClimaComms.CUDADevice, ...) is no longer exercised, +# but it may be helpful for debugging, due to its simplicity. So, let's leave +# it here for now. function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b) Ni, Nj, _, _, Nh = size(Fields.field_values(A)) nthreads, nblocks = Topologies._configure_threadblock(Ni * Nj * Nh) + device = ClimaComms.device(A) CUDA.@cuda always_inline = true threads = nthreads blocks = nblocks single_field_solve_kernel!( + device, cache, x, A, @@ -59,12 +65,13 @@ function single_field_solve!(::ClimaComms.CUDADevice, cache, x, A, b) ) end -function single_field_solve_kernel!(cache, x, A, b) +function single_field_solve_kernel!(device, cache, x, A, b) idx = CUDA.threadIdx().x + (CUDA.blockIdx().x - 1) * CUDA.blockDim().x Ni, Nj, _, _, Nh = size(Fields.field_values(A)) if idx <= Ni * Nj * Nh i, j, h = Topologies._get_idx((Ni, Nj, Nh), idx) _single_field_solve!( + device, Spaces.column(cache, i, j, h), Spaces.column(x, i, j, h), Spaces.column(A, i, j, h), @@ -80,22 +87,103 @@ single_field_solve_kernel!( b::Fields.ColumnField, ) = _single_field_solve!(cache, x, A, b) -_single_field_solve!(cache, x, A, b) = +# CPU (GPU has already called Spaces.column on arg) +_single_field_solve!(device::ClimaComms.AbstractCPUDevice, cache, x, A, b) = Fields.bycolumn(axes(A)) do colidx - _single_field_solve!(cache[colidx], x[colidx], A[colidx], b[colidx]) + _single_field_solve_col!( + ClimaComms.device(axes(A)), + cache[colidx], + x[colidx], + A[colidx], + b[colidx], + ) end + +function _single_field_solve_col!( + ::ClimaComms.AbstractCPUDevice, + cache::Fields.ColumnField, + x::Fields.ColumnField, + A, + b::Fields.ColumnField, +) + if A isa Fields.ColumnField + band_matrix_solve!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), + ) + elseif A isa UniformScaling + x .= inv(A.λ) .* b + else + error("uncaught case") + end +end + +# called by TuplesOfNTuples.jl's `inner_dispatch`: +# which requires a particular argument order: _single_field_solve!( + cache::Fields.Field, + x::Fields.Field, + A::Union{Fields.Field, UniformScaling}, + b::Fields.Field, + dev::ClimaComms.CUDADevice, +) = _single_field_solve!(dev, cache, x, A, b) + +_single_field_solve!( + cache::Fields.Field, + x::Fields.Field, + A::Union{Fields.Field, UniformScaling}, + b::Fields.Field, + dev::ClimaComms.AbstractCPUDevice, +) = _single_field_solve_col!(dev, cache, x, A, b) + +function _single_field_solve!( + ::ClimaComms.CUDADevice, cache::Fields.ColumnField, x::Fields.ColumnField, A::Fields.ColumnField, b::Fields.ColumnField, -) = band_matrix_solve!( - eltype(A), - unzip_tuple_field_values(Fields.field_values(cache)), - Fields.field_values(x), - unzip_tuple_field_values(Fields.field_values(A.entries)), - Fields.field_values(b), ) + band_matrix_solve!( + eltype(A), + unzip_tuple_field_values(Fields.field_values(cache)), + Fields.field_values(x), + unzip_tuple_field_values(Fields.field_values(A.entries)), + Fields.field_values(b), + ) +end + +function _single_field_solve!( + ::ClimaComms.CUDADevice, + cache::Fields.ColumnField, + x::Fields.ColumnField, + A::UniformScaling, + b::Fields.ColumnField, +) + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + n = length(x_data) + @inbounds for i in 1:n + x_data[i] = inv(A.λ) ⊠ b_data[i] + end +end + +function _single_field_solve!( + ::ClimaComms.CUDADevice, + cache::Fields.PointDataField, + x::Fields.PointDataField, + A::UniformScaling, + b::Fields.PointDataField, +) + x_data = Fields.field_values(x) + b_data = Fields.field_values(b) + n = length(x_data) + @inbounds begin + x_data[] = inv(A.λ) ⊠ b_data[] + end +end unzip_tuple_field_values(data) = ntuple(i -> data.:($i), Val(length(propertynames(data)))) diff --git a/test/MatrixFields/matrix_field_test_utils.jl b/test/MatrixFields/matrix_field_test_utils.jl index 3512133d52..d42764fd86 100644 --- a/test/MatrixFields/matrix_field_test_utils.jl +++ b/test/MatrixFields/matrix_field_test_utils.jl @@ -41,9 +41,11 @@ macro benchmark(expression) end end -const comms_device = ClimaComms.device() -const using_cuda = comms_device isa ClimaComms.CUDADevice -const ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () +comms_device = ClimaComms.device() +# comms_device = ClimaComms.CPUSingleThreaded() +@show comms_device +using_cuda = comms_device isa ClimaComms.CUDADevice +ignore_cuda = using_cuda ? (AnyFrameModule(CUDA),) : () # Test the allocating and non-allocating versions of a field broadcast against # a reference non-allocating implementation. Ensure that they are performant,