Skip to content

Commit

Permalink
Add nVec_actual to the color_spinor_fields/params - now the MMA kenre…
Browse files Browse the repository at this point in the history
…ls' flops is the 'real' ones.
  • Loading branch information
hummingtree committed Dec 11, 2024
1 parent d94156a commit 52bbc43
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
4 changes: 4 additions & 0 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ namespace quda
int nColor = 0; // Number of colors of the field
int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor
int nVec = 1; // number of packed vectors (for multigrid transfer operator)
int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded)

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID; // used by twisted mass
QudaSiteOrder siteOrder = QUDA_INVALID_SITE_ORDER; // defined for full fields
Expand Down Expand Up @@ -241,6 +242,7 @@ namespace quda
nColor(cpuParam.nColor),
nSpin(cpuParam.nSpin),
nVec(cpuParam.nVec),
nVec_actual(cpuParam.nVec_actual),
twistFlavor(cpuParam.twistFlavor),
siteOrder(QUDA_EVEN_ODD_SITE_ORDER),
fieldOrder(QUDA_INVALID_FIELD_ORDER),
Expand Down Expand Up @@ -318,6 +320,7 @@ namespace quda
int nColor = 0;
int nSpin = 0;
int nVec = 0;
int nVec_actual = 0;

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID;

Expand Down Expand Up @@ -455,6 +458,7 @@ namespace quda
int Ncolor() const { return nColor; }
int Nspin() const { return nSpin; }
int Nvec() const { return nVec; }
int Nvec_actual() const { return nVec_actual; }
QudaTwistFlavorType TwistFlavor() const { return twistFlavor; }
int Ndim() const { return nDim; }
const int *X() const { return x.data; }
Expand Down
3 changes: 3 additions & 0 deletions lib/color_spinor_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ namespace quda
nColor = param.nColor;
nSpin = param.nSpin;
nVec = param.nVec;
nVec_actual = param.nVec_actual;
twistFlavor = param.twistFlavor;

if (param.pc_type != QUDA_5D_PC && param.pc_type != QUDA_4D_PC) errorQuda("Unexpected pc_type %d", param.pc_type);
Expand Down Expand Up @@ -229,6 +230,7 @@ namespace quda
nColor = std::exchange(src.nColor, 0);
nSpin = std::exchange(src.nSpin, 0);
nVec = std::exchange(src.nVec, 0);
nVec_actual = std::exchange(src.nVec_actual, 0);
twistFlavor = std::exchange(src.twistFlavor, QUDA_TWIST_INVALID);
pc_type = std::exchange(src.pc_type, QUDA_PC_INVALID);
suggested_parity = std::exchange(src.suggested_parity, QUDA_INVALID_PARITY);
Expand Down Expand Up @@ -519,6 +521,7 @@ namespace quda
param.nColor = nColor;
param.nSpin = nSpin;
param.nVec = nVec;
param.nVec_actual = nVec_actual;
param.twistFlavor = twistFlavor;
param.fieldOrder = fieldOrder;
param.setPrecision(precision, ghost_precision); // intentionally called here and not in LatticeField
Expand Down
6 changes: 2 additions & 4 deletions lib/dslash_coarse_mma.in.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace quda
long long flops() const
{
return ((dslash * 2 * nDim + clover * 1) * (8 * Ns * Nc * Ns * Nc) - 2 * Ns * Nc) * nParity
* static_cast<long long>(out.VolumeCB()) * out.size() * out[0].Nvec();
* static_cast<long long>(out.VolumeCB()) * out.size() * out[0].Nvec_actual();
}

long long bytes() const
Expand Down Expand Up @@ -146,9 +146,7 @@ namespace quda
if (dslash) { strcat(aux, ",dslash"); }
if (clover) { strcat(aux, ",clover"); }

strcat(aux, ",n_rhs=");
char rhs_str[16];
i32toa(rhs_str, out[0].Nvec());
setRHSstring(aux, out[0].Nvec_actual());
strcat(aux, rhs_str);
#ifdef USE_TENSOR_MEMORY_ACCELERATOR
strcat(aux, ",use_tma");
Expand Down
1 change: 1 addition & 0 deletions lib/multigrid.in.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace quda {
template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order)
{
ColorSpinorParam param(fs[0]);
param.nVec_actual = fs.size();
int nVec = round_to_nearest_instantiated_nVec(fs.size());
param.nColor = fs[0].Ncolor() * nVec;
param.nVec = nVec;
Expand Down
4 changes: 2 additions & 2 deletions lib/prolongator_mma.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ namespace quda
strcat(vol, out.VolString().c_str());
strcat(aux, ",");
strcat(aux, out.AuxString().c_str());
setRHSstring(aux, in.Nvec());
setRHSstring(aux, out.Nvec_actual());

strcat(aux, mma_t::get_type_name().c_str());

Expand All @@ -85,7 +85,7 @@ namespace quda

long long flops() const
{
return nVec * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB();
return out.Nvec_actual() * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB();
}

long long bytes() const
Expand Down
4 changes: 2 additions & 2 deletions lib/restrictor_mma.in.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ namespace quda
strcat(vol, out.VolString().c_str());
strcat(aux, ",");
strcat(aux, out.AuxString().c_str());
setRHSstring(aux, in.Nvec());
setRHSstring(aux, out.Nvec_actual());

strcat(aux, mma_t::get_type_name().c_str());
strcat(aux, ",aggregate_size_block_max=");
Expand All @@ -89,7 +89,7 @@ namespace quda
apply(device::get_default_stream());
}

long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); }
long long flops() const { return out.Nvec_actual() * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); }

long long bytes() const
{
Expand Down

0 comments on commit 52bbc43

Please sign in to comment.