Skip to content

Commit

Permalink
Hoist UniversalSize computation outside of kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 26, 2024
1 parent 65d0e30 commit a540321
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
4 changes: 2 additions & 2 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ end
@assert prod((Nij, Nij, Nh_thread)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nh_thread))),$n_max_threads)"
return (; threads = (Nij, Nij, Nh_thread), blocks = (Nh_blocks,))
end
@inline function columnwise_universal_index()
@inline function columnwise_universal_index(us::UniversalSize)
(i, j, th) = CUDA.threadIdx()
(bh,) = CUDA.blockIdx()
h = th + (bh - 1) * CUDA.blockDim().z
Expand All @@ -207,7 +207,7 @@ end
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
end
@inline function multiple_field_solve_universal_index()
@inline function multiple_field_solve_universal_index(us::UniversalSize)
(i, j, iname) = CUDA.threadIdx()
(h,) = CUDA.blockIdx()
return (CartesianIndex((i, j, 1, 1, h)), iname)
Expand Down
8 changes: 4 additions & 4 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ NVTX.@annotate function multiple_field_solve!(

device = ClimaComms.device(x[first(names)])

args = (device, caches, xs, As, bs, x1, Val(Nnames))

us = UniversalSize(Fields.field_values(x1))
args = (device, caches, xs, As, bs, x1, us, Val(Nnames))

nitems = Ni * Nj * Nh * Nnames
threads = threads_via_occupancy(multiple_field_solve_kernel!, args)
n_max_threads = min(threads, nitems)
Expand Down Expand Up @@ -85,11 +85,11 @@ function multiple_field_solve_kernel!(
As,
bs,
x1,
us::UniversalSize,
::Val{Nnames},
) where {Nnames}
@inbounds begin
us = UniversalSize(Fields.field_values(x1))
(I, iname) = multiple_field_solve_universal_index()
(I, iname) = multiple_field_solve_universal_index(us)
if multiple_field_solve_is_valid_index(I, us)
(i, j, _, _, h) = I.I
generated_single_field_solve!(
Expand Down
2 changes: 1 addition & 1 deletion ext/cuda/matrix_fields_single_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function single_field_solve!(device::ClimaComms.CUDADevice, cache, x, A, b)
end

function single_field_solve_kernel!(device, cache, x, A, b, us)
I = columnwise_universal_index()
I = columnwise_universal_index(us)
if columnwise_is_valid_index(I, us)
(i, j, _, _, h) = I.I
_single_field_solve!(
Expand Down
13 changes: 8 additions & 5 deletions ext/cuda/operators_integral.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ function column_reduce_device!(
space,
) where {F, T}
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
us = UniversalSize(Fields.field_values(output))
args = (
single_column_reduce!,
f,
Expand All @@ -27,8 +28,8 @@ function column_reduce_device!(
strip_space(input, space),
init,
space,
us,
)
us = UniversalSize(Fields.field_values(output))
nitems = Ni * Nj * Nh
threads = threads_via_occupancy(bycolumn_kernel!, args)
n_max_threads = min(threads, nitems)
Expand All @@ -50,7 +51,8 @@ function column_accumulate_device!(
init,
space,
) where {F, T}
us = UniversalSize(Fields.field_values(output))
out_fv = Fields.field_values(output)
us = UniversalSize(out_fv)
args = (
single_column_accumulate!,
f,
Expand All @@ -59,8 +61,9 @@ function column_accumulate_device!(
strip_space(input, space),
init,
space,
us,
)
Ni, Nj, _, _, Nh = size(Fields.field_values(output))
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
nitems = Ni * Nj * Nh
threads = threads_via_occupancy(bycolumn_kernel!, args)
n_max_threads = min(threads, nitems)
Expand All @@ -81,12 +84,12 @@ bycolumn_kernel!(
input,
init,
space,
us::DataLayouts.UniversalSize,
) where {S, F, T} =
if space isa Spaces.FiniteDifferenceSpace
single_column_function!(f, transform, output, input, init, space)
else
I = columnwise_universal_index()
us = UniversalSize(Fields.field_values(output))
I = columnwise_universal_index(us)
if columnwise_is_valid_index(I, us)
(i, j, _, _, h) = I.I
single_column_function!(
Expand Down
6 changes: 3 additions & 3 deletions ext/cuda/operators_thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import CUDA
using CUDA: @cuda
function column_thomas_solve!(::ClimaComms.CUDADevice, A, b)
us = UniversalSize(Fields.field_values(A))
args = (A, b)
args = (A, b, us)
Ni, Nj, _, _, Nh = size(Fields.field_values(A))
threads = threads_via_occupancy(thomas_algorithm_kernel!, args)
nitems = Ni * Nj * Nh
Expand All @@ -23,9 +23,9 @@ end
function thomas_algorithm_kernel!(
A::Fields.ExtrudedFiniteDifferenceField,
b::Fields.ExtrudedFiniteDifferenceField,
us::DataLayouts.UniversalSize,
)
I = columnwise_universal_index()
us = UniversalSize(Fields.field_values(A))
I = columnwise_universal_index(us)
if columnwise_is_valid_index(I, us)
(i, j, _, _, h) = I.I
thomas_algorithm!(Spaces.column(A, i, j, h), Spaces.column(b, i, j, h))
Expand Down
1 change: 1 addition & 0 deletions src/Operators/thomas_algorithm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ column_thomas_solve!(::ClimaComms.AbstractCPUDevice, A, b) =
thomas_algorithm_kernel!(
A::Fields.FiniteDifferenceField,
b::Fields.FiniteDifferenceField,
us::DataLayouts.UniversalSize,
) = thomas_algorithm!(A, b)

function thomas_algorithm!(
Expand Down

0 comments on commit a540321

Please sign in to comment.