Skip to content

Commit

Permalink
sync hisq_paths_force with sycl-merge branch
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Nov 30, 2023
1 parent 92ca04d commit eebfcce
Showing 1 changed file with 19 additions and 41 deletions.
60 changes: 19 additions & 41 deletions include/kernels/hisq_paths_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ namespace quda {
* A _______ B
* mu_next | |
* H| |G
*
*
* Variables have been named to reflection dimensionality for
* mu_positive == true, sig_positive == true, mu_next_positive == true
**************************************************************************/
Expand Down Expand Up @@ -379,7 +379,7 @@ namespace quda {
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
@param[in/out] force_mu Accumulated force in the mu direction
@param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read)
@param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read)
@details This subset of the code computes the Lepage contribution to the fermion force.
Data traffic:
READ: cb_link, id_link, pMu_at_c
Expand All @@ -393,7 +393,8 @@ namespace quda {
Flops:
2 multiplies, 1 add, 1 rescale
*/
__device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, ThreadLocalCache<Link> &Uab_cache) {
template <typename LinkCache>
__device__ __host__ inline void lepage_force(int x[4], int point_a, int parity_a, Link &force_mu, LinkCache &Uab_cache) {
int point_b = linkExtendedIndexShiftMILC<sig_positive>(x, arg.sig, arg);
int parity_b = 1 - parity_a;

Expand Down Expand Up @@ -421,7 +422,7 @@ namespace quda {

Link Ow = mu_positive ? (conj(Ucb) * Oc) : (Ucb * Oc);
{
Link Uab = Uab_cache.load();
Link Uab = Uab_cache;
Link Oy = sig_positive ? Uab * Ow : conj(Uab) * Ow;
Link Ox = mu_positive ? (Oy * Uid) : (Uid * conj(Oy));
auto mycoeff_lepage = -coeff_sign<sig_positive, typename Arg::real>(parity_a)*coeff_sign<mu_positive, typename Arg::real>(parity_a)*arg.coeff_lepage;
Expand All @@ -447,7 +448,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
@param[in] Uab_cache Shared memory cache that stores the gauge link going from a to b (read)
@param[in] Uab_cache Thread local cache that stores the gauge link going from a to b (read)
Data traffic:
READ: gb_link, oProd_at_h
WRITE: pMu_next_at_b, p3_at_a
Expand All @@ -461,7 +462,8 @@ namespace quda {
Flops:
2 multiplies, 1 add, 1 rescale
*/
__device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, ThreadLocalCache<Link> &Uab_cache)
template <typename LinkCache>
__device__ __host__ inline void middle_three(int x[4], int point_a, int parity_a, LinkCache &Uab_cache)
{
int point_b = linkExtendedIndexShiftMILC<sig_positive>(x, arg.sig, arg);
int parity_b = 1 - parity_a;
Expand Down Expand Up @@ -494,7 +496,7 @@ namespace quda {
arg.pMu_next(0, point_b, parity_b) = Oz;
{
// scoped Uab load
Link Uab = Uab_cache.load();
Link Uab = Uab_cache;
if constexpr (!sig_positive) Uab = conj(Uab);
arg.p3(0, point_a, parity_a) = Uab * Oz;
}
Expand Down Expand Up @@ -542,8 +544,7 @@ namespace quda {
/*
* The "extra" low point corresponds to the Lepage contribution to the
* force_mu term.
*
*
*
* sig
* F E
* | |
Expand Down Expand Up @@ -704,7 +705,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
@param[in/out] Matrix_cache Shared memory cache that maintains the accumulated P5 contribution (write)
@param[in/out] Matrix_cache Thread local cache that maintains the accumulated P5 contribution (write)
the gauge link going from a to b (read), as well as force_sig when sig is positive (read/write)
@details This subset of the code computes the full seven link contribution to the HISQ force.
Data traffic:
Expand All @@ -721,9 +722,8 @@ namespace quda {
Flops:
4 multiplies, 2 adds, 2 rescales
*/
//__device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache<Link> &Matrix_cache) {
template <int N>
__device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, ThreadLocalCache<Link,N> &Matrix_cache) {
template <typename LinkCache>
__device__ __host__ inline void all_link(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
auto mycoeff_seven = parity_sign<typename Arg::real>(parity_a) * coeff_sign<sig_positive, typename Arg::real>(parity_a) * arg.coeff_seven;

int point_b = linkExtendedIndexShiftMILC<sig_positive>(x, arg.sig, arg);
Expand Down Expand Up @@ -752,22 +752,18 @@ namespace quda {
UbeOeOf = Ube * OeOf;

// Cache Ube to below
//Matrix_cache.save_z(Ube, 1);
Matrix_cache.save(Ube, 1);
}

// Take care of force_sig --- contribution from the negative rho direction
Link Uaf = arg.link(arg.rho, point_a, parity_a);
if constexpr (sig_positive) {
//Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(mycoeff_seven * UbeOeOf, conj(Uaf), force_sig);
//Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}

// Compute the force_rho --- contribution from the negative rho direction
//Link Uab = Matrix_cache.load_z(0);
Link Uab = Matrix_cache[0];
if constexpr (!sig_positive) Uab = conj(Uab);
Link force_rho = arg.force(arg.rho, point_a, parity_a);
Expand All @@ -777,7 +773,6 @@ namespace quda {
Link Ufe = arg.link(arg.sig, fe_link_nbr_idx, fe_link_nbr_parity);

// Load Ube from the cache
//Link Ube = Matrix_cache.load_z(1);
Link Ube = Matrix_cache[1];

// Form the product UfeUebOb
Expand Down Expand Up @@ -810,7 +805,6 @@ namespace quda {
Link Oz = Ucb * Ob;
Link Oy = (sig_positive ? Udc : conj(Udc)) * Oz;
p5_sig = mm_add(arg.accumu_coeff_seven * conj(Uda), Oy, p5_sig);
//Matrix_cache.save_z(p5_sig, 1);
Matrix_cache.save(p5_sig, 1);

// When sig is positive, compute the force_sig contribution from the
Expand All @@ -819,10 +813,8 @@ namespace quda {
Link Od = arg.qNuMu(0, point_d, parity_d);
Link Oc = arg.pNuMu(0, point_c, parity_c);
Link Oz = conj(Ucb) * Oc;
//Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(mycoeff_seven * Oz, Od * Uda, force_sig);
//Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}

Expand All @@ -833,7 +825,7 @@ namespace quda {
@param[in] x Local coordinate
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] parity_a Parity of the coordinate x
@param[in/out] Matrix_cache Shared memory cache that maintains the full P5 contribution
@param[in/out] Matrix_cache Thread local cache that maintains the full P5 contribution
summed from the previous middle five and all seven (read), as well as force_sig
when sig is positive (read/write)
@details This subset of the code computes the side link five link contribution to the HISQ force.
Expand All @@ -843,9 +835,8 @@ namespace quda {
Flops:
2 multiplies, 2 adds, 2 rescales
*/
//__device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache<Link> &Matrix_cache) {
template <int N>
__device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, ThreadLocalCache<Link,N> &Matrix_cache) {
template <typename LinkCache>
__device__ __host__ inline void side_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
int y[4] = {x[0], x[1], x[2], x[3]};
int point_h = updateCoordExtendedIndexShiftMILC<flip_dir(nu_positive)>(y, arg.nu, arg);
int parity_h = 1 - parity_a;
Expand All @@ -859,7 +850,6 @@ namespace quda {
int qh_link_nbr_idx = mu_positive ? point_q : point_h;
int qh_link_nbr_parity = mu_positive ? parity_q : parity_h;

//Link P5 = Matrix_cache.load_z(1);
Link P5 = Matrix_cache[1];
Link Uah = arg.link(arg.nu, ha_link_nbr_idx, ha_link_nbr_parity);
Link Ow = nu_positive ? Uah * P5 : conj(Uah) * P5;
Expand All @@ -885,7 +875,7 @@ namespace quda {
@param[in] point_a 1-d checkerboard index for the unit site in the full extended lattice
@param[in] point_b 1-d checkerboard index for the unit site shifted in the sig direction
@param[in] parity_a Parity of the coordinate x
@param[in/out] Matrix_cache Helper shared memory cache that maintains the gauge link going
@param[in/out] Matrix_cache Thread local cache that maintains the gauge link going
from a to b (read) and, when sig is positive, force_sig (read/write)
@details This subset of the code computes the middle link five link contribution to the HISQ force.
Data traffic:
Expand All @@ -898,11 +888,8 @@ namespace quda {
Flops:
1 multiply, 1 add, 1 rescale
*/
//__device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a,
// ThreadLocalCache<Link> &Matrix_cache) {
template <int N>
__device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a,
ThreadLocalCache<Link,N> &Matrix_cache) {
template <typename LinkCache>
__device__ __host__ inline void middle_five(int x[4], int point_a, int parity_a, LinkCache &Matrix_cache) {
int point_b = linkExtendedIndexShiftMILC<sig_positive>(x, arg.sig, arg);
int parity_b = 1 - parity_a;

Expand Down Expand Up @@ -933,7 +920,6 @@ namespace quda {
arg.pNuMu_next(0, point_b, parity_b) = Ow;
{
// scoped Uab load
//Link Uab = Matrix_cache.load_z(0);
Link Uab = Matrix_cache[0];
if constexpr (!sig_positive) Uab = conj(Uab);
arg.p5(0, point_a, parity_a) = Uab * Ow;
Expand All @@ -949,10 +935,8 @@ namespace quda {

// compute the force in the sigma direction if sig is positive
if constexpr (sig_positive) {
//Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
force_sig = mm_add(arg.coeff_five * Ow, Ox, force_sig);
//Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}
}
Expand Down Expand Up @@ -991,15 +975,11 @@ namespace quda {
int parity_a = parity;

// calculate p5_sig
//auto block_dim = target::block_dim();
//block_dim.z = (sig_positive ? 3 : 2);
//ThreadLocalCache<Link> Matrix_cache(block_dim);
constexpr int cacheLen = sig_positive ? 3 : 2;
//ThreadLocalCache<array<Link,cacheLen>> Matrix_cache{};
ThreadLocalCache<Link,cacheLen> Matrix_cache{*this};
if constexpr (sig_positive) {
Link force_sig = arg.force(arg.sig, point_a, parity_a);
//Matrix_cache.save_z(force_sig, 2);
Matrix_cache.save(force_sig, 2);
}

Expand All @@ -1010,7 +990,6 @@ namespace quda {
int ab_link_nbr_idx = (sig_positive) ? point_a : point_b;
int ab_link_nbr_parity = (sig_positive) ? parity_a : parity_b;
Link Uab = arg.link(arg.sig, ab_link_nbr_idx, ab_link_nbr_parity);
//Matrix_cache.save_z(Uab, 0);
Matrix_cache.save(Uab, 0);
}

Expand All @@ -1026,7 +1005,6 @@ namespace quda {

// update the force in the sigma direction
if constexpr (sig_positive) {
//Link force_sig = Matrix_cache.load_z(2);
Link force_sig = Matrix_cache[2];
arg.force(arg.sig, point_a, parity_a) = force_sig;
}
Expand Down

0 comments on commit eebfcce

Please sign in to comment.