diff --git a/include/kernels/hisq_paths_force.cuh b/include/kernels/hisq_paths_force.cuh
index b114eee9ec..c0156eab0b 100644
--- a/include/kernels/hisq_paths_force.cuh
+++ b/include/kernels/hisq_paths_force.cuh
@@ -272,7 +272,7 @@ namespace quda {
* A _______ B
* mu_next | |
* H| |G
- *
+ *
* Variables have been named to reflection dimensionality for
* mu_positive == true, sig_positive == true, mu_next_positive == true
**************************************************************************/
@@ -379,7 +379,7 @@ namespace quda {
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
@param[in/out] force_mu Accumulated force in the mu direction
- @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read)
+ @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read)
@details This subset of the code computes the Lepage contribution to the fermion force.
Data traffic:
READ: cb_link, id_link, pMu_at_c
@@ -393,7 +393,8 @@ namespace quda {
Flops:
2 multiplies, 1 add, 1 rescale
*/
- __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, ThreadLocalCache &Uab_cache) {
+ template
+ __device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, LinkCache &Uab_cache) {
int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg);
int parity_b = 1 - parity_a;
@@ -421,7 +422,7 @@ namespace quda {
Link Ow = mu_positive ? (conj(Ucb) * Oc) : (Ucb * Oc);
{
- Link Uab = Uab_cache.load();
+ Link Uab = Uab_cache;
Link Oy = sig_positive ? Uab * Ow : conj(Uab) * Ow;
Link Ox = mu_positive ? (Oy * Uid) : (Uid * conj(Oy));
auto mycoeff_lepage = -coeff_sign(parity_a)*coeff_sign(parity_a)*arg.coeff_lepage;
@@ -447,7 +448,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
- @param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read)
+ @param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read)
Data traffic:
READ: gb_link, oProd_at_h
WRITE: pMu_next_at_b, p3_at_a
@@ -461,7 +462,8 @@ namespace quda {
Flops:
2 multiplies, 1 add, 1 rescale
*/
- __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, ThreadLocalCache &Uab_cache)
+ template
+ __device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, LinkCache &Uab_cache)
{
int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg);
int parity_b = 1 - parity_a;
@@ -494,7 +496,7 @@ namespace quda {
arg.pMu_next(0, point_b, parity_b) = Oz;
{
// scoped Uab load
- Link Uab = Uab_cache.load();
+ Link Uab = Uab_cache;
if constexpr (!sig_positive) Uab = conj(Uab);
arg.p3(0, point_a, parity_a) = Uab * Oz;
}
@@ -542,8 +544,7 @@ namespace quda {
/*
* The "extra" low point corresponds to the Lepage contribution to the
* force_mu term.
- *
- *
+ *
* sig
* F E
* | |
@@ -704,7 +705,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
- @param[in/out] Matrix_cache Shared memory cache that maintains the accumulated P5 contribution (write)
+ @param[in/out] Matrix_cache Thread local cache that maintains the accumulated P5 contribution (write)
the gauge link going from a to b (read), as well as force_sig when sig is positive (read/write)
@details This subset of the code computes the full seven link contribution to the HISQ force.
Data traffic:
@@ -721,9 +722,8 @@ namespace quda {
Flops:
4 multiplies, 2 adds, 2 rescales
*/
- //__device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) {
- template
- __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) {
+ template
+ __device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
auto mycoeff_seven = parity_sign(parity_a) * coeff_sign(parity_a) * arg.coeff_seven;
int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg);
@@ -752,22 +752,18 @@ namespace quda {
UbeOeOf = Ube * OeOf;
// Cache Ube to below
- //Matrix_cache.save_z(Ube, 1);
Matrix_cache.save(Ube, 1);
}
// Take care of force_sig --- contribution from the negative rho direction
Link Uaf = arg.link(arg.rho, point_a, parity_a);
if constexpr (sig_positive) {
- //Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(mycoeff_seven * UbeOeOf, conj(Uaf), force_sig);
- //Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}
// Compute the force_rho --- contribution from the negative rho direction
- //Link Uab = Matrix_cache.load_z(0);
Link Uab = Matrix_cache[0];
if constexpr (!sig_positive) Uab = conj(Uab);
Link force_rho = arg.force(arg.rho, point_a, parity_a);
@@ -777,7 +773,6 @@ namespace quda {
Link Ufe = arg.link(arg.sig, fe_link_nbr_idx, fe_link_nbr_parity);
// Load Ube from the cache
- //Link Ube = Matrix_cache.load_z(1);
Link Ube = Matrix_cache[1];
// Form the product UfeUebOb
@@ -810,7 +805,6 @@ namespace quda {
Link Oz = Ucb * Ob;
Link Oy = (sig_positive ? Udc : conj(Udc)) * Oz;
p5_sig = mm_add(arg.accumu_coeff_seven * conj(Uda), Oy, p5_sig);
- //Matrix_cache.save_z(p5_sig, 1);
Matrix_cache.save(p5_sig, 1);
// When sig is positive, compute the force_sig contribution from the
@@ -819,10 +813,8 @@ namespace quda {
Link Od = arg.qNuMu(0, point_d, parity_d);
Link Oc = arg.pNuMu(0, point_c, parity_c);
Link Oz = conj(Ucb) * Oc;
- //Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(mycoeff_seven * Oz, Od * Uda, force_sig);
- //Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}
@@ -833,7 +825,7 @@ namespace quda {
@param[in] x Local coordinate
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] parity_a Parity of the coordinate x
- @param[in/out] Matrix_cache Shared memory cache that maintains the full P5 contribution
+ @param[in/out] Matrix_cache Thread local cache that maintains the full P5 contribution
summed from the previous middle five and all seven (read), as well as force_sig
when sig is positive (read/write)
@details This subset of the code computes the side link five link contribution to the HISQ force.
@@ -843,9 +835,8 @@ namespace quda {
Flops:
2 multiplies, 2 adds, 2 rescales
*/
- //__device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) {
- template
- __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache &Matrix_cache) {
+ template
+ __device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
int y[4] = {x[0], x[1], x[2], x[3]};
int point_h = updateCoordExtendedIndexShiftMILC(y, arg.nu, arg);
int parity_h = 1 - parity_a;
@@ -859,7 +850,6 @@ namespace quda {
int qh_link_nbr_idx = mu_positive ? point_q : point_h;
int qh_link_nbr_parity = mu_positive ? parity_q : parity_h;
- //Link P5 = Matrix_cache.load_z(1);
Link P5 = Matrix_cache[1];
Link Uah = arg.link(arg.nu, ha_link_nbr_idx, ha_link_nbr_parity);
Link Ow = nu_positive ? Uah * P5 : conj(Uah) * P5;
@@ -885,7 +875,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
- @param[in/out] Matrix_cache Helper shared memory cache that maintains the gauge link going
+ @param[in/out] Matrix_cache Thread local cache that maintains the gauge link going
from a to b (read) and, when sig is positive, force_sig (read/write)
@details This subset of the code computes the middle link five link contribution to the HISQ force.
Data traffic:
@@ -898,11 +888,8 @@ namespace quda {
Flops:
1 multiply, 1 add, 1 rescale
*/
- //__device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a,
- // ThreadLocalCache &Matrix_cache) {
- template
- __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a,
- ThreadLocalCache &Matrix_cache) {
+ template
+ __device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
int point_b = linkExtendedIndexShiftMILC(x, arg.sig, arg);
int parity_b = 1 - parity_a;
@@ -933,7 +920,6 @@ namespace quda {
arg.pNuMu_next(0, point_b, parity_b) = Ow;
{
// scoped Uab load
- //Link Uab = Matrix_cache.load_z(0);
Link Uab = Matrix_cache[0];
if constexpr (!sig_positive) Uab = conj(Uab);
arg.p5(0, point_a, parity_a) = Uab * Ow;
@@ -949,10 +935,8 @@ namespace quda {
// compute the force in the sigma direction if sig is positive
if constexpr (sig_positive) {
- //Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(arg.coeff_five * Ow, Ox, force_sig);
- //Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}
}
@@ -991,15 +975,11 @@ namespace quda {
int parity_a = parity;
// calculate p5_sig
- //auto block_dim = target::block_dim();
- //block_dim.z = (sig_positive ? 3 : 2);
- //ThreadLocalCache Matrix_cache(block_dim);
constexpr int cacheLen = sig_positive ? 3 : 2;
//ThreadLocalCache> Matrix_cache{};
ThreadLocalCache Matrix_cache{*this};
if constexpr (sig_positive) {
Link force_sig = arg.force(arg.sig, point_a, parity_a);
- //Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}
@@ -1010,7 +990,6 @@ namespace quda {
int ab_link_nbr_idx = (sig_positive) ? point_a : point_b;
int ab_link_nbr_parity = (sig_positive) ? parity_a : parity_b;
Link Uab = arg.link(arg.sig, ab_link_nbr_idx, ab_link_nbr_parity);
- //Matrix_cache.save_z(Uab, 0);
Matrix_cache.save(Uab, 0);
}
@@ -1026,7 +1005,6 @@ namespace quda {
// update the force in the sigma direction
if constexpr (sig_positive) {
- //Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
arg.force(arg.sig, point_a, parity_a) = force_sig;
}