Skip to content

Commit

Permalink
Define a linear partition, and use in FD stencils
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 26, 2024
1 parent 8a5e564 commit 3e89d94
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
19 changes: 19 additions & 0 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ end
##### Custom partitions
#####

##### linear partition
@inline function linear_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer,
)
nitems = prod(DataLayouts.universal_size(us))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function linear_universal_index(us::UniversalSize)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
inds = DataLayouts.universal_size(us)
CI = CartesianIndices(map(x -> Base.OneTo(x), inds))
return (CI, i)
end
@inline linear_is_valid_index(i::Integer, us::UniversalSize) =
1 i DataLayouts.get_N(us)

##### Column-wise
@inline function columnwise_partition(
us::DataLayouts.UniversalSize,
Expand Down
11 changes: 6 additions & 5 deletions ext/cuda/operators_finite_difference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ function Base.copyto!(
bounds = Operators.window_bounds(space, bc)
out_fv = Fields.field_values(out)
us = DataLayouts.UniversalSize(out_fv)
nitems = prod(DataLayouts.universal_size(us))
args =
(strip_space(out, space), strip_space(bc, space), axes(out), bounds, us)

threads = threads_via_occupancy(copyto_stencil_kernel!, args)
n_max_threads = min(threads, get_N(us))
p = partition(out_fv, n_max_threads)
n_max_threads = min(threads, nitems)
p = linear_partition(us, n_max_threads)

auto_launch!(
copyto_stencil_kernel!,
Expand All @@ -40,9 +41,9 @@ import ClimaCore.DataLayouts: get_N, get_Nv, get_Nij, get_Nij, get_Nh

function copyto_stencil_kernel!(out, bc, space, bds, us)
@inbounds begin
out_fv = Fields.field_values(out)
I = universal_index(out_fv)
if is_valid_index(out_fv, I, us)
(CI, i_linear) = linear_universal_index(us)
if linear_is_valid_index(i_linear, us)
I = CI[i_linear]
(li, lw, rw, ri) = bds
(i, j, _, v, h) = I.I
hidx = (i, j, h)
Expand Down
3 changes: 1 addition & 2 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ is excluded and is returned as 1.
Statically returns `prod((Ni, Nj, Nv, Nh))`
"""
@inline get_N(::UniversalSize{Ni, Nj, Nv, Nh}) where {Ni, Nj, Nv, Nh} =
prod((Ni, Nj, Nv, Nh))
@inline get_N(us::UniversalSize) = prod(universal_size(us))

"""
get_Nv(::UniversalSize)
Expand Down

0 comments on commit 3e89d94

Please sign in to comment.