Skip to content

Commit

Permalink
Fix Laplace solver tests (reference Laplace operator uses kappa norma…
Browse files Browse the repository at this point in the history
…lization, but staggered invert test was hard coded to use mass normalization). Fix split grid ctest when using single GPU
  • Loading branch information
maddyscientist committed Nov 18, 2024
1 parent f49f168 commit a66dbf0
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 12 deletions.
2 changes: 2 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ if(QUDA_MPI OR QUDA_QMP)
message(STATUS "ctest will run on ${QUDA_TEST_NUM_PROCS} processes")
set(QUDA_CTEST_LAUNCH ${MPIEXEC_EXECUTABLE};${MPIEXEC_NUMPROC_FLAG};${QUDA_TEST_NUM_PROCS};${MPIEXEC_PREFLAGS}
CACHE STRING "CTest Launcher command for QUDA's tests")
else()
set(QUDA_TEST_NUM_PROCS 1)
endif()

# BLAS tests
Expand Down
9 changes: 3 additions & 6 deletions tests/host_reference/dslash_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,17 @@ double verifyWilsonTypeSingularVector(void *spinor_left, void *spinor_right, dou

std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out,
quda::GaugeField &fat_link, quda::GaugeField &long_link,
QudaInvertParam &inv_param, int src_idx)
QudaInvertParam &inv_param, int laplace3D, int src_idx)
{
std::vector<quda::ColorSpinorField> out_vector(1);
out_vector[0] = out;
return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param, src_idx);
return verifyStaggeredInversion(in, out_vector, fat_link, long_link, inv_param, laplace3D, src_idx);
}

std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in,
std::vector<quda::ColorSpinorField> &out_vector,
quda::GaugeField &fat_link, quda::GaugeField &long_link,
QudaInvertParam &inv_param, int src_idx)
QudaInvertParam &inv_param, int laplace3D, int src_idx)
{
int dagger = inv_param.dagger == QUDA_DAG_YES ? 1 : 0;
double l2r_max = 0.0;
Expand Down Expand Up @@ -810,9 +810,6 @@ std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in,
double mass = inv_param.mass;
if (inv_param.solution_type == QUDA_MAT_SOLUTION) {
stag_mat(ref, fat_link, long_link, out, mass, dagger, dslash_type, laplace3D);

// correct for the massRescale function inside invertQuda
if (is_laplace(dslash_type)) ax(0.5 / kappa, ref.data(), ref.Length(), ref.Precision());
} else if (inv_param.solution_type == QUDA_MATPC_SOLUTION) {
QudaParity parity = QUDA_INVALID_PARITY;
switch (inv_param.matpc_type) {
Expand Down
8 changes: 6 additions & 2 deletions tests/host_reference/dslash_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,13 @@ std::array<double, 2> verifyWilsonTypeInversion(void *spinorOut, void **spinorOu
* @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied
* @param long_link The long links; null for naive staggered and Laplace
* @param inv_param Invert params, used to query the solve type, etc
* @param laplace3D Whether we are working on the 3-d Laplace operator
* @param src_idx The source index we working on (when doing mutil-RHS)
* @return The residual and HQ residual (if requested)
*/
std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in, quda::ColorSpinorField &out,
quda::GaugeField &fat_link, quda::GaugeField &long_link,
QudaInvertParam &inv_param, int src_idx);
QudaInvertParam &inv_param, int laplace3D, int src_idx);

/**
* @brief Verify a single- or multi-shift staggered inversion on the host
Expand All @@ -231,12 +233,14 @@ std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in, quda:
* @param fat_link The fat links in the context of an ASQTAD solve; otherwise the base gauge links with phases applied
* @param long_link The long links; null for naive staggered and Laplace
* @param inv_param Invert params, used to query the solve type, etc, also includes the shifts
* @param laplace3D Whether we are working on the 3-d Laplace operator
* @param src_idx The source index we working on (when doing mutil-RHS)
* @return The residual and HQ residual (if requested)
*/
std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &in,
std::vector<quda::ColorSpinorField> &out_vector,
quda::GaugeField &fat_link, quda::GaugeField &long_link,
QudaInvertParam &inv_param, int src_idx = 0);
QudaInvertParam &inv_param, int laplace3D, int src_idx = 0);

/**
* @brief Verify a staggered-type eigenvector
Expand Down
6 changes: 3 additions & 3 deletions tests/staggered_invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,9 @@ std::vector<std::array<double, 2>> solve(test_t param)
// Create an appropriate subset of the full out_multishift vector
std::vector<quda::ColorSpinorField> out_subset
= {out_multishift.begin() + n * multishift, out_multishift.begin() + (n + 1) * multishift};
res[n] = verifyStaggeredInversion(in[n], out_subset, cpuFatQDP, cpuLongQDP, inv_param);
res[n] = verifyStaggeredInversion(in[n], out_subset, cpuFatQDP, cpuLongQDP, inv_param, laplace3D);
} else {
res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param, n);
res[n] = verifyStaggeredInversion(in[n], out[n], cpuFatQDP, cpuLongQDP, inv_param, laplace3D, n);
}
}
}
Expand Down Expand Up @@ -510,7 +510,7 @@ int main(int argc, char **argv)
if (!is_staggered(dslash_type) && !is_laplace(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
} else {
if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON");
if (is_laplace(dslash_type)) errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_DIRAC_LAPLACE=ON");
if (!is_staggered(dslash_type)) errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
}

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/set_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ void setStaggeredInvertParam(QudaInvertParam &inv_param)
inv_param.solve_type = solve_type;
inv_param.matpc_type = matpc_type;
inv_param.dagger = QUDA_DAG_NO;
inv_param.mass_normalization = QUDA_MASS_NORMALIZATION;
inv_param.mass_normalization = dslash_type == QUDA_LAPLACE_DSLASH ? QUDA_KAPPA_NORMALIZATION : QUDA_MASS_NORMALIZATION;

inv_param.cpu_prec = cpu_prec;
inv_param.cuda_prec = prec;
Expand Down

0 comments on commit a66dbf0

Please sign in to comment.