From 52bbc434f69d92a31e56d1aaf0ecd875de072b62 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 11 Dec 2024 18:32:02 +0000 Subject: [PATCH] Add nVec_actual to the color_spinor_fields/params - now the MMA kenrels' flops is the 'real' ones. --- include/color_spinor_field.h | 4 ++++ lib/color_spinor_field.cpp | 3 +++ lib/dslash_coarse_mma.in.hpp | 6 ++---- lib/multigrid.in.hpp | 1 + lib/prolongator_mma.in.cu | 4 ++-- lib/restrictor_mma.in.cu | 4 ++-- 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 52106e24fa..7e98828c06 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -125,6 +125,7 @@ namespace quda int nColor = 0; // Number of colors of the field int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor int nVec = 1; // number of packed vectors (for multigrid transfer operator) + int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded) QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID; // used by twisted mass QudaSiteOrder siteOrder = QUDA_INVALID_SITE_ORDER; // defined for full fields @@ -241,6 +242,7 @@ namespace quda nColor(cpuParam.nColor), nSpin(cpuParam.nSpin), nVec(cpuParam.nVec), + nVec_actual(cpuParam.nVec_actual), twistFlavor(cpuParam.twistFlavor), siteOrder(QUDA_EVEN_ODD_SITE_ORDER), fieldOrder(QUDA_INVALID_FIELD_ORDER), @@ -318,6 +320,7 @@ namespace quda int nColor = 0; int nSpin = 0; int nVec = 0; + int nVec_actual = 0; QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID; @@ -455,6 +458,7 @@ namespace quda int Ncolor() const { return nColor; } int Nspin() const { return nSpin; } int Nvec() const { return nVec; } + int Nvec_actual() const { return nVec_actual; } QudaTwistFlavorType TwistFlavor() const { return twistFlavor; } int Ndim() const { return nDim; } const int *X() const { return x.data; } diff --git a/lib/color_spinor_field.cpp b/lib/color_spinor_field.cpp index ebd69a2ec4..4544b81fbc 100644 --- a/lib/color_spinor_field.cpp +++ b/lib/color_spinor_field.cpp @@ -97,6 +97,7 @@ namespace quda nColor = param.nColor; nSpin = param.nSpin; nVec = param.nVec; + nVec_actual = param.nVec_actual; twistFlavor = param.twistFlavor; if (param.pc_type != QUDA_5D_PC && param.pc_type != QUDA_4D_PC) errorQuda("Unexpected pc_type %d", param.pc_type); @@ -229,6 +230,7 @@ namespace quda nColor = std::exchange(src.nColor, 0); nSpin = std::exchange(src.nSpin, 0); nVec = std::exchange(src.nVec, 0); + nVec_actual = std::exchange(src.nVec_actual, 0); twistFlavor = std::exchange(src.twistFlavor, QUDA_TWIST_INVALID); pc_type = std::exchange(src.pc_type, QUDA_PC_INVALID); suggested_parity = std::exchange(src.suggested_parity, QUDA_INVALID_PARITY); @@ -519,6 +521,7 @@ namespace quda param.nColor = nColor; param.nSpin = nSpin; param.nVec = nVec; + param.nVec_actual = nVec_actual; param.twistFlavor = twistFlavor; param.fieldOrder = fieldOrder; param.setPrecision(precision, ghost_precision); // intentionally called here and not in LatticeField diff --git a/lib/dslash_coarse_mma.in.hpp b/lib/dslash_coarse_mma.in.hpp index 09e29f704a..5409880b0e 100644 --- a/lib/dslash_coarse_mma.in.hpp +++ b/lib/dslash_coarse_mma.in.hpp @@ -58,7 +58,7 @@ namespace quda long long flops() const { return ((dslash * 2 * nDim + clover * 1) * (8 * Ns * Nc * Ns * Nc) - 2 * Ns * Nc) * nParity - * static_cast(out.VolumeCB()) * out.size() * out[0].Nvec(); + * static_cast(out.VolumeCB()) * out.size() * out[0].Nvec_actual(); } long long bytes() const @@ -146,9 +146,7 @@ namespace quda if (dslash) { strcat(aux, ",dslash"); } if (clover) { strcat(aux, ",clover"); } - strcat(aux, ",n_rhs="); - char rhs_str[16]; - i32toa(rhs_str, out[0].Nvec()); + setRHSstring(aux, out[0].Nvec_actual()); strcat(aux, rhs_str); #ifdef USE_TENSOR_MEMORY_ACCELERATOR strcat(aux, ",use_tma"); diff --git a/lib/multigrid.in.hpp b/lib/multigrid.in.hpp index 55ac01280c..0b7fcc09f3 100644 --- a/lib/multigrid.in.hpp +++ b/lib/multigrid.in.hpp @@ -33,6 +33,7 @@ namespace quda { template auto create_color_spinor_copy(cvector_ref &fs, QudaFieldOrder order) { ColorSpinorParam param(fs[0]); + param.nVec_actual = fs.size(); int nVec = round_to_nearest_instantiated_nVec(fs.size()); param.nColor = fs[0].Ncolor() * nVec; param.nVec = nVec; diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 9f0c52efb6..bbdfc76875 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -76,7 +76,7 @@ namespace quda strcat(vol, out.VolString().c_str()); strcat(aux, ","); strcat(aux, out.AuxString().c_str()); - setRHSstring(aux, in.Nvec()); + setRHSstring(aux, out.Nvec_actual()); strcat(aux, mma_t::get_type_name().c_str()); @@ -85,7 +85,7 @@ namespace quda long long flops() const { - return nVec * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB(); + return out.Nvec_actual() * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB(); } long long bytes() const diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index b01333d64a..45671e6f19 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -80,7 +80,7 @@ namespace quda strcat(vol, out.VolString().c_str()); strcat(aux, ","); strcat(aux, out.AuxString().c_str()); - setRHSstring(aux, in.Nvec()); + setRHSstring(aux, out.Nvec_actual()); strcat(aux, mma_t::get_type_name().c_str()); strcat(aux, ",aggregate_size_block_max="); @@ -89,7 +89,7 @@ namespace quda apply(device::get_default_stream()); } - long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); } + long long flops() const { return out.Nvec_actual() * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); } long long bytes() const {