From eebfccec06a360b8efbafdd5eb733cf8a08d6480 Mon Sep 17 00:00:00 2001 From: James Osborn Date: Thu, 30 Nov 2023 16:12:11 -0600 Subject: [PATCH] sync hisq_paths_force with sycl-merge branch --- include/kernels/hisq_paths_force.cuh | 60 +++++++++------------------- 1 file changed, 19 insertions(+), 41 deletions(-) diff --git a/include/kernels/hisq_paths_force.cuh b/include/kernels/hisq_paths_force.cuh index b114eee9ec..c0156eab0b 100644 --- a/include/kernels/hisq_paths_force.cuh +++ b/include/kernels/hisq_paths_force.cuh @@ -272,7 +272,7 @@ namespace quda { * A _______ B * mu_next | | * H| |G - * + * * Variables have been named to reflection dimensionality for * mu_positive == true, sig_positive == true, mu_next_positive == true **************************************************************************/ @@ -379,7 +379,7 @@ namespace quda { @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x @param[in/out] force_mu Accumulated force in the mu direction - @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read) + @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read) @details This subset of the code computes the Lepage contribution to the fermion force. Data traffic: READ: cb_link, id_link, pMu_at_c @@ -393,7 +393,8 @@ namespace quda { Flops: 2 multiplies, 1 add, 1 rescale */ - __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, ThreadLocalCache &Uab_cache) { + template + __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, LinkCache &Uab_cache) { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -421,7 +422,7 @@ namespace quda { Link Ow = mu_positive ? (conj(Ucb) * Oc) : (Ucb * Oc); { - Link Uab = Uab_cache.load(); + Link Uab = Uab_cache; Link Oy = sig_positive ? Uab * Ow : conj(Uab) * Ow; Link Ox = mu_positive ? (Oy * Uid) : (Uid * conj(Oy)); auto mycoeff_lepage = -coeff_sign(parity_a)*coeff_sign(parity_a)*arg.coeff_lepage; @@ -447,7 +448,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read) + @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read) Data traffic: READ: gb_link, oProd_at_h WRITE: pMu_next_at_b, p3_at_a @@ -461,7 +462,8 @@ namespace quda { Flops: 2 multiplies, 1 add, 1 rescale */ - __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, ThreadLocalCache &Uab_cache) + template + __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, LinkCache &Uab_cache) { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -494,7 +496,7 @@ namespace quda { arg.pMu_next(0, point_b, parity_b) = Oz; { // scoped Uab load - Link Uab = Uab_cache.load(); + Link Uab = Uab_cache; if constexpr (!sig_positive) Uab = conj(Uab); arg.p3(0, point_a, parity_a) = Uab * Oz; } @@ -542,8 +544,7 @@ namespace quda { /* * The "extra" low point corresponds to the Lepage contribution to the * force_mu term. - * - * + * * sig * F E * | | @@ -704,7 +705,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Shared memory cache that maintains the accumulated P5 contribution (write) + @param[in/out] Matrix_cache Thread local cache that maintains the accumulated P5 contribution (write) the gauge link going from a to b (read), as well as force_sig when sig is positive (read/write) @details This subset of the code computes the full seven link contribution to the HISQ force. Data traffic: @@ -721,9 +722,8 @@ namespace quda { Flops: 4 multiplies, 2 adds, 2 rescales */ - //__device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) { - template - __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) { + template + __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { auto mycoeff_seven = parity_sign(parity_a) * coeff_sign(parity_a) * arg.coeff_seven; int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); @@ -752,22 +752,18 @@ namespace quda { UbeOeOf = Ube * OeOf; // Cache Ube to below - //Matrix_cache.save_z(Ube, 1); Matrix_cache.save(Ube, 1); } // Take care of force_sig --- contribution from the negative rho direction Link Uaf = arg.link(arg.rho, point_a, parity_a); if constexpr (sig_positive) { - //Link force_sig = Matrix_cache.load_z(2); Link force_sig = Matrix_cache[2]; force_sig = mm_add(mycoeff_seven * UbeOeOf, conj(Uaf), force_sig); - //Matrix_cache.save_z(force_sig, 2); Matrix_cache.save(force_sig, 2); } // Compute the force_rho --- contribution from the negative rho direction - //Link Uab = Matrix_cache.load_z(0); Link Uab = Matrix_cache[0]; if constexpr (!sig_positive) Uab = conj(Uab); Link force_rho = arg.force(arg.rho, point_a, parity_a); @@ -777,7 +773,6 @@ namespace quda { Link Ufe = arg.link(arg.sig, fe_link_nbr_idx, fe_link_nbr_parity); // Load Ube from the cache - //Link Ube = Matrix_cache.load_z(1); Link Ube = Matrix_cache[1]; // Form the product UfeUebOb @@ -810,7 +805,6 @@ namespace quda { Link Oz = Ucb * Ob; Link Oy = (sig_positive ? Udc : conj(Udc)) * Oz; p5_sig = mm_add(arg.accumu_coeff_seven * conj(Uda), Oy, p5_sig); - //Matrix_cache.save_z(p5_sig, 1); Matrix_cache.save(p5_sig, 1); // When sig is positive, compute the force_sig contribution from the @@ -819,10 +813,8 @@ namespace quda { Link Od = arg.qNuMu(0, point_d, parity_d); Link Oc = arg.pNuMu(0, point_c, parity_c); Link Oz = conj(Ucb) * Oc; - //Link force_sig = Matrix_cache.load_z(2); Link force_sig = Matrix_cache[2]; force_sig = mm_add(mycoeff_seven * Oz, Od * Uda, force_sig); - //Matrix_cache.save_z(force_sig, 2); Matrix_cache.save(force_sig, 2); } @@ -833,7 +825,7 @@ namespace quda { @param[in] x Local coordinate @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Shared memory cache that maintains the full P5 contribution + @param[in/out] Matrix_cache Thread local cache that maintains the full P5 contribution summed from the previous middle five and all seven (read), as well as force_sig when sig is positive (read/write) @details This subset of the code computes the side link five link contribution to the HISQ force. @@ -843,9 +835,8 @@ namespace quda { Flops: 2 multiplies, 2 adds, 2 rescales */ - //__device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) { - template - __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) { + template + __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { int y[4] = {x[0], x[1], x[2], x[3]}; int point_h = updateCoordExtendedIndexShiftMILC(y, arg.nu, arg); int parity_h = 1 - parity_a; @@ -859,7 +850,6 @@ namespace quda { int qh_link_nbr_idx = mu_positive ? point_q : point_h; int qh_link_nbr_parity = mu_positive ? parity_q : parity_h; - //Link P5 = Matrix_cache.load_z(1); Link P5 = Matrix_cache[1]; Link Uah = arg.link(arg.nu, ha_link_nbr_idx, ha_link_nbr_parity); Link Ow = nu_positive ? Uah * P5 : conj(Uah) * P5; @@ -885,7 +875,7 @@ namespace quda { @param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice @param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction @param[in] parity_a Parity of the coordinate x - @param[in/out] Matrix_cache Helper shared memory cache that maintains the gauge link going + @param[in/out] Matrix_cache Thread local cache that maintains the gauge link going from a to b (read) and, when sig is positive, force_sig (read/write) @details This subset of the code computes the middle link five link contribution to the HISQ force. Data traffic: @@ -898,11 +888,8 @@ namespace quda { Flops: 1 multiply, 1 add, 1 rescale */ - //__device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, - // ThreadLocalCache &Matrix_cache) { - template - __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, - ThreadLocalCache &Matrix_cache) { + template + __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) { int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg); int parity_b = 1 - parity_a; @@ -933,7 +920,6 @@ namespace quda { arg.pNuMu_next(0, point_b, parity_b) = Ow; { // scoped Uab load - //Link Uab = Matrix_cache.load_z(0); Link Uab = Matrix_cache[0]; if constexpr (!sig_positive) Uab = conj(Uab); arg.p5(0, point_a, parity_a) = Uab * Ow; @@ -949,10 +935,8 @@ namespace quda { // compute the force in the sigma direction if sig is positive if constexpr (sig_positive) { - //Link force_sig = Matrix_cache.load_z(2); Link force_sig = Matrix_cache[2]; force_sig = mm_add(arg.coeff_five * Ow, Ox, force_sig); - //Matrix_cache.save_z(force_sig, 2); Matrix_cache.save(force_sig, 2); } } @@ -991,15 +975,11 @@ namespace quda { int parity_a = parity; // calculate p5_sig - //auto block_dim = target::block_dim(); - //block_dim.z = (sig_positive ? 3 : 2); - //ThreadLocalCache Matrix_cache(block_dim); constexpr int cacheLen = sig_positive ? 3 : 2; //ThreadLocalCache> Matrix_cache{}; ThreadLocalCache Matrix_cache{*this}; if constexpr (sig_positive) { Link force_sig = arg.force(arg.sig, point_a, parity_a); - //Matrix_cache.save_z(force_sig, 2); Matrix_cache.save(force_sig, 2); } @@ -1010,7 +990,6 @@ namespace quda { int ab_link_nbr_idx = (sig_positive) ? point_a : point_b; int ab_link_nbr_parity = (sig_positive) ? parity_a : parity_b; Link Uab = arg.link(arg.sig, ab_link_nbr_idx, ab_link_nbr_parity); - //Matrix_cache.save_z(Uab, 0); Matrix_cache.save(Uab, 0); } @@ -1026,7 +1005,6 @@ namespace quda { // update the force in the sigma direction if constexpr (sig_positive) { - //Link force_sig = Matrix_cache.load_z(2); Link force_sig = Matrix_cache[2]; arg.force(arg.sig, point_a, parity_a) = force_sig; }