Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Sep 26, 2024
1 parent 3e89d94 commit 6adeb1d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
47 changes: 39 additions & 8 deletions ext/cuda/data_layouts_threadblock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,26 @@ end
@inline columnwise_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)

@inline function columnwise_linear_partition(
us::DataLayouts.UniversalSize,
n_max_threads::Integer,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
nitems = prod((Nij, Nij, Nh))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function columnwise_linear_universal_index(us)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
n = (Nij, Nij, Nh)
CI = CartesianIndices(map(x -> Base.OneTo(x), n))
return (CI, i)
end
@inline columnwise_linear_is_valid_index(i_linear::Integer, N::Integer) =
1 i_linear N

##### Element-wise (e.g., limiters)
# TODO

Expand All @@ -223,16 +243,27 @@ end
Nnames,
)
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
@assert prod((Nij, Nij, Nnames)) n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
# @assert prod((Nij, Nij, Nnames)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nij, Nij, Nnames))),$n_max_threads)"
# return (; threads = (Nij, Nij, Nnames), blocks = (Nh,))
nitems = prod((Nh, Nij, Nij, Nnames))
threads = min(nitems, n_max_threads)
blocks = cld(nitems, threads)
return (; threads, blocks)
end
@inline function multiple_field_solve_universal_index(us::UniversalSize)
(i, j, iname) = CUDA.threadIdx()
(h,) = CUDA.blockIdx()
return (CartesianIndex((i, j, 1, 1, h)), iname)
@inline function multiple_field_solve_universal_index(us::DataLayouts.UniversalSize, ::Val{Nnames}) where {Nnames}
# (i, j, iname) = CUDA.threadIdx()
# (h,) = CUDA.blockIdx()
# return (CartesianIndex((i, j, 1, 1, h)), iname)
i = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
(Nij, _, _, _, Nh) = DataLayouts.universal_size(us)
n = (Nij, Nij, Nh, Nnames)
CI = CartesianIndices(n)
return (CI, i)
end
@inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
1 I[5] DataLayouts.get_Nh(us)
# @inline multiple_field_solve_is_valid_index(I::CI5, us::UniversalSize) =
# 1 ≤ I[5] ≤ DataLayouts.get_Nh(us)
@inline multiple_field_solve_is_valid_index(i_linear::Integer, N::Integer) =
1 i_linear N

##### spectral kernel partition
@inline function spectral_partition(
Expand Down
6 changes: 3 additions & 3 deletions ext/cuda/matrix_fields_multiple_field_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ function multiple_field_solve_kernel!(
::Val{Nnames},
) where {Nnames}
@inbounds begin
(I, iname) = multiple_field_solve_universal_index(us)
if multiple_field_solve_is_valid_index(I, us)
(i, j, _, _, h) = I.I
(CI, i_linear) = multiple_field_solve_universal_index(us, Val(Nnames))
if multiple_field_solve_is_valid_index(i_linear, prod(CI.I))
(i, j, _, _, h, iname) = CI.I
generated_single_field_solve!(
device,
caches,
Expand Down

0 comments on commit 6adeb1d

Please sign in to comment.