Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifications to add use_smeared_gauge to InvertParam #1522

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
8 changes: 4 additions & 4 deletions include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ extern "C" {
/** The t0 parameter for distance preconditioning, the timeslice where the source is located */
int distance_pc_t0;

/** Whether to use the smeared gauge field for the Dirac operator, usually
when defined as a spatial Laplacian: mainly used in computing Laplacian eigenvectors */
QudaBoolean use_smeared_gauge;

} QudaInvertParam;

// Parameter set for solving eigenvalue problems.
Expand Down Expand Up @@ -505,10 +509,6 @@ extern "C" {
false, but preserve_deflation would be true */
QudaBoolean preserve_evals;

/** Whether to use the smeared gauge field for the Dirac operator
for whose eigenvalues are are computing. */
bool use_smeared_gauge;

/** What type of Dirac operator we are using **/
/** If !(use_norm_op) && !(use_dagger) use M. **/
/** If use_dagger, use Mdag **/
Expand Down
2 changes: 1 addition & 1 deletion lib/check_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ void printQudaEigParam(QudaEigParam *param) {
P(preserve_deflation, QUDA_BOOLEAN_FALSE);
P(preserve_deflation_space, 0);
P(preserve_evals, QUDA_BOOLEAN_TRUE);
P(use_smeared_gauge, false);
P(use_dagger, QUDA_BOOLEAN_FALSE);
P(use_norm_op, QUDA_BOOLEAN_FALSE);
P(compute_svd, QUDA_BOOLEAN_FALSE);
Expand Down Expand Up @@ -373,6 +372,7 @@ void printQudaInvertParam(QudaInvertParam *param) {
P(twist_flavor, QUDA_TWIST_INVALID);
P(laplace3D, INVALID_INT);
P(covdev_mu, INVALID_INT);
P(use_smeared_gauge, QUDA_BOOLEAN_FALSE);
#else
// asqtad and domain wall use mass parameterization
if (param->dslash_type == QUDA_STAGGERED_DSLASH || param->dslash_type == QUDA_ASQTAD_DSLASH
Expand Down
44 changes: 29 additions & 15 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1436,9 +1436,10 @@ namespace quda {

void setDiracParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
GaugeField *gaugePtr = (!inv_param->use_smeared_gauge) ? gaugePrecise : gaugeSmeared;
double kappa = inv_param->kappa;
if (inv_param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) {
kappa *= gaugePrecise->Anisotropy();
kappa *= gaugePtr->Anisotropy();
}

switch (inv_param->dslash_type) {
Expand Down Expand Up @@ -1528,7 +1529,7 @@ namespace quda {

diracParam.matpcType = inv_param->matpc_type;
diracParam.dagger = inv_param->dagger;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePrecise;
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise : gaugePtr;
diracParam.fatGauge = gaugeFatPrecise;
diracParam.longGauge = gaugeLongPrecise;
diracParam.clover = cloverPrecise;
Expand Down Expand Up @@ -1562,7 +1563,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_sloppy);
}
Expand All @@ -1580,7 +1581,7 @@ namespace quda {
diracParam.commDim[i] = 1; // comms are always on
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_refinement_sloppy))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_refinement_sloppy);
}
Expand Down Expand Up @@ -1612,24 +1613,37 @@ namespace quda {
diracParam.gauge = gaugeFatPrecondition;
}

if (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition)
if ((!inv_param->use_smeared_gauge) && (diracParam.gauge->Precision() != inv_param->cuda_prec_precondition))
errorQuda("Gauge precision %d does not match requested precision %d\n", diracParam.gauge->Precision(),
inv_param->cuda_prec_precondition);
}

void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc, bool use_smeared_gauge)
void setDiracEigParam(DiracParam &diracParam, QudaInvertParam *inv_param, bool pc)
{
setDiracParam(diracParam, inv_param, pc);

if (inv_param->overlap) {
diracParam.gauge = inv_param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatExtended : gaugeExtended;
diracParam.fatGauge = gaugeFatExtended;
diracParam.longGauge = gaugeLongExtended;
} else if (use_smeared_gauge) {
} else if (inv_param->use_smeared_gauge) {
if (!gaugeSmeared) errorQuda("No smeared gauge field present");
if (inv_param->dslash_type == QUDA_LAPLACE_DSLASH) {
if (gaugeSmeared->GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) {
GaugeFieldParam gauge_param(*gaugePrecise);
GaugeFieldParam gauge_param((gaugePrecise)? *gaugePrecise : *gaugeSmeared);
if (!gaugePrecise){
for (int k=0;k<gauge_param.nDim;++k){
gauge_param.x[k]-=2*gauge_param.r[k]; gauge_param.r[k]=0;} // smearedGauge is loaded as extended, so remove extensions
#ifdef MULTI_GPU
int x_face_size = gauge_param.x[1] * gauge_param.x[2] * gauge_param.x[3] / 2;
int y_face_size = gauge_param.x[0] * gauge_param.x[2] * gauge_param.x[3] / 2;
int z_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[3] / 2;
int t_face_size = gauge_param.x[0] * gauge_param.x[1] * gauge_param.x[2] / 2;
gauge_param.pad = std::max({x_face_size, y_face_size, z_face_size, t_face_size});
#endif
//gauge_param.link_type = QUDA_WILSON_LINKS;
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;}
gauge_param.ghostExchange = QUDA_GHOST_EXCHANGE_PAD;
GaugeField gaugeEig(gauge_param);
copyExtendedGauge(gaugeEig, *gaugeSmeared, QUDA_CUDA_FIELD_LOCATION);
gaugeEig.exchangeGhost();
Expand All @@ -1644,6 +1658,7 @@ namespace quda {
diracParam.fatGauge = gaugeFatEigensolver;
diracParam.longGauge = gaugeLongEigensolver;
}

diracParam.clover = cloverEigensolver;

for (int i = 0; i < 4; i++) { diracParam.commDim[i] = 1; }
Expand Down Expand Up @@ -1697,8 +1712,7 @@ namespace quda {
dRef = Dirac::create(diracRefParam);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge)
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve)
{
DiracParam diracParam;
DiracParam diracSloppyParam;
Expand All @@ -1709,7 +1723,7 @@ namespace quda {
setDiracSloppyParam(diracSloppyParam, &param, pc_solve);
bool pre_comms_flag = (param.schwarz_type != QUDA_INVALID_SCHWARZ) ? false : true;
setDiracPreParam(diracPreParam, &param, pc_solve, pre_comms_flag);
setDiracEigParam(diracEigParam, &param, pc_solve, use_smeared_gauge);
setDiracEigParam(diracEigParam, &param, pc_solve);

d = Dirac::create(diracParam); // create the Dirac operator
dSloppy = Dirac::create(diracSloppyParam);
Expand Down Expand Up @@ -2406,6 +2420,7 @@ void checkClover(QudaInvertParam *param) {
quda::GaugeField *checkGauge(QudaInvertParam *param)
{
quda::GaugeField *U = param->dslash_type == QUDA_ASQTAD_DSLASH ? gaugeFatPrecise :
param->use_smeared_gauge ? gaugeSmeared :
gaugePrecise;

if (U == nullptr)
Expand All @@ -2415,7 +2430,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
errorQuda("Solve precision %d doesn't match gauge precision %d", param->cuda_prec, U->Precision());
}

if (param->dslash_type != QUDA_ASQTAD_DSLASH) {
if (param->dslash_type != QUDA_ASQTAD_DSLASH && !param->use_smeared_gauge) {
if (param->cuda_prec_sloppy != gaugeSloppy->Precision()
|| param->cuda_prec_precondition != gaugePrecondition->Precision()
|| param->cuda_prec_refinement_sloppy != gaugeRefinement->Precision()
Expand All @@ -2433,7 +2448,7 @@ quda::GaugeField *checkGauge(QudaInvertParam *param)
if (gaugeRefinement == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (gaugeEigensolver == nullptr) errorQuda("Refinement gauge field doesn't exist");
if (param->overlap && gaugeExtended == nullptr) errorQuda("Extended gauge field doesn't exist");
} else {
} else if (!param->use_smeared_gauge) {
if (gaugeLongPrecise == nullptr) errorQuda("Precise gauge long field doesn't exist");

if (param->cuda_prec_sloppy != gaugeFatSloppy->Precision()
Expand Down Expand Up @@ -2585,10 +2600,9 @@ void eigensolveQuda(void **host_evecs, double _Complex *host_evals, QudaEigParam

// Create the dirac operator with a sloppy and a precon.
bool pc_solve = (inv_param->solve_type == QUDA_DIRECT_PC_SOLVE) || (inv_param->solve_type == QUDA_NORMOP_PC_SOLVE);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve, eig_param->use_smeared_gauge);
createDiracWithEig(d, dSloppy, dPre, dEig, *inv_param, pc_solve);
Dirac &dirac = *dEig;
//------------------------------------------------------

// Construct vectors
//------------------------------------------------------
// Create host wrappers around application vector set
Expand Down
6 changes: 2 additions & 4 deletions lib/solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,7 @@ namespace quda
getProfile().TPSTOP(QUDA_PROFILE_EPILOGUE);
}

void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve,
bool use_smeared_gauge);
void createDiracWithEig(Dirac *&d, Dirac *&dSloppy, Dirac *&dPre, Dirac *&dEig, QudaInvertParam &param, bool pc_solve);

extern std::vector<ColorSpinorField> solutionResident;

Expand Down Expand Up @@ -349,8 +348,7 @@ namespace quda

// Create the dirac operator and operators for sloppy, precondition,
// and an eigensolver
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve,
param.eig_param ? static_cast<QudaEigParam *>(param.eig_param)->use_smeared_gauge : false);
createDiracWithEig(dirac, diracSloppy, diracPre, diracEig, param, pc_solve);

// wrap CPU host side pointers
ColorSpinorParam cpuParam(hp_b[0], param, u.X(), pc_solution, param.input_location);
Expand Down
2 changes: 1 addition & 1 deletion tests/staggered_eigensolve_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ std::vector<double> eigensolve(test_t test_param)
eig_inv_param.solution_type = eig_param.use_pc ? QUDA_MATPC_SOLUTION : QUDA_MAT_SOLUTION;

// whether we are using the resident smeared gauge or not
eig_param.use_smeared_gauge = gauge_smear;
eig_param.invert_param->use_smeared_gauge = (gauge_smear ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE);

if (dslash_type == QUDA_LAPLACE_DSLASH) {
int dimension = laplace3D < 4 ? 3 : 4;
Expand Down
Loading