From d751a7bac27dc4ab01c7b42fd5df0fc4618d37b6 Mon Sep 17 00:00:00 2001 From: sriharshakandala Date: Wed, 18 Dec 2024 11:45:37 -0800 Subject: [PATCH] Swap CUDA grid dimensions for some partitions --- ext/cuda/data_layouts_threadblock.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ext/cuda/data_layouts_threadblock.jl b/ext/cuda/data_layouts_threadblock.jl index 6ff4967855..91cd0191fb 100644 --- a/ext/cuda/data_layouts_threadblock.jl +++ b/ext/cuda/data_layouts_threadblock.jl @@ -55,11 +55,11 @@ function is_valid_index end Nv_thread = min(Int(fld(n_max_threads, Nij * Nij)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Nij, Nij)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Nij, Nij))),$n_max_threads)" - return (; threads = (Nv_thread, Nij, Nij), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Nij, Nij), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIJFH, DataLayouts.VIJHF}) (tv, i, j) = CUDA.threadIdx() - (bv, h) = CUDA.blockIdx() + (h, bv) = CUDA.blockIdx() v = tv + (bv - 1) * CUDA.blockDim().x return CartesianIndex((i, j, 1, v, h)) end @@ -152,11 +152,11 @@ end Nv_thread = min(Int(fld(n_max_threads, Ni)), Nv) Nv_blocks = cld(Nv, Nv_thread) @assert prod((Nv_thread, Ni)) ≤ n_max_threads "threads,n_max_threads=($(prod((Nv_thread, Ni))),$n_max_threads)" - return (; threads = (Nv_thread, Ni), blocks = (Nv_blocks, Nh)) + return (; threads = (Nv_thread, Ni), blocks = (Nh, Nv_blocks)) end @inline function universal_index(::Union{DataLayouts.VIFH, DataLayouts.VIHF}) (tv, i) = CUDA.threadIdx() - (bv, h) = CUDA.blockIdx() + (h, bv) = CUDA.blockIdx() v = tv + (bv - 1) * CUDA.blockDim().x return CartesianIndex((i, 1, 1, v, h)) end