diff --git a/.github/workflows/cuda_githubactions_build.yml b/.github/workflows/cuda_githubactions_build.yml index 8248e5242e..85e9aaea10 100644 --- a/.github/workflows/cuda_githubactions_build.yml +++ b/.github/workflows/cuda_githubactions_build.yml @@ -48,6 +48,7 @@ jobs: -DCMAKE_CUDA_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -DQUDA_GPU_ARCH=sm_80 -DQUDA_GPU_ARCH_SUFFIX=virtual -DQUDA_JITIFY=ON + -DQUDA_COVDEV=ON -DQUDA_MULTIGRID=ON -DQUDA_MULTIGRID_NVEC_LIST=24 -DQUDA_MDW_FUSED_LS_LIST=4 diff --git a/.github/workflows/rocm-build-ci.yml b/.github/workflows/rocm-build-ci.yml index c8812aa8c9..ff2f877f0c 100644 --- a/.github/workflows/rocm-build-ci.yml +++ b/.github/workflows/rocm-build-ci.yml @@ -30,6 +30,7 @@ jobs: -DQUDA_DIRAC_WILSON=ON \ -DQUDA_DIRAC_LAPLACE=ON \ -DQUDA_CLOVER_DYNAMIC=ON \ + -DQUDA_COVDEV=ON \ -DQUDA_QDPJIT=OFF \ -DQUDA_INTERFACE_QDPJIT=OFF \ -DQUDA_INTERFACE_MILC=ON \ diff --git a/.gitignore b/.gitignore index 58f2516546..a60ac7da48 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,6 @@ *.a *~ tests/*_test -make.inc milc_interface/* *#* *.pyc diff --git a/CMakeLists.txt b/CMakeLists.txt index 5529bfad0d..f4e09a30ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,7 +154,6 @@ option(QUDA_DIRAC_LAPLACE "build laplace operator" ${QUDA_DIRAC_DEFAULT}) option(QUDA_DIRAC_DISTANCE_PRECONDITIONING "build code for distance preconditioned Wilson/clover Dirac operators" OFF) -option(QUDA_CONTRACT "build code for bilinear contraction" OFF) option(QUDA_COVDEV "build code for covariant derivative" OFF) option(QUDA_QIO "build QIO code for binary I/O" OFF) diff --git a/README.md b/README.md index 435677c4be..72daf0e721 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,7 @@ Advanced Scientific Computing (PASC21) [arXiv:2104.05615[hep-lat]]. * Mario Schröck (INFN) * Aniket Sen (HISKP, University of Bonn) * Guochun Shi (NCSA) +* James Simone (Fermi National Accelerator Laboratory) * Alexei Strelchenko (Fermi National Accelerator Laboratory) * Jiqun Tu (NVIDIA) * Carsten Urbach (HISKP, University of Bonn) diff --git a/ci/docker/Dockerfile.build b/ci/docker/Dockerfile.build index 26adb18840..5ca1d7a3dc 100644 --- a/ci/docker/Dockerfile.build +++ b/ci/docker/Dockerfile.build @@ -45,6 +45,7 @@ RUN QUDA_TEST_GRID_SIZE="1 1 1 2" cmake -S /quda/src \ -DQUDA_DIRAC_TWISTED_CLOVER=ON \ -DQUDA_DIRAC_STAGGERED=ON \ -DQUDA_DIRAC_LAPLACE=ON \ + -DQUDA_COVDEV=ON \ -GNinja \ -B /quda/build diff --git a/include/contract_quda.h b/include/contract_quda.h index 3cdc2ee843..557fea6252 100644 --- a/include/contract_quda.h +++ b/include/contract_quda.h @@ -4,6 +4,32 @@ namespace quda { + /** + * Interface function that launch contraction compute kernels, + * used in interface_quda.cpp + * @param[in] x input source field + * @param[in] y input source field + * @param[out] result container of complex contraction results for + * all decay slices and spins + * @param[in] cType contraction types as defined in QudaContractType enum + * @param[in] source_position 4d array of source position + * @param[in] mom_mode 4d array of momentum + * @param[in] fft_type Fourier phase factor type + * as defined in QudaFFTSymmType enum + * @param[in] s1 spin component index (0 for staggered) + * @param[in] b1 spin component index (0 for staggered) + */ + + void contractSummedQuda(const ColorSpinorField &x, const ColorSpinorField &y, std::vector &result, + QudaContractType cType, const int *const source_position, const int *const mom_mode, + const QudaFFTSymmType *const fft_type, const size_t s1, const size_t b1); + /** + * @param[in] x input color spinor + * @param[in] y input color spinor + * @param[out] result pointer to the spinxspin projections per lattice site + * @param[in] cType contraction type + */ + void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, QudaContractType cType); /** diff --git a/include/dslash_quda.h b/include/dslash_quda.h index 4296486448..38efe07e6c 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -927,4 +927,26 @@ namespace quda */ void gamma5(ColorSpinorField &out, const ColorSpinorField &in); + /** + @brief Applies a (1 \pm gamma5)/2 projection matrix to a spinor + @param[out] out Output field + @param[in] in Input field + @param[in] proj Sign of \pm projection + */ + void ApplyChiralProj(ColorSpinorField &out, const ColorSpinorField &in, const int proj); + + /** + @brief Constructs the mid-point 4D propagator from a 5D domain wall propagator + @param[out] out Output field + @param[in] in Input field + */ + void make4DMidPointProp(ColorSpinorField &out, ColorSpinorField &in); + + /** + @brief Constructs the chiral 4D propagator from a 5D domain wall propagator + @param[out] out Output field + @param[in] in Input field + */ + void make4DChiralProp(ColorSpinorField &out, ColorSpinorField &in); + } // namespace quda diff --git a/include/enum_quda.h b/include/enum_quda.h index 869499e018..9676823288 100644 --- a/include/enum_quda.h +++ b/include/enum_quda.h @@ -29,6 +29,7 @@ typedef enum QudaLinkType_s { QUDA_MOMENTUM_LINKS, QUDA_COARSE_LINKS, // used for coarse-gauge field with multigrid QUDA_SMEARED_LINKS, // used for loading and saving gaugeSmeared in the interface + QUDA_TWOLINK_LINKS, // used for staggered fermion smearing. QUDA_WILSON_LINKS = QUDA_SU3_LINKS, // used by wilson, clover, twisted mass, and domain wall QUDA_ASQTAD_FAT_LINKS = QUDA_GENERAL_LINKS, QUDA_ASQTAD_LONG_LINKS = QUDA_THREE_LINKS, @@ -544,12 +545,47 @@ typedef enum QudaStaggeredPhase_s { QUDA_STAGGERED_PHASE_INVALID = QUDA_INVALID_ENUM } QudaStaggeredPhase; +typedef enum QudaSpinTasteGamma_s { + QUDA_SPIN_TASTE_G1 = 0, + QUDA_SPIN_TASTE_GX = 1, + QUDA_SPIN_TASTE_GY = 2, + QUDA_SPIN_TASTE_GZ = 4, + QUDA_SPIN_TASTE_GT = 8, + QUDA_SPIN_TASTE_G5 = 15, + QUDA_SPIN_TASTE_GYGZ = 6, + QUDA_SPIN_TASTE_GZGX = 5, + QUDA_SPIN_TASTE_GXGY = 3, + QUDA_SPIN_TASTE_GXGT = 9, + QUDA_SPIN_TASTE_GYGT = 10, + QUDA_SPIN_TASTE_GZGT = 12, + QUDA_SPIN_TASTE_G5GX = 14, + QUDA_SPIN_TASTE_G5GY = 13, + QUDA_SPIN_TASTE_G5GZ = 11, + QUDA_SPIN_TASTE_G5GT = 7, + QUDA_SPIN_TASTE_INVALID = QUDA_INVALID_ENUM +} QudaSpinTasteGamma; + typedef enum QudaContractType_s { - QUDA_CONTRACT_TYPE_OPEN, // Open spin elementals - QUDA_CONTRACT_TYPE_DR, // DegrandRossi + QUDA_CONTRACT_TYPE_STAGGERED_FT_T, // Staggered, FT in tdim + QUDA_CONTRACT_TYPE_DR_FT_T, // DegrandRossi insertion, FT in tdim + QUDA_CONTRACT_TYPE_DR_FT_Z, // DegrandRossi insertion, FT in zdim + QUDA_CONTRACT_TYPE_STAGGERED, // Staggered, no summation (TODO: remove line) + QUDA_CONTRACT_TYPE_DR, // DegrandRossi insertion, no summation + QUDA_CONTRACT_TYPE_OPEN, // Open spin elementals, no summation + QUDA_CONTRACT_TYPE_OPEN_SUM_T, // Open spin elementals, spatially summed over tdim + QUDA_CONTRACT_TYPE_OPEN_SUM_Z, // Open spin elementals, spatially summed over zdim + QUDA_CONTRACT_TYPE_OPEN_FT_T, // Open spin elementals, FT in tdim + QUDA_CONTRACT_TYPE_OPEN_FT_Z, // Open spin elementals, FT in zdim QUDA_CONTRACT_TYPE_INVALID = QUDA_INVALID_ENUM } QudaContractType; +typedef enum QudaFFTSymmType_t { + QUDA_FFT_SYMM_ODD = 1, // sin(phase) + QUDA_FFT_SYMM_EVEN = 2, // cos(phase) + QUDA_FFT_SYMM_EO = 3, // exp(-i phase) + QUDA_FFT_SYMM_INVALID = QUDA_INVALID_ENUM +} QudaFFTSymmType; + typedef enum QudaContractGamma_s { QUDA_CONTRACT_GAMMA_I = 0, QUDA_CONTRACT_GAMMA_G1 = 1, @@ -580,6 +616,12 @@ typedef enum QudaGaugeSmearType_s { QUDA_GAUGE_SMEAR_INVALID = QUDA_INVALID_ENUM } QudaGaugeSmearType; +typedef enum QudaWFlowType_s { + QUDA_WFLOW_TYPE_WILSON, + QUDA_WFLOW_TYPE_SYMANZIK, + QUDA_WFLOW_TYPE_INVALID = QUDA_INVALID_ENUM +} QudaWFlowType; + typedef enum QudaFermionSmearType_s { QUDA_FERMION_SMEAR_TYPE_GAUSSIAN, QUDA_FERMION_SMEAR_TYPE_WUPPERTAL, diff --git a/include/enum_quda_fortran.h b/include/enum_quda_fortran.h index ef98d0c9a2..596232a7b3 100644 --- a/include/enum_quda_fortran.h +++ b/include/enum_quda_fortran.h @@ -23,6 +23,7 @@ #define QUDA_MOMENTUM_LINKS 3 #define QUDA_COARSE_LINKS 4 #define QUDA_SMEARED_LINKS 5 +#define QUDA_TWOLINK_LINKS 6 #define QUDA_WILSON_LINKS QUDA_SU3_LINKS #define QUDA_ASQTAD_FAT_LINKS QUDA_GENERAL_LINKS @@ -477,10 +478,37 @@ #define QUDA_STAGGERED_PHASE_TIFR 3 #define QUDA_STAGGERED_PHASE_INVALID QUDA_INVALID_ENUM +#define QudaSpinTasteGamma integer(4) +#define QUDA_SPIN_TASTE_G1 0 +#define QUDA_SPIN_TASTE_GX 1 +#define QUDA_SPIN_TASTE_GY 2 +#define QUDA_SPIN_TASTE_GZ 4 +#define QUDA_SPIN_TASTE_GT 8 +#define QUDA_SPIN_TASTE_G5 15 +#define QUDA_SPIN_TASTE_GYGZ 6 +#define QUDA_SPIN_TASTE_GZGX 5 +#define QUDA_SPIN_TASTE_GXGY 3 +#define QUDA_SPIN_TASTE_GXGT 9 +#define QUDA_SPIN_TASTE_GYGT 10 +#define QUDA_SPIN_TASTE_GZGT 12 +#define QUDA_SPIN_TASTE_G5GX 14 +#define QUDA_SPIN_TASTE_G5GY 13 +#define QUDA_SPIN_TASTE_G5GZ 11 +#define QUDA_SPIN_TASTE_G5GT 7 +#define QUDA_SPIN_TASTE_INVALID QUDA_INVALID_ENUM + #define QudaContractType integer(4) -#define QUDA_CONTRACT_TYPE_OPEN , -#define QUDA_CONTRACT_TYPE_DR , -#define QUDA_CONTRACT_TYPE_INVALID = QUDA_INVALID_ENUM +#define QUDA_CONTRACT_TYPE_STAGGERED_FT_T 0 +#define QUDA_CONTRACT_TYPE_DR_FT_T 1 +#define QUDA_CONTRACT_TYPE_DR_FT_Z 2 +#define QUDA_CONTRACT_TYPE_STAGGERED 3 +#define QUDA_CONTRACT_TYPE_DR 4 +#define QUDA_CONTRACT_TYPE_OPEN 5 +#define QUDA_CONTRACT_TYPE_OPEN_SUM_T 6 +#define QUDA_CONTRACT_TYPE_OPEN_SUM_Z 7 +#define QUDA_CONTRACT_TYPE_OPEN_FT_T 8 +#define QUDA_CONTRACT_TYPE_OPEN_FT_Z 9 +#define QUDA_CONTRACT_TYPE_INVALID QUDA_INVALID_ENUM #define QudaContractGamma integer(4) #define QUDA_CONTRACT_GAMMA_I 0 diff --git a/include/gamma.cuh b/include/gamma.cuh index c108ee9725..ed5474711f 100644 --- a/include/gamma.cuh +++ b/include/gamma.cuh @@ -291,4 +291,123 @@ namespace quda { inline constexpr int Dir() const { return dir; } }; + // list of specialized structures used in the contraction kernels: + + constexpr array, 16> get_dr_gm_i() + { + return {{// VECTORS + // G_idx = 1: \gamma_1 + {3, 2, 1, 0}, + + // G_idx = 2: \gamma_2 + {3, 2, 1, 0}, + + // G_idx = 3: \gamma_3 + {2, 3, 0, 1}, + + // G_idx = 4: \gamma_4 + {2, 3, 0, 1}, + + // PSEUDO-VECTORS + // G_idx = 6: \gamma_5\gamma_1 + {3, 2, 1, 0}, + + // G_idx = 7: \gamma_5\gamma_2 + {3, 2, 1, 0}, + + // G_idx = 8: \gamma_5\gamma_3 + {2, 3, 0, 1}, + + // G_idx = 9: \gamma_5\gamma_4 + {2, 3, 0, 1}, + + // SCALAR + // G_idx = 0: I + {0, 1, 2, 3}, + + // PSEUDO-SCALAR + // G_idx = 5: \gamma_5 + {0, 1, 2, 3}, + + // TENSORS + // G_idx = 10: (i/2) * [\gamma_1, \gamma_2] + {0, 1, 2, 3}, + + // G_idx = 11: (i/2) * [\gamma_1, \gamma_3]. this matrix was corrected + {1, 0, 3, 2}, + + // G_idx = 12: (i/2) * [\gamma_1, \gamma_4] + {1, 0, 3, 2}, + + // G_idx = 13: (i/2) * [\gamma_2, \gamma_3] + {1, 0, 3, 2}, + + // G_idx = 14: (i/2) * [\gamma_2, \gamma_4] + {1, 0, 3, 2}, + + // G_idx = 15: (i/2) * [\gamma_3, \gamma_4]. this matrix was corrected + {0, 1, 2, 3}}}; + } + + template constexpr array, 4>, 16> get_dr_g5gm_z() + { + + constexpr complex p_i = complex(0., +1.); + constexpr complex m_i = complex(0., -1.); + constexpr complex p_1 = complex(+1., 0.); + constexpr complex m_1 = complex(-1., 0.); + + return {{// VECTORS + // G_idx = 1: \gamma_1 + {p_i, p_i, p_i, p_i}, + + // G_idx = 2: \gamma_2 + {m_1, p_1, m_1, p_1}, + + // G_idx = 3: \gamma_3 + {p_i, m_i, p_i, m_i}, + + // G_idx = 4: \gamma_4 + {p_1, p_1, m_1, m_1}, + + // PSEUDO-VECTORS + // G_idx = 6: \gamma_5\gamma_1 + {p_i, p_i, m_i, m_i}, + + // G_idx = 7: \gamma_5\gamma_2 + {m_1, p_1, p_1, m_1}, + + // G_idx = 8: \gamma_5\gamma_3 + {p_i, m_i, m_i, p_i}, + + // G_idx = 9: \gamma_5\gamma_4 + {p_1, p_1, p_1, p_1}, + + // SCALAR + // G_idx = 0: I + {p_1, p_1, m_1, m_1}, + + // PSEUDO-SCALAR + // G_idx = 5: \gamma_5 + {p_1, p_1, p_1, p_1}, + + // TENSORS + // G_idx = 10: (i/2) * [\gamma_1, \gamma_2] + {p_1, m_1, m_1, p_1}, + + // G_idx = 11: (i/2) * [\gamma_1, \gamma_3]. this matrix was corrected + {m_i, p_i, p_i, m_i}, + + // G_idx = 12: (i/2) * [\gamma_1, \gamma_4] + {m_1, m_1, m_1, m_1}, + + // G_idx = 13: (i/2) * [\gamma_2, \gamma_3] + {p_1, p_1, m_1, m_1}, + + // G_idx = 14: (i/2) * [\gamma_2, \gamma_4] + {m_i, p_i, m_i, p_i}, + + // G_idx = 15: (i/2) * [\gamma_3, \gamma_4]. this matrix was corrected + {m_1, p_1, m_1, p_1}}}; + } } // namespace quda diff --git a/include/kernels/contraction.cuh b/include/kernels/contraction.cuh index b1f6e7f052..96079ed041 100644 --- a/include/kernels/contraction.cuh +++ b/include/kernels/contraction.cuh @@ -5,17 +5,230 @@ #include #include #include +#include namespace quda { + static constexpr int max_contract_results = 16; // sized for nSpin**2 = 16 - template struct ContractionArg : kernel_param<> { + using spinor_array = array, max_contract_results>; + using staggered_spinor_array = array; + + template __device__ void sink_from_t_xyz(int sink[4], int t, int xyz, T X[4]) + { +#pragma unroll + for (int d = 0; d < 4; d++) { + if (d != reduction_dim) { + sink[d] = xyz % X[d]; + xyz /= X[d]; + } + } + sink[reduction_dim] = t; + return; + } + + template __device__ int idx_from_sink(T X[4], int *sink) + { + return ((sink[3] * X[2] + sink[2]) * X[1] + sink[1]) * X[0] + sink[0]; + } + + template __device__ int idx_from_t_xyz(int t, int xyz, T X[4]) + { + int x[4]; +#pragma unroll + for (int d = 0; d < 4; d++) { + if (d != reduction_dim) { + x[d] = xyz % X[d]; + xyz /= X[d]; + } + } + x[reduction_dim] = t; + return (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]); + } + + template + struct ContractionSummedArg : public ReduceArg { + using reduce_t = contract_t; + // This the direction we are performing reduction on. default to 3. + static constexpr int reduction_dim = reduction_dim_; + + using real = typename mapper::type; + static constexpr int nColor = nColor_; + static constexpr int nSpin = nSpin_; + static constexpr bool spin_project = nSpin_ == 1 ? false : true; + static constexpr bool spinor_direct_load = false; // false means texture load + + typedef typename colorspinor_mapper::type F; + + F x; + F y; + int s1, b1; + int mom_mode[4]; + QudaFFTSymmType fft_type[4]; + int source_position[4]; + int NxNyNzNt[4]; + int t_offset; + int offsets[4]; + + int_fastdiv X[4]; // grid dimensions + + ContractionSummedArg(const ColorSpinorField &x, const ColorSpinorField &y, const int source_position_in[4], + const int mom_mode_in[4], const QudaFFTSymmType fft_type_in[4], const int s1, const int b1) : + ReduceArg(dim3(x.Volume() / x.X()[reduction_dim], 1, x.X()[reduction_dim]), x.X()[reduction_dim]), + x(x), + y(y), + s1(s1), + b1(b1) + // Launch xyz threads per t, t times. + { + for (int i = 0; i < 4; i++) { + X[i] = x.X()[i]; + source_position[i] = source_position_in[i]; + mom_mode[i] = mom_mode_in[i]; + fft_type[i] = fft_type_in[i]; + offsets[i] = comm_coord(i) * x.X()[i]; + NxNyNzNt[i] = comm_dim(i) * x.X()[i]; + } + } + }; + + template struct DegrandRossiContractFT : plus { + + using reduce_t = spinor_array; + using plus::operator(); + static constexpr int reduce_block_dim = 1; // + + const Arg &arg; + constexpr DegrandRossiContractFT(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + // overload comm_reduce to defer until the entire "tile" is complete + template static inline void comm_reduce(U &) { } + + // Final param is unused in the MultiReduce functor in this use case. + __device__ __host__ inline reduce_t operator()(reduce_t &result, int xyz, int, int t) + { + constexpr int nSpin = Arg::nSpin; + constexpr int nColor = Arg::nColor; + + using real = typename Arg::real; + using Vector = ColorSpinor; + + constexpr array, nSpin *nSpin> gm_i = get_dr_gm_i(); + constexpr array, nSpin>, nSpin *nSpin> g5gm_z = get_dr_g5gm_z(); + + int s1 = arg.s1; + int b1 = arg.b1; + + // The coordinate of the sink + int sink[4]; + sink_from_t_xyz(sink, t, xyz, arg.X); + + // Calculate exp(-i * [x dot p]) + double Sum_dXi_dot_Pi = 0.0; + for (int i = 0; i < 4; i++) + Sum_dXi_dot_Pi += (arg.source_position[i] - sink[i] - arg.offsets[i]) * arg.mom_mode[i] * 1. / arg.NxNyNzNt[i]; + + complex phase = {cospi(Sum_dXi_dot_Pi * 2.), -sinpi(Sum_dXi_dot_Pi * 2.)}; + + // Collect vector data + int parity = 0; + int idx = idx_from_t_xyz(t, xyz, arg.X); + int idx_cb = getParityCBFromFull(parity, arg.X, idx); + Vector x = arg.x(idx_cb, parity); + Vector y = arg.y(idx_cb, parity); + + // loop over channels + reduce_t result_all_channels = {}; + for (int G_idx = 0; G_idx < 16; G_idx++) { + for (int s2 = 0; s2 < nSpin; s2++) { + + // We compute the contribution from s1,b1 and s2,b2 from props x and y respectively. + int b2 = gm_i[G_idx][s2]; + // get non-zero column index for current s1 + int b1_tmp = gm_i[G_idx][s1]; + + // only contributes if we're at the correct b1 from the outer loop FIXME + if (b1_tmp == b1) { + // use tr[ Gamma * Prop * Gamma * g5 * conj(Prop) * g5] = tr[g5*Gamma*Prop*g5*Gamma*(-1)^{?}*conj(Prop)]. + // gamma_5 * gamma_i gamma_5 * gamma_idx + auto prop_product = g5gm_z[G_idx][b2] * innerProduct(x, y, b2, s2) * g5gm_z[G_idx][b1]; + result_all_channels[G_idx][0] += prop_product.real() * phase.real() - prop_product.imag() * phase.imag(); + result_all_channels[G_idx][1] += prop_product.imag() * phase.real() + prop_product.real() * phase.imag(); + } + } + } + + return operator()(result_all_channels, result); + } + }; + + template struct StaggeredContractFT : plus { + using reduce_t = typename Arg::reduce_t; + using plus::operator(); + + static constexpr int reduce_block_dim = 1; + + const Arg &arg; + constexpr StaggeredContractFT(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + // overload comm_reduce to defer until the entire "tile" is complete + template static inline void comm_reduce(U &) { } + + // y index param is unused in the MultiReduce functor in this use case. + __device__ __host__ inline reduce_t operator()(reduce_t &result, int xyz, int, int t) + { + using real = typename Arg::real; + using Vector = ColorSpinor; + + // The coordinate of the sink + int sink[4]; + sink_from_t_xyz(sink, t, xyz, arg.X); + + // Collect vector data + int parity = 0; + int idx = idx_from_t_xyz(t, xyz, arg.X); + int idx_cb = getParityCBFromFull(parity, arg.X, idx); + Vector x = arg.x(idx_cb, parity); + Vector y = arg.y(idx_cb, parity); + + // Color inner product: <\phi(x)_{\mu} | \phi(y)_{\nu}> ; The Bra is conjugated + complex prop_prod = innerProduct(x, y); + + // Fourier phase + complex ph; + complex phase(1.0, 0.0); + // Phase factor for each direction is either the cos, sin, or exp Fourier phase +#pragma unroll + for (int dir = 0; dir < 4; dir++) { + auto dXi_dot_Pi + = 2.0 * (sink[dir] + arg.offsets[dir] - arg.source_position[dir]) * arg.mom_mode[dir] / arg.NxNyNzNt[dir]; + if (arg.fft_type[dir] == QUDA_FFT_SYMM_EO) { + // exp(+i k.x) case + ph = {cospi(dXi_dot_Pi), sinpi(dXi_dot_Pi)}; + } else if (arg.fft_type[dir] == QUDA_FFT_SYMM_EVEN) { + // cos(k.x) case + ph = {cospi(dXi_dot_Pi), 0.0}; + } else if (arg.fft_type[dir] == QUDA_FFT_SYMM_ODD) { + // sin(k.x) case + ph = {0.0, sinpi(dXi_dot_Pi)}; + } + phase *= ph; + } + + complex result_all_channels = phase * complex {prop_prod.real(), prop_prod.imag()}; + return operator()({result_all_channels.real(), result_all_channels.imag()}, result); + } + }; + + template struct ContractionArg : kernel_param<> { using real = typename mapper::type; int X[4]; // grid dimensions - static constexpr int nSpin = 4; + static constexpr int nSpin = nSpin_; static constexpr int nColor = nColor_; - static constexpr bool spin_project = true; + static constexpr bool spin_project = spin_project_; static constexpr bool spinor_direct_load = false; // false means texture load // Create a typename F for the ColorSpinorField (F for fermion) @@ -236,4 +449,26 @@ namespace quda arg.s.save(A_, x_cb, parity); } }; + + template struct StaggeredContract { + const Arg &arg; + constexpr StaggeredContract(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ inline void operator()(int x_cb, int parity) + { + constexpr int nSpin = Arg::nSpin; + using real = typename Arg::real; + using Vector = ColorSpinor; + + Vector x = arg.x(x_cb, parity); + Vector y = arg.y(x_cb, parity); + + Matrix, nSpin> A; + // Color inner product: <\phi(x)_{\mu} | \phi(y)_{\nu}> ; The Bra is conjugated + A(0, 0) = innerProduct(x, y); + + arg.s.save(A, x_cb, parity); + } + }; } // namespace quda diff --git a/include/kernels/covariant_derivative.cuh b/include/kernels/covariant_derivative.cuh index 0e4e9213ed..8a9a18f30a 100644 --- a/include/kernels/covariant_derivative.cuh +++ b/include/kernels/covariant_derivative.cuh @@ -14,10 +14,10 @@ namespace quda /** @brief Parameter structure for driving the covariant derivative operator */ - template + template struct CovDevArg : DslashArg { static constexpr int nColor = nColor_; - static constexpr int nSpin = 4; + static constexpr int nSpin = nSpin_; static constexpr bool spin_project = false; static constexpr bool spinor_direct_load = false; // false means texture load typedef typename colorspinor_mapper::type F; @@ -131,7 +131,7 @@ namespace quda __device__ __host__ inline void operator()(int idx, int src_idx, int parity) { using real = typename mapper::type; - using Vector = ColorSpinor; + using Vector = ColorSpinor; // is thread active (non-trival for fused kernel only) bool active = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; diff --git a/include/kernels/dslash_gamma_helper.cuh b/include/kernels/dslash_gamma_helper.cuh index a19ac449f7..0d18a416ad 100644 --- a/include/kernels/dslash_gamma_helper.cuh +++ b/include/kernels/dslash_gamma_helper.cuh @@ -19,7 +19,7 @@ namespace quda { F out[MAX_MULTI_RHS]; // output vector field F in[MAX_MULTI_RHS]; // input vector field const int d; // which gamma matrix are we applying - const int nParity; // number of parities we're working on + const int proj = 0; // which gamma projection are we applying const bool doublet; // whether we applying the operator to a doublet const int n_flavor; // number of flavors const int volumeCB; // checkerboarded volume @@ -27,11 +27,12 @@ namespace quda { real b; // chiral twist real c; // flavor twist - GammaArg(cvector_ref &out, cvector_ref &in, int d, real kappa = 0.0, - real mu = 0.0, real epsilon = 0.0, bool dagger = false, + GammaArg(cvector_ref &out, cvector_ref &in, int d, int proj = 0, + real kappa = 0.0, real mu = 0.0, real epsilon = 0.0, bool dagger = false, QudaTwistGamma5Type twist = QUDA_TWIST_GAMMA5_INVALID) : + kernel_param(dim3(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET ? in.VolumeCB() / 2 : in.VolumeCB(), in.size(), in.SiteSubset())), d(d), - nParity(in.SiteSubset()), + proj(proj), doublet(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET), n_flavor(doublet ? 2 : 1), volumeCB(doublet ? in.VolumeCB() / 2 : in.VolumeCB()), @@ -47,6 +48,7 @@ namespace quda { checkPrecision(out, in); checkLocation(out, in); if (d < 0 || d > 4) errorQuda("Undefined gamma matrix %d", d); + if (proj < -1 || proj > 1) errorQuda("Undefined gamma projection %d", proj); if (in.Nspin() != 4) errorQuda("Cannot apply gamma5 to nSpin=%d field", in.Nspin()); if (!in.isNative() || !out.isNative()) errorQuda("Unsupported field order out=%d in=%d\n", out.FieldOrder(), in.FieldOrder()); @@ -73,7 +75,6 @@ namespace quda { } if (dagger) b *= -1.0; } - this->threads = dim3(doublet ? in.VolumeCB()/2 : in.VolumeCB(), in.SiteSubset(), 1); } }; @@ -100,6 +101,38 @@ namespace quda { } }; + /** + @brief Application of chiral projection to a color spinor field + */ + template struct ChiralProject { + const Arg &arg; + constexpr ChiralProject(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int src_idx, int parity) + { + ColorSpinor in = arg.in[src_idx](x_cb, parity); + ColorSpinor chi; + + // arg.proj is either +1 or -1, or 0. + // chiral_project/reconstruct(int p) expects 0 (+ve proj) or 1 (-ve proj) + // chiral_reconstruct(int p) returns the projected spinor with the + // opposite projection zerod out. + switch (arg.proj) { + case -1: + chi = in.chiral_project(1); + arg.out[src_idx](x_cb, parity) = chi.chiral_reconstruct(1); + break; + + case 1: + chi = in.chiral_project(0); + arg.out[src_idx](x_cb, parity) = chi.chiral_reconstruct(0); + break; + case 0: break; + } + } + }; + /** @brief Application of twist to a color spinor field */ @@ -135,13 +168,12 @@ namespace quda { F out[MAX_MULTI_RHS]; // output vector field F in[MAX_MULTI_RHS]; // input vector field const int d; // which gamma matrix are we applying - const int nParity; // number of parities we're working on bool doublet; // whether we applying the operator to a doublet const int volumeCB; // checkerboarded volume TauArg(cvector_ref &out, cvector_ref &in, int d) : + kernel_param(dim3(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET ? in.VolumeCB() / 2 : in.VolumeCB(), in.size(), in.SiteSubset())), d(d), - nParity(in.SiteSubset()), doublet(in.TwistFlavor() == QUDA_TWIST_NONDEG_DOUBLET), volumeCB(doublet ? in.VolumeCB() / 2 : in.VolumeCB()) { @@ -157,8 +189,6 @@ namespace quda { if (!in.isNative() || !out.isNative()) errorQuda("Unsupported field order out=%d in=%d\n", out.FieldOrder(), in.FieldOrder()); if (!doublet) errorQuda("tau matrix can be applyed only to spinor doublet"); - - this->threads = dim3(doublet ? in.VolumeCB() / 2 : in.VolumeCB(), in.SiteSubset(), 1); } }; /** diff --git a/include/kernels/spin_taste.cuh b/include/kernels/spin_taste.cuh new file mode 100644 index 0000000000..bcd2460e2f --- /dev/null +++ b/include/kernels/spin_taste.cuh @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include + +namespace quda +{ + + template struct SpinTasteArg : kernel_param<> { + using Float = Float_; + static constexpr int nColor = nColor_; + static_assert(nColor == 3, "Only nColor=3 enabled at this time"); + static constexpr QudaSpinTasteGamma gamma = gamma_; + using F = typename colorspinor_mapper::type; + + int X[4]; + + F out; /** output vector field */ + const F in; /** input vector field */ + + SpinTasteArg(ColorSpinorField &out_, const ColorSpinorField &in_) : + kernel_param(dim3(in_.VolumeCB(), in_.SiteSubset(), 1)), out(out_), in(in_) + { + checkOrder(out_, in_); // check all orders match + checkPrecision(out_, in_); // check all precisions match + checkLocation(out_, in_); // check all locations match + if (!in_.isNative()) errorQuda("Unsupported field order colorspinor= %d \n", in_.FieldOrder()); + if (!out_.isNative()) errorQuda("Unsupported field order colorspinor= %d \n", out_.FieldOrder()); +#pragma unroll + for (int i = 0; i < 4; i++) { X[i] = in_.X()[i]; } + } + }; + + // FIXME only works with even local volumes + template struct SpinTastePhase { + const Arg &arg; + constexpr SpinTastePhase(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity) + { + using real = typename mapper::type; + using Vector = ColorSpinor; + + int x[4]; + + getCoords(x, x_cb, arg.X, parity); + + real sign = 1.0; + + if (Arg::gamma == QUDA_SPIN_TASTE_GX) { + sign = 1.0 - 2.0 * ((x[1] + x[2] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GY) { + sign = 1.0 - 2.0 * ((x[0] + x[2] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GZ) { + sign = 1.0 - 2.0 * ((x[0] + x[1] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GT) { + sign = 1.0 - 2.0 * ((x[0] + x[1] + x[2]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_G5) { + sign = 1.0 - 2.0 * ((x[0] + x[1] + x[2] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GYGZ) { + sign = 1.0 - 2.0 * ((x[1] + x[2]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GZGX) { + sign = 1.0 - 2.0 * ((x[2] + x[0]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GXGY) { + sign = 1.0 - 2.0 * ((x[0] + x[1]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GXGT) { + sign = 1.0 - 2.0 * ((x[0] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GYGT) { + sign = 1.0 - 2.0 * ((x[1] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_GZGT) { + sign = 1.0 - 2.0 * ((x[2] + x[3]) % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_G5GX) { + sign = 1.0 - 2.0 * (x[0] % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_G5GY) { + sign = 1.0 - 2.0 * (x[1] % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_G5GZ) { + sign = 1.0 - 2.0 * (x[2] % 2); + } else if (Arg::gamma == QUDA_SPIN_TASTE_G5GT) { + sign = 1.0 - 2.0 * (x[3] % 2); + } + + Vector out = arg.in(x_cb, parity); + + arg.out(x_cb, parity) = sign * out; + } + }; + +} // namespace quda diff --git a/include/quda.h b/include/quda.h index 992b4d4459..ed974dbe9f 100644 --- a/include/quda.h +++ b/include/quda.h @@ -90,6 +90,7 @@ extern "C" { size_t gauge_offset; /**< Offset into MILC site struct to the gauge field (only if gauge_order=MILC_SITE_GAUGE_ORDER) */ size_t mom_offset; /**< Offset into MILC site struct to the momentum field (only if gauge_order=MILC_SITE_GAUGE_ORDER) */ size_t site_size; /**< Size of MILC site struct (only if gauge_order=MILC_SITE_GAUGE_ORDER) */ + } QudaGaugeParam; @@ -1146,7 +1147,12 @@ extern "C" { * Free QUDA's internal smeared gauge field. */ void freeGaugeSmearedQuda(void); - + + /** + * Free QUDA's internal two-link gauge field. + */ + void freeGaugeTwoLinkQuda(void); + /** * Save the gauge field to the host. * @param h_gauge Base pointer to host gauge field (regardless of dimensionality) @@ -1263,14 +1269,43 @@ extern "C" { /** * Apply the Dslash operator (D_{eo} or D_{oe}). - * @param h_out Result spinor field - * @param h_in Input spinor field - * @param param Contains all metadata regarding host and device + * @param[out] h_out Result spinor field + * @param[in] h_in Input spinor field + * @param[in] param Contains all metadata regarding host and device * storage - * @param parity The destination parity of the field + * @param[in] parity The destination parity of the field */ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity parity); + /** + * Apply the covariant derivative. + * @param[out] h_out Result spinor field + * @param[in] h_in Input spinor field + * @param[in] dir Direction of application + * @param[in] param Metadata for host and device storage + */ + void covDevQuda(void *h_out, void *h_in, int dir, QudaInvertParam *param); + + /** + * Apply the covariant derivative. + * @param[out] h_out Result spinor field + * @param[in] h_in Input spinor field + * @param[in] dir Direction of application + * @param[in] sym Apply forward=2, backward=2 or symmetric=3 shift + * @param[in] param Metadata for host and device storage + */ + void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param); + + /** + * Apply the spin-taste operator. + * @param[out] h_out Result spinor field + * @param[in] h_in Input spinor field + * @param[in] spin Spin gamma structure + * @param[in] taste Taste gamma structure + * @param[in] param Metadata for host and device storage + */ + void spinTasteQuda(void *h_out, void *h_in, int spin, int taste, QudaInvertParam *param); + /** * @brief Perform the solve like @dslashQuda but for multiple rhs by spliting the comm grid into * sub-partitions: each sub-partition does one or more rhs'. @@ -1597,6 +1632,21 @@ extern "C" { void copyExtendedResidentGaugeQuda(void *resident_gauge); /** + * Performs gaussian/Wuppertal smearing on a given spinor using the gauge field + * gaugeSmeared, if it exist, or gaugePrecise if no smeared field is present. + * @param h_in Input spinor field + * @param h_out Output spinor field + * @param param Contains all metadata regarding host and device + * storage and operator which will be applied to the spinor + * @param n_steps Number of steps to apply. + * @param coeff Width of the Gaussian distribution + * @param smear_type Gaussian/Wuppertal smearing + */ + void performFermionSmearQuda(void *h_out, void *h_in, QudaInvertParam *param, const int n_steps, const double coeff, + const QudaFermionSmearType smear_type); + + /** + * LEGACY * Performs Wuppertal smearing on a given spinor using the gauge field * gaugeSmeared, if it exist, or gaugePrecise if no smeared field is present. * @param h_out Result spinor field @@ -1608,6 +1658,19 @@ extern "C" { */ void performWuppertalnStep(void *h_out, void *h_in, QudaInvertParam *param, unsigned int n_steps, double alpha); + /** + * LEGACY + * Performs gaussian smearing on a given spinor using the gauge field + * gaugeSmeared, if it exist, or gaugePrecise if no smeared field is present. + * @param h_in Input spinor field + * @param h_out Output spinor field + * @param param Contains all metadata regarding host and device + * storage and operator which will be applied to the spinor + * @param n_steps Number of steps to apply. + * @param omega Width of the Gaussian distribution + */ + void performGaussianSmearNStep(void *h_out, void *h_in, QudaInvertParam *param, const int n_steps, const double omega); + /** * Performs APE, Stout, or Over Imroved STOUT smearing on gaugePrecise and stores it in gaugeSmeared * @param[in] smear_param Parameter struct that defines the computation parameters @@ -1646,6 +1709,23 @@ extern "C" { void contractQuda(const void *x, const void *y, void *result, const QudaContractType cType, QudaInvertParam *param, const int *X); + /** + * @param[in] x pointer to host data array + * @param[in] y pointer to host data array + * @param[out] result pointer to the spin*spin projections per lattice slice site + * @param[in] cType Which type of contraction (open, degrand-rossi, etc) + * @param[in] param meta data for construction of ColorSpinorFields. + * @param[in] src_colors color dilution parameter + * @param[in] X local lattice dimansions + * @param[in] source_position source position array + * @param[in] number of momentum modes + * @param[in] mom_modes momentum modes + * @param[in] fft_type Fourier phase factor type (cos, sin or exp{ikx}) + */ + void contractFTQuda(void **x, void **y, void **result, const QudaContractType cType, void *cs_param_ptr, + const int src_colors, const int *X, const int *const source_position, const int n_mom, + const int *const mom_modes, const QudaFFTSymmType *const fft_type); + /** * @brief Gauge fixing with overrelaxation with support for single and multi GPU. * @param[in,out] gauge, gauge field to be fixed diff --git a/include/quda_define.h.in b/include/quda_define.h.in index 1272ef6404..ea6657120f 100644 --- a/include/quda_define.h.in +++ b/include/quda_define.h.in @@ -129,15 +129,6 @@ #define GPU_COVDEV #endif -#cmakedefine QUDA_CONTRACT -#ifdef QUDA_CONTRACT -/** - * @def GPU_CONTRACT - * @brief This macro is set when we have contractions enabled - */ -#define GPU_CONTRACT -#endif - #cmakedefine QUDA_MULTIGRID #ifdef QUDA_MULTIGRID /** diff --git a/include/quda_milc_interface.h b/include/quda_milc_interface.h index c3207cfaa2..bda2aa7f6b 100644 --- a/include/quda_milc_interface.h +++ b/include/quda_milc_interface.h @@ -118,6 +118,18 @@ extern "C" { int use_pinned_memory; /** use page-locked memory in QUDA */ } QudaFatLinkArgs_t; + /** + * Parameters for propagator contractions with FT + */ + typedef struct { + int n_mom; /* Number of sink momenta */ + int *mom_modes; /* List of 4-component momenta as integers. Dimension 4*n_mom */ + QudaFFTSymmType *fft_type; /* The "parity" of the FT component */ + int *source_position; /* The coordinate origin for the Fourier phases */ + double flops; /* Return value */ + double dtime; /* Return value */ + } QudaContractArgs_t; + /** * Parameters for two-link Gaussian quark smearing. */ @@ -129,25 +141,25 @@ extern "C" { int t0; /** Set if the input spinor is on a time slice **/ int laplaceDim; /** Dimension of Laplacian **/ } QudaTwoLinkQuarkSmearArgs_t; - + /** * Optional: Set the MPI Comm Handle if it is not MPI_COMM_WORLD * - * @param input Pointer to an MPI_Comm handle, static cast as a void * + * @param[in] input Pointer to an MPI_Comm handle, static cast as a void * */ void qudaSetMPICommHandle(void *mycomm); /** * Initialize the QUDA context. * - * @param input Meta data for the QUDA context + * @param[in] input Meta data for the QUDA context */ void qudaInit(QudaInitArgs_t input); /** * Set set the local dimensions and machine topology for QUDA to use * - * @param layout Struct defining local dimensions and machine topology + * @param[in] layout Struct defining local dimensions and machine topology */ void qudaSetLayout(QudaLayout_t layout); @@ -158,27 +170,27 @@ extern "C" { /** * Allocate pinned memory suitable for CPU-GPU transfers - * @param bytes The size of the requested allocation + * @param[in] bytes The size of the requested allocation * @return Pointer to allocated memory - */ + */ void* qudaAllocatePinned(size_t bytes); /** * Free pinned memory - * @param ptr Pointer to memory to be free + * @param[in] ptr Pointer to memory to be free */ void qudaFreePinned(void *ptr); /** * Allocate managed memory to reduce CPU-GPU transfers - * @param bytes The size of the requested allocation + * @param[in] bytes The size of the requested allocation * @return Pointer to allocated memory */ void *qudaAllocateManaged(size_t bytes); /** * Free managed memory - * @param ptr Pointer to memory to be free + * @param[in] ptr Pointer to memory to be free */ void qudaFreeManaged(void *ptr); @@ -186,7 +198,7 @@ extern "C" { * Set the algorithms to use for HISQ fermion calculations, e.g., * SVD parameters for reunitarization. * - * @param hisq_params Meta data desribing the algorithms to use for HISQ fermions + * @param[in] hisq_params Meta data desribing the algorithms to use for HISQ fermions */ void qudaHisqParamsInit(QudaHisqParams_t hisq_params); @@ -195,12 +207,12 @@ extern "C" { * fields passed here are host fields, that must be preallocated. * The precision of all fields must match. * - * @param precision The precision of the fields - * @param fatlink_args Meta data for the algorithms to deploy - * @param act_path_coeff Array of coefficients for each path in the action - * @param inlink Host gauge field used for input - * @param fatlink Host fat-link field that is computed - * @param longlink Host long-link field that is computed + * @param[in] precision The precision of the fields + * @param[in] fatlink_args Meta data for the algorithms to deploy + * @param[in] act_path_coeff Array of coefficients for each path in the action + * @param[in] inlink Host gauge field used for input + * @param[out] fatlink Host fat-link field that is computed + * @param[out] longlink Host long-link field that is computed */ void qudaLoadKSLink(int precision, QudaFatLinkArgs_t fatlink_args, @@ -214,12 +226,12 @@ extern "C" { * All fields passed here are host fields, that must be * preallocated. The precision of all fields must match. * - * @param precision The precision of the fields - * @param fatlink_args Meta data for the algorithms to deploy - * @param path_coeff Array of coefficients for each path in the action - * @param inlink Host gauge field used for input - * @param fatlink Host fat-link field that is computed - * @param ulink Host unitarized field that is computed + * @param[in] precision The precision of the fields + * @param[in] fatlink_args Meta data for the algorithms to deploy + * @param[in] path_coeff Array of coefficients for each path in the action + * @param[in] inlink Host gauge field used for input + * @param[out] fatlink Host fat-link field that is computed + * @param[out] ulink Host unitarized field that is computed */ void qudaLoadUnitarizedLink(int precision, QudaFatLinkArgs_t fatlink_args, @@ -228,18 +240,48 @@ extern "C" { void* fatlink, void* ulink); + /** + * Apply the forward/backward/symmetric shift for the spin-taste opeartor. All fields + * passed and returned are host (CPU) field in MILC order. + * + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] links Gauge field on the host + * @param[in] src Input spinor field + * @param[out] dst Output spinor field + * @param[in] dir Direction of application of the spin-taste operator + * @param[in] sym Kind of spin-taste operator (1 forward, 2 backward, 3 symmetric) + * @param[in] reloadGaugeField Should we transfer again the gauge field from the CPU to the GPU? (0 = false, anything else = true) + */ + void qudaShift(int external_precision, int quda_precision, const void *const links, void *source, void *solution, + int dir, int sym, int reloadGaugeField); + /** + * Apply the forward/backward/symmetric shift for the spin-taste opeartor. All fields + * passed and returned are host (CPU) field in MILC order. + * + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] links Gauge field on the host + * @param[in] src Input spinor field + * @param[out] dst Output spinor field + * @param[in] spin Spin gamma structure using MILC numbering + * @param[in] taste Taste gamma structure using MILC numbering + * @param[in] reloadGaugeField Should we transfer again the gauge field from the CPU to the GPU? (0 = false, anything else = true) + */ + void qudaSpinTaste(int external_precision, int quda_precision, const void *const links, void *src, void *dst, + int spin, int taste, int reloadGaugeField); /** * Apply the improved staggered operator to a field. All fields * passed and returned are host (CPU) field in MILC order. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param inv_args Struct setting some solver metadata - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param source Right-hand side source field - * @param solution Solution spinor field + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] inv_args Struct setting some solver metadata + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field */ void qudaDslash(int external_precision, int quda_precision, @@ -257,20 +299,20 @@ extern "C" { * function requires that persistent gauge and clover fields have * been created prior. This interface is experimental. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param precision Precision for QUDA to use (2 - double, 1 - single) - * @param mass Fermion mass parameter - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param target_relative_residual Target Fermilab residual - * @param domain_overlap Array specifying the overlap of the domains in each dimension - * @param fatlink Fat-link field on the host - * @param longlink Long-link field on the host - * @param source Right-hand side source field - * @param solution Solution spinor field - * @param final_residual True residual - * @param final_relative_residual True Fermilab residual - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] mass Fermion mass parameter + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] target_relative_residual Target Fermilab residual + * @param[in] domain_overlap Array specifying the overlap of the domains in each dimension + * @param[in] fatlink Fat-link field on the host + * @param[in] longlink Long-link field on the host + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field + * @param[in] final_residual True residual + * @param[in] final_relative_residual True Fermilab residual + * @param[in] num_iters Number of iterations taken */ void qudaDDInvert(int external_precision, int quda_precision, @@ -287,27 +329,25 @@ extern "C" { double* const final_fermilab_residual, int* num_iters); - - /** * Solve Ax=b for an improved staggered operator. All fields are fields * passed and returned are host (CPU) field in MILC order. This * function requires that persistent gauge and clover fields have * been created prior. This interface is experimental. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param mass Fermion mass parameter - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param target_relative_residual Target Fermilab residual - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param source Right-hand side source field - * @param solution Solution spinor field - * @param final_residual True residual - * @param final_relative_residual True Fermilab residual - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] mass Fermion mass parameter + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] target_relative_residual Target Fermilab residual + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field + * @param[in] final_residual True residual + * @param[in] final_relative_residual True Fermilab residual + * @param[in] num_iters Number of iterations taken */ void qudaInvert(int external_precision, int quda_precision, @@ -329,13 +369,13 @@ extern "C" { * in MILC order. This function requires persistent gauge fields. * This interface is experimental. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param mass Fermion mass parameter - * @param inv_args Struct setting some solver metadata; required for tadpole, naik coeff - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param mg_param_file Path to an input text file describing the MG solve, to be documented on QUDA wiki + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] mass Fermion mass parameter + * @param[in] inv_args Struct setting some solver metadata; required for tadpole, naik coeff + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] mg_param_file Path to an input text file describing the MG solve, to be documented on QUDA wiki * @return Void pointer wrapping a pack of multigrid-related structures */ void *qudaMultigridCreate(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, @@ -350,21 +390,21 @@ extern "C" { * requires a multigrid parameter built from qudaSetupMultigrid * This interface is experimental. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param mass Fermion mass parameter - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param target_relative_residual Target Fermilab residual - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param mg_pack_ptr MG preconditioner structure created by qudaSetupMultigrid - * @param mg_rebuild_type whether to do a full (1) or thin (0) MG rebuild - * @param source Right-hand side source field - * @param solution Solution spinor field - * @param final_residual True residual - * @param final_relative_residual True Fermilab residual - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] mass Fermion mass parameter + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] target_relative_residual Target Fermilab residual + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] mg_pack_ptr MG preconditioner structure created by qudaSetupMultigrid + * @param[in] mg_rebuild_type whether to do a full (1) or thin (0) MG rebuild + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field + * @param[in] final_residual True residual + * @param[in] final_relative_residual True Fermilab residual + * @param[in] num_iters Number of iterations taken */ void qudaInvertMG(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const milc_fatlink, @@ -375,30 +415,30 @@ extern "C" { * Clean up a staggered/HISQ multigrid object, freeing all internal * fields and otherwise allocated memory. * - * @param mg_pack_ptr Void pointer mapping to the multigrid structure returned by qudaSetupMultigrid + * @param[in] mg_pack_ptr Void pointer mapping to the multigrid structure returned by qudaSetupMultigrid */ void qudaMultigridDestroy(void *mg_pack_ptr); /** - * Solve Ax=b for an improved staggered operator with many right hand sides. + * Solve Ax=b for an improved staggered operator with many right hand sides. * All fields are fields passed and returned are host (CPU) field in MILC order. * This function requires that persistent gauge and clover fields have * been created prior. This interface is experimental. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param mass Fermion mass parameter - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param target_relative_residual Target Fermilab residual - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param source array of right-hand side source fields - * @param solution array of solution spinor fields - * @param final_residual True residual - * @param final_relative_residual True Fermilab residual - * @param num_iters Number of iterations taken - * @param num_src Number of source fields + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] mass Fermion mass parameter + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] target_relative_residual Target Fermilab residual + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] source array of right-hand side source fields + * @param[out] solution array of solution spinor fields + * @param[in] final_residual True residual + * @param[in] final_relative_residual True Fermilab residual + * @param[in] num_iters Number of iterations taken + * @param[in] num_src Number of source fields */ void qudaInvertMsrc(int external_precision, int quda_precision, @@ -424,20 +464,20 @@ extern "C" { * are used, else reliable updates are used with a reliable_delta * parameter of 0.1. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param precision Precision for QUDA to use (2 - double, 1 - single) - * @param num_offsets Number of shifts to solve for - * @param offset Array of shift offset values - * @param inv_args Struct setting some solver metadata - * @param target_residual Array of target residuals per shift - * @param target_relative_residual Array of target Fermilab residuals per shift - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param source Right-hand side source field - * @param solutionArray Array of solution spinor fields - * @param final_residual Array of true residuals - * @param final_relative_residual Array of true Fermilab residuals - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] num_offsets Number of shifts to solve for + * @param[in] offset Array of shift offset values + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Array of target residuals per shift + * @param[in] target_relative_residual Array of target Fermilab residuals per shift + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] source Right-hand side source field + * @param[out] solutionArray Array of solution spinor fields + * @param[in] final_residual Array of true residuals + * @param[in] final_relative_residual Array of true Fermilab residuals + * @param[in] num_iters Number of iterations taken */ void qudaMultishiftInvert( int external_precision, @@ -455,7 +495,7 @@ extern "C" { double* const final_fermilab_residual, int* num_iters); - /** + /** * Solve for a system with many RHS using an improved * staggered operator. * The solving procedure consists of two computation phases : @@ -465,23 +505,23 @@ extern "C" { * are host (CPU) field in MILC order. This function requires that * persistent gauge and clover fields have been created prior. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param precision Precision for QUDA to use (2 - double, 1 - single) - * @param num_offsets Number of shifts to solve for - * @param offset Array of shift offset values - * @param inv_args Struct setting some solver metadata - * @param target_residual Array of target residuals per shift - * @param target_relative_residual Array of target Fermilab residuals per shift - * @param milc_fatlink Fat-link field on the host - * @param milc_longlink Long-link field on the host - * @param source Right-hand side source field - * @param solution Array of solution spinor fields - * @param eig_args contains info about deflation space - * @param rhs_idx bookkeep current rhs - * @param last_rhs_flag is this the last rhs to solve? - * @param final_residual Array of true residuals - * @param final_relative_residual Array of true Fermilab residuals - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] num_offsets Number of shifts to solve for + * @param[in] offset Array of shift offset values + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Array of target residuals per shift + * @param[in] target_relative_residual Array of target Fermilab residuals per shift + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] source Right-hand side source field + * @param[out] solution Array of solution spinor fields + * @param[in] eig_args contains info about deflation space + * @param[in] rhs_idx bookkeep current rhs + * @param[in] last_rhs_flag is this the last rhs to solve? + * @param[in] final_residual Array of true residuals + * @param[in] final_relative_residual Array of true Fermilab residuals + * @param[in] num_iters Number of iterations taken */ void qudaEigCGInvert( int external_precision, @@ -507,21 +547,21 @@ extern "C" { * function creates the gauge and clover field from the host fields. * Reliable updates are used with a reliable_delta parameter of 0.1. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param kappa Kappa value - * @param clover_coeff Clover coefficient - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param milc_link Gauge field on the host - * @param milc_clover Clover field on the host - * @param milc_clover_inv Inverse clover on the host - * @param clover_coeff Clover coefficient - * @param source Right-hand side source field - * @param solution Solution spinor field - * @param final_residual True residual returned by the solver - * @param final_residual True Fermilab residual returned by the solver - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] kappa Kappa value + * @param[in] clover_coeff Clover coefficient + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] milc_link Gauge field on the host + * @param[in] milc_clover Clover field on the host + * @param[in] milc_clover_inv Inverse clover on the host + * @param[in] clover_coeff Clover coefficient + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field + * @param[in] final_residual True residual returned by the solver + * @param[in] final_residual True Fermilab residual returned by the solver + * @param[in] num_iters Number of iterations taken */ void qudaCloverInvert(int external_precision, int quda_precision, @@ -548,24 +588,24 @@ extern "C" { * are host (CPU) field in MILC order. This function requires that * persistent gauge and clover fields have been created prior. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param kappa Kappa value - * @param clover_coeff Clover coefficient - * @param inv_args Struct setting some solver metadata - * @param target_residual Target residual - * @param milc_link Gauge field on the host - * @param milc_clover Clover field on the host - * @param milc_clover_inv Inverse clover on the host - * @param clover_coeff Clover coefficient - * @param source Right-hand side source field - * @param solution Solution spinor field - * @param eig_args contains info about deflation space - * @param rhs_idx bookkeep current rhs - * @param last_rhs_flag is this the last rhs to solve? - * @param final_residual Array of true residuals - * @param final_relative_residual Array of true Fermilab residuals - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] kappa Kappa value + * @param[in] clover_coeff Clover coefficient + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Target residual + * @param[in] milc_link Gauge field on the host + * @param[in] milc_clover Clover field on the host + * @param[in] milc_clover_inv Inverse clover on the host + * @param[in] clover_coeff Clover coefficient + * @param[in] source Right-hand side source field + * @param[out] solution Solution spinor field + * @param[in] eig_args contains info about deflation space + * @param[in] rhs_idx bookkeep current rhs + * @param[in] last_rhs_flag is this the last rhs to solve? + * @param[in] final_residual Array of true residuals + * @param[in] final_relative_residual Array of true Fermilab residuals + * @param[in] num_iters Number of iterations taken */ void qudaEigCGCloverInvert( int external_precision, @@ -587,14 +627,13 @@ extern "C" { double* const final_fermilab_residual, int *num_iters); - /** * Load the gauge field from the host. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param inv_args Meta data - * @param milc_link Base pointer to host gauge field (regardless of dimensionality) + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] inv_args Meta data + * @param[in] milc_link Base pointer to host gauge field (regardless of dimensionality) */ void qudaLoadGaugeField(int external_precision, int quda_precision, @@ -611,24 +650,24 @@ extern "C" { Free the two-link field allocated in QUDA. */ void qudaFreeTwoLink(); - + /** * Load the clover field and its inverse from the host. If null * pointers are passed, the clover field and / or its inverse will * be computed dynamically from the resident gauge field. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param inv_args Meta data - * @param milc_clover Pointer to host clover field. If 0 then the + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] inv_args Meta data + * @param[in] milc_clover Pointer to host clover field. If 0 then the * clover field is computed dynamically within QUDA. - * @param milc_clover_inv Pointer to host inverse clover field. If + * @param[in] milc_clover_inv Pointer to host inverse clover field. If * 0 then the inverse if computed dynamically within QUDA. - * @param solution_type The type of solution required (mat, matpc) - * @param solve_type The solve type to use (normal/direct/preconditioning) - * @param clover_coeff Clover coefficient - * @param compute_trlog Whether to compute the trlog of the clover field when inverting - * @param Array for storing the trlog (length two, one for each parity) + * @param[in] solution_type The type of solution required (mat, matpc) + * @param[in] solve_type The solve type to use (normal/direct/preconditioning) + * @param[in] clover_coeff Clover coefficient + * @param[in] compute_trlog Whether to compute the trlog of the clover field when inverting + * @param[in] Array for storing the trlog (length two, one for each parity) */ void qudaLoadCloverField(int external_precision, int quda_precision, @@ -655,19 +694,19 @@ extern "C" { * no reliable updates are used, else reliable updates are used with * a reliable_delta parameter of 0.1. * - * @param external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) - * @param quda_precision Precision for QUDA to use (2 - double, 1 - single) - * @param num_offsets Number of shifts to solve for - * @param offset Array of shift offset values - * @param kappa Kappa value - * @param clover_coeff Clover coefficient - * @param inv_args Struct setting some solver metadata - * @param target_residual Array of target residuals per shift - * @param clover_coeff Clover coefficient - * @param source Right-hand side source field - * @param solutionArray Array of solution spinor fields - * @param final_residual Array of true residuals - * @param num_iters Number of iterations taken + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] num_offsets Number of shifts to solve for + * @param[in] offset Array of shift offset values + * @param[in] kappa Kappa value + * @param[in] clover_coeff Clover coefficient + * @param[in] inv_args Struct setting some solver metadata + * @param[in] target_residual Array of target residuals per shift + * @param[in] clover_coeff Clover coefficient + * @param[in] source Right-hand side source field + * @param[out] solutionArray Array of solution spinor fields + * @param[in] final_residual Array of true residuals + * @param[in] num_iters Number of iterations taken */ void qudaCloverMultishiftInvert(int external_precision, int quda_precision, @@ -688,18 +727,19 @@ extern "C" { * are host fields in MILC order, and the precision of these fields * must match. * - * @param precision The precision of the fields - * @param num_terms The number of quark fields - * @param num_naik_terms The number of naik contributions - * @param dt Integrating step size - * @param coeff The coefficients multiplying the fermion fields in the outer product - * @param quark_field The input fermion field. - * @param level2_coeff The coefficients for the second level of smearing in the quark action. - * @param fat7_coeff The coefficients for the first level of smearing (fat7) in the quark action. - * @param w_link Unitarized link variables obtained by applying fat7 smearing and unitarization to the original links. - * @param v_link Fat7 link variables. - * @param u_link SU(3) think link variables. - * @param milc_momentum The momentum contribution from the quark action. + * @param[in] precision The precision of the fields + * @param[in] num_terms The number of quark fields + * @param[in] num_naik_terms The number of naik contributions + * @param[in] dt Integrating step size + * @param[in] coeff The coefficients multiplying the fermion fields in the outer product + * @param[in] quark_field The input fermion field. + * @param[in] level2_coeff The coefficients for the second level of smearing in the quark action. + * @param[in] fat7_coeff The coefficients for the first level of smearing (fat7) in the quark action. + * @param[in] w_link Unitarized link variables obtained by applying fat7 smearing and unitarization to the + * original links. + * @param[in] v_link Fat7 link variables. + * @param[in] u_link SU(3) think link variables. + * @param[in] milc_momentum The momentum contribution from the quark action. */ void qudaHisqForce(int precision, int num_terms, @@ -719,11 +759,11 @@ extern "C" { * here are CPU fields in MILC order, and their precisions should * match. * - * @param precision The precision of the field (2 - double, 1 - single) - * @param num_loop_types 1, 2 or 3 - * @param milc_loop_coeff Coefficients of the different loops in the Symanzik action - * @param eb3 The integration step size (for MILC this is dt*beta/3) - * @param arg Metadata for MILC's internal site struct array + * @param[in] precision The precision of the field (2 - double, 1 - single) + * @param[in] num_loop_types 1, 2 or 3 + * @param[in] milc_loop_coeff Coefficients of the different loops in the Symanzik action + * @param[in] eb3 The integration step size (for MILC this is dt*beta/3) + * @param[in] arg Metadata for MILC's internal site struct array */ void qudaGaugeForce(int precision, int num_loop_types, @@ -736,12 +776,12 @@ extern "C" { * here are CPU fields in MILC order, and their precisions should * match. * - * @param precision The precision of the field (2 - double, 1 - single) - * @param num_loop_types 1, 2 or 3 - * @param milc_loop_coeff Coefficients of the different loops in the Symanzik action - * @param eb3 The integration step size (for MILC this is dt*beta/3) - * @param arg Metadata for MILC's internal site struct array - * @param phase_in whether staggered phases are applied + * @param[in] precision The precision of the field (2 - double, 1 - single) + * @param[in] num_loop_types 1, 2 or 3 + * @param[in] milc_loop_coeff Coefficients of the different loops in the Symanzik action + * @param[in] eb3 The integration step size (for MILC this is dt*beta/3) + * @param[in] arg Metadata for MILC's internal site struct array + * @param[in] phase_in whether staggered phases are applied */ void qudaGaugeForcePhased(int precision, int num_loop_types, double milc_loop_coeff[3], double eb3, QudaMILCSiteArg_t *arg, int phase_in); @@ -815,9 +855,9 @@ extern "C" { * Evolve the gauge field by step size dt, using the momentum field * I.e., Evalulate U(t+dt) = e(dt pi) U(t). All fields are CPU fields in MILC order. * - * @param precision Precision of the field (2 - double, 1 - single) - * @param dt The integration step size step - * @param arg Metadata for MILC's internal site struct array + * @param[in] precision Precision of the field (2 - double, 1 - single) + * @param[in] dt The integration step size step + * @param[in] arg Metadata for MILC's internal site struct array */ void qudaUpdateU(int precision, double eps, @@ -827,10 +867,10 @@ extern "C" { * Evolve the gauge field by step size dt, using the momentum field * I.e., Evalulate U(t+dt) = e(dt pi) U(t). All fields are CPU fields in MILC order. * - * @param precision Precision of the field (2 - double, 1 - single) - * @param dt The integration step size step - * @param arg Metadata for MILC's internal site struct array - * @param phase_in whether staggered phases are applied + * @param[in] precision Precision of the field (2 - double, 1 - single) + * @param[in] dt The integration step size step + * @param[in] arg Metadata for MILC's internal site struct array + * @param[in] phase_in whether staggered phases are applied */ void qudaUpdateUPhased(int precision, double eps, QudaMILCSiteArg_t *arg, int phase_in); @@ -838,11 +878,11 @@ extern "C" { * Evolve the gauge field by step size dt, using the momentum field * I.e., Evalulate U(t+dt) = e(dt pi) U(t). All fields are CPU fields in MILC order. * - * @param precision Precision of the field (2 - double, 1 - single) - * @param dt The integration step size step - * @param arg Metadata for MILC's internal site struct array - * @param phase_in whether staggered phases are applied - * @param want_gaugepipe whether to enabled QUDA gaugepipe for HMC + * @param[in] precision Precision of the field (2 - double, 1 - single) + * @param[in] dt The integration step size step + * @param[in] arg Metadata for MILC's internal site struct array + * @param[in] phase_in whether staggered phases are applied + * @param[in] want_gaugepipe whether to enabled QUDA gaugepipe for HMC */ void qudaUpdateUPhasedPipeline(int precision, double eps, QudaMILCSiteArg_t *arg, int phase_in, int want_gaugepipe); @@ -863,8 +903,8 @@ extern "C" { * struct (QUDA_MILC_SITE_GAUGE_ORDER) or as a separate field * (QUDA_MILC_GAUGE_ORDER). * - * @param precision Precision of the field (2 - double, 1 - single) - * @param arg Metadata for MILC's internal site struct array + * @param[in] precision Precision of the field (2 - double, 1 - single) + * @param[in] arg Metadata for MILC's internal site struct array */ void qudaMomSave(int precision, QudaMILCSiteArg_t *arg); @@ -873,8 +913,8 @@ extern "C" { * action. MILC convention is applied, subtracting 4.0 from each * momentum matrix to increase stability. * - * @param precision Precision of the field (2 - double, 1 - single) - * @param arg Metadata for MILC's internal site struct array + * @param[in] precision Precision of the field (2 - double, 1 - single) + * @param[in] arg Metadata for MILC's internal site struct array * @return momentum action */ double qudaMomAction(int precision, QudaMILCSiteArg_t *arg); @@ -885,10 +925,10 @@ extern "C" { * exp(imu/T) will be applied to the links in the temporal * direction. * - * @param prec Precision of the gauge field - * @param gauge_h The gauge field - * @param flag Whether to apply to remove the staggered phase - * @param i_mu Imaginary chemical potential + * @param[in] prec Precision of the gauge field + * @param[in/out] gauge_h The gauge field + * @param[in] flag Whether to apply to remove the staggered phase + * @param[in] i_mu Imaginary chemical potential */ void qudaRephase(int prec, void *gauge, int flag, double i_mu); @@ -896,9 +936,9 @@ extern "C" { * Project the input field on the SU(3) group. If the target * tolerance is not met, this routine will give a runtime error. * - * @param prec Precision of the gauge field - * @param tol The tolerance to which we iterate - * @param arg Metadata for MILC's internal site struct array + * @param[in] prec Precision of the gauge field + * @param[in] tol The tolerance to which we iterate + * @param[in] arg Metadata for MILC's internal site struct array */ void qudaUnitarizeSU3(int prec, double tol, QudaMILCSiteArg_t *arg); @@ -906,10 +946,10 @@ extern "C" { * Project the input field on the SU(3) group. If the target * tolerance is not met, this routine will give a runtime error. * - * @param prec Precision of the gauge field - * @param tol The tolerance to which we iterate - * @param arg Metadata for MILC's internal site struct array - * @param phase_in whether staggered phases are applied + * @param[in] prec Precision of the gauge field + * @param[in] tol The tolerance to which we iterate + * @param[in] arg Metadata for MILC's internal site struct array + * @param[in] phase_in whether staggered phases are applied */ void qudaUnitarizeSU3Phased(int prec, double tol, QudaMILCSiteArg_t *arg, int phase_in); @@ -918,18 +958,18 @@ extern "C" { * the array solution fields, and compute the resulting momentum * field. * - * @param mom Momentum matrix - * @param dt Integrating step size - * @param x Array of solution vectors - * @param p Array of intermediate vectors - * @param coeff Array of residues for each contribution - * @param kappa kappa parameter - * @param ck -clover_coefficient * kappa / 8 - * @param nvec Number of vectors - * @param multiplicity Number of fermions represented by this bilinear - * @param gauge Gauge Field - * @param precision Precision of the fields - * @param inv_args Struct setting some solver metadata + * @param[in] mom Momentum matrix + * @param[in] dt Integrating step size + * @param[out] x Array of solution vectors + * @param[in] p Array of intermediate vectors + * @param[in] coeff Array of residues for each contribution + * @param[in] kappa kappa parameter + * @param[in] ck -clover_coefficient * kappa / 8 + * @param[in] nvec Number of vectors + * @param[in] multiplicity Number of fermions represented by this bilinear + * @param[in] gauge Gauge Field + * @param[in] precision Precision of the fields + * @param[in] inv_args Struct setting some solver metadata */ void qudaCloverForce(void *mom, double dt, void **x, void **p, double *coeff, double kappa, double ck, int nvec, double multiplicity, void *gauge, int precision, @@ -941,31 +981,30 @@ extern "C" { * precisions of all fields must match. This function requires that * there is a persistent clover field. * - * @param out Sigma trace field (QUDA device field, geometry = 1) - * @param dummy (not used) - * @param mu mu direction - * @param nu nu direction + * @param[out] out Sigma trace field (QUDA device field, geometry = 1) + * @param[in] dummy (not used) + * @param[in] mu mu direction + * @param[in] nu nu direction */ void qudaCloverTrace(void* out, void* dummy, int mu, int nu); - /** * Compute the derivative of the clover term (part of clover force * computation). All the pointers here are for QUDA native device * objects. The precisions of all fields must match. * - * @param out Clover derivative field (QUDA device field, geometry = 1) - * @param gauge Gauge field (extended QUDA device field, gemoetry = 4) - * @param oprod Matrix field (outer product) which is multiplied by the derivative - * @param mu mu direction - * @param nu nu direction - * @param coeff Coefficient of the clover derviative (including stepsize and clover coefficient) - * @param precision Precision of the fields (2 = double, 1 = single) - * @param parity Parity for which we are computing - * @param conjugate Whether to make the oprod field anti-hermitian prior to multiplication + * @param[out] out Clover derivative field (QUDA device field, geometry = 1) + * @param[in] gauge Gauge field (extended QUDA device field, gemoetry = 4) + * @param[in] oprod Matrix field (outer product) which is multiplied by the derivative + * @param[in] mu mu direction + * @param[in] nu nu direction + * @param[in] coeff Coefficient of the clover derviative (including stepsize and clover coefficient) + * @param[in] precision Precision of the fields (2 = double, 1 = single) + * @param[in] parity Parity for which we are computing + * @param[in] conjugate Whether to make the oprod field anti-hermitian prior to multiplication */ void qudaCloverDerivative(void* out, void* gauge, @@ -977,14 +1016,13 @@ extern "C" { int parity, int conjugate); - /** * Take a gauge field on the host, load it onto the device and extend it. * Return a pointer to the extended gauge field object. * - * @param gauge The CPU gauge field (optional - if set to 0 then the gauge field zeroed) - * @param geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) - * @param precision The precision of the fields (2 - double, 1 - single) + * @param[in] gauge The CPU gauge field (optional - if set to 0 then the gauge field zeroed) + * @param[in] geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) + * @param[in] precision The precision of the fields (2 - double, 1 - single) * @return Pointer to the gauge field (cast as a void*) */ void* qudaCreateExtendedGaugeField(void* gauge, @@ -995,9 +1033,9 @@ extern "C" { * Take the QUDA resident gauge field and extend it. * Return a pointer to the extended gauge field object. * - * @param gauge The CPU gauge field (optional - if set to 0 then the gauge field zeroed) - * @param geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) - * @param precision The precision of the fields (2 - double, 1 - single) + * @param[in] gauge The CPU gauge field (optional - if set to 0 then the gauge field zeroed) + * @param[in] geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) + * @param[in] precision The precision of the fields (2 - double, 1 - single) * @return Pointer to the gauge field (cast as a void*) */ void* qudaResidentExtendedGaugeField(void* gauge, @@ -1007,9 +1045,9 @@ extern "C" { /** * Allocate a gauge (matrix) field on the device and optionally download a host gauge field. * - * @param gauge The host gauge field (optional - if set to 0 then the gauge field zeroed) - * @param geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) - * @param precision The precision of the field to be created (2 - double, 1 - single) + * @param[in] gauge The host gauge field (optional - if set to 0 then the gauge field zeroed) + * @param[in] geometry The geometry of the matrix field to create (1 - scaler, 4 - vector, 6 - tensor) + * @param[in] precision The precision of the field to be created (2 - double, 1 - single) * @return Pointer to the gauge field (cast as a void*) */ void* qudaCreateGaugeField(void* gauge, @@ -1019,8 +1057,8 @@ extern "C" { /** * Copy the QUDA gauge (matrix) field on the device to the CPU * - * @param outGauge Pointer to the host gauge field - * @param inGauge Pointer to the device gauge field (QUDA device field) + * @param[out] outGauge Pointer to the host gauge field + * @param[in] inGauge Pointer to the device gauge field (QUDA device field) */ void qudaSaveGaugeField(void* gauge, void* inGauge); @@ -1028,7 +1066,7 @@ extern "C" { /** * Reinterpret gauge as a pointer to a GaugeField and call destructor. * - * @param gauge Gauge field to be freed + * @param[in] gauge Gauge field to be freed */ void qudaDestroyGaugeField(void* gauge); @@ -1073,6 +1111,22 @@ extern "C" { void* milc_sitelink ); + /** + * @brief Tie together two staggered propagators including spatial Fourier phases. + * The result is summed separately over each time slice and across all MPI ranks. + * The FT is defined by a list of momentum indices (three-component integer vectors) + * Included with the FT is a parity (symmetry) parameter for each momentum + * component that selects an exp, cos, or sin factor for each direction + * + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in,out] parameters for the contraction, including FT specification + * @param[in] local storage of color spinor field. three complex values * number of sites on node + * @param[in] local storage of color spinor field. three complex values * number of sites on node + * @param[out] hadron correlator Flattened double array as though [n_mom][nt][2] for 2 = re,im. + */ + void qudaContractFT(int external_precision, QudaContractArgs_t *cont_args, void *const quark1, void *const quark2, + double *corr); + /** * @brief Perform two-link Gaussian smearing on a given spinor (for staggered fermions). * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) @@ -1083,7 +1137,7 @@ extern "C" { */ void qudaTwoLinkGaussianSmear(int external_precision, int quda_precision, void * h_gauge, void * source, QudaTwoLinkQuarkSmearArgs_t qsmear_args); - + /* The below declarations are for removed functions from prior versions of QUDA. */ /** diff --git a/include/spin_taste.h b/include/spin_taste.h new file mode 100644 index 0000000000..009c3c0456 --- /dev/null +++ b/include/spin_taste.h @@ -0,0 +1,25 @@ +#pragma once +#include + +namespace quda +{ + + /** + @brief Compute the outer-product field between the staggered quark + field's one and (for HISQ and ASQTAD) three hop sites. E.g., + + out[0][d](x) = (in(x+1_d) x conj(in(x))) + out[1][d](x) = (in(x+3_d) x conj(in(x))) + + where 1_d and 3_d represent a relative shift of magnitude 1 and 3 in dimension d, respectively + + Note out[1] is only computed if nFace=3 + + @param[out] out Array of nFace outer-product matrix fields + @param[in] in Input quark field + @param[in] coeff Coefficient + @param[in] nFace Number of faces (1 or 3) + */ + void applySpinTaste(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma); + +} // namespace quda diff --git a/include/vector_io.h b/include/vector_io.h index 8e3a1e6ee9..8ed287f9d6 100644 --- a/include/vector_io.h +++ b/include/vector_io.h @@ -40,7 +40,6 @@ namespace quda @param[in] size Optional cap to number of vectors saved */ void save(cvector_ref &vecs, QudaPrecision prec = QUDA_INVALID_PRECISION, uint32_t size = 0); - }; } // namespace quda diff --git a/jenkins/bqcd.config.cmake b/jenkins/bqcd.config.cmake index 5f3d29da4a..9c3df527a2 100644 --- a/jenkins/bqcd.config.cmake +++ b/jenkins/bqcd.config.cmake @@ -2,20 +2,12 @@ # MILC - turns on staggered dirac and all HISQ and gauge features for MILC RHMC - set(QUDA_DIRAC_WILSON ON CACHE BOOL "build Wilson Dirac operators") set(QUDA_DIRAC_CLOVER ON CACHE BOOL "build clover Dirac operators") set(QUDA_DIRAC_DOMAIN_WALL OFF CACHE BOOL "build domain wall Dirac operators") set(QUDA_DIRAC_STAGGERED OFF CACHE BOOL "build staggered Dirac operators") set(QUDA_DIRAC_TWISTED_MASS OFF CACHE BOOL "build twisted mass Dirac operators") set(QUDA_DIRAC_TWISTED_CLOVER OFF CACHE BOOL "build twisted clover Dirac operators") -set(QUDA_DIRAC_NDEG_TWISTED_MASS OFF CACHE BOOL "build non-degenerate twisted mass Dirac operators") -set(QUDA_FORCE_GAUGE OFF CACHE BOOL "build code for (1-loop Symanzik) gauge force") -set(QUDA_FORCE_ASQTAD OFF CACHE BOOL "build code for asqtad fermion force") -set(QUDA_FORCE_HISQ OFF CACHE BOOL "build code for hisq fermion force") -set(QUDA_GAUGE_TOOLS OFF CACHE BOOL "build auxiliary gauge-field tools") -set(QUDA_GAUGE_ALG OFF CACHE BOOL "build gauge-fixing and pure-gauge algorithms") -set(QUDA_CONTRACT OFF CACHE BOOL "build code for bilinear contraction") set(QUDA_DYNAMIC_CLOVER OFF CACHE BOOL "Dynamically invert the clover term for twisted-clover") set(QUDA_QIO OFF CACHE BOOL "build QIO code for binary I/O") @@ -28,4 +20,4 @@ set(QUDA_INTERFACE_MILC OFF CACHE BOOL "build milc interface") set(QUDA_INTERFACE_CPS OFF CACHE BOOL "build cps interface") set(QUDA_INTERFACE_QDPJIT OFF CACHE BOOL "build qdpjit interface") set(QUDA_INTERFACE_BQCD ON CACHE BOOL "build bqcd interface") -set(QUDA_INTERFACE_TIFR OFF CACHE BOOL "build tifr interface") \ No newline at end of file +set(QUDA_INTERFACE_TIFR OFF CACHE BOOL "build tifr interface") diff --git a/jenkins/milc.config.cmake b/jenkins/milc.config.cmake index aae720c0ab..7d9d3c9ce0 100644 --- a/jenkins/milc.config.cmake +++ b/jenkins/milc.config.cmake @@ -2,22 +2,13 @@ # MILC - turns on staggered dirac and all HISQ and gauge features for MILC RHMC - set(QUDA_DIRAC_WILSON OFF CACHE BOOL "build Wilson Dirac operators") set(QUDA_DIRAC_CLOVER OFF CACHE BOOL "build clover Dirac operators") set(QUDA_DIRAC_DOMAIN_WALL OFF CACHE BOOL "build domain wall Dirac operators") set(QUDA_DIRAC_STAGGERED ON CACHE BOOL "build staggered Dirac operators") set(QUDA_DIRAC_TWISTED_MASS OFF CACHE BOOL "build twisted mass Dirac operators") set(QUDA_DIRAC_TWISTED_CLOVER OFF CACHE BOOL "build twisted clover Dirac operators") -set(QUDA_DIRAC_NDEG_TWISTED_MASS OFF CACHE BOOL "build non-degenerate twisted mass Dirac operators") -set(QUDA_FORCE_GAUGE ON CACHE BOOL "build code for (1-loop Symanzik) gauge force") -set(QUDA_FORCE_ASQTAD OFF CACHE BOOL "build code for asqtad fermion force") -set(QUDA_FORCE_HISQ ON CACHE BOOL "build code for hisq fermion force") -set(QUDA_GAUGE_TOOLS OFF CACHE BOOL "build auxiliary gauge-field tools") -set(QUDA_GAUGE_ALG OFF CACHE BOOL "build gauge-fixing and pure-gauge algorithms") -set(QUDA_CONTRACT OFF CACHE BOOL "build code for bilinear contraction") set(QUDA_DYNAMIC_CLOVER OFF CACHE BOOL "Dynamically invert the clover term for twisted-clover") set(QUDA_QIO OFF CACHE BOOL "build QIO code for binary I/O") - -set(QUDA_MULTIGRID OFF CACHE BOOL "build multigrid solvers") \ No newline at end of file +set(QUDA_MULTIGRID OFF CACHE BOOL "build multigrid solvers") diff --git a/jenkins/twistedmass.config.cmake b/jenkins/twistedmass.config.cmake index 5f3d29da4a..9c3df527a2 100644 --- a/jenkins/twistedmass.config.cmake +++ b/jenkins/twistedmass.config.cmake @@ -2,20 +2,12 @@ # MILC - turns on staggered dirac and all HISQ and gauge features for MILC RHMC - set(QUDA_DIRAC_WILSON ON CACHE BOOL "build Wilson Dirac operators") set(QUDA_DIRAC_CLOVER ON CACHE BOOL "build clover Dirac operators") set(QUDA_DIRAC_DOMAIN_WALL OFF CACHE BOOL "build domain wall Dirac operators") set(QUDA_DIRAC_STAGGERED OFF CACHE BOOL "build staggered Dirac operators") set(QUDA_DIRAC_TWISTED_MASS OFF CACHE BOOL "build twisted mass Dirac operators") set(QUDA_DIRAC_TWISTED_CLOVER OFF CACHE BOOL "build twisted clover Dirac operators") -set(QUDA_DIRAC_NDEG_TWISTED_MASS OFF CACHE BOOL "build non-degenerate twisted mass Dirac operators") -set(QUDA_FORCE_GAUGE OFF CACHE BOOL "build code for (1-loop Symanzik) gauge force") -set(QUDA_FORCE_ASQTAD OFF CACHE BOOL "build code for asqtad fermion force") -set(QUDA_FORCE_HISQ OFF CACHE BOOL "build code for hisq fermion force") -set(QUDA_GAUGE_TOOLS OFF CACHE BOOL "build auxiliary gauge-field tools") -set(QUDA_GAUGE_ALG OFF CACHE BOOL "build gauge-fixing and pure-gauge algorithms") -set(QUDA_CONTRACT OFF CACHE BOOL "build code for bilinear contraction") set(QUDA_DYNAMIC_CLOVER OFF CACHE BOOL "Dynamically invert the clover term for twisted-clover") set(QUDA_QIO OFF CACHE BOOL "build QIO code for binary I/O") @@ -28,4 +20,4 @@ set(QUDA_INTERFACE_MILC OFF CACHE BOOL "build milc interface") set(QUDA_INTERFACE_CPS OFF CACHE BOOL "build cps interface") set(QUDA_INTERFACE_QDPJIT OFF CACHE BOOL "build qdpjit interface") set(QUDA_INTERFACE_BQCD ON CACHE BOOL "build bqcd interface") -set(QUDA_INTERFACE_TIFR OFF CACHE BOOL "build tifr interface") \ No newline at end of file +set(QUDA_INTERFACE_TIFR OFF CACHE BOOL "build tifr interface") diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d3c1fece0e..a4525bbd7f 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -52,7 +52,7 @@ set (QUDA_OBJS madwf_transfer.cu madwf_tensor.cu blas_quda.cu multi_blas_quda.cu reduce_quda.cu multi_reduce_quda.cu reduce_helper.cu - contract.cu comm_common.cpp communicator_stack.cpp + contract.cu spin_taste.cu comm_common.cpp communicator_stack.cpp clover_force.cpp clover_deriv_quda.cu clover_invert.cu copy_gauge_extended.cu extract_gauge_ghost_extended.cu copy_color_spinor.cpp diff --git a/lib/contract.cu b/lib/contract.cu index 74206419c6..15a9072caa 100644 --- a/lib/contract.cu +++ b/lib/contract.cu @@ -1,11 +1,132 @@ #include #include + #include +#include #include #include namespace quda { + // Summed contraction type kernels. + template class ContractionSummed : TunableMultiReduction + { + protected: + const ColorSpinorField &x; + const ColorSpinorField &y; + std::vector &result_global; + const QudaContractType cType; + const int *const source_position; + const int *const mom_mode; + const QudaFFTSymmType *const fft_type; + const size_t s1; + const size_t b1; + + public: + ContractionSummed(const ColorSpinorField &x, const ColorSpinorField &y, std::vector &result_global, + const QudaContractType cType, const int *const source_position, const int *const mom_mode, + const QudaFFTSymmType *const fft_type, const size_t s1, const size_t b1) : + TunableMultiReduction( + x, 1u, x.X()[cType == QUDA_CONTRACT_TYPE_DR_FT_Z || cType == QUDA_CONTRACT_TYPE_OPEN_SUM_Z ? 2 : 3]), + x(x), + y(y), + result_global(result_global), + cType(cType), + source_position(source_position), + mom_mode(mom_mode), + fft_type(fft_type), + s1(s1), + b1(b1) + { + switch (cType) { + case QUDA_CONTRACT_TYPE_OPEN_SUM_T: strcat(aux, "open-sum-t,"); break; + case QUDA_CONTRACT_TYPE_OPEN_SUM_Z: strcat(aux, "open-sum-z,"); break; + case QUDA_CONTRACT_TYPE_DR_FT_T: strcat(aux, "degrand-rossi-ft-t,"); break; + case QUDA_CONTRACT_TYPE_DR_FT_Z: strcat(aux, "degrand-rossi-ft-z,"); break; + case QUDA_CONTRACT_TYPE_STAGGERED_FT_T: strcat(aux, "staggered-ft-t,"); break; + default: errorQuda("Unexpected contraction type %d", cType); + } + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + + int reduction_dim = 3; + const int nSpinSq = x.Nspin() * x.Nspin(); + + if (cType == QUDA_CONTRACT_TYPE_DR_FT_Z) reduction_dim = 2; + std::vector result_local(2 * nSpinSq * x.X()[reduction_dim], 0.0); + + // Pass the integer value of the redection dim as a template arg + switch (cType) { + case QUDA_CONTRACT_TYPE_DR_FT_T: { + constexpr int nSpin = 4; + constexpr int ft_dir = 3; + ContractionSummedArg arg(x, y, source_position, mom_mode, fft_type, s1, b1); + launch(result_local, tp, stream, arg); + } break; + case QUDA_CONTRACT_TYPE_DR_FT_Z: { + constexpr int nSpin = 4; + constexpr int ft_dir = 2; + ContractionSummedArg arg(x, y, source_position, mom_mode, fft_type, s1, b1); + launch(result_local, tp, stream, arg); + } break; + case QUDA_CONTRACT_TYPE_STAGGERED_FT_T: { + constexpr int nSpin = 1; + constexpr int ft_dir = 3; + ContractionSummedArg arg(x, y, source_position, mom_mode, + fft_type, s1, b1); + launch(result_local, tp, stream, arg); + } break; + default: errorQuda("Unexpected contraction type %d", cType); + } + + // Copy results back to host array + if (!activeTuning()) { + for (int i = 0; i < nSpinSq * x.X()[reduction_dim]; i++) { + result_global[nSpinSq * x.X()[reduction_dim] * comm_coord(reduction_dim) + i].real(result_local[2 * i]); + result_global[nSpinSq * x.X()[reduction_dim] * comm_coord(reduction_dim) + i].imag(result_local[2 * i + 1]); + } + } + } + + long long flops() const // DMH: Restore const qualifier for warning suppression + { + return ((x.Nspin() * x.Nspin() * x.Ncolor() * 6ll) + (x.Nspin() * x.Nspin() * (x.Nspin() + x.Nspin() * x.Ncolor()))) + * x.Volume(); + } + + long long bytes() const + { + return x.Bytes() + y.Bytes() + x.Nspin() * x.Nspin() * x.Volume() * sizeof(complex); + } + }; + + void contractSummedQuda(const ColorSpinorField &x, const ColorSpinorField &y, std::vector &result_global, + const QudaContractType cType, const int *const source_position, const int *const mom_mode, + const QudaFFTSymmType *const fft_type, const size_t s1, const size_t b1) + { + checkPrecision(x, y); + if (x.Nspin() != y.Nspin()) + errorQuda("Contraction between unequal number of spins x=%d y=%d", x.Nspin(), y.Nspin()); + if (x.Ncolor() != y.Ncolor()) + errorQuda("Contraction between unequal number of colors x=%d y=%d", x.Ncolor(), y.Ncolor()); + if (cType != QUDA_CONTRACT_TYPE_STAGGERED_FT_T) { + if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Expected four-spinors x=%d y=%d", x.Nspin(), y.Nspin()); + if (x.GammaBasis() != y.GammaBasis()) + errorQuda("Contracting spinors in different gamma bases x=%d y=%d", x.GammaBasis(), y.GammaBasis()); + } + if (cType == QUDA_CONTRACT_TYPE_DR_FT_T || cType == QUDA_CONTRACT_TYPE_DR_FT_Z) { + if (x.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || y.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS) + errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis()); + } + if (x.Ncolor() != 3 || y.Ncolor() != 3) errorQuda("Unexpected number of colors x=%d y=%d", x.Ncolor(), y.Ncolor()); + + instantiate(x, y, result_global, cType, source_position, mom_mode, fft_type, s1, b1); + } + template class Contraction : TunableKernel2D { complex *result; @@ -25,6 +146,7 @@ public: switch (cType) { case QUDA_CONTRACT_TYPE_OPEN: strcat(aux, "open,"); break; case QUDA_CONTRACT_TYPE_DR: strcat(aux, "degrand-rossi,"); break; + case QUDA_CONTRACT_TYPE_STAGGERED: strcat(aux, "staggered,"); break; default: errorQuda("Unexpected contraction type %d", cType); } apply(device::get_default_stream()); @@ -33,10 +155,26 @@ public: void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - ContractionArg arg(x, y, result); + switch (cType) { - case QUDA_CONTRACT_TYPE_OPEN: launch(tp, stream, arg); break; - case QUDA_CONTRACT_TYPE_DR: launch(tp, stream, arg); break; + case QUDA_CONTRACT_TYPE_OPEN: + case QUDA_CONTRACT_TYPE_DR: { + constexpr int nSpin = 4; + constexpr bool spin_project = true; + ContractionArg arg(x, y, result); + switch (cType) { + case QUDA_CONTRACT_TYPE_OPEN: launch(tp, stream, arg); break; + case QUDA_CONTRACT_TYPE_DR: launch(tp, stream, arg); break; + default: errorQuda("Unexpected contraction type %d", cType); + } + }; break; + case QUDA_CONTRACT_TYPE_STAGGERED: { + constexpr int nSpin = 1; + constexpr bool spin_project = false; + ContractionArg arg(x, y, result); + + launch(tp, stream, arg); + }; break; default: errorQuda("Unexpected contraction type %d", cType); } } @@ -44,9 +182,11 @@ public: long long flops() const { if (cType == QUDA_CONTRACT_TYPE_OPEN) - return 16 * 3 * 6ll * x.Volume(); + return x.Nspin() * x.Nspin() * x.Ncolor() * 6ll * x.Volume(); else - return ((16 * 3 * 6ll) + (16 * (4 + 12))) * x.Volume(); + return ((x.Nspin() * x.Nspin() * x.Ncolor() * 6ll) + + (x.Nspin() * x.Nspin() * (x.Nspin() + x.Nspin() * x.Ncolor()))) + * x.Volume(); } long long bytes() const @@ -55,23 +195,26 @@ public: } }; -#ifdef GPU_CONTRACT void contractQuda(const ColorSpinorField &x, const ColorSpinorField &y, void *result, const QudaContractType cType) { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); checkPrecision(x, y); - if (x.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || y.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS) - errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis()); - if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Unexpected number of spins x=%d y=%d", x.Nspin(), y.Nspin()); + if (x.Nspin() != y.Nspin()) + errorQuda("Contraction between unequal number of spins x=%d y=%d", x.Nspin(), y.Nspin()); + if (x.Ncolor() != y.Ncolor()) + errorQuda("Contraction between unequal number of colors x=%d y=%d", x.Ncolor(), y.Ncolor()); + if (cType == QUDA_CONTRACT_TYPE_OPEN || cType == QUDA_CONTRACT_TYPE_DR) { + if (x.Nspin() != 4 || y.Nspin() != 4) errorQuda("Expected four-spinors x=%d y=%d", x.Nspin(), y.Nspin()); + if (x.GammaBasis() != y.GammaBasis()) + errorQuda("Contracting spinors in different gamma bases x=%d y=%d", x.GammaBasis(), y.GammaBasis()); + } + if (cType == QUDA_CONTRACT_TYPE_DR) { + if (x.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS || y.GammaBasis() != QUDA_DEGRAND_ROSSI_GAMMA_BASIS) + errorQuda("Unexpected gamma basis x=%d y=%d", x.GammaBasis(), y.GammaBasis()); + } instantiate(x, y, result, cType); getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); } -#else - void contractQuda(const ColorSpinorField &, const ColorSpinorField &, void *, const QudaContractType) - { - errorQuda("Contraction code has not been built"); - } -#endif } // namespace quda diff --git a/lib/covariant_derivative.cu b/lib/covariant_derivative.cu index 000c70945c..501448789f 100644 --- a/lib/covariant_derivative.cu +++ b/lib/covariant_derivative.cu @@ -42,7 +42,7 @@ namespace quda constexpr bool xpay = false; constexpr int nParity = 2; - Dslash::template instantiate(tp, stream); + Dslash::template instantiate(tp, stream); } long long flops() const override @@ -141,9 +141,17 @@ namespace quda { constexpr int nDim = 4; auto halo = ColorSpinorField::create_comms_batch(in); - CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); - CovDev covDev(arg, out, in, halo); - dslash::DslashPolicyTune policy(covDev, in, halo, profile); + if (in.Nspin() == 4) { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, in, halo, profile); + } else if (in.Nspin() == 1) { + CovDevArg arg(out, in, halo, U, mu, parity, dagger, comm_override); + CovDev covDev(arg, out, in, halo); + dslash::DslashPolicyTune policy(covDev, in, halo, profile); + } else { + errorQuda("Spin not supported"); + } } }; diff --git a/lib/dslash_gamma_helper.cu b/lib/dslash_gamma_helper.cu index 214a293613..2824776915 100644 --- a/lib/dslash_gamma_helper.cu +++ b/lib/dslash_gamma_helper.cu @@ -11,11 +11,12 @@ namespace quda { cvector_ref &out; cvector_ref ∈ const int d; + const int proj; unsigned int minThreads() const { return in.VolumeCB() / (in.Ndim() == 5 ? in.X(4) : 1); } public: - GammaApply(cvector_ref &out, cvector_ref &in, int d) : - TunableKernel3D(in[0], in.size(), in.SiteSubset()), out(out), in(in), d(d) + GammaApply(cvector_ref &out, cvector_ref &in, int d, int proj = 0) : + TunableKernel3D(in[0], in.size(), in.SiteSubset()), out(out), in(in), d(d), proj(proj) { setRHSstring(aux, in.size()); apply(device::get_default_stream()); @@ -23,7 +24,10 @@ namespace quda { void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - launch(tp, stream, GammaArg(out, in, d)); + if (proj == 0) + launch(tp, stream, GammaArg(out, in, d)); + else + launch(tp, stream, GammaArg(out, in, d, proj)); } void preTune() { out.backup(); } @@ -39,6 +43,16 @@ namespace quda { instantiate_recurse2(out, in, d); } + // Applies out(x) = 1/2 * [(1 +/- gamma5) * in] + out + void ApplyChiralProj(cvector_ref &out, cvector_ref &in, const int proj) + { + checkPrecision(out, in); // check all precisions match + checkLocation(out, in); // check all locations match + // Launch with 4 as the gamma matrix arg to stop the constructor from erroring out, + // but this parameter is not used for chiral projection. + instantiate(out, in, 4, proj); + } + template class TwistGammaApply : public TunableKernel3D { cvector_ref &out; diff --git a/lib/dslash_pack2.cu b/lib/dslash_pack2.cu index 2e38ce4417..1db387df98 100644 --- a/lib/dslash_pack2.cu +++ b/lib/dslash_pack2.cu @@ -1,7 +1,7 @@ #include // STRIPED - spread the blocks throughout the workload to ensure we -// work on all directions/dimensions simultanesouly to maximize NVLink saturation +// work on all directions/dimensions simultaneously to maximize NVLink saturation // if not STRIPED then this means we assign one thread block per direction / dimension // currently does not work with NVSHMEM #ifndef NVSHMEM_COMMS diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index dc9f4f6887..0b1bfe5473 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -96,6 +97,9 @@ GaugeField *gaugeLongExtended = nullptr; GaugeField *gaugeSmeared = nullptr; +// Holds the Two Link gauge +GaugeField *gaugeTwoLink = nullptr; + CloverField *cloverPrecise = nullptr; CloverField *cloverSloppy = nullptr; CloverField *cloverPrecondition = nullptr; @@ -199,6 +203,9 @@ static TimeProfile profilePhase("staggeredPhaseQuda"); //!< Profiler for contractions static TimeProfile profileContract("contractQuda"); +//!< Profiler for FT contractions +static TimeProfile profileContractFT("contractFTQuda"); + //!< Profiler for GEMM and other BLAS static TimeProfile profileBLAS("blasQuda"); TimeProfile &getProfileBLAS() { return profileBLAS; } @@ -992,6 +999,7 @@ void freeGaugeQuda(void) freeUniqueGaugeQuda(QUDA_ASQTAD_FAT_LINKS); freeUniqueGaugeQuda(QUDA_ASQTAD_LONG_LINKS); freeUniqueGaugeQuda(QUDA_SMEARED_LINKS); + freeUniqueGaugeQuda(QUDA_TWOLINK_LINKS); // Need to merge extendedGaugeResident and gaugeFatPrecise/gaugePrecise if (extendedGaugeResident) { @@ -1063,6 +1071,10 @@ void freeUniqueGaugeQuda(QudaLinkType link_type) if (gaugeSmeared) delete gaugeSmeared; gaugeSmeared = nullptr; break; + case QUDA_TWOLINK_LINKS: + if (gaugeTwoLink) delete gaugeTwoLink; + gaugeTwoLink = nullptr; + break; default: errorQuda("Invalid gauge type %d", link_type); } } @@ -1073,6 +1085,12 @@ void freeGaugeSmearedQuda() freeUniqueGaugeQuda(QUDA_SMEARED_LINKS); } +void freeGaugeTwoLinkQuda() +{ + // thin wrapper + freeUniqueGaugeQuda(QUDA_TWOLINK_LINKS); +} + void loadSloppyGaugeQuda(const QudaPrecision *prec, const QudaReconstructType *recon) { // first do SU3 links (if they exist) @@ -1384,6 +1402,7 @@ void endQuda(void) profileStaggeredForce.Print(); profileHISQForce.Print(); profileContract.Print(); + profileContractFT.Print(); profileBLAS.Print(); profileCovDev.Print(); profilePlaq.Print(); @@ -1874,6 +1893,446 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, QudaParity popVerbosity(); } +void shiftQuda(void *h_out, void *h_in, int dir, int sym, QudaInvertParam *param) +{ + auto profile = pushProfile(profileCovDev, param->secs, param->gflops); + + const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; + + QudaInvertParam &inv_param = *param; + + inv_param.solution_type = QUDA_MAT_SOLUTION; + inv_param.dirac_order = QUDA_DIRAC_ORDER; + + if (!gaugePrecise) errorQuda("Gauge field not allocated"); + + pushVerbosity(inv_param.verbosity); + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(&inv_param); + + ColorSpinorParam cpuParam(h_in, inv_param, gauge.X(), false, inv_param.input_location); + ColorSpinorField in_h(cpuParam); + ColorSpinorParam cudaParam(cpuParam, inv_param, QUDA_CUDA_FIELD_LOCATION); + + cpuParam.v = h_out; + cpuParam.location = inv_param.output_location; + ColorSpinorField out_h(cpuParam); + + cudaParam.create = QUDA_NULL_FIELD_CREATE; + ColorSpinorField in(cudaParam); + in = in_h; + ColorSpinorField out(cudaParam); + out = in; + ColorSpinorField tmp(cudaParam); + tmp = in; + + profileCovDev.TPSTART(QUDA_PROFILE_COMPUTE); + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(in_h); + double gpu = blas::norm2(in); + printfQuda("In CPU %e CUDA %e\n", cpu, gpu); + } + + inv_param.dslash_type = QUDA_COVDEV_DSLASH; // ensure we use the correct dslash + DiracParam diracParam; + setDiracParam(diracParam, &inv_param, false); + + GaugeCovDev myCovDev(diracParam); // create the Dirac operator + + if (sym & 1) { + myCovDev.MCD(out, in, dir); // apply the operator + } + if (sym & 2) { + myCovDev.MCD(tmp, in, dir + 4); // apply the operator + } + + quda::blas::xpy(tmp, out); + + if (sym == 3) quda::blas::ax(0.5, out); + + profileCovDev.TPSTOP(QUDA_PROFILE_COMPUTE); + + out_h = out; + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(out_h); + double gpu = blas::norm2(out); + printfQuda("Out CPU %e CUDA %e\n", cpu, gpu); + } + + popVerbosity(); +} + +void spinTasteQuda(void *h_out, void *h_in, int spin_, int taste, QudaInvertParam *param) +{ + auto profile = pushProfile(profileCovDev, param->secs, param->gflops); + + const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; + + QudaInvertParam &inv_param = *param; + + inv_param.solution_type = QUDA_MAT_SOLUTION; + inv_param.dirac_order = QUDA_DIRAC_ORDER; + + if (!gaugePrecise) errorQuda("Gauge field not allocated"); + + pushVerbosity(inv_param.verbosity); + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(&inv_param); + + ColorSpinorParam cpuParam(h_in, inv_param, gauge.X(), false, inv_param.input_location); + ColorSpinorField in_h(cpuParam); + ColorSpinorParam cudaParam(cpuParam, inv_param, QUDA_CUDA_FIELD_LOCATION); + + cpuParam.v = h_out; + cpuParam.location = inv_param.output_location; + ColorSpinorField out_h(cpuParam); + + cudaParam.create = QUDA_NULL_FIELD_CREATE; + ColorSpinorField in(cudaParam); // cudaColorSpinorField + in = in_h; + cudaParam.create = QUDA_ZERO_FIELD_CREATE; // create new field and zero it + ColorSpinorField out(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField tmp(cudaParam); // cudaColorSpinorField = 0 + + profileCovDev.TPSTART(QUDA_PROFILE_COMPUTE); + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(in_h); + double gpu = blas::norm2(in); + printfQuda("In CPU %e CUDA %e\n", cpu, gpu); + } + + inv_param.dslash_type = QUDA_COVDEV_DSLASH; // ensure we use the correct dslash + DiracParam diracParam; + setDiracParam(diracParam, &inv_param, false); + + GaugeCovDev myCovDev(diracParam); // create the Dirac operator + + int offset = spin_ ^ taste; + QudaSpinTasteGamma spin = (QudaSpinTasteGamma)spin_; + + constexpr QudaSpinTasteGamma gDirs[4] + = {QUDA_SPIN_TASTE_GX, QUDA_SPIN_TASTE_GY, QUDA_SPIN_TASTE_GZ, QUDA_SPIN_TASTE_GT}; + + switch (offset) { + + case 0: // local + { + applySpinTaste(tmp, in, spin); + applySpinTaste(out, tmp, QUDA_SPIN_TASTE_G5); // antiquark + break; + } + + case 1: // one-link X + case 2: // one-link Y + case 4: // one-link Z + case 8: // one-link T + { + int cDir = 0; + + if (offset == 1) { + cDir = 0; + } else if (offset == 2) { + cDir = 1; + } else if (offset == 4) { + cDir = 2; + } else if (offset == 8) { + cDir = 3; + } + + ColorSpinorField pr1(cudaParam); // cudaColorSpinorField = 0 + applySpinTaste(out, in, spin); + myCovDev.MCD(tmp, out, cDir); + myCovDev.MCD(pr1, out, cDir + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[cDir]); + applySpinTaste(out, pr1, QUDA_SPIN_TASTE_G5); + quda::blas::ax(0.5, out); + break; + } + + case 3: // two-link XY + case 6: // two-link YZ + case 5: // two-link ZX + case 9: // two-link XT + case 10: // two-link YT + case 12: // two-link ZT + { + int dirs[2]; + + { + if (offset == 3) { + dirs[0] = 0; + dirs[1] = 1; + } + if (offset == 6) { + dirs[0] = 1; + dirs[1] = 2; + } + if (offset == 5) { + dirs[0] = 2; + dirs[1] = 0; + } + if (offset == 9) { + dirs[0] = 0; + dirs[1] = 3; + } + if (offset == 10) { + dirs[0] = 1; + dirs[1] = 3; + } + if (offset == 12) { + dirs[0] = 2; + dirs[1] = 3; + } + } + + ColorSpinorField pr1(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField acc(cudaParam); // cudaColorSpinorField = 0 + + applySpinTaste(out, in, spin); + // YX result in acc + myCovDev.MCD(tmp, out, dirs[1]); + myCovDev.MCD(pr1, out, dirs[1] + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[dirs[1]]); + myCovDev.MCD(tmp, pr1, dirs[0]); + myCovDev.MCD(acc, pr1, dirs[0] + 4); + quda::blas::xpy(acc, tmp); + applySpinTaste(acc, tmp, gDirs[dirs[0]]); + // XY result in tmp + myCovDev.MCD(tmp, out, dirs[0]); + myCovDev.MCD(pr1, out, dirs[0] + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[dirs[0]]); + myCovDev.MCD(tmp, pr1, dirs[1]); + myCovDev.MCD(out, pr1, dirs[1] + 4); + quda::blas::xpy(tmp, out); + applySpinTaste(tmp, out, gDirs[dirs[1]]); + + quda::blas::mxpy(tmp, acc); + applySpinTaste(out, acc, QUDA_SPIN_TASTE_G5); + quda::blas::ax(0.125, out); + break; + } + + case 14: // three-link 5X + case 13: // three-link 5Y + case 11: // three-link 5Z + case 7: // three-link 5T + { + ColorSpinorField pr1(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField pr2(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField acc(cudaParam); // cudaColorSpinorField = 0 + + applySpinTaste(out, in, spin); + + int noDir = 0; + int dirs[3]; + + // quda::blas::ax(0.0, acc); + if (offset == 14) { + noDir = 0; + } else if (offset == 13) { + noDir = 1; + } else if (offset == 11) { + noDir = 2; + } else if (offset == 7) { + noDir = 3; + } + { + int j = 0; + for (int i = 0; i < 4; i++) { + if (i == noDir) continue; + dirs[j++] = i; + } + } + + for (int i = 0; i < 3; i++) { + + const int d1 = dirs[(i + 0) % 3]; + const int d2 = dirs[(i + 1) % 3]; + const int d3 = dirs[(i + 2) % 3]; + + // Accumulate result in acc + myCovDev.MCD(tmp, out, d1); + myCovDev.MCD(pr1, out, d1 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d1]); + myCovDev.MCD(tmp, pr1, d2); + myCovDev.MCD(pr2, pr1, d2 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[d2]); + myCovDev.MCD(tmp, pr2, d3); + myCovDev.MCD(pr1, pr2, d3 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d3]); + quda::blas::xpy(pr1, acc); + + // Accumulate result in acc + myCovDev.MCD(tmp, out, d3); + myCovDev.MCD(pr1, out, d3 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d3]); + myCovDev.MCD(tmp, pr1, d2); + myCovDev.MCD(pr2, pr1, d2 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[d2]); + myCovDev.MCD(tmp, pr2, d1); + myCovDev.MCD(pr1, pr2, d1 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d1]); + quda::blas::mxpy(pr1, acc); + } + + applySpinTaste(out, acc, QUDA_SPIN_TASTE_G5); + quda::blas::ax(0.125 / 6., out); + break; + } + + case 15: // four-link 5 + { + const int dPlus[12][4] = {{0, 1, 2, 3}, {1, 2, 0, 3}, {2, 0, 1, 3}, {0, 3, 1, 2}, {1, 3, 2, 0}, {2, 3, 0, 1}, + {3, 2, 1, 0}, {3, 0, 2, 1}, {3, 1, 0, 2}, {2, 1, 3, 0}, {0, 2, 3, 1}, {1, 0, 3, 2}}; + const int dMnus[12][4] = {{0, 2, 1, 3}, {1, 0, 2, 3}, {2, 1, 0, 3}, {0, 3, 2, 1}, {1, 3, 0, 2}, {2, 3, 1, 0}, + {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 0, 1, 2}, {1, 2, 3, 0}, {2, 0, 3, 1}, {0, 1, 3, 2}}; + + ColorSpinorField pr1(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField pr2(cudaParam); // cudaColorSpinorField = 0 + ColorSpinorField acc(cudaParam); // cudaColorSpinorField = 0 + + applySpinTaste(out, in, spin); + + for (int i = 0; i < 12; i++) { + + const int d1 = dPlus[i][0]; + const int d2 = dPlus[i][1]; + const int d3 = dPlus[i][2]; + const int d4 = dPlus[i][3]; + + // Accumulate result in acc + myCovDev.MCD(tmp, out, d1); + myCovDev.MCD(pr1, out, d1 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d1]); + myCovDev.MCD(tmp, pr1, d2); + myCovDev.MCD(pr2, pr1, d2 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[d2]); + myCovDev.MCD(tmp, pr2, d3); + myCovDev.MCD(pr1, pr2, d3 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[d3]); + myCovDev.MCD(tmp, pr1, d4); + myCovDev.MCD(pr2, pr1, d4 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[d4]); + quda::blas::xpy(pr2, acc); + + const int m1 = dMnus[i][0]; + const int m2 = dMnus[i][1]; + const int m3 = dMnus[i][2]; + const int m4 = dMnus[i][3]; + + // Accumulate result in acc + myCovDev.MCD(tmp, out, m1); + myCovDev.MCD(pr1, out, m1 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[m1]); + myCovDev.MCD(tmp, pr1, m2); + myCovDev.MCD(pr2, pr1, m2 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[m2]); + myCovDev.MCD(tmp, pr2, m3); + myCovDev.MCD(pr1, pr2, m3 + 4); + quda::blas::xpy(pr1, tmp); + applySpinTaste(pr1, tmp, gDirs[m3]); + myCovDev.MCD(tmp, pr1, m4); + myCovDev.MCD(pr2, pr1, m4 + 4); + quda::blas::xpy(pr2, tmp); + applySpinTaste(pr2, tmp, gDirs[m4]); + quda::blas::mxpy(pr2, acc); + } + + applySpinTaste(out, acc, QUDA_SPIN_TASTE_G5); + quda::blas::ax(0.0625 / 24., out); + break; + } + } + + // FIXME: This is not exactly all covDev + profileCovDev.TPSTOP(QUDA_PROFILE_COMPUTE); + + out_h = out; + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(out_h); + double gpu = blas::norm2(out); + printfQuda("Out CPU %e CUDA %e\n", cpu, gpu); + } + + popVerbosity(); +} + +void covDevQuda(void *h_out, void *h_in, int dir, QudaInvertParam *param) +{ + auto profile = pushProfile(profileCovDev, param->secs, param->gflops); + + QudaInvertParam &inv_param = *param; + + const auto &gauge = *gaugePrecise; //(inv_param->dslash_type != QUDA_ASQTAD_DSLASH) ? *gaugePrecise : *gaugeFatPrecise; + + inv_param.solution_type = QUDA_MAT_SOLUTION; + inv_param.dirac_order = QUDA_DIRAC_ORDER; + + // if ((!gaugePrecise && inv_param->dslash_type != QUDA_ASQTAD_DSLASH) + // || ((!gaugeFatPrecise || !gaugeLongPrecise) && inv_param->dslash_type == QUDA_ASQTAD_DSLASH)) + if (!gaugePrecise) errorQuda("Gauge field not allocated"); + + pushVerbosity(inv_param.verbosity); + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(&inv_param); + + ColorSpinorParam cpuParam(h_in, inv_param, gauge.X(), false, inv_param.input_location); + ColorSpinorField in_h(cpuParam); + ColorSpinorParam cudaParam(cpuParam, inv_param, QUDA_CUDA_FIELD_LOCATION); + + cpuParam.v = h_out; + cpuParam.location = inv_param.output_location; + ColorSpinorField out_h(cpuParam); + + cudaParam.create = QUDA_NULL_FIELD_CREATE; + ColorSpinorField in(cudaParam); // cudaColorSpinorField + in = in_h; + ColorSpinorField out(cudaParam); // cudaColorSpinorField + out = in; + + profileCovDev.TPSTART(QUDA_PROFILE_COMPUTE); + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(in_h); + double gpu = blas::norm2(in); + printfQuda("In CPU %e CUDA %e\n", cpu, gpu); + } + + inv_param.dslash_type = QUDA_COVDEV_DSLASH; // ensure we use the correct dslash + DiracParam diracParam; + setDiracParam(diracParam, &inv_param, false); + + GaugeCovDev myCovDev(diracParam); // create the Dirac operator + myCovDev.MCD(out, in, dir); // apply the operator + profileCovDev.TPSTOP(QUDA_PROFILE_COMPUTE); + + out_h = out; + + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) { + double cpu = blas::norm2(out_h); + double gpu = blas::norm2(out); + printfQuda("Out CPU %e CUDA %e\n", cpu, gpu); + } + + popVerbosity(); +} + void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) { pushVerbosity(inv_param->verbosity); @@ -2353,7 +2812,7 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam multigrid_solver::multigrid_solver(QudaMultigridParam &mg_param) { QudaInvertParam *param = mg_param.invert_param; - // set whether we are going use native or generic blas + // set whether we are going to use native or generic blas blas_lapack::set_native(param->native_blas_lapack); checkMultigridParam(&mg_param); @@ -5021,14 +5480,13 @@ void performTwoLinkGaussianSmearNStep(void *h_in, QudaQuarkSmearParam *smear_par saveTuneCache(); } - void performGaugeSmearQuda(QudaGaugeSmearParam *smear_param, QudaGaugeObservableParam *obs_param) { auto profile = pushProfile(profileGaugeSmear); pushOutputPrefix("performGaugeSmearQuda: "); checkGaugeSmearParam(smear_param); - if (gaugePrecise == nullptr) errorQuda("Gauge field must be loaded"); + if (gaugePrecise == nullptr) errorQuda("Precise gauge field must be loaded"); freeUniqueGaugeQuda(QUDA_SMEARED_LINKS); gaugeSmeared = createExtendedGauge(*gaugePrecise, R, profileGaugeSmear); @@ -5202,6 +5660,101 @@ int computeGaugeFixingFFTQuda(void *gauge, const unsigned int gauge_dir, const u return 0; } +void contractFTQuda(void **prop_array_flavor_1, void **prop_array_flavor_2, void **result, const QudaContractType cType, + void *cs_param_ptr, const int src_colors, const int *X, const int *const source_position, + const int n_mom, const int *const mom_modes, const QudaFFTSymmType *const fft_type) +{ + auto profile = pushProfile(profileContractFT); + + // create ColorSpinorFields from void** and parameter + auto cs_param = (ColorSpinorParam *)cs_param_ptr; + const size_t nSpin = cs_param->nSpin; + const size_t src_nColor = src_colors; + cs_param->location = QUDA_CPU_FIELD_LOCATION; + cs_param->create = QUDA_REFERENCE_FIELD_CREATE; + + // The number of complex contraction results expected in the output + size_t num_out_results = nSpin * nSpin; + + // FIXME can we merge the two propagators if they are the same to save mem? + // wrap CPU host side pointers + std::vector h_prop1, h_prop2; + h_prop1.reserve(nSpin * src_nColor); + h_prop2.reserve(nSpin * src_nColor); + for (size_t i = 0; i < nSpin * src_nColor; i++) { + cs_param->v = prop_array_flavor_1[i]; + h_prop1.push_back(ColorSpinorField(*cs_param)); + cs_param->v = prop_array_flavor_2[i]; + h_prop2.push_back(ColorSpinorField(*cs_param)); + } + + // Create device spinor fields + ColorSpinorParam cudaParam(*cs_param); + cudaParam.create = QUDA_NULL_FIELD_CREATE; + cudaParam.location = QUDA_CUDA_FIELD_LOCATION; + cudaParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // not relevant for staggered + cudaParam.setPrecision(cs_param->Precision(), cs_param->Precision(), true); + + std::vector d_prop1, d_prop2; + d_prop1.reserve(nSpin * src_nColor); + d_prop2.reserve(nSpin * src_nColor); + for (size_t i = 0; i < nSpin * src_nColor; i++) { + d_prop1.push_back(ColorSpinorField(cudaParam)); + d_prop2.push_back(ColorSpinorField(cudaParam)); + } + + // temporal or spatial correlator? + size_t corr_dim = 0, local_decay_dim_slices = 0; + if (cType == QUDA_CONTRACT_TYPE_DR_FT_Z) + corr_dim = 2; + else if (cType == QUDA_CONTRACT_TYPE_DR_FT_T || cType == QUDA_CONTRACT_TYPE_STAGGERED_FT_T) + corr_dim = 3; + else + errorQuda("Unsupported contraction type %d given", cType); + + // The number of slices in the decay dimension on this MPI rank. + local_decay_dim_slices = X[corr_dim]; + + // The number of slices in the decay dimension globally. + size_t global_decay_dim_slices = local_decay_dim_slices * comm_dim(corr_dim); + + // Transfer data from host to device + for (size_t i = 0; i < nSpin * src_nColor; i++) { + d_prop1[i] = h_prop1[i]; + d_prop2[i] = h_prop2[i]; + } + + // Array for all decay slices and spins, is zeroed prior to kernel launch + std::vector result_global(global_decay_dim_slices * num_out_results); + + profileContractFT.TPSTART(QUDA_PROFILE_COMPUTE); + for (int mom_idx = 0; mom_idx < n_mom; ++mom_idx) { + + for (size_t s1 = 0; s1 < nSpin; s1++) { + for (size_t b1 = 0; b1 < nSpin; b1++) { + for (size_t c1 = 0; c1 < src_nColor; c1++) { + + std::fill(result_global.begin(), result_global.end(), 0.0); + contractSummedQuda(d_prop1[s1 * src_nColor + c1], d_prop2[b1 * src_nColor + c1], result_global, cType, + source_position, &mom_modes[4 * mom_idx], &fft_type[4 * mom_idx], s1, b1); + + comm_allreduce_sum(result_global); + for (size_t t = 0; t < global_decay_dim_slices; t++) { + for (size_t G_idx = 0; G_idx < num_out_results; G_idx++) { + int index = 2 * (global_decay_dim_slices * num_out_results * mom_idx + num_out_results * t + G_idx); + ((double *)*result)[index + 0] += result_global[num_out_results * t + G_idx].real(); + ((double *)*result)[index + 1] += result_global[num_out_results * t + G_idx].imag(); + } + } + } + } + } + } + profileContractFT.TPSTOP(QUDA_PROFILE_COMPUTE); + + saveTuneCache(); +} + void contractQuda(const void *hp_x, const void *hp_y, void *h_result, const QudaContractType cType, QudaInvertParam *param, const int *X) { diff --git a/lib/laplace.cu b/lib/laplace.cu index 0644626e18..aab6dd2d8e 100644 --- a/lib/laplace.cu +++ b/lib/laplace.cu @@ -160,6 +160,11 @@ namespace quda LaplaceArg arg(out, in, halo, U, dir, a, b, x, parity, dagger, comm_override); Laplace laplace(arg, out, in, halo); dslash::DslashPolicyTune policy(laplace, in, halo, profile); + } else if (in.Nspin() == 4) { + constexpr int nSpin = 4; + LaplaceArg arg(out, in, halo, U, dir, a, b, x, parity, dagger, comm_override); + Laplace laplace(arg, out, in, halo); + dslash::DslashPolicyTune policy(laplace, in, halo, profile); } else { errorQuda("Unsupported nSpin= %d", in.Nspin()); } diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index 43fe9e865d..131c57fde0 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -1259,15 +1259,109 @@ void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv int src_offset = getColorVectorOffset(other_parity, false, localDim); int dst_offset = getColorVectorOffset(local_parity, false, localDim); - dslashQuda(static_cast(dst) + dst_offset*host_precision, - static_cast(src) + src_offset*host_precision, - &invertParam, local_parity); + dslashQuda(static_cast(dst) + dst_offset * host_precision, + static_cast(src) + src_offset * host_precision, &invertParam, local_parity); if (!create_quda_gauge) invalidateGaugeQuda(); qudamilc_called(__func__, verbosity); } // qudaDslash +void qudaShift(int external_precision, int quda_precision, const void *const links, void *src, void *dst, int dir, + int sym, int reloadGaugeField) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + + // static const QudaVerbosity verbosity = getVerbosity(); + QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision_sloppy = device_precision; + + QudaGaugeParam gparam = newQudaGaugeParam(); + QudaGaugeParam dparam = newQudaGaugeParam(); + + setGaugeParams(gparam, dparam, nullptr, localDim, host_precision, device_precision, device_precision_sloppy, 1.0, 0.0); + gparam.type = QUDA_WILSON_LINKS; + gparam.make_resident_gauge = true; + QudaInvertParam invertParam = newQudaInvertParam(); + setInvertParams(host_precision, device_precision, device_precision_sloppy, 0.0, 0, 0, 0, 0.0, QUDA_EVEN_PARITY, + verbosity, QUDA_CG_INVERTER, &invertParam); + invertParam.solution_type = QUDA_MAT_SOLUTION; + + ColorSpinorParam csParam; + setColorSpinorParams(localDim, host_precision, &csParam); + csParam.siteSubset = QUDA_FULL_SITE_SUBSET; + csParam.x[0] *= 2; + QudaDslashType saveDslash = invertParam.dslash_type; + invertParam.dslash_type = QUDA_COVDEV_DSLASH; + + // dirty hack to invalidate the cached gauge field without breaking interface compatability + if (reloadGaugeField || !canReuseResidentGauge(&invertParam)) { + if (links == nullptr) { + errorQuda("Can't offload a null gauge field\n"); + exit(1); + } + loadGaugeQuda(const_cast(links), &gparam); + // Assume the caller resets reloadGaugeField + // invalidate_quda_gauge = false; + } + invertParam.dslash_type = saveDslash; + + if ((sym < 1) || (sym > 3)) { + errorQuda("Wrong shift. Select forward (1), backward (2) or symmetric (3).\n"); + } else { + shiftQuda(dst, src, dir, sym, &invertParam); + } + + qudamilc_called(__func__, verbosity); +} // qudaShift + +void qudaSpinTaste(int external_precision, int quda_precision, const void *const links, void *src, void *dst, int spin, + int taste, int reloadGaugeField) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + + // static const QudaVerbosity verbosity = getVerbosity(); + QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision_sloppy = device_precision; + + QudaGaugeParam gparam = newQudaGaugeParam(); + QudaGaugeParam dparam = newQudaGaugeParam(); + + setGaugeParams(gparam, dparam, nullptr, localDim, host_precision, device_precision, device_precision_sloppy, 1.0, 0.0); + gparam.type = QUDA_WILSON_LINKS; + gparam.make_resident_gauge = true; + QudaInvertParam invertParam = newQudaInvertParam(); + setInvertParams(host_precision, device_precision, device_precision_sloppy, 0.0, 0, 0, 0, 0.0, QUDA_EVEN_PARITY, + verbosity, QUDA_CG_INVERTER, &invertParam); + invertParam.solution_type = QUDA_MAT_SOLUTION; + + ColorSpinorParam csParam; + setColorSpinorParams(localDim, host_precision, &csParam); + csParam.siteSubset = QUDA_FULL_SITE_SUBSET; + csParam.x[0] *= 2; + QudaDslashType saveDslash = invertParam.dslash_type; + invertParam.dslash_type = QUDA_COVDEV_DSLASH; + + // dirty hack to invalidate the cached gauge field without breaking interface compatability + if (reloadGaugeField || !canReuseResidentGauge(&invertParam)) { + if (links == nullptr) { + errorQuda("Can't offload a null gauge field\n"); + exit(1); + } + loadGaugeQuda(const_cast(links), &gparam); + // Assume the caller resets reloadGaugeField + } + invertParam.dslash_type = saveDslash; + + spinTasteQuda(dst, src, spin, taste, &invertParam); + + qudamilc_called(__func__, verbosity); +} // qudaSpinTaste + void qudaInvertMsrc(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const fatlink, const void *const longlink, void **sourceArray, void **solutionArray, double *const final_residual, @@ -1977,6 +2071,47 @@ struct mgInputStruct { } }; +void qudaContractFT(int external_precision, QudaContractArgs_t *cont_args, void *const quark1, void *const quark2, + double *corr) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + ColorSpinorParam csParam; + { // set ColorSpinorParam block + csParam.nColor = 3; + csParam.nSpin = 1; // Support only staggered color fields for now + for (int dir = 0; dir < 4; ++dir) csParam.x[dir] = localDim[dir]; + csParam.x[4] = 1; + csParam.setPrecision(host_precision); + csParam.pad = 0; + csParam.siteSubset = QUDA_FULL_SITE_SUBSET; + csParam.siteOrder = QUDA_EVEN_ODD_SITE_ORDER; + csParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + csParam.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // meaningless for staggered, but required by the code. + csParam.create = QUDA_ZERO_FIELD_CREATE; + csParam.location = QUDA_CPU_FIELD_LOCATION; + csParam.pc_type = QUDA_4D_PC; // must be set + } + + int const n_mom = cont_args->n_mom; + int *const mom_modes = cont_args->mom_modes; + const QudaFFTSymmType *const fft_type = cont_args->fft_type; + int const *source_position = cont_args->source_position; + + QudaContractType cType = QUDA_CONTRACT_TYPE_STAGGERED_FT_T; + int const src_colors = 1; + // Only one pair of color fields and one result, so only one element in the arrays + void *prop_array_flavor_1[1] = {quark1}; + void *prop_array_flavor_2[1] = {quark2}; + void *result[1] = {corr}; + + contractFTQuda(prop_array_flavor_1, prop_array_flavor_2, result, cType, &csParam, src_colors, localDim, + source_position, n_mom, mom_modes, fft_type); + + qudamilc_called(__func__, verbosity); +} // qudaContractFT + // Internal structure that maintains `QudaMultigridParam`, // `QudaInvertParam`, `QudaEigParam`s, and the traditional // void* returned by `newMultigridQuda`. @@ -2766,7 +2901,7 @@ void qudaFreeGaugeField() { void qudaFreeTwoLink() { qudamilc_called(__func__); - freeGaugeSmearedQuda(); + freeGaugeTwoLinkQuda(); qudamilc_called(__func__); } // qudaFreeTwoLink diff --git a/lib/spin_taste.cu b/lib/spin_taste.cu new file mode 100644 index 0000000000..824665f51e --- /dev/null +++ b/lib/spin_taste.cu @@ -0,0 +1,88 @@ +#include +#include +#include +#include + +namespace quda +{ + + template class SpinTastePhase_ : TunableKernel2D + { + const ColorSpinorField ∈ + ColorSpinorField &out; + QudaSpinTasteGamma gamma; // used for meta data only + unsigned int minThreads() const { return in.VolumeCB(); } + + public: + template using Arg = SpinTasteArg; + + SpinTastePhase_(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma) : + TunableKernel2D(in, in.SiteSubset()), in(in), out(out), gamma(gamma) + { + strcat(aux, "gamma="); + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + + if (gamma == QUDA_SPIN_TASTE_G1) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GX) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GY) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GZ) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GT) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_G5) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GYGZ) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GZGX) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GXGY) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GXGT) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GYGT) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_GZGT) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_G5GX) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_G5GY) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_G5GZ) { + launch(tp, stream, Arg(out, in)); + } else if (gamma == QUDA_SPIN_TASTE_G5GT) { + launch(tp, stream, Arg(out, in)); + } else { + errorQuda("Undefined gamma type"); + } + } + + void preTune() { out.backup(); } + void postTune() { out.restore(); } + + long long flops() const { return 0; } + long long bytes() const { return 2 * in.Bytes(); } + }; + +#ifdef GPU_STAGGERED_DIRAC + void applySpinTaste(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma) + { + instantiate(out, in, gamma); + //// ensure that ghosts are updated if needed + // if (u.GhostExchange() == QUDA_GHOST_EXCHANGE_PAD) u.exchangeGhost(); + } +#else + void applySpinTaste(ColorSpinorField &out, const ColorSpinorField &in, QudaSpinTasteGamma gamma) + { + errorQuda("Gauge tools are not build"); + } +#endif + +} // namespace quda diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 21cd442e61..4d18cec933 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -130,6 +130,10 @@ if(QUDA_DIRAC_STAGGERED) quda_checkbuildtest(staggered_eigensolve_test QUDA_BUILD_ALL_TESTS) install(TARGETS staggered_eigensolve_test ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) + add_executable(contract_ft_test contract_ft_test.cpp) + target_link_libraries(contract_ft_test ${TEST_LIBS}) + quda_checkbuildtest(contract_ft_test QUDA_BUILD_ALL_TESTS) + install(TARGETS contract_ft_test ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() if(QUDA_DIRAC_WILSON @@ -207,13 +211,6 @@ if(QUDA_COVDEV) install(TARGETS covdev_test ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() -if(QUDA_CONTRACT) - add_executable(contract_test contract_test.cpp) - target_link_libraries(contract_test ${TEST_LIBS}) - quda_checkbuildtest(contract_test QUDA_BUILD_ALL_TESTS) - install(TARGETS contract_test ${QUDA_EXCLUDE_FROM_INSTALL} DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - if(QUDA_DIRAC_STAGGERED) add_executable(llfat_test llfat_test.cpp) target_link_libraries(llfat_test ${TEST_LIBS}) @@ -355,12 +352,12 @@ if(QUDA_BUILD_NATIVE_LAPACK) endif() #Contraction test -if(QUDA_CONTRACT) - add_test(NAME contract_test - COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} - --dim 2 4 6 8 - --gtest_output=xml:contract_test.xml) -endif() +if(QUDA_DIRAC_STAGGERED) + add_test(NAME contract_ft_test + COMMAND ${QUDA_CTEST_LAUNCH} $ ${MPIEXEC_POSTFLAGS} + --dim 2 4 6 8 --enable-testing true + --gtest_output=xml:contract_ft_test.xml) +endif() # loop over Dslash policies if(QUDA_CTEST_SEP_DSLASH_POLICIES) diff --git a/tests/contract_ft_test.cpp b/tests/contract_ft_test.cpp new file mode 100644 index 0000000000..552a0527db --- /dev/null +++ b/tests/contract_ft_test.cpp @@ -0,0 +1,263 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include "misc.h" + +// google test +#include + +// In a typical application, quda.h is the only QUDA header required. +#include +#include + +void display_test_info() +{ + printfQuda("running the following test:\n"); + + printfQuda("contraction_type prec S_dimension T_dimension Ls_dimension\n"); + printfQuda("%s %s %d/%d/%d %d %d\n", get_contract_str(contract_type), + get_prec_str(prec), xdim, ydim, zdim, tdim, Lsdim); + + printfQuda("contractFTQuda test"); + printfQuda("Grid partition info: X Y Z T\n"); + printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), + dimPartitioned(3)); + return; +} + +int main(int argc, char **argv) +{ + // Start Google Test Suite + ::testing::InitGoogleTest(&argc, argv); + + // QUDA initialise + // command line options: + auto app = make_app(); + + add_propagator_option_group(app); + add_contraction_option_group(app); + add_testing_option_group(app); + + try { + app->parse(argc, argv); + } catch (const CLI::ParseError &e) { + return app->exit(e); + } + + // Set values for precisions via the command line. + setQudaPrecisions(); + + // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp) + initComms(argc, argv, gridsize_from_cmdline); + + // All parameters have been set. Display the parameters via stdout + display_test_info(); + + // Initialize the QUDA library + initQuda(device_ordinal); + + // Ensure gtest prints only from rank 0 + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + + // call srand() with a rank-dependent seed + initRand(); + + std::array X = {xdim, ydim, zdim, tdim}; // local dims + + setDims(X.data()); + + // Check for correctness: + int result = 0; + + if (enable_testing) { // tests are defined in invert_test_gtest.hpp + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (quda::comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + result = RUN_ALL_TESTS(); + } else { // + contract(test_t {contract_type, prec}); + } + + // finalize the QUDA library + endQuda(); + + // finalize the communications layer + finalizeComms(); + + return result; +} + +template +inline void fill_buffers(std::array, N> &buffs, const std::array &X, const int dofs) +{ + + const std::array X0 = {X[0] * comm_coord(0), X[1] * comm_coord(1), X[2] * comm_coord(2), X[3] * comm_coord(3)}; + const std::array XN = {X[0] * comm_dim(0), X[1] * comm_dim(1), X[2] * comm_dim(2), X[3] * comm_dim(3)}; + + for (int ix = 0; ix < X[0]; ix++) { + for (int iy = 0; iy < X[1]; iy++) { + for (int iz = 0; iz < X[2]; iz++) { + for (int it = 0; it < X[3]; it++) { + + int l + = (ix + X0[0]) + (iy + X0[1]) * XN[0] + (iz + X0[2]) * XN[0] * XN[1] + (it + X0[3]) * XN[0] * XN[1] * XN[2]; + int ll = ix + iy * X[0] + iz * X[0] * X[1] + it * X[0] * X[1] * X[2]; + + srand(l); + for (int i = 0; i < dofs; i++) { +#pragma unroll + for (int n = 0; n < N; n++) { buffs[n][ll * dofs + i] = 2. * (rand() / (Float)RAND_MAX) - 1.; } + } + } + } + } + } +} + +template +inline int launch_contract_test(const QudaContractType cType, const std::array &X, const int red_size, + const std::array &source_position, const std::array &mom, + const std::array &fft_type) +{ + ColorSpinorParam cs_param; + + cs_param.nColor = 3; + cs_param.nSpin = nSpin; + cs_param.nDim = 4; + + for (int i = 0; i < 4; i++) cs_param.x[i] = X[i]; + + cs_param.x[4] = 1; + cs_param.siteSubset = QUDA_FULL_SITE_SUBSET; + cs_param.setPrecision(sizeof(Float) == sizeof(float) ? QUDA_SINGLE_PRECISION : QUDA_DOUBLE_PRECISION); + + cs_param.pad = 0; + cs_param.siteOrder = QUDA_EVEN_ODD_SITE_ORDER; + cs_param.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + cs_param.gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // meaningless for staggered, but required by the code. + cs_param.create = QUDA_ZERO_FIELD_CREATE; + cs_param.location = QUDA_CPU_FIELD_LOCATION; + cs_param.pc_type = QUDA_4D_PC; + + const int my_spinor_site_size = nSpin * 3; // nSpin X nColor + + const int spinor_field_floats = V * my_spinor_site_size * 2; // DMH: Vol * spinor elems * 2(re,im) + + const int n_contract_results = red_size * n_mom * nSpin * nSpin * 2; + + std::vector d_result(n_contract_results, 0.0); + + constexpr int nprops = nSpin * src_colors; + + const int dof = my_spinor_site_size * 2 * nprops; + + // array of spinor field for each source spin and color + size_t off = 0; + + // array of spinor field for each source spin and color + std::array spinorX; + std::array spinorY; + + std::array, 2> buffs {std::vector(nprops * spinor_field_floats, 0), + std::vector(nprops * spinor_field_floats, 0)}; + + fill_buffers(buffs, X, dof); + + for (int s = 0; s < nprops; ++s, off += spinor_field_floats * sizeof(Float)) { + spinorX[s] = (void *)((uintptr_t)buffs[0].data() + off); + spinorY[s] = (void *)((uintptr_t)buffs[1].data() + off); + } + // Perform GPU contraction: + void *d_result_ = static_cast(d_result.data()); + + contractFTQuda(spinorX.data(), spinorY.data(), &d_result_, cType, (void *)(&cs_param), src_colors, X.data(), + source_position.data(), n_mom, mom.data(), fft_type.data()); + // Check results: + int faults + = contractionFT_reference((Float **)spinorX.data(), (Float **)spinorY.data(), d_result.data(), cType, + src_colors, X.data(), source_position.data(), n_mom, mom.data(), fft_type.data()); + + return faults; +} + +template +int launch_contract_test(const QudaContractType cType, const std::array &X, const int nspin, const int red_size, + const std::array &source_position, const std::array &mom, + const std::array &fft_type) +{ + int faults = 0; + + if (nspin == 1) { + faults = launch_contract_test(cType, X, red_size, source_position, mom, fft_type); + //} else if ( nspin == 4 ){ //TODO : must be enabled when spin=4 case will be re-activated + // faults = launch_contract_test(cType, X, red_size, source_position, mom, fft_type ); + } else { + errorQuda("Unsupported spin.\n"); + } + + return faults; +} + +// Functions used for Google testing +// Performs the CPU GPU comparison with the given parameters +int contract(test_t param) +{ + if (xdim % 2) errorQuda("odd local x-dimension is not supported"); + + const std::array X = {xdim, ydim, zdim, tdim}; + + QudaContractType cType = ::testing::get<0>(param); + QudaPrecision test_prec = ::testing::get<1>(param); + + const int nSpin = cType == QUDA_CONTRACT_TYPE_STAGGERED_FT_T ? 1 : 4; + const int red_size = cType == QUDA_CONTRACT_TYPE_STAGGERED_FT_T || cType == QUDA_CONTRACT_TYPE_DR_FT_T ? + comm_dim(3) * X[3] : + comm_dim(2) * X[2]; // WARNING : check if needed + + const QudaFFTSymmType eo = QUDA_FFT_SYMM_EO; + const QudaFFTSymmType ev = QUDA_FFT_SYMM_EVEN; + const QudaFFTSymmType od = QUDA_FFT_SYMM_ODD; + + const std::array source_position = prop_source_position[0]; // make command option + + constexpr int n_mom = 18; + + const std::array mom + = {0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, + 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, -1, -1, 0, 0, 1, 1, 0, 0, 1, 1, 0}; + + const std::array fft_type + = {eo, eo, eo, eo, // (0,0,0) + ev, ev, ev, eo, eo, eo, eo, eo, // (1,0,0) + eo, eo, eo, eo, ev, ev, ev, eo, od, ev, ev, eo, eo, eo, eo, eo, // (0,1,0) + eo, eo, eo, eo, ev, ev, ev, eo, ev, od, ev, eo, eo, eo, eo, eo, // (0,0,1) + eo, eo, eo, eo, ev, ev, ev, eo, ev, ev, od, eo, eo, eo, eo, eo, // (0,1,1) + eo, eo, eo, eo, ev, ev, ev, eo, ev, od, od, eo}; + + int faults = 0; + + constexpr int src_colors = 1; + + if (test_prec == QUDA_SINGLE_PRECISION) { + faults = launch_contract_test(cType, X, nSpin, red_size, source_position, mom, fft_type); + } else if (test_prec == QUDA_DOUBLE_PRECISION) { + faults = launch_contract_test(cType, X, nSpin, red_size, source_position, mom, fft_type); + } else { + errorQuda("Unsupported precision.\n"); + } + + const int n_contract_results = red_size * n_mom * nSpin * nSpin * 2; + + printfQuda("Contraction comparison for contraction type %s complete with %d/%d faults\n", get_contract_str(cType), + faults, n_contract_results); + + return faults; +} diff --git a/tests/contract_ft_test_gtest.hpp b/tests/contract_ft_test_gtest.hpp new file mode 100644 index 0000000000..55a427fdae --- /dev/null +++ b/tests/contract_ft_test_gtest.hpp @@ -0,0 +1,55 @@ +#include +#include +#include + +using test_t = ::testing::tuple; + +class ContractFTTest : public ::testing::TestWithParam +{ + test_t param; + +public: + ContractFTTest() : param(GetParam()) { } +}; + +bool skip_test(test_t param) +{ + auto contract_type = ::testing::get<0>(param); + auto prec = ::testing::get<1>(param); + + // skip spin 4 cases + if (contract_type == QUDA_CONTRACT_TYPE_DR_FT_T or contract_type == QUDA_CONTRACT_TYPE_DR_FT_Z) return true; + if (prec < QUDA_SINGLE_PRECISION) return true; // outer precision >= sloppy precision + if (!(QUDA_PRECISION & prec)) return true; // precision not enabled so skip it + + return false; +} + +int contract(test_t param); + +TEST_P(ContractFTTest, verify) +{ + if (skip_test(GetParam())) GTEST_SKIP(); + + auto faults = contract(GetParam()); + EXPECT_EQ(faults, 0) << "CPU and GPU implementations do not agree"; +} + +std::string gettestname(::testing::TestParamInfo param) +{ + std::string str("contract_"); + + str += get_contract_str(::testing::get<0>(param.param)); + str += std::string("_") + get_prec_str(::testing::get<1>(param.param)); + + return str; +} + +using ::testing::Combine; +using ::testing::Values; + +auto contract_types = Values(QUDA_CONTRACT_TYPE_STAGGERED_FT_T); // FIXME : extend if needed + +auto precisions = Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION); + +INSTANTIATE_TEST_SUITE_P(contraction_ft, ContractFTTest, Combine(contract_types, precisions), gettestname); diff --git a/tests/contract_test.cpp b/tests/contract_test.cpp deleted file mode 100644 index 50c91611fe..0000000000 --- a/tests/contract_test.cpp +++ /dev/null @@ -1,208 +0,0 @@ -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include "misc.h" - -// google test -#include - -// In a typical application, quda.h is the only QUDA header required. -#include -#include - -// If you add a new contraction type, this must be updated++ -constexpr int NcontractType = 2; -// For googletest, names must be non-empty, unique, and may only contain ASCII -// alphanumeric characters or underscore. -const char *names[] = {"OpenSpin", "DegrandRossi"}; -const char *prec_str[] = {"quarter", "half", "single", "double"}; - -namespace quda -{ - extern void setTransferGPU(bool); -} - -void display_test_info() -{ - printfQuda("running the following test:\n"); - - printfQuda("prec sloppy_prec S_dimension T_dimension Ls_dimension\n"); - printfQuda("%s %s %d/%d/%d %d %d\n", get_prec_str(prec), get_prec_str(prec_sloppy), - xdim, ydim, zdim, tdim, Lsdim); - - printfQuda("Contraction test"); - printfQuda("Grid partition info: X Y Z T\n"); - printfQuda(" %d %d %d %d\n", dimPartitioned(0), dimPartitioned(1), dimPartitioned(2), - dimPartitioned(3)); - return; -} - -int main(int argc, char **argv) -{ - // Start Google Test Suite - //----------------------------------------------------------------------------- - ::testing::InitGoogleTest(&argc, argv); - - // QUDA initialise - //----------------------------------------------------------------------------- - // command line options - auto app = make_app(); - try { - app->parse(argc, argv); - } catch (const CLI::ParseError &e) { - return app->exit(e); - } - - // initialize QMP/MPI, QUDA comms grid and RNG (host_utils.cpp) - initComms(argc, argv, gridsize_from_cmdline); - - // Ensure gtest prints only from rank 0 - ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); - if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } - - // call srand() with a rank-dependent seed - initRand(); - display_test_info(); - - // initialize the QUDA library - initQuda(device_ordinal); - int X[4] = {xdim, ydim, zdim, tdim}; - setDims(X); - //----------------------------------------------------------------------------- - - prec = QUDA_INVALID_PRECISION; - - // Check for correctness - int result = 0; - if (verify_results) { - ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); - if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } - result = RUN_ALL_TESTS(); - if (result) warningQuda("Google tests for QUDA contraction failed!"); - } - //----------------------------------------------------------------------------- - - // finalize the QUDA library - endQuda(); - - // finalize the communications layer - finalizeComms(); - - return result; -} - -// Functions used for Google testing -//----------------------------------------------------------------------------- - -// Performs the CPU GPU comparison with the given parameters -int test(int contractionType, QudaPrecision test_prec) -{ - int X[4] = {xdim, ydim, zdim, tdim}; - - QudaInvertParam inv_param = newQudaInvertParam(); - setContractInvertParam(inv_param); - inv_param.cpu_prec = test_prec; - inv_param.cuda_prec = test_prec; - inv_param.cuda_prec_sloppy = test_prec; - inv_param.cuda_prec_precondition = test_prec; - - size_t data_size = (test_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double) : sizeof(float); - void *spinorX = safe_malloc(V * spinor_site_size * data_size); - void *spinorY = safe_malloc(V * spinor_site_size * data_size); - void *d_result = safe_malloc(2 * V * 16 * data_size); - - if (test_prec == QUDA_SINGLE_PRECISION) { - for (auto i = 0lu; i < V * spinor_site_size; i++) { - ((float *)spinorX)[i] = rand() / (float)RAND_MAX; - ((float *)spinorY)[i] = rand() / (float)RAND_MAX; - } - } else { - for (auto i = 0lu; i < V * spinor_site_size; i++) { - ((double *)spinorX)[i] = rand() / (double)RAND_MAX; - ((double *)spinorY)[i] = rand() / (double)RAND_MAX; - } - } - - // Host side spinor data and result passed to QUDA. - // QUDA will allocate GPU memory, transfer the data, - // perform the requested contraction, and return the - // result in the array 'result' - // We then compare the GPU result with a CPU refernce code - - QudaContractType cType = QUDA_CONTRACT_TYPE_INVALID; - switch (contractionType) { - case 0: cType = QUDA_CONTRACT_TYPE_OPEN; break; - case 1: cType = QUDA_CONTRACT_TYPE_DR; break; - default: errorQuda("Undefined contraction type %d\n", contractionType); - } - - // Perform GPU contraction. - contractQuda(spinorX, spinorY, d_result, cType, &inv_param, X); - - // Compare each site contraction from the host and device. - // It returns the number of faults it detects. - int faults = 0; - if (test_prec == QUDA_DOUBLE_PRECISION) { - faults = contraction_reference((double *)spinorX, (double *)spinorY, (double *)d_result, cType); - } else { - faults = contraction_reference((float *)spinorX, (float *)spinorY, (float *)d_result, cType); - } - - printfQuda("Contraction comparison for contraction type %s complete with %d/%d faults\n", get_contract_str(cType), - faults, V * 16 * 2); - - host_free(spinorX); - host_free(spinorY); - host_free(d_result); - - return faults; -} - -// The following tests gets each contraction type and precision using google testing framework -using ::testing::Bool; -using ::testing::Combine; -using ::testing::Range; -using ::testing::TestWithParam; -using ::testing::Values; - -class ContractionTest : public ::testing::TestWithParam<::testing::tuple> -{ -protected: - ::testing::tuple param; - -public: - virtual ~ContractionTest() { } - virtual void SetUp() { param = GetParam(); } -}; - -// Sets up the Google test -TEST_P(ContractionTest, verify) -{ - QudaPrecision prec = getPrecision(::testing::get<0>(GetParam())); - int contractionType = ::testing::get<1>(GetParam()); - if ((QUDA_PRECISION & prec) == 0) GTEST_SKIP(); - auto faults = test(contractionType, prec); - EXPECT_EQ(faults, 0) << "CPU and GPU implementations do not agree"; -} - -// Helper function to construct the test name -std::string getContractName(testing::TestParamInfo<::testing::tuple> param) -{ - int prec = ::testing::get<0>(param.param); - int contractType = ::testing::get<1>(param.param); - std::string str(names[contractType]); - str += std::string("_"); - str += std::string(prec_str[prec]); - return str; -} - -// Instantiate all test cases -INSTANTIATE_TEST_SUITE_P(QUDA, ContractionTest, Combine(Range(2, 4), Range(0, NcontractType)), getContractName); diff --git a/tests/covdev_test.cpp b/tests/covdev_test.cpp index d296a553af..9e7898fafc 100644 --- a/tests/covdev_test.cpp +++ b/tests/covdev_test.cpp @@ -18,7 +18,11 @@ #include #include -#include + +#include +#include + +#include using namespace quda; @@ -42,6 +46,8 @@ const int nColor = 3; void init(int argc, char **argv) { + if (test_type != 0 and test_type != 1) errorQuda("Test type %d is not supported", test_type); + initQuda(device_ordinal); setVerbosity(QUDA_VERBOSE); @@ -60,10 +66,9 @@ void init(int argc, char **argv) ColorSpinorParam csParam; csParam.nColor = nColor; - csParam.nSpin = 4; + csParam.nSpin = test_type == 0 ? 4 : 1; // use --test 1 for staggered case csParam.nDim = 4; for (int d = 0; d < 4; d++) { csParam.x[d] = gauge_param.X[d]; } - // csParam.x[4] = Nsrc; // number of sources becomes the fifth dimension csParam.setPrecision(inv_param.cpu_prec); csParam.pad = 0; @@ -168,15 +173,6 @@ void covdevRef(int mu) printfQuda("done.\n"); } -TEST(dslash, verify) -{ - double deviation = pow(10, -(double)(ColorSpinorField::Compare(*spinorRef, *spinorOut))); - double tol - = (inv_param.cuda_prec == QUDA_DOUBLE_PRECISION ? 1e-12 : - (inv_param.cuda_prec == QUDA_SINGLE_PRECISION ? 1e-3 : 1e-1)); - ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree"; -} - void display_test_info() { printfQuda("running the following test:\n"); @@ -193,10 +189,11 @@ int main(int argc, char **argv) { // initalize google test ::testing::InitGoogleTest(&argc, argv); - // return code for google test - int test_rc = 0; + // command line options auto app = make_app(); + add_covdev_option_group(app); + try { app->parse(argc, argv); } catch (const CLI::ParseError &e) { @@ -213,56 +210,71 @@ int main(int argc, char **argv) init(argc, argv); - int attempts = 1; - for (int i = 0; i < attempts; i++) { + int result = 0; + + if (enable_testing) { // tests are defined in invert_test_gtest.hpp + ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); + if (quda::comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + result = RUN_ALL_TESTS(); + } else { // + covdev_test(test_t {prec, dagger ? QUDA_DAG_YES : QUDA_DAG_NO}); + } + + end(); + finalizeComms(); - // Test forward directions, then backward - for (int dag = 0; dag < 2; dag++) { - dag == 0 ? dagger = QUDA_DAG_NO : dagger = QUDA_DAG_YES; + return result; +} - for (int mu = 0; mu < 4; mu++) { // We test all directions in one go - int muCuda = mu + (dagger ? 4 : 0); - int muCpu = mu * 2 + (dagger ? 1 : 0); +std::array covdev_test(test_t param) +{ - // Reference computation - covdevRef(muCpu); - printfQuda("\n\nChecking muQuda = %d\n", muCuda); + // QudaPrecision test_prec = ::testing::get<0>(param); + QudaDagType test_dagger = ::testing::get<1>(param); - { // warm-up run - printfQuda("Tuning...\n"); - dslashCUDA(1, muCuda); - } + std::array mu_flags {covdev_mu}; - printfQuda("Executing %d kernel loop(s)...", niter); + if (std::all_of(mu_flags.begin(), mu_flags.end(), [](int x) { return x == 0; })) { + errorQuda("No direction was chosen, exiting...\n"); + } - double secs = dslashCUDA(niter, muCuda); - *spinorOut = *cudaSpinorOut; - printfQuda("\n%fms per loop\n", 1000 * secs); + // Test forward directions, then backward + for (int mu = 0; mu < 4; mu++) { // We test all directions in one go + if (mu_flags[mu] == 0) continue; // skip direction + int muCuda = mu + (test_dagger ? 4 : 0); + int muCpu = mu * 2 + (test_dagger ? 1 : 0); - unsigned long long flops - = niter * cudaSpinor->Nspin() * (8 * nColor - 2) * nColor * (long long)cudaSpinor->Volume(); - printfQuda("GFLOPS = %f\n", 1.0e-9 * flops / secs); + // Reference computation + covdevRef(muCpu); + printfQuda("\n\nChecking muQuda = %d\n", muCuda); - double spinor_ref_norm2 = blas::norm2(*spinorRef); - double spinor_out_norm2 = blas::norm2(*spinorOut); + { // warm-up run + printfQuda("Tuning...\n"); + dslashCUDA(1, muCuda); + } + printfQuda("Executing %d kernel loop(s)...", niter); - double cuda_spinor_out_norm2 = blas::norm2(*cudaSpinorOut); - printfQuda("Results mu = %d: CPU=%f, CUDA=%f, CPU-CUDA=%f\n", muCuda, spinor_ref_norm2, cuda_spinor_out_norm2, - spinor_out_norm2); + double secs = dslashCUDA(niter, muCuda); - if (verify_results) { - ::testing::TestEventListeners &listeners = ::testing::UnitTest::GetInstance()->listeners(); - if (comm_rank() != 0) { delete listeners.Release(listeners.default_result_printer()); } + *spinorOut = *cudaSpinorOut; + printfQuda("\n%fms per loop\n", 1000 * secs); - test_rc = RUN_ALL_TESTS(); - if (test_rc != 0) warningQuda("Tests failed"); - } - } // Directions - } // Dagger - } + unsigned long long flops = niter * cudaSpinor->Nspin() * (8 * nColor - 2) * nColor * (long long)cudaSpinor->Volume(); + printfQuda("GFLOPS = %f\n", 1.0e-9 * flops / secs); - end(); + double spinor_ref_norm2 = blas::norm2(*spinorRef); + double spinor_out_norm2 = blas::norm2(*spinorOut); - finalizeComms(); - return test_rc; + double cuda_spinor_out_norm2 = blas::norm2(*cudaSpinorOut); + printfQuda("Results mu = %d: CPU=%f, CUDA=%f, CPU-CUDA=%f\n", muCuda, spinor_ref_norm2, cuda_spinor_out_norm2, + spinor_out_norm2); + + } // Directions + + double deviation = pow(10, -(double)(ColorSpinorField::Compare(*spinorRef, *spinorOut))); + double tol + = (inv_param.cuda_prec == QUDA_DOUBLE_PRECISION ? 1e-12 : + (inv_param.cuda_prec == QUDA_SINGLE_PRECISION ? 1e-3 : 1e-1)); + + return std::array {deviation, tol}; } diff --git a/tests/covdev_test_gtest.hpp b/tests/covdev_test_gtest.hpp new file mode 100644 index 0000000000..3b65adeab6 --- /dev/null +++ b/tests/covdev_test_gtest.hpp @@ -0,0 +1,56 @@ +#include +#include +#include + +using test_t = ::testing::tuple; + +class CovDevTest : public ::testing::TestWithParam +{ +protected: + test_t param; + +public: + CovDevTest() : param(GetParam()) { } +}; + +bool skip_test(test_t param) +{ + auto prec = ::testing::get<0>(param); + // auto dag = ::testing::get<1>(param); + // should we keep for all options? + if (!(QUDA_PRECISION & prec)) return true; // precision not enabled so skip i + + return false; +} + +std::array covdev_test(test_t param); + +TEST_P(CovDevTest, verify) +{ + if (skip_test(GetParam())) GTEST_SKIP(); + + std::array test_results = covdev_test(param); + + double deviation = test_results[0]; + double tol = test_results[1]; + + ASSERT_LE(deviation, tol) << "CPU and CUDA implementations do not agree"; +} + +std::string gettestname(::testing::TestParamInfo param) +{ + std::string str("covdev_"); + + str += get_prec_str(::testing::get<0>(param.param)); + str += std::string("_") + get_dag_str(::testing::get<1>(param.param)); + + return str; +} + +using ::testing::Combine; +using ::testing::Values; + +auto precisions = Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION, QUDA_HALF_PRECISION); +auto dagger_opt = Values(QUDA_DAG_YES, QUDA_DAG_NO); + +INSTANTIATE_TEST_SUITE_P(covdevtst, CovDevTest, Combine(precisions, dagger_opt), gettestname); diff --git a/tests/eigensolve_test.cpp b/tests/eigensolve_test.cpp index a5e265bf60..5b063755cc 100644 --- a/tests/eigensolve_test.cpp +++ b/tests/eigensolve_test.cpp @@ -20,6 +20,7 @@ QudaGaugeParam gauge_param; QudaInvertParam eig_inv_param; QudaEigParam eig_param; +QudaGaugeSmearParam smear_param; std::vector gauge_; std::array gauge; @@ -76,6 +77,10 @@ void init(int argc, char **argv) //------------------------------------------------------ gauge_param = newQudaGaugeParam(); setWilsonGaugeParam(gauge_param); + if (gauge_smear) { + smear_param = newQudaGaugeSmearParam(); + setGaugeSmearParam(smear_param); + } // Though no inversions are performed, the inv_param // structure contains all the information we need to @@ -245,6 +250,7 @@ int main(int argc, char **argv) // Parse command line options auto app = make_app(); add_eigen_option_group(app); + add_su3_option_group(app); add_madwf_option_group(app); add_comms_option_group(app); add_testing_option_group(app); diff --git a/tests/host_reference/contract_ft_reference.h b/tests/host_reference/contract_ft_reference.h new file mode 100644 index 0000000000..7f77ab9754 --- /dev/null +++ b/tests/host_reference/contract_ft_reference.h @@ -0,0 +1,207 @@ +#pragma once + +#include +#include +#include "color_spinor_field.h" +#include +#include +#include +#include + +extern int Z[4]; +extern int Vh; +extern int V; + +using namespace quda; +template using complex = std::complex; + +// Color contract two ColorSpinors at a site returning a [nSpin x nSpin] matrix. +template +inline void contractColors(const Float *const spinorX, const Float *const spinorY, const int nSpin, Float M[]) +{ + for (int s1 = 0; s1 < nSpin; s1++) { + for (int s2 = 0; s2 < nSpin; s2++) { + Float re = 0.0; + Float im = 0.0; + for (int c = 0; c < 3; c++) { + re += (spinorX[6 * s1 + 2 * c + 0] * spinorY[6 * s2 + 2 * c + 0] + + spinorX[6 * s1 + 2 * c + 1] * spinorY[6 * s2 + 2 * c + 1]); + + im += (spinorX[6 * s1 + 2 * c + 0] * spinorY[6 * s2 + 2 * c + 1] + - spinorX[6 * s1 + 2 * c + 1] * spinorY[6 * s2 + 2 * c + 0]); + } + + M[2 * (nSpin * s1 + s2) + 0] = re; + M[2 * (nSpin * s1 + s2) + 1] = im; + } + } +}; + +// accumulate Fourier phase +template inline void FourierPhase(Float z[2], const Float theta, const QudaFFTSymmType fft_type) +{ + Float w[2] {z[0], z[1]}; + if (fft_type == QUDA_FFT_SYMM_EVEN) { + Float costh = cos(theta); + z[0] = w[0] * costh; + z[1] = w[1] * costh; + } else if (fft_type == QUDA_FFT_SYMM_ODD) { + Float sinth = sin(theta); + z[0] = -w[1] * sinth; + z[1] = w[0] * sinth; + } else if (fft_type == QUDA_FFT_SYMM_EO) { + Float costh = cos(theta); + Float sinth = sin(theta); + z[0] = w[0] * costh - w[1] * sinth; + z[1] = w[1] * costh + w[0] * sinth; + } +}; + +template +void contractFTHost(Float **h_prop_array_flavor_1, Float **h_prop_array_flavor_2, double *h_result, + const QudaContractType cType, const int src_colors, const int *X, const int *const source_position, + const int n_mom, const int *const mom_modes, const QudaFFTSymmType *const fft_type) +{ + int nSpin = 4; + if (cType == QUDA_CONTRACT_TYPE_STAGGERED_FT_T) nSpin = 1; + + // The number of contraction results expected in the output + size_t num_out_results = nSpin * nSpin; + + int reduct_dim = 3; // t-dir is default + if (cType == QUDA_CONTRACT_TYPE_DR_FT_Z) reduct_dim = 2; + + // The number of slices in the decay dimension on this MPI rank. + size_t local_reduct_slices = X[reduct_dim]; + + // The number of slices in the decay dimension globally. + size_t global_reduct_slices = local_reduct_slices * comm_dim(reduct_dim); + + // Array for all momenta, reduction slices, and channels. It is zeroed prior to kernel launch. + std::vector> result_global(n_mom * global_reduct_slices * num_out_results); + std::fill(result_global.begin(), result_global.end(), Complex {0.0, 0.0}); + + // Strides for computing local coordinates + int strides[4] {1, X[0], X[1] * X[0], X[2] * X[1] * X[0]}; + + // Global lattice dimensions + int L[4]; + for (int dir = 0; dir < 4; ++dir) L[dir] = X[dir] * comm_dim(dir); + + double phase[n_mom * 2]; + Float M[num_out_results * 2]; + // size_t x ; + int sink[4]; + int red_coord = -1; + for (int sindx = 0; sindx < V; ++sindx) { + // compute local coordinates; lexicographical with x fastest + int parity = 0; + int rem = sindx; + for (int dir = 3; dir >= 0; --dir) { + sink[dir] = rem / strides[dir]; + rem -= sink[dir] * strides[dir]; + parity += sink[dir]; + } + parity &= 1; + int cb_idx = sindx / 2; + // global coords + for (int dir = 0; dir < 4; ++dir) { + sink[dir] += comm_coord(dir) * X[dir]; + if (reduct_dim == dir) red_coord = sink[dir]; // project to this coord + } + + // compute Fourier phases + for (int mom_idx = 0; mom_idx < n_mom; ++mom_idx) { + phase[2 * mom_idx + 0] = 1.; + phase[2 * mom_idx + 1] = 0.; + for (int dir = 0; dir < 4; ++dir) { + double theta = 2. * M_PI / L[dir]; + theta *= (sink[dir] - source_position[dir]) * mom_modes[4 * mom_idx + dir]; + FourierPhase(phase + 2 * mom_idx, theta, fft_type[4 * mom_idx + dir]); + } + } + + for (int s1 = 0; s1 < nSpin; s1++) { + for (int s2 = 0; s2 < nSpin; s2++) { + for (int c1 = 0; c1 < src_colors; c1++) { + // color contraction + size_t off = nSpin * 3 * 2 * (Vh * parity + cb_idx); + contractColors(h_prop_array_flavor_1[s1 * src_colors + c1] + off, + h_prop_array_flavor_2[s2 * src_colors + c1] + off, nSpin, M); + + // apply gamma matrices here + + // mutiply by Fourier phases and accumulate + for (int mom_idx = 0; mom_idx < n_mom; ++mom_idx) { + for (size_t m_idx = 0; m_idx < num_out_results; ++m_idx) { + Float prod[2]; + prod[0] = phase[2 * mom_idx + 0] * M[2 * m_idx + 0] - phase[2 * mom_idx + 1] * M[2 * m_idx + 1]; + prod[1] = phase[2 * mom_idx + 1] * M[2 * m_idx + 0] + phase[2 * mom_idx + 0] * M[2 * m_idx + 1]; + // result[mom_idx][red_coord][m_idx] + size_t g_idx = global_reduct_slices * num_out_results * mom_idx + num_out_results * red_coord + m_idx; + result_global[g_idx] += std::complex {prod[0], prod[1]}; + } + } + } + } + } + } // sites + + // global reduction + quda::comm_allreduce_sum(result_global); + + // copy to output array + for (size_t idx = 0; idx < n_mom * global_reduct_slices * num_out_results; ++idx) { + h_result[2 * idx + 0] = result_global[idx].real(); + h_result[2 * idx + 1] = result_global[idx].imag(); + } +}; + +template +int contractionFT_reference(Float **spinorX, Float **spinorY, const double *const d_result, const QudaContractType cType, + const int src_colors, const int *X, const int *const source_position, const int n_mom, + const int *const mom_modes, const QudaFFTSymmType *const fft_type) +{ + int nSpin = 4; + if (cType == QUDA_CONTRACT_TYPE_STAGGERED_FT_T) nSpin = 1; + + size_t reduct_dim = 3; // t-dir is default + if (cType == QUDA_CONTRACT_TYPE_DR_FT_Z) reduct_dim = 2; + + // The number of slices in the reduction dimension. + size_t reduction_slices = X[reduct_dim] * comm_dim(reduct_dim); + + // space for the host result + const size_t n_floats = n_mom * reduction_slices * nSpin * nSpin * 2; + double *h_result = static_cast(safe_malloc(n_floats * sizeof(double))); + + // compute contractions on the host + contractFTHost(spinorX, spinorY, h_result, cType, src_colors, X, source_position, n_mom, mom_modes, fft_type); + + const int ntol = 7; + auto epsilon = std::numeric_limits::epsilon(); + auto fact = epsilon; + fact *= sqrt((double)nSpin * 6 * V * 2 / reduction_slices); // account for repeated roundoff in float ops + fact *= 10; // account for variation in phase computation + std::array tolerance {1.0e-5 * fact, 1.0e-4 * fact, 1.0e-3 * fact, 1.0e-2 * fact, + 1.0e-1 * fact, 1.0e+0 * fact, 1.0e+1 * fact}; + int check_tol = 5; + std::array fails = {}; + + for (size_t idx = 0; idx < n_floats; ++idx) { + double rel = abs(d_result[idx] - h_result[idx]) / (abs(h_result[idx]) + epsilon); + // printfQuda("%5ld: %10.3e %10.3e: %10.3e\n", idx, d_result[idx], h_result[idx], rel); + for (int d = 0; d < ntol; ++d) + if (rel > tolerance[d]) ++fails[d]; + } + + printfQuda("tolerance n_diffs\n"); + printfQuda("---------- --------\n"); + for (int j = 0; j < ntol; ++j) { printfQuda("%9.1e: %8d\n", tolerance[j], fails[j]); } + printfQuda("---------- --------\n"); + printfQuda("check tolerance is %9.1e\n", tolerance[check_tol]); + + host_free(h_result); + + return fails[check_tol]; +}; diff --git a/tests/host_reference/covdev_reference.cpp b/tests/host_reference/covdev_reference.cpp index 66aaf85fa8..fb922fca82 100644 --- a/tests/host_reference/covdev_reference.cpp +++ b/tests/host_reference/covdev_reference.cpp @@ -161,7 +161,9 @@ void covdevReference_mg4dir(sFloat *res, gFloat **link, gFloat **ghostLink, cons auto fwd_nbr_spinor = reinterpret_cast(in.fwdGhostFaceBuffer); auto back_nbr_spinor = reinterpret_cast(in.backGhostFaceBuffer); - for (auto i = 0lu; i < Vh * spinor_site_size; i++) res[i] = 0.0; + const int my_spinor_site_size = in.Nspin() == 1 ? stag_spinor_site_size : spinor_site_size; + + for (int i = 0; i < Vh * my_spinor_site_size; i++) res[i] = 0.0; gFloat *linkEven[4], *linkOdd[4]; gFloat *ghostLinkEven[4], *ghostLinkOdd[4]; @@ -175,18 +177,18 @@ void covdevReference_mg4dir(sFloat *res, gFloat **link, gFloat **ghostLink, cons } for (int sid = 0; sid < Vh; sid++) { - int offset = spinor_site_size * sid; + int offset = my_spinor_site_size * sid; gFloat *lnk = gaugeLink_mg4dir(sid, mu, oddBit, linkEven, linkOdd, ghostLinkEven, ghostLinkOdd, 1, 1); const sFloat *spinor = spinorNeighbor_mg4dir(sid, mu, oddBit, static_cast(in.data()), - fwd_nbr_spinor, back_nbr_spinor, 1, 1); + fwd_nbr_spinor, back_nbr_spinor, 1, 1, my_spinor_site_size); - sFloat gaugedSpinor[spinor_site_size]; + sFloat gaugedSpinor[my_spinor_site_size]; if (daggerBit) { - for (int s = 0; s < 4; s++) su3Tmul(&gaugedSpinor[s * 6], lnk, &spinor[s * 6]); + for (int s = 0; s < in.Nspin(); s++) su3Tmul(&gaugedSpinor[s * 6], lnk, &spinor[s * 6]); } else { - for (int s = 0; s < 4; s++) su3Mul(&gaugedSpinor[s * 6], lnk, &spinor[s * 6]); + for (int s = 0; s < in.Nspin(); s++) su3Mul(&gaugedSpinor[s * 6], lnk, &spinor[s * 6]); } sum(&res[offset], &res[offset], gaugedSpinor, spinor_site_size); } // 4-d volume diff --git a/tests/staggered_eigensolve_test.cpp b/tests/staggered_eigensolve_test.cpp index 6e717437fe..335367b478 100644 --- a/tests/staggered_eigensolve_test.cpp +++ b/tests/staggered_eigensolve_test.cpp @@ -74,6 +74,12 @@ void init() gauge_param = newQudaGaugeParam(); setStaggeredGaugeParam(gauge_param); + QudaGaugeSmearParam smear_param; + if (gauge_smear) { + smear_param = newQudaGaugeSmearParam(); + setGaugeSmearParam(smear_param); + } + // Though no inversions are performed, the inv_param // structure contains all the information we need to // construct the dirac operator. @@ -256,6 +262,7 @@ int main(int argc, char **argv) auto app = make_app(); add_eigen_option_group(app); + add_su3_option_group(app); add_testing_option_group(app); try { app->parse(argc, argv); diff --git a/tests/staggered_gsmear_test.cpp b/tests/staggered_gsmear_test.cpp index 925ebfeff5..e5ce29c346 100644 --- a/tests/staggered_gsmear_test.cpp +++ b/tests/staggered_gsmear_test.cpp @@ -61,6 +61,7 @@ int main(int argc, char **argv) auto app = make_app(); app->add_option("--test", gtest_type, "Test method")->transform(CLI::CheckedTransformer(gtest_type_map)); add_quark_smear_option_group(app); + add_su3_option_group(app); add_comms_option_group(app); try { app->parse(argc, argv); diff --git a/tests/staggered_gsmear_test_utils.h b/tests/staggered_gsmear_test_utils.h index 7dc02bc46e..bb31a1c02c 100644 --- a/tests/staggered_gsmear_test_utils.h +++ b/tests/staggered_gsmear_test_utils.h @@ -100,6 +100,7 @@ struct StaggeredGSmearTestWrapper { // GaugeField *cpuTwoLink = nullptr; QudaGaugeParam gauge_param; + QudaGaugeSmearParam smear_param; QudaInvertParam inv_param; ColorSpinorField spinor; @@ -167,6 +168,11 @@ struct StaggeredGSmearTestWrapper { // inv_param = newQudaInvertParam(); setStaggeredGaugeParam(gauge_param); + if (gauge_smear) { + smear_param = newQudaGaugeSmearParam(); + setGaugeSmearParam(smear_param); + } + setStaggeredInvertParam(inv_param); auto prec = getPrecision(precision); @@ -190,6 +196,10 @@ struct StaggeredGSmearTestWrapper { // inv_param = newQudaInvertParam(); setStaggeredGaugeParam(gauge_param); + if (gauge_smear) { + smear_param = newQudaGaugeSmearParam(); + setGaugeSmearParam(smear_param); + } setStaggeredInvertParam(inv_param); init(argc, argv); @@ -363,9 +373,8 @@ struct StaggeredGSmearTestWrapper { // printfQuda("GFLOPS = %f\n", gflops); ::testing::Test::RecordProperty("Gflops", std::to_string(gflops)); - size_t ghost_bytes = gtest_type == gsmear_test_type::GaussianSmear ? spinor.GhostBytes() : 0; - - if (gtest_type == gsmear_test_type::GaussianSmear) { + size_t ghost_bytes = spinor.GhostBytes(); + if (gtest_type == gsmear_test_type::GaussianSmear && ghost_bytes > 0) { ::testing::Test::RecordProperty("Halo_bidirectitonal_BW_GPU", 1.0e-9 * 2 * ghost_bytes * niter / gsmear_time.event_time); ::testing::Test::RecordProperty("Halo_bidirectitonal_BW_CPU", @@ -374,12 +383,12 @@ struct StaggeredGSmearTestWrapper { // ::testing::Test::RecordProperty("Halo_bidirectitonal_BW_CPU_max", 1.0e-9 * 2 * ghost_bytes / gsmear_time.cpu_min); ::testing::Test::RecordProperty("Halo_message_size_bytes", 2 * ghost_bytes); - printfQuda( - "Effective halo bi-directional bandwidth (GB/s) GPU = %f ( CPU = %f, min = %f , max = %f ) for aggregate " - "message size %lu bytes\n", - 1.0e-9 * 2 * ghost_bytes * niter / gsmear_time.event_time, - 1.0e-9 * 2 * ghost_bytes * niter / gsmear_time.cpu_time, 1.0e-9 * 2 * ghost_bytes / gsmear_time.cpu_max, - 1.0e-9 * 2 * ghost_bytes / gsmear_time.cpu_min, 2 * ghost_bytes); + printfQuda("Effective halo bi-directional bandwidth (GB/s) GPU = %f ( CPU = %f, min = %f , max = %f ) for " + "aggregate message size %lu bytes\n", + 1.0e-9 * 2 * ghost_bytes * niter / gsmear_time.event_time, + 1.0e-9 * 2 * ghost_bytes * niter / gsmear_time.cpu_time, + 1.0e-9 * 2 * ghost_bytes / gsmear_time.cpu_max, 1.0e-9 * 2 * ghost_bytes / gsmear_time.cpu_min, + 2 * ghost_bytes); } } } diff --git a/tests/staggered_invert_test.cpp b/tests/staggered_invert_test.cpp index 767961fbbe..9a00b3d875 100644 --- a/tests/staggered_invert_test.cpp +++ b/tests/staggered_invert_test.cpp @@ -138,6 +138,11 @@ void init() // Set QUDA internal parameters gauge_param = newQudaGaugeParam(); setStaggeredGaugeParam(gauge_param); + QudaGaugeSmearParam smear_param; + if (gauge_smear) { + smear_param = newQudaGaugeSmearParam(); + setGaugeSmearParam(smear_param); + } inv_param = newQudaInvertParam(); mg_inv_param = newQudaInvertParam(); diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index e3b604a444..204c336ac2 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -72,6 +72,7 @@ quda::mgarray mg_vec_partfile = {}; QudaInverterType inv_type; bool inv_deflate = false; bool inv_multigrid = false; +bool gauge_smear = false; QudaInverterType precon_type = QUDA_INVALID_INVERTER; QudaSchwarzType precon_schwarz_type = QUDA_INVALID_SCHWARZ; QudaAcceleratorType precon_accelerator_type = QUDA_INVALID_ACCELERATOR; @@ -97,6 +98,7 @@ bool low_mode_check = false; bool oblique_proj_check = false; double mass = 0.1; double kappa = -1.0; +quda::mass_array kappa_array = {}; double mu = 0.1; double epsilon = 0.01; double evmax = 0.1; @@ -279,7 +281,7 @@ int heatbath_num_steps = 10; int heatbath_num_heatbath_per_step = 5; int heatbath_num_overrelax_per_step = 5; bool heatbath_coldstart = false; - +// GF Options int gf_gauge_dir = 4; int gf_maxiter = 10000; int gf_verbosity_interval = 100; @@ -296,7 +298,38 @@ double eofa_mq1 = 1.0; double eofa_mq2 = 0.85; double eofa_mq3 = 1.0; -QudaContractType contract_type = QUDA_CONTRACT_TYPE_OPEN; +// SU(3) smearing options +double gauge_smear_rho = 0.1; +double gauge_smear_epsilon = 1.0; +double gauge_smear_alpha = 0.6; +int gauge_smear_steps = 5; +QudaWFlowType wflow_type = QUDA_WFLOW_TYPE_WILSON; +int measurement_interval = 5; +QudaGaugeSmearType gauge_smear_type = QUDA_GAUGE_SMEAR_STOUT; + +// contract options +QudaContractType contract_type = QUDA_CONTRACT_TYPE_STAGGERED_FT_T; +std::array momentum = {0, 0, 0, 0}; +char correlator_file_affix[256] = ""; +char correlator_save_dir[256] = "."; +bool open_flavor = false; + +// Propagator options +quda::file_array prop_source_infile; +quda::file_array prop_source_outfile; +quda::file_array prop_sink_infile; +quda::file_array prop_sink_outfile; +quda::source_array> prop_source_position = {{{0, 0, 0, 0}}}; + +int prop_source_smear_steps = 0; +int prop_sink_smear_steps = 0; +double prop_source_smear_coeff = 2.0; +double prop_sink_smear_coeff = 2.0; +bool prop_read_sources = false; +int prop_n_sources = 1; +QudaPrecision prop_save_prec = QUDA_SINGLE_PRECISION; + +std::array covdev_mu = {1, 1, 1, 1}; // Parameters for the (gaussian) quark smearing operator int smear_n_steps = 50; @@ -313,9 +346,6 @@ namespace { CLI::TransformPairs ca_basis_map {{"power", QUDA_POWER_BASIS}, {"chebyshev", QUDA_CHEBYSHEV_BASIS}}; - CLI::TransformPairs contract_type_map {{"open", QUDA_CONTRACT_TYPE_OPEN}, - {"dr", QUDA_CONTRACT_TYPE_DR}}; - CLI::TransformPairs dslash_type_map {{"wilson", QUDA_WILSON_DSLASH}, {"clover", QUDA_CLOVER_WILSON_DSLASH}, {"twisted-mass", QUDA_TWISTED_MASS_DSLASH}, @@ -420,10 +450,20 @@ namespace {"SR", QUDA_SPECTRUM_SR_EIG}, {"LR", QUDA_SPECTRUM_LR_EIG}, {"SM", QUDA_SPECTRUM_SM_EIG}, {"LM", QUDA_SPECTRUM_LM_EIG}, {"SI", QUDA_SPECTRUM_SI_EIG}, {"LI", QUDA_SPECTRUM_LI_EIG}}; + CLI::TransformPairs wflow_type_map {{"wilson", QUDA_WFLOW_TYPE_WILSON}, + {"symanzik", QUDA_WFLOW_TYPE_SYMANZIK}}; + + CLI::TransformPairs gauge_smear_type_map { + {"ape", QUDA_GAUGE_SMEAR_APE}, {"stout", QUDA_GAUGE_SMEAR_STOUT}, {"ovr-imp-stout", QUDA_GAUGE_SMEAR_OVRIMP_STOUT}}; + CLI::TransformPairs setup_type_map {{"test", QUDA_TEST_VECTOR_SETUP}, {"null", QUDA_TEST_VECTOR_SETUP}}; CLI::TransformPairs extlib_map {{"eigen", QUDA_EIGEN_EXTLIB}}; + CLI::TransformPairs contract_type_map {{"dr-ft-t", QUDA_CONTRACT_TYPE_DR_FT_T}, + {"dr-ft-z", QUDA_CONTRACT_TYPE_DR_FT_Z}, + {"stag-ft-t", QUDA_CONTRACT_TYPE_STAGGERED_FT_T}}; + } // namespace std::shared_ptr make_app(std::string app_description, std::string app_name) @@ -469,12 +509,6 @@ std::shared_ptr make_app(std::string app_description, std::string app_n quda_app->add_option("--compute-fat-long", compute_fatlong, "Compute the fat/long field or use random numbers (default false)"); - quda_app - ->add_option("--contraction-type", contract_type, - "Whether to leave spin elemental open, or use a gamma basis and contract on " - "spin (default open)") - ->transform(CLI::QUDACheckedTransformer(contract_type_map)); - quda_app->add_flag("--dagger", dagger, "Set the dagger to 1 (default 0)"); quda_app->add_option("--device", device_ordinal, "Set the CUDA device to use (default 0, single GPU only)") ->check(CLI::Range(0, 16)); @@ -502,8 +536,10 @@ std::shared_ptr make_app(std::string app_description, std::string app_n quda_app->add_option("--inv-type", inv_type, "The type of solver to use (default cg)") ->transform(CLI::QUDACheckedTransformer(inverter_type_map)); - quda_app->add_option("--inv-deflate", inv_deflate, "Deflate the inverter using the eigensolver"); - quda_app->add_option("--inv-multigrid", inv_multigrid, "Precondition the inverter using multigrid"); + quda_app->add_option("--inv-deflate", inv_deflate, "Deflate the inverter using the eigensolver (default false)"); + quda_app->add_option("--inv-multigrid", inv_multigrid, "Precondition the inverter using multigrid (default false)"); + quda_app->add_option("--gauge-smear", gauge_smear, + "Smear the gauge prior to dirac operator construction (default false)"); quda_app->add_option("--kappa", kappa, "Kappa of Dirac operator (default 0.12195122... [equiv to mass])"); quda_app->add_option( "--laplace3D", laplace3D, @@ -1056,6 +1092,32 @@ void add_eofa_option_group(std::shared_ptr quda_app) opgroup->add_option("--eofa-mq3", eofa_mq1, "Set mq3 for EOFA operator (default 1.0)"); } +void add_su3_option_group(std::shared_ptr quda_app) +{ + + // Option group for SU(3) related options + auto opgroup = quda_app->add_option_group("SU(3)", "Options controlling SU(3) tests"); + opgroup->add_option("--su3-smear-alpha", gauge_smear_alpha, "alpha coefficient for APE smearing (default 0.6)"); + + opgroup->add_option("--su3-smear-rho", gauge_smear_rho, + "rho coefficient for Stout and Over-Improved Stout smearing (default 0.1)"); + + opgroup->add_option( + "--su3-smear-epsilon", gauge_smear_epsilon, + "epsilon coefficient for Over-Improved Stout smearing and step size for Wilson flow (default 1.0)"); + + opgroup->add_option("--su3-smear-steps", gauge_smear_steps, "The number of smearing steps to perform (default 10)"); + + opgroup->add_option("--su3-wflow-type", wflow_type, "The type of action to use in the wilson flow (default wilson)") + ->transform(CLI::QUDACheckedTransformer(wflow_type_map)); + + opgroup->add_option("--su3-smear-type", gauge_smear_type, "The type of smearing to use (default stout)") + ->transform(CLI::QUDACheckedTransformer(gauge_smear_type_map)); + + opgroup->add_option("--su3-measurement-interval", measurement_interval, + "Measure the field energy and topological charge every Nth step (default 5) "); +} + void add_madwf_option_group(std::shared_ptr quda_app) { auto opgroup = quda_app->add_option_group("MADWF", "Options controlling MADWF parameteres"); @@ -1091,6 +1153,79 @@ void add_heatbath_option_group(std::shared_ptr quda_app) "Number of measurement steps in heatbath test (default 10)"); opgroup->add_option("--heatbath-warmup-steps", heatbath_warmup_steps, "Number of warmup steps in heatbath test (default 10)"); + // DMH + // opgroup->add_option("--heatbath-checkpoint", heatbath_checkpoint, + //"Number of measurement steps in heatbath before checkpointing (default 5)"); +} + +void add_propagator_option_group(std::shared_ptr quda_app) +{ + // Option group for propagator related options + auto opgroup = quda_app->add_option_group("Propagator", "Options controlling propagator construction"); + + opgroup->add_option("--prop-read-sources", prop_read_sources, + "Read all sources from file. There will be one propagator for each source (default false)"); + + opgroup->add_option("--prop-n-sources", prop_n_sources, "The number of point sources to construct (default 1)"); + + quda_app->add_fileoption(opgroup, "--prop-save-sink-file", prop_sink_outfile, CLI::Validator(), + "Save propagators to (requires QIO)"); + + quda_app + ->add_fileoption(opgroup, "--prop-load-sink-file", prop_sink_infile, CLI::Validator(), + "Load propagators from (requires QIO)") + ->check(CLI::ExistingFile); + + quda_app->add_fileoption(opgroup, "--prop-save-source-file", prop_source_outfile, CLI::Validator(), + "Save source to (requires QIO)"); + + // Do not check for an existing file as QUDA will append any + // string with a dilution index: "string_" + quda_app->add_fileoption(opgroup, "--prop-load-source-file", prop_source_infile, CLI::Validator(), + "Load source to (requires QIO)"); + + opgroup->add_option("--prop-source-smear-coeff", prop_source_smear_coeff, + "Set the alpha(Wuppertal) or omega(Gaussian) source smearing value (default 0.2)"); + + opgroup->add_option("--prop-source-smear-steps", prop_source_smear_steps, + "Set the number of source smearing steps (default 0)"); + + opgroup->add_option("--prop-sink-smear-coeff", prop_sink_smear_coeff, + "Set the alpha(Wuppertal) or omega(Gaussian) sink smearing value (default 0.2)"); + + opgroup->add_option("--prop-sink-smear-steps", prop_sink_smear_steps, + "Set the number of sink smearing steps (default 0)"); + + quda_app->add_psoption(opgroup, "--prop-source-position", prop_source_position, CLI::Validator(), + "Set the position of the nth point source (X Y Z T) (default(0,0,0,0))"); + + CLI::QUDACheckedTransformer prec_transform(precision_map); + opgroup->add_option("--prop-save-prec", prop_save_prec, "Precision with which to save propagators (default single)") + ->transform(prec_transform); +} + +void add_contraction_option_group(std::shared_ptr quda_app) +{ + // Option group for contraction related options + auto opgroup = quda_app->add_option_group("Contraction", "Options controlling contraction"); + + opgroup + ->add_option("--contraction-type", contract_type, + "Whether to leave spin elemental open or insert a gamma basis, " + "and whether to sum in t,z, or not at all (default stag-ft-t)") + ->transform(CLI::QUDACheckedTransformer(contract_type_map)); + + opgroup->add_option("--correlator-save-dir", correlator_save_dir, "Save propagators in directory "); + opgroup->add_option("--momentum", momentum, "Set momentum for correlators (px py pz pt) (default(0,0,0,0))")->expected(4); + opgroup->add_option("--open-flavor", open_flavor, "Compute the open flavor correlators (default false)"); + opgroup->add_option("--correlator-file-affix", correlator_file_affix, + "Additional string to put into the correlator file name"); + + quda_app->add_massoption(opgroup, "--kappa-array", kappa_array, CLI::Validator(), + "set the Nth kappa value of the Dirac operator)"); + + quda_app->add_massoption(opgroup, "--mass-array", kappa_array, CLI::Validator(), + "set the Nth mass value of the Dirac operator)"); } void add_gaugefix_option_group(std::shared_ptr quda_app) @@ -1145,3 +1280,9 @@ void add_clover_force_option_group(std::shared_ptr quda_app) auto opgroup = quda_app->add_option_group("Clover force", "Options controlling clover force testing"); opgroup->add_option("--determinant-ratio", detratio, "Test a ratio of determinants. Default is false"); } + +void add_covdev_option_group(std::shared_ptr quda_app) +{ + auto opgroup = quda_app->add_option_group("Covdev", "Options controlling cov derivative parameteres"); + opgroup->add_option("--covdev-mu", covdev_mu, "Set the direction(s) (default 1 1 1 1 - all directions)")->expected(4); +} diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index ddf8d966af..5b038e4cb0 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -7,9 +7,18 @@ // for compatibility while porting - remove later extern void usage(char **); +// Put this is quda_constants.h? +#define QUDA_MAX_SOURCES 128 + +// Put this is quda_constants.h? +#define QUDA_MAX_MASSES 128 + namespace quda { template using mgarray = std::array; + template using file_array = std::array; + template using source_array = std::array; + template using mass_array = std::array; } class QUDAApp : public CLI::App @@ -126,6 +135,115 @@ class QUDAApp : public CLI::App group->add_option(opt); return opt; } + + // Add option to parse multiple point source locations + template + CLI::Option *add_psoption(CLI::Option_group *group, std::string option_name, + std::array, QUDA_MAX_SOURCES> &variable, CLI::Validator trans, + std::string option_description = "") + { + + CLI::callback_t f = [&variable, &option_name, trans](CLI::results_t vals) { + size_t l; + T j; // results_t is just a vector of strings + bool worked = true; + + CLI::Range validsource(0, QUDA_MAX_SOURCES); + for (size_t i {0}; i < vals.size() / (4 + 1); ++i) { + auto sourceok = validsource(vals.at((4 + 1) * i)); + + if (!sourceok.empty()) throw CLI::ValidationError(option_name, sourceok); + worked = worked and CLI::detail::lexical_cast(vals.at((4 + 1) * i), l); + + for (int k = 0; k < 4; k++) { + auto transformok = trans(vals.at((4 + 1) * i + k + 1)); + if (!transformok.empty()) throw CLI::ValidationError(option_name, transformok); + worked = worked and CLI::detail::lexical_cast(vals.at((4 + 1) * i + k + 1), j); + if (worked) variable[l][k] = j; + } + } + return worked; + }; + CLI::Option *opt = add_option(option_name, f, option_description); + auto valuename = std::string("SOURCE ") + std::string(CLI::detail::type_name()); + opt->type_name(valuename)->type_size(-4 - 1); + opt->expected(-1); + opt->check(CLI::Validator(trans.get_description())); + + group->add_option(opt); + return opt; + } + + // Add option to parse multiple files. + template + CLI::Option *add_fileoption(CLI::Option_group *group, std::string option_name, + std::array &variable, CLI::Validator trans, + std::string option_description = "") + { + + CLI::callback_t f = [&variable, &option_name, trans](CLI::results_t vals) { + size_t l; + // T j; // results_t is just a vector of strings + bool worked = true; + + CLI::Range validsource(0, QUDA_MAX_SOURCES); + for (size_t i {0}; i < vals.size() / 2; ++i) { // will always be a multiple of 2 + auto sourceok = validsource(vals.at(2 * i)); + auto transformok = trans(vals.at(2 * i + 1)); + if (!sourceok.empty()) throw CLI::ValidationError(option_name, sourceok); + if (!transformok.empty()) throw CLI::ValidationError(option_name, transformok); + worked = worked and CLI::detail::lexical_cast(vals.at(2 * i), l); + auto &j = variable[l]; + worked = worked and CLI::detail::lexical_cast(vals.at(2 * i + 1), j); + + // if (worked) variable[l] = j; + } + return worked; + }; + CLI::Option *opt = add_option(option_name, f, option_description); + auto valuename = std::string("SOURCE ") + std::string(CLI::detail::type_name()); + opt->type_name(valuename)->type_size(-2); + opt->expected(-1); + opt->check(CLI::Validator(trans.get_description())); + + group->add_option(opt); + return opt; + } + + template + CLI::Option *add_massoption(CLI::Option_group *group, std::string option_name, std::array &variable, + CLI::Validator trans, std::string option_description = "", bool = false) + { + + CLI::callback_t f = [&variable, &option_name, trans](CLI::results_t vals) { + size_t l; + // T j; // results_t is just a vector of strings + bool worked = true; + + CLI::Range validlevel(0, QUDA_MAX_MASSES); + for (size_t i {0}; i < vals.size() / 2; ++i) { // will always be a multiple of 2 + auto levelok = validlevel(vals.at(2 * i)); + auto transformok = trans(vals.at(2 * i + 1)); + if (!levelok.empty()) throw CLI::ValidationError(option_name, levelok); + if (!transformok.empty()) throw CLI::ValidationError(option_name, transformok); + worked = worked and CLI::detail::lexical_cast(vals.at(2 * i), l); + auto &j = variable[l]; + worked = worked and CLI::detail::lexical_cast(vals.at(2 * i + 1), j); + + // if (worked) variable[l] = j; + } + return worked; + }; + CLI::Option *opt = add_option(option_name, f, option_description); + auto valuename = std::string("FLAVOR ") + std::string(CLI::detail::type_name()); + opt->type_name(valuename)->type_size(-2); + opt->expected(-1); + opt->check(CLI::Validator(trans.get_description())); + // opt->transform(trans); + // opt->default_str(""); + group->add_option(opt); + return opt; + } }; std::shared_ptr make_app(std::string app_description = "QUDA internal test", std::string app_name = ""); @@ -136,11 +254,14 @@ void add_eofa_option_group(std::shared_ptr quda_app); void add_madwf_option_group(std::shared_ptr quda_app); void add_su3_option_group(std::shared_ptr quda_app); void add_heatbath_option_group(std::shared_ptr quda_app); +void add_propagator_option_group(std::shared_ptr quda_app); +void add_contraction_option_group(std::shared_ptr quda_app); void add_gaugefix_option_group(std::shared_ptr quda_app); void add_comms_option_group(std::shared_ptr quda_app); void add_testing_option_group(std::shared_ptr quda_app); void add_quark_smear_option_group(std::shared_ptr quda_app); void add_clover_force_option_group(std::shared_ptr quda_app); +void add_covdev_option_group(std::shared_ptr quda_app); template std::string inline get_string(CLI::TransformPairs &map, T val) { @@ -209,6 +330,7 @@ extern quda::mgarray mg_vec_partfile; extern QudaInverterType inv_type; extern bool inv_deflate; extern bool inv_multigrid; +extern bool gauge_smear; extern QudaInverterType precon_type; extern QudaSchwarzType precon_schwarz_type; extern QudaAcceleratorType precon_accelerator_type; @@ -234,6 +356,7 @@ extern bool low_mode_check; extern bool oblique_proj_check; extern double mass; extern double kappa; +extern quda::mass_array kappa_array; extern double mu; extern double epsilon; extern double evmax; @@ -423,7 +546,14 @@ extern double eofa_mq1; extern double eofa_mq2; extern double eofa_mq3; -extern QudaContractType contract_type; +// SU(3) smearing options +extern double gauge_smear_rho; +extern double gauge_smear_epsilon; +extern double gauge_smear_alpha; +extern int gauge_smear_steps; +extern QudaWFlowType wflow_type; +extern int measurement_interval; +extern QudaGaugeSmearType gauge_smear_type; extern double smear_coeff; extern int smear_n_steps; @@ -433,6 +563,28 @@ extern bool smear_delete_two_link; extern std::array grid_partition; +// contract options +extern QudaContractType contract_type; +extern char correlator_save_dir[256]; +extern char correlator_file_affix[256]; +extern std::array momentum; +extern bool open_flavor; + +extern quda::file_array prop_source_infile; +extern quda::file_array prop_source_outfile; +extern quda::file_array prop_sink_infile; +extern quda::file_array prop_sink_outfile; +extern quda::source_array> prop_source_position; +extern int prop_source_smear_steps; +extern int prop_sink_smear_steps; +extern double prop_source_smear_coeff; +extern double prop_sink_smear_coeff; +extern bool prop_read_sources; +extern int prop_n_sources; +extern QudaPrecision prop_save_prec; + extern bool enable_testing; extern bool detratio; + +extern std::array covdev_mu; diff --git a/tests/utils/host_utils.h b/tests/utils/host_utils.h index d001fd6eea..dd2b87b1f9 100644 --- a/tests/utils/host_utils.h +++ b/tests/utils/host_utils.h @@ -115,6 +115,9 @@ template void applyGaugeFieldScaling(Float **gauge, int Vh, Qud //------------------------------------------------------ void constructWilsonTestSpinorParam(quda::ColorSpinorParam *csParam, const QudaInvertParam *inv_param, const QudaGaugeParam *gauge_param); +void constructPointSpinorSource(void *v, QudaPrecision precision, const int *const x, const int dil, + const int *const src); +void constructWallSpinorSource(void *v, int nSpin, int nColor, QudaPrecision precision, const int dil); void constructRandomSpinorSource(void *v, int nSpin, int nColor, QudaPrecision precision, QudaSolutionType sol_type, const int *const x, int nDim, quda::RNG &rng); //------------------------------------------------------ @@ -343,3 +346,7 @@ void setStaggeredMGInvertParam(QudaInvertParam &inv_param); void setGaugeParam(QudaGaugeParam &gauge_param); void setWilsonGaugeParam(QudaGaugeParam &gauge_param); void setStaggeredGaugeParam(QudaGaugeParam &gauge_param); + +// Smear param types +void setGaugeSmearParam(QudaGaugeSmearParam &smear_param); +void setFermionSmearParam(QudaInvertParam &inv_param, double omega, int steps); diff --git a/tests/utils/misc.cpp b/tests/utils/misc.cpp index 29a448d7a2..808ad0ca85 100644 --- a/tests/utils/misc.cpp +++ b/tests/utils/misc.cpp @@ -125,7 +125,28 @@ const char *get_contract_str(QudaContractType type) switch (type) { case QUDA_CONTRACT_TYPE_OPEN: ret = "open"; break; - case QUDA_CONTRACT_TYPE_DR: ret = "Degrand_Rossi"; break; + case QUDA_CONTRACT_TYPE_OPEN_SUM_T: ret = "open_sum_t"; break; + case QUDA_CONTRACT_TYPE_OPEN_SUM_Z: ret = "open_sum_z"; break; + case QUDA_CONTRACT_TYPE_OPEN_FT_T: ret = "open_ft_t"; break; + case QUDA_CONTRACT_TYPE_OPEN_FT_Z: ret = "open_ft_z"; break; + case QUDA_CONTRACT_TYPE_DR: ret = "dr"; break; + case QUDA_CONTRACT_TYPE_DR_FT_T: ret = "dr_ft_t"; break; + case QUDA_CONTRACT_TYPE_DR_FT_Z: ret = "dr_ft_z"; break; + case QUDA_CONTRACT_TYPE_STAGGERED: ret = "stag"; break; + case QUDA_CONTRACT_TYPE_STAGGERED_FT_T: ret = "stag_ft_t"; break; + default: ret = "unknown"; break; + } + + return ret; +} + +const char *get_dag_str(QudaDagType type) +{ + const char *ret; + + switch (type) { + case QUDA_DAG_YES: ret = "dag"; break; + case QUDA_DAG_NO: ret = "nodag"; break; default: ret = "unknown"; break; } diff --git a/tests/utils/misc.h b/tests/utils/misc.h index 625cb2ca01..626a0c24da 100644 --- a/tests/utils/misc.h +++ b/tests/utils/misc.h @@ -22,6 +22,7 @@ const char *get_eig_type_str(QudaEigType type); const char *get_ritz_location_str(QudaFieldLocation type); const char *get_memory_type_str(QudaMemoryType type); const char *get_contract_str(QudaContractType type); +const char *get_dag_str(QudaDagType type); const char *get_gauge_smear_str(QudaGaugeSmearType type); std::string get_dilution_type_str(QudaDilutionType type); const char *get_blas_type_str(QudaBLASType type); diff --git a/tests/utils/set_params.cpp b/tests/utils/set_params.cpp index 404401c2d6..3ba1900cf0 100644 --- a/tests/utils/set_params.cpp +++ b/tests/utils/set_params.cpp @@ -3,6 +3,17 @@ #include #include "misc.h" +void setGaugeSmearParam(QudaGaugeSmearParam &smear_param) +{ + smear_param.alpha = gauge_smear_alpha; + smear_param.rho = gauge_smear_rho; + smear_param.epsilon = gauge_smear_epsilon; + smear_param.n_steps = gauge_smear_steps; + smear_param.meas_interval = measurement_interval; + smear_param.smear_type = gauge_smear_type; + smear_param.struct_size = sizeof(smear_param); +} + void setGaugeParam(QudaGaugeParam &gauge_param) { gauge_param.type = QUDA_SU3_LINKS; @@ -277,6 +288,22 @@ void setInvertParam(QudaInvertParam &inv_param) inv_param.struct_size = sizeof(inv_param); } +void setFermionSmearParam(QudaInvertParam &smear_param, double omega, int steps) +{ + // Construct a copy of the current invert parameters + setInvertParam(smear_param); + + // Construct 4D smearing parameters. + smear_param.dslash_type = QUDA_LAPLACE_DSLASH; + double smear_coeff = -1.0 * omega * omega / (4 * steps); + smear_param.mass_normalization = QUDA_KAPPA_NORMALIZATION; // Enforce kappa normalisation + smear_param.mass = 1.0; + smear_param.kappa = smear_coeff; + smear_param.laplace3D = laplace3D; // Omit this dim + smear_param.solution_type = QUDA_MAT_SOLUTION; + smear_param.solve_type = QUDA_DIRECT_SOLVE; +} + // Parameters defining the eigensolver void setEigParam(QudaEigParam &eig_param) {