diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 153c2c9695..91aaec4d63 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -47,6 +47,15 @@ namespace quda constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } + struct CacheDims { + static constexpr dim3 dims(dim3 block) + { + block.x += 1; + block.z = 1; + return block; + } + }; + /** @brief Transpose between the two different orders of batched colorspinor fields: - B: nVec -> spatial/N -> spin/color -> N, where N is for that in floatN @@ -60,7 +69,7 @@ namespace quda int parity = parity_color / Arg::nColor; using color_spinor_t = ColorSpinor; - SharedMemoryCache cache({target::block_dim().x + 1, target::block_dim().y, 1}); + SharedMemoryCache cache; int x_offset = target::block_dim().x * target::block_idx().x; int v_offset = target::block_dim().y * target::block_idx().y; diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index b63bf7f435..03f9d4b75e 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -10,6 +10,7 @@ #include #include #include +#include namespace quda { @@ -1387,14 +1388,20 @@ namespace quda { }; template <> struct storeCoarseSharedAtomic_impl { + template + using CacheT = complex[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4] + [Arg::coarseSpin][Arg::coarseSpin]; + template using Cache = SharedMemoryCache, DimsStatic<2, 1, 1>>; + template inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, int parity, const Pack &pack, const Arg &arg) { using real = typename Arg::Float; using TileType = typename Arg::vuvTileType; const int dim_index = arg.dim_index % arg.Y_atomic.geometry; - __shared__ complex X[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin]; - __shared__ complex Y[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin]; + Cache cache; + auto &X = cache.data()[0]; + auto &Y = cache.data()[1]; int x_ = coarse_x_cb % arg.aggregates_per_block; int tx = virtualThreadIdx(arg); @@ -1416,7 +1423,7 @@ namespace quda { } } - __syncthreads(); + cache.sync(); #pragma unroll for (int i = 0; i < TileType::M; i++) { @@ -1445,7 +1452,7 @@ namespace quda { } } - __syncthreads(); + cache.sync(); if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) { diff --git a/include/kernels/color_spinor_pack.cuh b/include/kernels/color_spinor_pack.cuh index ff1a7969c4..38100fff50 100644 --- a/include/kernels/color_spinor_pack.cuh +++ b/include/kernels/color_spinor_pack.cuh @@ -172,17 +172,24 @@ namespace quda { }; template <> struct site_max { + template struct CacheDims { + static constexpr int Ms = spins_per_thread(Arg::nSpin); + static constexpr int Mc = colors_per_thread(Arg::nColor); + static constexpr int color_spin_threads = (Arg::nSpin / Ms) * (Arg::nColor / Mc); + static constexpr dim3 dims(dim3 block) + { + // pad the shared block size to avoid bank conflicts for native ordering + if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size(); + block.y = color_spin_threads; // state the y block since we know it at compile time + return block; + } + }; + template __device__ inline auto operator()(typename Arg::real thread_max, Arg &) { using real = typename Arg::real; - constexpr int Ms = spins_per_thread(Arg::nSpin); - constexpr int Mc = colors_per_thread(Arg::nColor); - constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc); - auto block = target::block_dim(); - // pad the shared block size to avoid bank conflicts for native ordering - if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size(); - block.y = color_spin_threads; // state the y block since we know it at compile time - SharedMemoryCache cache(block); + constexpr int color_spin_threads = CacheDims::color_spin_threads; + SharedMemoryCache> cache; cache.save(thread_max); cache.sync(); real this_site_max = static_cast(0); diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index 00b61a9b3e..4cf6f1a311 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -203,7 +203,7 @@ namespace quda { Mat A = arg.clover(x_cb, clover_parity, chirality); - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion #pragma unroll diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 6d74bcda41..744a74f874 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -301,7 +301,7 @@ namespace quda { template <> struct dim_collapse { template __device__ __host__ inline void operator()(T &out, int dir, int dim, const Arg &arg) { - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; // only need to write to shared memory if not master thread if (dim > 0 || dir) cache.save(out); diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index bab21d4c11..dfce18bede 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -220,7 +220,7 @@ namespace quda if (mobius_m5::use_half_vector()) { // if using shared-memory caching then load spinor field for my site into cache typedef ColorSpinor HalfVector; - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; { // forwards direction constexpr int proj_dir = dagger ? +1 : -1; @@ -271,7 +271,7 @@ namespace quda } else { // use_half_vector // if using shared-memory caching then load spinor field for my site into cache - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; if (shared) { if (sync) { cache.sync(); } cache.save(in); @@ -377,7 +377,7 @@ namespace quda const auto inv = arg.inv; // if using shared-memory caching then load spinor field for my site into cache - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; if (shared) { // cache.save(arg.in(s_ * arg.volume_4d_cb + x_cb, parity)); if (sync) { cache.sync(); } @@ -436,7 +436,7 @@ namespace quda Vector out; if (mobius_m5::use_half_vector()) { - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; { // first do R constexpr int proj_dir = dagger ? -1 : +1; @@ -495,7 +495,7 @@ namespace quda out += l.reconstruct(4, proj_dir); } } else { // use_half_vector - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; if (shared) { if (sync) { cache.sync(); } cache.save(in); diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index f5e0a5c8ac..3d62ec3923 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -107,7 +107,7 @@ namespace quda using real = typename Arg::real; typedef ColorSpinor Vector; - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; Vector out; cache.save(arg.in(s * arg.volume_4d_cb + x_cb, parity)); @@ -185,7 +185,7 @@ namespace quda typedef ColorSpinor Vector; const auto sherman_morrison = arg.sherman_morrison; - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; cache.save(arg.in(s * arg.volume_4d_cb + x_cb, parity)); cache.sync(); diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index 108f8c5e84..cb8bc61ad7 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -72,7 +72,7 @@ namespace quda // apply the chiral and flavor twists // use consistent load order across s to ensure better cache locality Vector x = arg.x(my_flavor_idx, my_spinor_parity); - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; cache.save(x); x.toRel(); // switch to chiral basis diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index bdbff30817..ebd8f71da6 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -91,7 +91,7 @@ namespace quda int chirality = flavor; // relabel flavor as chirality - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; enum swizzle_direction { FORWARDS = 0, diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 98e72eb61a..8bab3d5623 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -95,7 +95,7 @@ namespace quda } if (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in - SharedMemoryCache cache(target::block_dim()); + SharedMemoryCache cache; if (isComplete(arg, coord) && active) { // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it cache.save(out); diff --git a/include/kernels/gauge_stout.cuh b/include/kernels/gauge_stout.cuh index 45cc176d88..2a512e77e3 100644 --- a/include/kernels/gauge_stout.cuh +++ b/include/kernels/gauge_stout.cuh @@ -6,6 +6,7 @@ #include #include #include +#include namespace quda { @@ -134,8 +135,8 @@ namespace quda } Link U, Q; - SharedMemoryCache Stap(target::block_dim()); - SharedMemoryCache Rect(target::block_dim(), sizeof(Link)); + ThreadLocalCache Stap; + ThreadLocalCache Rect; // offset by Stap type to ensure non-overlapping allocations // This function gets stap = S_{mu,nu} i.e., the staple of length 3, // and the 1x2 and 2x1 rectangles of length 5. From the following paper: diff --git a/include/kernels/gauge_utils.cuh b/include/kernels/gauge_utils.cuh index 48c7e6c1cc..ded8c9377a 100644 --- a/include/kernels/gauge_utils.cuh +++ b/include/kernels/gauge_utils.cuh @@ -19,6 +19,7 @@ namespace quda // matrix+matrix = 18 floating-point ops // => Total number of floating point ops per function call // dims * (2*18 + 4*198) = dims*828 + using computeStapleOps = thread_array; template __host__ __device__ inline void computeStaple(const Arg &arg, const int *x, const Int *X, const int parity, const int nu, Staple &staple, const int dir_ignore) { @@ -94,6 +95,7 @@ namespace quda // matrix+matrix = 18 floating-point ops // => Total number of floating point ops per function call // dims * (8*18 + 28*198) = dims*5688 + using computeStapleRectangleOps = thread_array; template __host__ __device__ inline void computeStapleRectangle(const Arg &arg, const int *x, const Int *X, const int parity, const int nu, Staple &staple, Rectangle &rectangle, const int dir_ignore) diff --git a/include/kernels/gauge_wilson_flow.cuh b/include/kernels/gauge_wilson_flow.cuh index 327f7c7eb0..457d93beab 100644 --- a/include/kernels/gauge_wilson_flow.cuh +++ b/include/kernels/gauge_wilson_flow.cuh @@ -4,6 +4,7 @@ #include #include #include +#include namespace quda { @@ -71,8 +72,8 @@ namespace quda // This function gets stap = S_{mu,nu} i.e., the staple of length 3, // and the 1x2 and 2x1 rectangles of length 5. From the following paper: // https://arxiv.org/abs/0801.1165 - SharedMemoryCache Stap(target::block_dim()); - SharedMemoryCache Rect(target::block_dim(), sizeof(Link)); // offset to ensure non-overlapping allocations + ThreadLocalCache Stap; + ThreadLocalCache Rect; // offset by Stap type to ensure non-overlapping allocations computeStapleRectangle(arg, x, arg.E, parity, dir, Stap, Rect, Arg::wflow_dim); Z = arg.coeff1x1 * static_cast(Stap) + arg.coeff2x1 * static_cast(Rect); break; diff --git a/include/kernels/hisq_paths_force.cuh b/include/kernels/hisq_paths_force.cuh index cac909bf8a..a16f7783eb 100644 --- a/include/kernels/hisq_paths_force.cuh +++ b/include/kernels/hisq_paths_force.cuh @@ -4,7 +4,7 @@ #include #include #include -#include +#include namespace quda { @@ -272,7 +272,7 @@ namespace quda { * A _______ B * mu_next | | * H| |G - * + * * Variables have been named to reflection dimensionality for * mu_positive == true, sig_positive == true, mu_next_positive == true **************************************************************************/ @@ -372,7 +372,7 @@ namespace quda { @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x @param[in/out] force_mu Accumulated force in the mu direction - @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read) + @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read) @details This subset of the code computes the Lepage contribution to the fermion force. Data traffic: READ: cb_link, id_link, pMu_at_c @@ -386,7 +386,10 @@ namespace quda { Flops: 2 multiplies, 1 add, 1 rescale */ - __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, SharedMemoryCache &Uab_cache) { + template + __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, + LinkCache &Uab_cache) + { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -414,7 +417,7 @@ namespace quda { Link Ow = mu_positive ? (conj(Ucb) * Oc) : (Ucb * Oc); { - Link Uab = Uab_cache.load(); + Link Uab = Uab_cache; Link Oy = sig_positive ? Uab * Ow : conj(Uab) * Ow; Link Ox = mu_positive ? (Oy * Uid) : (Uid * conj(Oy)); auto mycoeff_lepage = -coeff_sign(parity_a)*coeff_sign(parity_a)*arg.coeff_lepage; @@ -440,7 +443,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read) + @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read) Data traffic: READ: gb_link, oProd_at_h WRITE: pMu_next_at_b, p3_at_a @@ -454,7 +457,8 @@ namespace quda { Flops: 2 multiplies, 1 add, 1 rescale */ - __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, SharedMemoryCache &Uab_cache) + template + __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, LinkCache &Uab_cache) { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -487,7 +491,7 @@ namespace quda { arg.pMu_next(0, point_b, parity_b) = Oz; { // scoped Uab load - Link Uab = Uab_cache.load(); + Link Uab = Uab_cache; if constexpr (!sig_positive) Uab = conj(Uab); arg.p3(0, point_a, parity_a) = Uab * Oz; } @@ -535,8 +539,7 @@ namespace quda { /* * The "extra" low point corresponds to the Lepage contribution to the * force_mu term. - * - * + * * sig * F E * | | @@ -557,7 +560,7 @@ namespace quda { int point_a = e_cb; int parity_a = parity; - SharedMemoryCache Uab_cache(target::block_dim()); + ThreadLocalCache Uab_cache; // Scoped load of Uab { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); @@ -636,7 +639,7 @@ namespace quda { Link force; Link shortP; Link p5; - + const Link pMu; // double-buffer: read pNuMu, qNuMu for side 5, middle 7 @@ -688,7 +691,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Shared memory cache that maintains the accumulated P5 contribution (write) + @param[in/out] Matrix_cache Thread local cache that maintains the accumulated P5 contribution (write) the gauge link going from a to b (read), as well as force_sig when sig is positive (read/write) @details This subset of the code computes the full seven link contribution to the HISQ force. Data traffic: @@ -705,8 +708,9 @@ namespace quda { Flops: 4 multiplies, 2 adds, 2 rescales */ - __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, - SharedMemoryCache &Matrix_cache) { + template + __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { auto mycoeff_seven = parity_sign(parity_a) * coeff_sign(parity_a) * arg.coeff_seven; int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); @@ -735,19 +739,19 @@ namespace quda { UbeOeOf = Ube * OeOf; // Cache Ube to below - Matrix_cache.save_z(Ube, 1); + Matrix_cache.save(Ube, 1); } // Take care of force_sig --- contribution from the negative rho direction Link Uaf = arg.link(arg.rho, point_a, parity_a); if constexpr (sig_positive) { - Link force_sig = Matrix_cache.load_z(2); + Link force_sig = Matrix_cache[2]; force_sig = mm_add(mycoeff_seven * UbeOeOf, conj(Uaf), force_sig); - Matrix_cache.save_z(force_sig, 2); + Matrix_cache.save(force_sig, 2); } // Compute the force_rho --- contribution from the negative rho direction - Link Uab = Matrix_cache.load_z(0); + Link Uab = Matrix_cache[0]; if constexpr (!sig_positive) Uab = conj(Uab); Link force_rho = arg.force(arg.rho, point_a, parity_a); force_rho = mm_add(mycoeff_seven * conj(UbeOeOf), conj(Uab), force_rho); @@ -756,7 +760,7 @@ namespace quda { Link Ufe = arg.link(arg.sig, fe_link_nbr_idx, fe_link_nbr_parity); // Load Ube from the cache - Link Ube = Matrix_cache.load_z(1); + Link Ube = Matrix_cache[1]; // Form the product UfeUebOb Link UfeUeb = (sig_positive ? Ufe : conj(Ufe)) * conj(Ube); @@ -788,7 +792,7 @@ namespace quda { Link Oz = Ucb * Ob; Link Oy = (sig_positive ? Udc : conj(Udc)) * Oz; p5_sig = mm_add(arg.accumu_coeff_seven * conj(Uda), Oy, p5_sig); - Matrix_cache.save_z(p5_sig, 1); + Matrix_cache.save(p5_sig, 1); // When sig is positive, compute the force_sig contribution from the // positive rho direction @@ -796,11 +800,10 @@ namespace quda { Link Od = arg.qNuMu(0, point_d, parity_d); Link Oc = arg.pNuMu(0, point_c, parity_c); Link Oz = conj(Ucb) * Oc; - Link force_sig = Matrix_cache.load_z(2); + Link force_sig = Matrix_cache[2]; force_sig = mm_add(mycoeff_seven * Oz, Od * Uda, force_sig); - Matrix_cache.save_z(force_sig, 2); + Matrix_cache.save(force_sig, 2); } - } /** @@ -808,7 +811,7 @@ namespace quda { @param[in] x Local coordinate @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Shared memory cache that maintains the full P5 contribution + @param[in/out] Matrix_cache Thread local cache that maintains the full P5 contribution summed from the previous middle five and all seven (read), as well as force_sig when sig is positive (read/write) @details This subset of the code computes the side link five link contribution to the HISQ force. @@ -818,7 +821,9 @@ namespace quda { Flops: 2 multiplies, 2 adds, 2 rescales */ - __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, SharedMemoryCache &Matrix_cache) { + template + __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { int y[4] = {x[0], x[1], x[2], x[3]}; int point_h = updateCoordExtendedIndexShiftMILC(y, arg.nu, arg); int parity_h = 1 - parity_a; @@ -832,7 +837,7 @@ namespace quda { int qh_link_nbr_idx = mu_positive ? point_q : point_h; int qh_link_nbr_parity = mu_positive ? parity_q : parity_h; - Link P5 = Matrix_cache.load_z(1); + Link P5 = Matrix_cache[1]; Link Uah = arg.link(arg.nu, ha_link_nbr_idx, ha_link_nbr_parity); Link Ow = nu_positive ? Uah * P5 : conj(Uah) * P5; @@ -857,7 +862,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Helper shared memory cache that maintains the gauge link going + @param[in/out] Matrix_cache Thread local cache that maintains the gauge link going from a to b (read) and, when sig is positive, force_sig (read/write) @details This subset of the code computes the middle link five link contribution to the HISQ force. Data traffic: @@ -870,8 +875,9 @@ namespace quda { Flops: 1 multiply, 1 add, 1 rescale */ - __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, - SharedMemoryCache &Matrix_cache) { + template + __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) + { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -902,7 +908,7 @@ namespace quda { arg.pNuMu_next(0, point_b, parity_b) = Ow; { // scoped Uab load - Link Uab = Matrix_cache.load_z(0); + Link Uab = Matrix_cache[0]; if constexpr (!sig_positive) Uab = conj(Uab); arg.p5(0, point_a, parity_a) = Uab * Ow; } @@ -917,9 +923,9 @@ namespace quda { // compute the force in the sigma direction if sig is positive if constexpr (sig_positive) { - Link force_sig = Matrix_cache.load_z(2); + Link force_sig = Matrix_cache[2]; force_sig = mm_add(arg.coeff_five * Ow, Ox, force_sig); - Matrix_cache.save_z(force_sig, 2); + Matrix_cache.save(force_sig, 2); } } @@ -955,14 +961,14 @@ namespace quda { int point_a = e_cb; int parity_a = parity; - + // calculate p5_sig - auto block_dim = target::block_dim(); - block_dim.z = (sig_positive ? 3 : 2); - SharedMemoryCache Matrix_cache(block_dim); + constexpr int cacheLen = sig_positive ? 3 : 2; + ThreadLocalCache Matrix_cache; + if constexpr (sig_positive) { Link force_sig = arg.force(arg.sig, point_a, parity_a); - Matrix_cache.save_z(force_sig, 2); + Matrix_cache.save(force_sig, 2); } // Scoped load of Uab @@ -972,7 +978,7 @@ namespace quda { int ab_link_nbr_idx = (sig_positive) ? point_a : point_b; int ab_link_nbr_parity = (sig_positive) ? parity_a : parity_b; Link Uab = arg.link(arg.sig, ab_link_nbr_idx, ab_link_nbr_parity); - Matrix_cache.save_z(Uab, 0); + Matrix_cache.save(Uab, 0); } // accumulate into P5, force_sig @@ -987,7 +993,7 @@ namespace quda { // update the force in the sigma direction if constexpr (sig_positive) { - Link force_sig = Matrix_cache.load_z(2); + Link force_sig = Matrix_cache[2]; arg.force(arg.sig, point_a, parity_a) = force_sig; } diff --git a/include/targets/cuda/load_store.h b/include/targets/cuda/load_store.h index 4a5420b166..0550ad62dd 100644 --- a/include/targets/cuda/load_store.h +++ b/include/targets/cuda/load_store.h @@ -6,6 +6,12 @@ namespace quda { + /** + @brief Element type used for coalesced storage. + */ + template + using atom_t = std::conditional_t>; + // pre-declaration of vector_load that we wish to specialize template struct vector_load_impl; diff --git a/include/targets/cuda/shared_memory_cache_helper.h b/include/targets/cuda/shared_memory_cache_helper.h index 7c7c0a1b28..73be0cd01b 100644 --- a/include/targets/cuda/shared_memory_cache_helper.h +++ b/include/targets/cuda/shared_memory_cache_helper.h @@ -1,295 +1 @@ -#pragma once - -#include -#include - -/** - @file shared_memory_cache_helper.h - - Helper functionality for aiding the use of the shared memory for - sharing data between threads in a thread block. - */ - -namespace quda -{ - - /** - @brief Class which wraps around a shared memory cache for type T, - where each thread in the thread block stores a unique value in - the cache which any other thread can access. - - This accessor supports both explicit run-time block size and - compile-time sizing. - - * For run-time block size, the constructor should be initialied - with the desired block size. - - * For compile-time block size, no arguments should be passed to - the constructor, and then the second and third template - parameters correspond to the y and z dimensions of the block, - respectively. The x dimension of the block will be set - according the maximum number of threads possible, given these - dimensions. - */ - template class SharedMemoryCache - { - public: - using value_type = T; - static constexpr int block_size_y = block_size_y_; - static constexpr int block_size_z = block_size_z_; - static constexpr bool dynamic = dynamic_; - - private: - /** maximum number of threads in x given the y and z block sizes */ - static constexpr int block_size_x = device::max_block_size(); - - using atom_t = std::conditional_t>; - static_assert(sizeof(T) % 4 == 0, "Shared memory cache does not support sub-word size types"); - - // The number of elements of type atom_t that we break T into for optimal shared-memory access - static constexpr int n_element = sizeof(T) / sizeof(atom_t); - - const dim3 block; - const int stride; - const unsigned int offset = 0; // dynamic offset in bytes - - /** - @brief This is a dummy instantiation for the host compiler - */ - template struct cache_dynamic { - atom_t *operator()(unsigned) - { - static atom_t *cache_; - return reinterpret_cast(cache_); - } - }; - - /** - @brief This is the handle to the shared memory, dynamic specialization - @return Shared memory pointer - */ - template struct cache_dynamic { - __device__ inline atom_t *operator()(unsigned int offset) - { - extern __shared__ int cache_[]; - return reinterpret_cast(reinterpret_cast(cache_) + offset); - } - }; - - /** - @brief This is a dummy instantiation for the host compiler - */ - template struct cache_static { - atom_t *operator()() - { - static atom_t *cache_; - return reinterpret_cast(cache_); - } - }; - - /** - @brief This is the handle to the shared memory, static specialization - @return Shared memory pointer - */ - template struct cache_static { - __device__ inline atom_t *operator()() - { - static __shared__ atom_t cache_[n_element * block_size_x * block_size_y * block_size_z]; - return reinterpret_cast(cache_); - } - }; - - template __device__ __host__ inline std::enable_if_t cache() const - { - return target::dispatch(offset); - } - - template __device__ __host__ inline std::enable_if_t cache() const - { - return target::dispatch(); - } - - __device__ __host__ inline void save_detail(const T &a, int x, int y, int z) const - { - atom_t tmp[n_element]; - memcpy(tmp, (void *)&a, sizeof(T)); - int j = (z * block.y + y) * block.x + x; -#pragma unroll - for (int i = 0; i < n_element; i++) cache()[i * stride + j] = tmp[i]; - } - - __device__ __host__ inline T load_detail(int x, int y, int z) const - { - atom_t tmp[n_element]; - int j = (z * block.y + y) * block.x + x; -#pragma unroll - for (int i = 0; i < n_element; i++) tmp[i] = cache()[i * stride + j]; - T a; - memcpy((void *)&a, tmp, sizeof(T)); - return a; - } - - /** - @brief Dummy instantiation for the host compiler - */ - template struct sync_impl { - void operator()() { } - }; - - /** - @brief Synchronize the cache when on the device - */ - template struct sync_impl { - __device__ inline void operator()() { __syncthreads(); } - }; - - public: - /** - @brief constructor for SharedMemory cache. If no arguments are - pass, then the dimensions are set according to the templates - block_size_y and block_size_z, together with the derived - block_size_x. Otherwise use the block sizes passed into the - constructor. - - @param[in] block Block dimensions for the 3-d shared memory object - @param[in] thread_offset "Perceived" offset from dynamic shared - memory base pointer (used when we have multiple caches in - scope). Need to include block size to actual offset. - */ - constexpr SharedMemoryCache(dim3 block = dim3(block_size_x, block_size_y, block_size_z), - unsigned int thread_offset = 0) : - block(block), stride(block.x * block.y * block.z), offset(stride * thread_offset) - { - } - - /** - @brief Grab the raw base address to shared memory. - */ - __device__ __host__ inline auto data() const { return reinterpret_cast(cache()); } - - /** - @brief Save the value into the 3-d shared memory cache. - @param[in] a The value to store in the shared memory cache - @param[in] x The x index to use - @param[in] y The y index to use - @param[in] z The z index to use - */ - __device__ __host__ inline void save(const T &a, int x = -1, int y = -1, int z = -1) const - { - auto tid = target::thread_idx(); - x = (x == -1) ? tid.x : x; - y = (y == -1) ? tid.y : y; - z = (z == -1) ? tid.z : z; - save_detail(a, x, y, z); - } - - /** - @brief Save the value into the 3-d shared memory cache. - @param[in] a The value to store in the shared memory cache - @param[in] x The x index to use - */ - __device__ __host__ inline void save_x(const T &a, int x = -1) const - { - auto tid = target::thread_idx(); - x = (x == -1) ? tid.x : x; - save_detail(a, x, tid.y, tid.z); - } - - /** - @brief Save the value into the 3-d shared memory cache. - @param[in] a The value to store in the shared memory cache - @param[in] y The y index to use - */ - __device__ __host__ inline void save_y(const T &a, int y = -1) const - { - auto tid = target::thread_idx(); - y = (y == -1) ? tid.y : y; - save_detail(a, tid.x, y, tid.z); - } - - /** - @brief Save the value into the 3-d shared memory cache. - @param[in] a The value to store in the shared memory cache - @param[in] z The z index to use - */ - __device__ __host__ inline void save_z(const T &a, int z = -1) const - { - auto tid = target::thread_idx(); - z = (z == -1) ? tid.z : z; - save_detail(a, tid.x, tid.y, z); - } - - /** - @brief Load a value from the shared memory cache - @param[in] x The x index to use - @param[in] y The y index to use - @param[in] z The z index to use - @return The value at coordinates (x,y,z) - */ - __device__ __host__ inline T load(int x = -1, int y = -1, int z = -1) const - { - auto tid = target::thread_idx(); - x = (x == -1) ? tid.x : x; - y = (y == -1) ? tid.y : y; - z = (z == -1) ? tid.z : z; - return load_detail(x, y, z); - } - - /** - @brief Load a vector from the shared memory cache - @param[in] x The x index to use - @return The value at coordinates (x,y,z) - */ - __device__ __host__ inline T load_x(int x = -1) const - { - auto tid = target::thread_idx(); - x = (x == -1) ? tid.x : x; - return load_detail(x, tid.y, tid.z); - } - - /** - @brief Load a vector from the shared memory cache - @param[in] y The y index to use - @return The value at coordinates (x,y,z) - */ - __device__ __host__ inline T load_y(int y = -1) const - { - auto tid = target::thread_idx(); - y = (y == -1) ? tid.y : y; - return load_detail(tid.x, y, tid.z); - } - - /** - @brief Load a vector from the shared memory cache - @param[in] z The z index to use - @return The value at coordinates (x,y,z) - */ - __device__ __host__ inline T load_z(int z = -1) const - { - auto tid = target::thread_idx(); - z = (z == -1) ? tid.z : z; - return load_detail(tid.x, tid.y, z); - } - - /** - @brief Synchronize the cache - */ - __device__ __host__ void sync() const { target::dispatch(); } - - /** - @brief Cast operator to allow cache objects to be used where T - is expected - */ - __device__ __host__ operator T() const { return load(); } - - /** - @brief Assignment operator to allow cache objects to be used on - the lhs where T is otherwise expected. - */ - __device__ __host__ void operator=(const T &src) const { save(src); } - }; - -} // namespace quda - -// include overloads #include "../generic/shared_memory_cache_helper.h" diff --git a/include/targets/cuda/shared_memory_helper.h b/include/targets/cuda/shared_memory_helper.h new file mode 100644 index 0000000000..bc9bd7c66b --- /dev/null +++ b/include/targets/cuda/shared_memory_helper.h @@ -0,0 +1,89 @@ +#pragma once + +#include + +/** + @file shared_memory_helper.h + + Target specific helper for allocating and accessing shared memory. + */ + +namespace quda +{ + + /** + @brief Class which is used to allocate and access shared memory. + The shared memory is treated as an array of type T, with the + number of elements given by a call to the static member + S::size(target::block_dim()). The byte offset from the beginning + of the total shared memory block is given by the static member + O::shared_mem_size(target::block_dim()), or 0 if O is void. + */ + template class SharedMemory + { + public: + using value_type = T; + + private: + T *data; + + /** + @brief This is a dummy instantiation for the host compiler + */ + template struct cache_dynamic { + T *operator()(unsigned int) + { + static T *cache_; + return cache_; + } + }; + + /** + @brief This is the handle to the dynamic shared memory + @return Shared memory pointer + */ + template struct cache_dynamic { + __device__ inline T *operator()(unsigned int offset) + { + extern __shared__ int cache_[]; + return reinterpret_cast(reinterpret_cast(cache_) + offset); + } + }; + + __device__ __host__ inline T *cache(unsigned int offset) const { return target::dispatch(offset); } + + public: + /** + @brief Byte offset for this shared memory object. + */ + static constexpr unsigned int get_offset(dim3 block) + { + unsigned int o = 0; + if constexpr (!std::is_same_v) { o = O::shared_mem_size(block); } + return o; + } + + /** + @brief Shared memory size in bytes. + */ + static constexpr unsigned int shared_mem_size(dim3 block) { return get_offset(block) + S::size(block) * sizeof(T); } + + /** + @brief Constructor for SharedMemory object. + */ + constexpr SharedMemory() : data(cache(get_offset(target::block_dim()))) { } + + /** + @brief Return this SharedMemory object. + */ + constexpr auto sharedMem() const { return *this; } + + /** + @brief Subscripting operator returning a reference to element. + @param[in] i The index to use. + @return Reference to value stored at that index. + */ + __device__ __host__ T &operator[](const int i) const { return data[i]; } + }; + +} // namespace quda diff --git a/include/targets/cuda/thread_array.h b/include/targets/cuda/thread_array.h index 4fe1bb33f6..1c4d7f3244 100644 --- a/include/targets/cuda/thread_array.h +++ b/include/targets/cuda/thread_array.h @@ -1,49 +1,18 @@ #pragma once -#include "shared_memory_cache_helper.h" - -namespace quda -{ - #ifndef _NVHPC_CUDA - /** - @brief Class that provides indexable per-thread storage. On CUDA - this maps to using assigning each thread a unique window of - shared memory using the SharedMemoryCache object. - */ - template struct thread_array { - SharedMemoryCache, 1, 1, false> device_array; - int offset; - array host_array; - array &array_; - - __device__ __host__ constexpr thread_array() : - offset((target::thread_idx().z * target::block_dim().y + target::thread_idx().y) * target::block_dim().x - + target::thread_idx().x), - array_(target::is_device() ? *(device_array.data() + offset) : host_array) - { - array_ = array(); // call default constructor - } - - template - __device__ __host__ constexpr thread_array(T first, const Ts... other) : - offset((target::thread_idx().z * target::block_dim().y + target::thread_idx().y) * target::block_dim().x - + target::thread_idx().x), - array_(target::is_device() ? *(device_array.data() + offset) : host_array) - { - array_ = array {first, other...}; - } - - __device__ __host__ T &operator[](int i) { return array_[i]; } - __device__ __host__ const T &operator[](int i) const { return array_[i]; } - }; +#include "../generic/thread_array.h" #else +#include + +namespace quda +{ template struct thread_array : array { + static constexpr unsigned int shared_mem_size(dim3 block) { return 0; } }; +} // namespace quda #endif - -} // namespace quda diff --git a/include/targets/cuda/thread_local_cache.h b/include/targets/cuda/thread_local_cache.h new file mode 100644 index 0000000000..dd4cd863fc --- /dev/null +++ b/include/targets/cuda/thread_local_cache.h @@ -0,0 +1 @@ +#include "../generic/thread_local_cache.h" diff --git a/include/targets/cuda/tunable_kernel.h b/include/targets/cuda/tunable_kernel.h index d7936eb497..7306ab355c 100644 --- a/include/targets/cuda/tunable_kernel.h +++ b/include/targets/cuda/tunable_kernel.h @@ -45,6 +45,7 @@ namespace quda std::enable_if_t(), qudaError_t> launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { + checkSharedBytes(tp); #ifdef JITIFY launch_error = launch_jitify(kernel.name, tp, stream, arg); #else @@ -62,6 +63,7 @@ namespace quda std::enable_if_t(), qudaError_t> launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { + checkSharedBytes(tp); #ifdef JITIFY // note we do the copy to constant memory after the kernel has been compiled in launch_jitify launch_error = launch_jitify(kernel.name, tp, stream, arg); @@ -83,6 +85,7 @@ namespace quda template