From 8d923f629334770ba77f74ebb3f13e47918c579f Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Fri, 18 Oct 2024 14:41:11 -0400 Subject: [PATCH] Refactor dss_load_perimeter_data Refactor dss_local Remove more use of DataLayout internals in DSS --- ext/cuda/data_layouts_threadblock.jl | 15 +++ ext/cuda/topologies_dss.jl | 186 +++++++++++++++------------ src/Topologies/dss.jl | 8 +- 3 files changed, 126 insertions(+), 83 deletions(-) diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 02a0aeff6c..9e289d7cfa 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -170,6 +170,21 @@ end ##### Custom partitions ##### +##### linear partition +@inline function linear_partition(nitems::Integer, n_max_threads::Integer) + threads = min(nitems, n_max_threads) + blocks = cld(nitems, threads) + return (; threads, blocks) +end +@inline function linear_universal_index(us::UniversalSize) + i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x + inds = DataLayouts.universal_size(us) + CI = CartesianIndices(map(x -> Base.OneTo(x), inds)) + return (CI, i) +end +@inline linear_is_valid_index(i::Integer, us::UniversalSize) = + 1 ≤ i ≤ DataLayouts.get_N(us) + ##### Column-wise @inline function columnwise_partition( us::DataLayouts.UniversalSize, diff --git a/ext/cuda/topologies_dss.jl b/ext/cuda/topologies_dss.jl index 6c36d64071..b51a857551 100644 --- a/ext/cuda/topologies_dss.jl +++ b/ext/cuda/topologies_dss.jl @@ -1,4 +1,5 @@ import ClimaCore: DataLayouts, Topologies, Spaces, Fields +import ClimaCore.DataLayouts: getindex_field, setindex_field! using CUDA import ClimaCore.Topologies import ClimaCore.Topologies: perimeter_vertex_node_index @@ -21,14 +22,15 @@ function Topologies.dss_load_perimeter_data!( perimeter::Topologies.Perimeter2D, ) (; perimeter_data) = dss_buffer - nitems = prod(DataLayouts.farray_size(perimeter_data)) - nthreads, nblocks = _configure_threadblock(nitems) + nitems = prod(parent(perimeter_data)) args = (perimeter_data, data, perimeter) + threads = threads_via_occupancy(dss_load_perimeter_data_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( dss_load_perimeter_data_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) return nothing end @@ -39,17 +41,16 @@ function dss_load_perimeter_data_kernel!( perimeter::Topologies.Perimeter2D{Nq}, ) where {Nq} gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x - (nlevels, _, nfidx, nelems) = - sizep = DataLayouts.farray_size(perimeter_data) # size of perimeter data array - sized = (nlevels, Nq, Nq, nfidx, nelems) # size of data - pperimeter_data = parent(perimeter_data) - pdata = parent(data) + (nperimeter, _, _, nlevels, nelems) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) + sizep = (nlevels, nperimeter, nfidx, nelems) # specialize on VIJFH data + CI = CartesianIndex if gidx ≤ prod(sizep) (level, p, fidx, elem) = cart_ind(sizep, gidx).I (ip, jp) = perimeter[p] - data_idx = linear_ind(sized, (level, ip, jp, fidx, elem)) - pperimeter_data[level, p, fidx, elem] = pdata[data_idx] + val = getindex_field(data, CI(ip, jp, fidx, level, elem)) + setindex_field!(perimeter_data, val, CI(p, 1, fidx, level, elem)) end return nothing end @@ -61,14 +62,15 @@ function Topologies.dss_unload_perimeter_data!( perimeter, ) (; perimeter_data) = dss_buffer - nitems = prod(DataLayouts.farray_size(perimeter_data)) - nthreads, nblocks = _configure_threadblock(nitems) + nitems = prod(parent(perimeter_data)) args = (data, perimeter_data, perimeter) + threads = threads_via_occupancy(dss_unload_perimeter_data_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( dss_unload_perimeter_data_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) return nothing end @@ -79,17 +81,16 @@ function dss_unload_perimeter_data_kernel!( perimeter::Topologies.Perimeter2D{Nq}, ) where {Nq} gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x - (nlevels, nperimeter, nfidx, nelems) = - sizep = DataLayouts.farray_size(perimeter_data) # size of perimeter data array - sized = (nlevels, Nq, Nq, nfidx, nelems) # size of data - pperimeter_data = parent(perimeter_data) - pdata = parent(data) + (nperimeter, _, _, nlevels, nelems) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) + sizep = (nlevels, nperimeter, nfidx, nelems) # specialize on VIJFH data + CI = CartesianIndex if gidx ≤ prod(sizep) (level, p, fidx, elem) = cart_ind(sizep, gidx).I (ip, jp) = perimeter[p] - data_idx = linear_ind(sized, (level, ip, jp, fidx, elem)) - pdata[data_idx] = pperimeter_data[level, p, fidx, elem] + val = getindex_field(perimeter_data, CI(p, 1, fidx, level, elem)) + setindex_field!(data, val, CI(ip, jp, fidx, level, elem)) end return nothing end @@ -103,11 +104,9 @@ function Topologies.dss_local!( nlocalvertices = length(topology.local_vertex_offset) - 1 nlocalfaces = length(topology.interior_faces) if (nlocalvertices + nlocalfaces) > 0 - (nlevels, nperimeter, nfid, nelems) = - DataLayouts.farray_size(perimeter_data) - + (nperimeter, _, _, nlevels, nelems) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) nitems = nlevels * nfid * (nlocalfaces + nlocalvertices) - nthreads, nblocks = _configure_threadblock(nitems) args = ( perimeter_data, topology.local_vertices, @@ -115,11 +114,13 @@ function Topologies.dss_local!( topology.interior_faces, perimeter, ) + threads = threads_via_occupancy(dss_local_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( dss_local_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) end return nothing @@ -136,9 +137,9 @@ function dss_local_kernel!( gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x nlocalvertices = length(local_vertex_offset) - 1 nlocalfaces = length(interior_faces) - pperimeter_data = parent(perimeter_data) - FT = eltype(pperimeter_data) - (nlevels, nperimeter, nfidx, _) = DataLayouts.farray_size(perimeter_data) + (nperimeter, _, _, nlevels, _) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) + CI = CartesianIndex if gidx ≤ nlevels * nfidx * nlocalvertices # local vertices sizev = (nlevels, nfidx, nlocalvertices) (level, fidx, vertexid) = cart_ind(sizev, gidx).I @@ -148,12 +149,17 @@ function dss_local_kernel!( for idx in st:(en - 1) (lidx, vert) = local_vertices[idx] ip = perimeter_vertex_node_index(vert) - sum_data += pperimeter_data[level, ip, fidx, lidx] + sum_data += + getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx)) end for idx in st:(en - 1) (lidx, vert) = local_vertices[idx] ip = perimeter_vertex_node_index(vert) - pperimeter_data[level, ip, fidx, lidx] = sum_data + setindex_field!( + perimeter_data, + sum_data, + CI(ip, 1, fidx, level, lidx), + ) end elseif gidx ≤ nlevels * nfidx * (nlocalvertices + nlocalfaces) # interior faces nfacedof = div(nperimeter - 4, 4) @@ -168,11 +174,13 @@ function dss_local_kernel!( for i in 1:nfacedof ip1 = inc1 == 1 ? first1 + i - 1 : first1 - i + 1 ip2 = inc2 == 1 ? first2 + i - 1 : first2 - i + 1 + idx1 = CI(ip1, 1, fidx, level, lidx1) + idx2 = CI(ip2, 1, fidx, level, lidx2) val = - pperimeter_data[level, ip1, fidx, lidx1] + - pperimeter_data[level, ip2, fidx, lidx2] - pperimeter_data[level, ip1, fidx, lidx1] = val - pperimeter_data[level, ip2, fidx, lidx2] = val + getindex_field(perimeter_data, idx1) + + getindex_field(perimeter_data, idx2) + setindex_field!(perimeter_data, val, idx1) + setindex_field!(perimeter_data, val, idx2) end end @@ -311,22 +319,22 @@ function Topologies.dss_local_ghost!( ) nghostvertices = length(topology.ghost_vertex_offset) - 1 if nghostvertices > 0 - (nlevels, nperimeter, nfid, nelems) = - DataLayouts.farray_size(perimeter_data) - max_threads = 256 + (_, _, _, nlevels, _) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) nitems = nlevels * nfid * nghostvertices - nthreads, nblocks = _configure_threadblock(nitems) args = ( perimeter_data, topology.ghost_vertices, topology.ghost_vertex_offset, perimeter, ) + threads = threads_via_occupancy(dss_local_ghost_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( dss_local_ghost_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) end return nothing @@ -339,9 +347,10 @@ function dss_local_ghost_kernel!( perimeter::Topologies.Perimeter2D, ) gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x - pperimeter_data = parent(perimeter_data) - FT = eltype(pperimeter_data) - (nlevels, nperimeter, nfidx, _) = DataLayouts.farray_size(perimeter_data) + FT = eltype(parent(perimeter_data)) + (nperimeter, _, _, nlevels, _) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) + CI = CartesianIndex nghostvertices = length(ghost_vertex_offset) - 1 if gidx ≤ nlevels * nfidx * nghostvertices sizev = (nlevels, nfidx, nghostvertices) @@ -353,14 +362,19 @@ function dss_local_ghost_kernel!( isghost, lidx, vert = ghost_vertices[idx] if !isghost ip = perimeter_vertex_node_index(vert) - sum_data += pperimeter_data[level, ip, fidx, lidx] + sum_data += + getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx)) end end for idx in st:(en - 1) isghost, lidx, vert = ghost_vertices[idx] if !isghost ip = perimeter_vertex_node_index(vert) - pperimeter_data[level, ip, fidx, lidx] = sum_data + setindex_field!( + perimeter_data, + sum_data, + CI(ip, 1, fidx, level, lidx), + ) end end end @@ -373,18 +387,19 @@ function Topologies.fill_send_buffer!( synchronize = true, ) (; perimeter_data, send_buf_idx, send_data) = dss_buffer - (nlevels, nperimeter, nfid, nelems) = - DataLayouts.farray_size(perimeter_data) + (nperimeter, _, _, nlevels, nelems) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) nsend = size(send_buf_idx, 1) if nsend > 0 nitems = nsend * nlevels * nfid - nthreads, nblocks = _configure_threadblock(nitems) args = (send_data, send_buf_idx, perimeter_data, Val(nsend)) + threads = threads_via_occupancy(fill_send_buffer_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( fill_send_buffer_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) if synchronize CUDA.synchronize(; blocking = true) # CUDA MPI uses a separate stream. This will synchronize across streams @@ -400,18 +415,16 @@ function fill_send_buffer_kernel!( ::Val{nsend}, ) where {FT <: AbstractFloat, I <: Int, nsend} gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x - (nlevels, _, nfid, nelems) = DataLayouts.farray_size(perimeter_data) - pperimeter_data = parent(perimeter_data) - #sizet = (nsend, nlevels, nfid) + (_, _, _, nlevels, nelems) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) sizet = (nlevels, nfid, nsend) - #if gidx ≤ nsend * nlevels * nfid if gidx ≤ nlevels * nfid * nsend - #(isend, level, fidx) = cart_ind(sizet, gidx).I (level, fidx, isend) = cart_ind(sizet, gidx).I lidx = send_buf_idx[isend, 1] ip = send_buf_idx[isend, 2] idx = level + ((fidx - 1) + (isend - 1) * nfid) * nlevels - send_data[idx] = pperimeter_data[level, ip, fidx, lidx] + send_data[idx] = + getindex_field(perimeter_data, CI(ip, 1, fidx, level, lidx)) end return nothing end @@ -421,18 +434,19 @@ function Topologies.load_from_recv_buffer!( dss_buffer::Topologies.DSSBuffer, ) (; perimeter_data, recv_buf_idx, recv_data) = dss_buffer - (nlevels, nperimeter, nfid, nelems) = - DataLayouts.farray_size(perimeter_data) + (nperimeter, _, _, nlevels, nelems) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) nrecv = size(recv_buf_idx, 1) if nrecv > 0 nitems = nrecv * nlevels * nfid - nthreads, nblocks = _configure_threadblock(nitems) args = (perimeter_data, recv_data, recv_buf_idx, Val(nrecv)) + threads = threads_via_occupancy(load_from_recv_buffer_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( load_from_recv_buffer_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) end return nothing @@ -446,17 +460,22 @@ function load_from_recv_buffer_kernel!( ) where {FT <: AbstractFloat, I <: Int, nrecv} gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x pperimeter_data = parent(perimeter_data) - (nlevels, _, nfid, nelems) = DataLayouts.farray_size(perimeter_data) - #sizet = (nrecv, nlevels, nfid) + (_, _, _, nlevels, nelems) = size(perimeter_data) + nfid = DataLayouts.ncomponents(perimeter_data) sizet = (nlevels, nfid, nrecv) - #if gidx ≤ nrecv * nlevels * nfid + CI = CartesianIndex if gidx ≤ nlevels * nfid * nrecv - #(irecv, level, fidx) = cart_ind(sizet, gidx).I (level, fidx, irecv) = cart_ind(sizet, gidx).I lidx = recv_buf_idx[irecv, 1] ip = recv_buf_idx[irecv, 2] idx = level + ((fidx - 1) + (irecv - 1) * nfid) * nlevels - CUDA.@atomic pperimeter_data[level, ip, fidx, lidx] += recv_data[idx] + ci = CI(ip, 1, fidx, level, lidx) + # CUDA.@atomic has limited support, so + # let's use the methods in DataLayouts + # to allow this to work: + s = DataLayouts.singleton(perimeter_data) + ci_data = CartesianIndex(DataLayouts.to_data_specific_field(s, ci.I)) + CUDA.@atomic pperimeter_data[ci_data] += getindex_field(perimeter_data, ci) end return nothing end @@ -470,9 +489,9 @@ function Topologies.dss_ghost!( ) nghostvertices = length(topology.ghost_vertex_offset) - 1 if nghostvertices > 0 - (nlevels, _, nfidx, _) = DataLayouts.farray_size(perimeter_data) + (_, _, _, nlevels, _) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) nitems = nlevels * nfidx * nghostvertices - nthreads, nblocks = _configure_threadblock(nitems) args = ( perimeter_data, topology.ghost_vertices, @@ -480,11 +499,13 @@ function Topologies.dss_ghost!( topology.repr_ghost_vertex, perimeter, ) + threads = threads_via_occupancy(dss_ghost_kernel!, args) + p = linear_partition(nitems, threads) auto_launch!( dss_ghost_kernel!, args; - threads_s = (nthreads), - blocks_s = (nblocks), + threads_s = p.threads, + blocks_s = p.blocks, ) end return nothing @@ -497,25 +518,32 @@ function dss_ghost_kernel!( repr_ghost_vertex, perimeter::Topologies.Perimeter2D, ) - pperimeter_data = parent(perimeter_data) - FT = eltype(pperimeter_data) + FT = eltype(parent(perimeter_data)) gidx = threadIdx().x + (blockIdx().x - Int32(1)) * blockDim().x - (nlevels, _, nfidx, _) = DataLayouts.farray_size(perimeter_data) + (_, _, _, nlevels, _) = size(perimeter_data) + nfidx = DataLayouts.ncomponents(perimeter_data) nghostvertices = length(ghost_vertex_offset) - 1 - + CI = CartesianIndex if gidx ≤ nlevels * nfidx * nghostvertices (level, fidx, ghostvertexidx) = cart_ind((nlevels, nfidx, nghostvertices), gidx).I idxresult, lvertresult = repr_ghost_vertex[ghostvertexidx] ipresult = perimeter_vertex_node_index(lvertresult) - result = pperimeter_data[level, ipresult, fidx, idxresult] + result = getindex_field( + perimeter_data, + CI(ipresult, 1, fidx, level, idxresult), + ) st, en = ghost_vertex_offset[ghostvertexidx], ghost_vertex_offset[ghostvertexidx + 1] for vertexidx in st:(en - 1) isghost, eidx, lvert = ghost_vertices[vertexidx] if !isghost ip = perimeter_vertex_node_index(lvert) - pperimeter_data[level, ip, fidx, eidx] = result + setindex_field!( + perimeter_data, + result, + CI(ip, 1, fidx, level, eidx), + ) end end end diff --git a/src/Topologies/dss.jl b/src/Topologies/dss.jl index 6c17fc6be2..ea6eff0713 100644 --- a/src/Topologies/dss.jl +++ b/src/Topologies/dss.jl @@ -579,14 +579,13 @@ function fill_send_buffer!( (; perimeter_data, send_buf_idx, send_data) = dss_buffer (Np, _, _, Nv, nelems) = size(perimeter_data) Nf = DataLayouts.ncomponents(perimeter_data) - pdata = parent(perimeter_data) nsend = size(send_buf_idx, 1) ctr = 1 @inbounds for i in 1:nsend lidx = send_buf_idx[i, 1] ip = send_buf_idx[i, 2] for f in 1:Nf, v in 1:Nv - send_data[ctr] = pdata[v, ip, f, lidx] + send_data[ctr] = getindex_field(data, CI(ip, 1, f, v, lidx)) ctr += 1 end end @@ -608,14 +607,15 @@ function load_from_recv_buffer!( (; perimeter_data, recv_buf_idx, recv_data) = dss_buffer (Np, _, _, Nv, nelems) = size(perimeter_data) Nf = DataLayouts.ncomponents(perimeter_data) - pdata = parent(perimeter_data) nrecv = size(recv_buf_idx, 1) ctr = 1 @inbounds for i in 1:nrecv lidx = recv_buf_idx[i, 1] ip = recv_buf_idx[i, 2] for f in 1:Nf, v in 1:Nv - pdata[v, ip, f, lidx] += recv_data[ctr] + ci = CI(ip, 1, f, v, lidx) + val = getindex_field(data, ci) + recv_data[ctr] + setindex_field!(data, val, ci) ctr += 1 end end