Skip to content

Commit

Permalink
Add param and command line option for running transfer with MMA.
Browse files Browse the repository at this point in the history
  • Loading branch information
hummingtree committed Nov 5, 2024
1 parent 6d6beb5 commit 43e7cf7
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 12 deletions.
9 changes: 7 additions & 2 deletions include/multigrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ namespace quda {
/** Whether to use tensor cores (if available) for dslash */
bool dslash_use_mma;

/** Whether to use tensor cores (if available) for transfer */
bool transfer_use_mma;

/**
This is top level instantiation done when we start creating the multigrid operator.
*/
Expand Down Expand Up @@ -203,7 +206,8 @@ namespace quda {
mg_vec_partfile(param.mg_vec_partfile[level]),
transfer_type(param.transfer_type[level]),
setup_use_mma(param.setup_use_mma[level] == QUDA_BOOLEAN_TRUE),
dslash_use_mma(param.dslash_use_mma[level] == QUDA_BOOLEAN_TRUE)
dslash_use_mma(param.dslash_use_mma[level] == QUDA_BOOLEAN_TRUE),
transfer_use_mma(param.transfer_use_mma[level] == QUDA_BOOLEAN_TRUE)
{
// set the block size
for (int i = 0; i < QUDA_MAX_DIM; i++) geoBlockSize[i] = param.geo_block_size[level][i];
Expand Down Expand Up @@ -242,7 +246,8 @@ namespace quda {
mg_vec_partfile(param.mg_global.mg_vec_partfile[level]),
transfer_type(param.mg_global.transfer_type[level]),
setup_use_mma(param.mg_global.setup_use_mma[level] == QUDA_BOOLEAN_TRUE),
dslash_use_mma(param.mg_global.dslash_use_mma[level] == QUDA_BOOLEAN_TRUE)
dslash_use_mma(param.mg_global.dslash_use_mma[level] == QUDA_BOOLEAN_TRUE),
transfer_use_mma(param.mg_global.transfer_use_mma[level] == QUDA_BOOLEAN_TRUE)
{
// set the block size
for (int i = 0; i < QUDA_MAX_DIM; i++) geoBlockSize[i] = param.mg_global.geo_block_size[level][i];
Expand Down
3 changes: 3 additions & 0 deletions include/quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ extern "C" {
/** Dslash MMA usage on each level of the multigrid */
QudaBoolean dslash_use_mma[QUDA_MAX_MG_LEVEL];

/** Transfer MMA usage on each level of the multigrid */
QudaBoolean transfer_use_mma[QUDA_MAX_MG_LEVEL];

/** Inverter to use in the setup phase */
QudaInverterType setup_inv_type[QUDA_MAX_MG_LEVEL];

Expand Down
11 changes: 8 additions & 3 deletions include/transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ namespace quda {
/** Whether the CPU transfer operator has been constructed */
mutable bool enable_cpu = false;

/** Whether to apply the transfer operaton the GPU (requires
/** Whether to apply the transfer operation the GPU (requires
enable_gpu=true in the constructor) */
mutable bool use_gpu;

/** Whether to apply the transfer operation with MMA */
mutable bool _use_mma;

/** Implies whether or not the fine level is a staggered operator, in which
case we don't actually need to allocate any memory. */
mutable QudaTransferType transfer_type;
Expand Down Expand Up @@ -176,6 +179,8 @@ namespace quda {
*/
void reset();

void set_use_mma(bool b) const { _use_mma = b; }

/**
* Apply the prolongator
* @param out The resulting field on the fine lattice
Expand Down Expand Up @@ -319,7 +324,7 @@ namespace quda {
@param[in] parity of the output fine field (if single parity output field)
*/
void Prolongate(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *const *spin_map, int parity = QUDA_INVALID_PARITY);
const int *fine_to_coarse, const int *const *spin_map, bool use_mma, int parity = QUDA_INVALID_PARITY);

template <int coarseColor, int fineColor>
void Prolongate(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
Expand All @@ -340,7 +345,7 @@ namespace quda {
@param[in] parity of the input fine field (if single parity input field)
*/
void Restrict(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity = QUDA_INVALID_PARITY);
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, bool use_mma, int parity = QUDA_INVALID_PARITY);

template <int coarseColor, int fineColor>
void Restrict(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
Expand Down
2 changes: 2 additions & 0 deletions lib/check_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,9 +876,11 @@ void printQudaMultigridParam(QudaMultigridParam *param) {
P(setup_use_mma[i], QUDA_BOOLEAN_FALSE);
#endif
P(dslash_use_mma[i], QUDA_BOOLEAN_FALSE);
P(transfer_use_mma[i], QUDA_BOOLEAN_FALSE);
#else
P(setup_use_mma[i], QUDA_BOOLEAN_INVALID);
P(dslash_use_mma[i], QUDA_BOOLEAN_INVALID);
P(transfer_use_mma[i], QUDA_BOOLEAN_INVALID);
#endif
#ifdef INIT_PARAM
P(setup_inv_type[i], QUDA_BICGSTAB_INVERTER);
Expand Down
4 changes: 3 additions & 1 deletion lib/multigrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ namespace quda
}
} else {
// create transfer operator
logQuda(QUDA_VERBOSE, "Creating transfer operator\n");
logQuda(QUDA_VERBOSE, "Creating transfer operator %s\n",
param.transfer_use_mma == QUDA_BOOLEAN_TRUE ? "with MMA enabled" : "");
transfer = new Transfer(param.B, param.Nvec, param.NblockOrtho, param.blockOrthoTwoPass, param.geoBlockSize,
param.spinBlockSize, param.mg_global.precision_null[param.level],
param.mg_global.transfer_type[param.level]);
transfer->set_use_mma(param.transfer_use_mma == QUDA_BOOLEAN_TRUE);
for (int i = 0; i < QUDA_MAX_MG_LEVEL; i++)
param.mg_global.geo_block_size[param.level][i] = param.geoBlockSize[i];

Expand Down
4 changes: 2 additions & 2 deletions lib/prolongator.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ namespace quda
}

void Prolongate(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *const *spin_map, int parity)
const int *fine_to_coarse, const int *const *spin_map, bool use_mma, int parity)
{
if constexpr (is_enabled_multigrid()) {
if (v.Nspin() != 1 && in[0].GammaBasis() != v.GammaBasis())
Expand All @@ -122,7 +122,7 @@ namespace quda
// clang-format off
IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors;
// clang-format on
if (in.size() % 8 == 0) {
if (use_mma && in.size() % 8 == 0) {
// use MMA
Prolongate<true>(out, in, v, fine_to_coarse, spin_map, parity, fineColors);
} else {
Expand Down
4 changes: 2 additions & 2 deletions lib/restrictor.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ namespace quda
}

void Restrict(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity)
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, bool use_mma, int parity)
{
if constexpr (is_enabled_multigrid()) {
if (v.Nspin() != 1 && out[0].GammaBasis() != v.GammaBasis())
Expand All @@ -122,7 +122,7 @@ namespace quda
IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors;
// clang-format on

if (in.size() % 8 == 0) {
if (use_mma && in.size() % 8 == 0) {
Restrict<true>(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors);
} else {
Restrict<false>(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors);
Expand Down
4 changes: 2 additions & 2 deletions lib/transfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ namespace quda {
if (V.SiteSubset() == QUDA_PARITY_SITE_SUBSET && out.SiteSubset() == QUDA_FULL_SITE_SUBSET)
errorQuda("Cannot prolongate to a full field since only have single parity null-space components");

Prolongate(output, input, V, fine_to_coarse, spin_map, parity);
Prolongate(output, input, V, fine_to_coarse, spin_map, _use_mma, parity);

for (auto i = 0u; i < out.size(); i++) out[i] = output[i]; // copy result to out field (aliasing handled automatically)
} else {
Expand Down Expand Up @@ -475,7 +475,7 @@ namespace quda {
if (V.SiteSubset() == QUDA_PARITY_SITE_SUBSET && in.SiteSubset() == QUDA_FULL_SITE_SUBSET)
errorQuda("Cannot restrict a full field since only have single parity null-space components");

Restrict(output, input, V, fine_to_coarse, coarse_to_fine, spin_map, parity);
Restrict(output, input, V, fine_to_coarse, coarse_to_fine, spin_map, _use_mma, parity);

for (auto i = 0u; i < out.size(); i++) out[i] = output[i]; // copy result to out field (aliasing handled automatically)

Expand Down
3 changes: 3 additions & 0 deletions tests/utils/command_line_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ quda::mgarray<double> mu_factor = {};
quda::mgarray<QudaVerbosity> mg_verbosity = {};
quda::mgarray<bool> mg_setup_use_mma = {};
quda::mgarray<bool> mg_dslash_use_mma = {};
quda::mgarray<bool> mg_transfer_use_mma = {};
quda::mgarray<QudaInverterType> setup_inv = {};
quda::mgarray<QudaSolveType> coarse_solve_type = {};
quda::mgarray<QudaSolveType> smoother_solve_type = {};
Expand Down Expand Up @@ -1098,6 +1099,8 @@ void add_multigrid_option_group(std::shared_ptr<QUDAApp> quda_app)
"Whether multigrid setup should use mma (default to true when supported)");
quda_app->add_mgoption(opgroup, "--mg-dslash-use-mma", mg_dslash_use_mma, CLI::Validator(),
"Whether multigrid dslash should use mma (default to false)");
quda_app->add_mgoption(opgroup, "--mg-transfer-use-mma", mg_transfer_use_mma, CLI::Validator(),
"Whether multigrid transfer should use mma (default to false)");
quda_app->add_mgoption(opgroup, "--mg-verbosity", mg_verbosity, CLI::QUDACheckedTransformer(verbosity_map),
"The verbosity to use on each level of the multigrid (default summarize)");

Expand Down
1 change: 1 addition & 0 deletions tests/utils/command_line_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ extern quda::mgarray<double> mu_factor;
extern quda::mgarray<QudaVerbosity> mg_verbosity;
extern quda::mgarray<bool> mg_setup_use_mma;
extern quda::mgarray<bool> mg_dslash_use_mma;
extern quda::mgarray<bool> mg_transfer_use_mma;
extern quda::mgarray<QudaInverterType> setup_inv;
extern quda::mgarray<QudaSolveType> coarse_solve_type;
extern quda::mgarray<QudaSolveType> smoother_solve_type;
Expand Down
1 change: 1 addition & 0 deletions tests/utils/set_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ void setMultigridParam(QudaMultigridParam &mg_param)
mg_param.verbosity[i] = mg_verbosity[i];
mg_param.setup_use_mma[i] = mg_setup_use_mma[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
mg_param.dslash_use_mma[i] = mg_dslash_use_mma[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
mg_param.transfer_use_mma[i] = mg_transfer_use_mma[i] ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
mg_param.setup_inv_type[i] = setup_inv[i];
mg_param.num_setup_iter[i] = num_setup_iter[i];
mg_param.setup_tol[i] = setup_tol[i];
Expand Down

0 comments on commit 43e7cf7

Please sign in to comment.