From 88a108e255d8260bb9593ccbcfb0e13e5beff0cd Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 27 Aug 2024 12:20:25 -0700 Subject: [PATCH 01/79] Add MMA version of prolongator. --- include/kernels/prolongator_mma.cuh | 191 +++++++++++++++++++++++++ include/transfer.h | 4 + lib/CMakeLists.txt | 4 + lib/prolongator.in.cpp | 86 +++++++++++- lib/prolongator_mma.in.cu | 211 ++++++++++++++++++++++++++++ 5 files changed, 489 insertions(+), 7 deletions(-) create mode 100644 include/kernels/prolongator_mma.cuh create mode 100644 lib/prolongator_mma.in.cu diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh new file mode 100644 index 0000000000..df16df6a99 --- /dev/null +++ b/include/kernels/prolongator_mma.cuh @@ -0,0 +1,191 @@ +#include +#include +#include +#include +#include + +namespace quda +{ + + /** + Kernel argument struct + */ + template + struct ProlongateMmaArg : kernel_param<> { + + static constexpr int block_dim = block_z_ * block_y_; + static constexpr int min_blocks = 1; + + using Float = Float_; + using vFloat = vFloat_; + using real = Float; + using mma_t = mma_t_; + static constexpr int nVec = nVec_; + static constexpr int fineSpin = fineSpin_; + static constexpr int coarseSpin = coarseSpin_; + static constexpr int fineColor = fineColor_; + static constexpr int coarseColor = coarseColor_; + static constexpr bool to_non_rel = to_non_rel_; + + static constexpr int bN = bN_; + static constexpr int bM = bM_; + static constexpr int bK = bK_; + static constexpr int block_y = block_y_; + static constexpr int block_z = block_z_; + + static constexpr QudaFieldOrder csOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + + // disable ghost to reduce arg size + using in_accessor_t = + typename colorspinor::FieldOrderCB; + using out_accessor_t = + typename colorspinor::FieldOrderCB; + using v_accessor_t = + typename colorspinor::FieldOrderCB; + + out_accessor_t out; + const in_accessor_t in; + const v_accessor_t v; + const int *geo_map; // need to make a device copy of this + const spin_mapper spin_map; + const int parity; // the parity of the output field (if single parity) + const int nParity; // number of parities of input fine field + + ProlongateMmaArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, const int *geo_map, + const int parity) : + kernel_param(dim3(out.VolumeCB() * out.SiteSubset() * out.Nspin(), block_y, block_z)), + out(out), + in(in), + v(v), + geo_map(geo_map), + spin_map(), + parity(parity), + nParity(out.SiteSubset()) + { + if (out.Nvec() > get_max_multi_rhs()) + errorQuda("vector set size %d greater than max size %d", out.Nvec(), get_max_multi_rhs()); + if (out.Nvec() != nVec) { errorQuda("out.Nvec() (%d) != nVec (%d)", out.Nvec(), nVec); } + if (in.Nvec() != nVec) { errorQuda("in.Nvec() (%d) != nVec (%d)", in.Nvec(), nVec); } + } + }; + + /** + Applies the grid prolongation operator (coarse to fine) + */ + template + __device__ inline void prolongate_mma(int parity, int x_cb, int spin, int m_offset, int n_offset, const Arg &arg) + { + int x = parity * arg.out.VolumeCB() + x_cb; + int x_coarse = arg.geo_map[x]; + int parity_coarse = (x_coarse >= arg.in.VolumeCB()) ? 1 : 0; + int x_coarse_cb = x_coarse - parity_coarse * arg.in.VolumeCB(); + int spinor_parity = (arg.nParity == 2) ? parity : 0; + int v_parity = (arg.v.Nparity() == 2) ? parity : 0; + + constexpr int M = Arg::fineColor; + constexpr int N = Arg::nVec; + constexpr int K = Arg::coarseColor; + + constexpr int lda = K; + constexpr int ldb = N; + constexpr int ldc = N; + + using mma_t = typename Arg::mma_t; + using Config = mma::MmaConfig; + + static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); + static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); + static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); + + extern __shared__ typename mma_t::compute_t smem_ptr[]; + + typename Config::SmemObjA smem_obj_a_real(smem_ptr); + typename Config::SmemObjA smem_obj_a_imag(smem_obj_a_real.ptr + Config::smem_lda * Arg::bK); + typename Config::SmemObjB smem_obj_b_real(smem_obj_a_imag.ptr + Config::smem_lda * Arg::bK); + typename Config::SmemObjB smem_obj_b_imag(smem_obj_b_real.ptr + Config::smem_ldb * Arg::bK); + + using store_a_t = complex; + using store_b_t = complex; + store_a_t *smem_tmp_a = reinterpret_cast(smem_obj_b_imag.ptr + Config::smem_ldb * Arg::bK); + store_b_t *smem_tmp_b = reinterpret_cast(smem_tmp_a + (Arg::bK + 4) * (Arg::bM + 4)); + + pipeline_t pipe = make_pipeline(); + + typename Config::ALoader a_loader; + typename Config::BLoader b_loader; + + typename Config::Accumulator accumulator((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x); + + auto producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { + auto a = arg.v(v_parity, x_cb, spin, 0, 0); + auto b = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + __syncthreads(); + pipe.producer_acquire(); + scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); + scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); + pipe.producer_commit(); + }; + + auto consumer = [&](float scale_inv_a, float scale_inv_b) { + constexpr bool a_dagger = false; + constexpr bool b_dagger = false; + + using a_wrapper_t = decltype(arg.v(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.in(0, 0, 0, 0, 0)); + constexpr bool a_fixed = a_wrapper_t::fixed; + constexpr bool b_fixed = b_wrapper_t::fixed; + + pipe.consumer_wait(); + __syncthreads(); + a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); + b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); + pipe.consumer_release(); + __syncthreads(); + }; + + auto compute = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; + + accumulator.zero(); + + float scale_inv_a; + float scale_inv_b; + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + producer(scale_inv_a, scale_inv_b, k_offset); + consumer(scale_inv_a, scale_inv_b); + compute(); + } + + // if constexpr (Arg::fineSpin == 4 && Arg::to_non_rel) { + // out.toNonRel(); + // out *= rsqrt(static_cast(2.0)); + // } + + auto c = arg.out(spinor_parity, x_cb, spin, 0, 0); + accumulator.template store(c, m_offset, n_offset, assign_t()); + } + + template struct ProlongatorMma { + const Arg &arg; + constexpr ProlongatorMma(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ inline void operator()() + { + int n_offset = target::block_idx().z * Arg::bN; + int m_offset = target::block_idx().y * Arg::bM; + + int parity_x_cb_spin = target::block_idx().x; + int spin = parity_x_cb_spin % Arg::fineSpin; + int parity_x_cb = parity_x_cb_spin / Arg::fineSpin; + int parity = (arg.nParity == 2) ? parity_x_cb % 2 : arg.parity; + int x_cb = (arg.nParity == 2) ? parity_x_cb / 2 : parity_x_cb; + + prolongate_mma(parity, x_cb, spin, m_offset, n_offset, arg); + } + }; + +} // namespace quda diff --git a/include/transfer.h b/include/transfer.h index 91f388ab74..a64ae0b621 100644 --- a/include/transfer.h +++ b/include/transfer.h @@ -323,6 +323,10 @@ namespace quda { void Prolongate(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *const *spin_map, int parity = QUDA_INVALID_PARITY); + template + void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *const *spin_map, int parity); + /** @brief Apply the restriction operator @param[out] out Resulting coarsened field diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 0b08bb37f4..414198bd78 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -224,6 +224,10 @@ if(QUDA_MULTIGRID) list(PREPEND QUDA_CU_OBJS "prolongator_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}.cu") list(PREPEND QUDA_CU_OBJS "restrictor_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}.cu") list(PREPEND QUDA_CU_OBJS "block_orthogonalize_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}.cu") + foreach(QUDA_MULTIGRID_MRHS ${QUDA_MULTIGRID_MRHS_LIST_SEMICOLON}) + list(PREPEND QUDA_CU_OBJS "prolongator_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu") + configure_file(prolongator_mma.in.cu "prolongator_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu" @ONLY) + endforeach() endif() endforeach() endforeach() diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 286c18ba11..3126ce77c4 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -1,4 +1,5 @@ #include "multigrid.h" +#include namespace quda { @@ -6,26 +7,92 @@ namespace quda template struct IntList { }; - template + template + void ProlongateMma2(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *const *spin_map, int parity, IntList) { + if (out.Nvec() == nVec) { + ProlongateMma(out, in, v, fine_to_coarse, spin_map, parity); + } else { + if constexpr (sizeof...(N) > 0) { + ProlongateMma2(out, in, v, fine_to_coarse, spin_map, parity, IntList()); + } else { + errorQuda("nVec = %d has not been instantiated", out.Nvec()); + } + } + } + + template auto create_color_spinor_copy(cvector_ref &fs, QudaFieldOrder order) + { + ColorSpinorParam param(fs[0]); + int nVec = (fs.size() + 7) / 8 * 8; // Make a multiple of 8 + param.nColor = fs[0].Ncolor() * nVec; + param.nVec = nVec; + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + + auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) + { + ColorSpinorParam param(f); + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + + template void Prolongate2(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *const *spin_map, int parity, IntList) { if (in[0].Ncolor() == coarseColor) { if constexpr (coarseColor >= fineColor) { - Prolongate(out, in, v, fine_to_coarse, spin_map, parity); + if constexpr (use_mma) { + + constexpr QudaFieldOrder csOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + ColorSpinorField v_in = create_color_spinor_copy(in, csOrder); + ColorSpinorField v_out = create_color_spinor_copy(out, csOrder); + ColorSpinorField V = create_color_spinor_copy(v, csOrder); + BlockTransposeForward(v_in, in); + V.copy(v); + + IntList<@QUDA_MULTIGRID_MRHS_LIST@> nvecs; + ProlongateMma2(v_out, v_in, V, fine_to_coarse, spin_map, parity, nvecs); + + BlockTransposeBackward(v_out, out); +#if 0 + std::vector v_cmp(out.size()); + for (int i = 0; i < out.size(); i++) { + ColorSpinorParam param(out[i]); + param.create = QUDA_NULL_FIELD_CREATE; + v_cmp[i] = ColorSpinorField(param); + } + auto vv_cmp = make_set(v_cmp); + Prolongate(vv_cmp, in, v, fine_to_coarse, spin_map, parity); + + blas::mxpy(out, v_cmp); + auto vn = blas::norm2(vv_cmp); + printf("n = "); + for (int i = 0; i < vn.size(); i++) { + printf("%f ", vn[i]); + } + printf("\n"); +#endif + } else { + Prolongate(out, in, v, fine_to_coarse, spin_map, parity); + } } else { errorQuda("Invalid coarseColor = %d, cannot be less than fineColor = %d", coarseColor, fineColor); } } else { if constexpr (sizeof...(N) > 0) { - Prolongate2(out, in, v, fine_to_coarse, spin_map, parity, IntList()); + Prolongate2(out, in, v, fine_to_coarse, spin_map, parity, IntList()); } else { errorQuda("Coarse Nc = %d has not been instantiated", in[0].Ncolor()); } } } - template + template void Prolongate(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *const *spin_map, int parity, IntList) { @@ -33,10 +100,10 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NVEC_LIST@> coarseColors; // clang-format on - Prolongate2(out, in, v, fine_to_coarse, spin_map, parity, coarseColors); + Prolongate2(out, in, v, fine_to_coarse, spin_map, parity, coarseColors); } else { if constexpr (sizeof...(N) > 0) { - Prolongate(out, in, v, fine_to_coarse, spin_map, parity, IntList()); + Prolongate(out, in, v, fine_to_coarse, spin_map, parity, IntList()); } else { errorQuda("Fine Nc = %d has not been instantiated", out[0].Ncolor()); } @@ -54,7 +121,12 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); + if (out[0].Ncolor() != 3) { + // use MMA + Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); + } else { + Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); + } } else { errorQuda("Multigrid has not been built"); } diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu new file mode 100644 index 0000000000..ba9b20cbbf --- /dev/null +++ b/lib/prolongator_mma.in.cu @@ -0,0 +1,211 @@ +#include +#include +#include +#include +#include + +namespace quda +{ + + template + class ProlongateLaunchMma : public TunableKernel + { + + ColorSpinorField &out; + const ColorSpinorField ∈ + const ColorSpinorField &V; + const int *fine_to_coarse; + int parity; + QudaFieldLocation location; + + bool checkParam(const TuneParam ¶m) const { return true; } + + unsigned int sharedBytesPerThread() const { return 0; } + + bool advanceTuneParam(TuneParam ¶m) const { return false; } + + void initTuneParam(TuneParam ¶m) const + { + param.aux.x = 0; + param.aux.y = 0; + param.aux.z = 0; + param.aux.w = 0; + set_mma_param(param); + } + + /** sets default values for when tuning is disabled */ + void defaultTuneParam(TuneParam ¶m) const + { + param.aux.x = 0; + param.aux.y = 0; + param.aux.z = 0; + param.aux.w = 0; + set_mma_param(param); + } + + public: + ProlongateLaunchMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &V, + const int *fine_to_coarse, int parity) : + TunableKernel(in), + out(out), + in(in), + V(V), + fine_to_coarse(fine_to_coarse), + parity(parity), + location(checkLocation(out, in, V)) + { + printf("out.Location() = %d, parity = %d\n", out.Location(), parity); + strcat(vol, ","); + strcat(vol, out.VolString().c_str()); + strcat(aux, ","); + strcat(aux, out.AuxString().c_str()); + if (out.GammaBasis() == QUDA_UKQCD_GAMMA_BASIS) strcat(aux, ",to_non_rel"); + + apply(device::get_default_stream()); + } + + using mma_t = simt::simt_t; + static constexpr int n_atom_size = nVec; + static constexpr int m_atom_size = fineColor; + static constexpr int k_atom_size = coarseColor; + + long long flops() const + { + return nVec * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB(); + } + + long long bytes() const + { + size_t v_bytes = V.Bytes() / (V.SiteSubset() == out.SiteSubset() ? 1 : 2); + return in.Bytes() + out.Bytes() + nVec * (v_bytes + out.SiteSubset() * out.VolumeCB() * sizeof(int)); + } + + bool set_mma_param(TuneParam &tp) const + { + tp.block.x = 1; + tp.block.y = 16; + tp.block.z = 8; + + int bN = nVec; + int bM = fineColor; + int bK = coarseColor; + + tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, fineColor / bM, nVec / bN); + tp.set_max_shared_bytes = true; + + int shared_bytes = shared_bytes_per_block(bM, bN, bK); + tp.shared_bytes = shared_bytes; + + return shared_bytes <= device::maximum_dynamic_shared_memory(); + } + + static constexpr int shared_bytes_per_block(int bM, int bN, int bK) + { + return mma::shared_memory_bytes(bM, bN, bK) + (bM + 4) * (bK + 4) * 2 * sizeof(vFloat) + + (bK + 4) * (bN + 4) * 2 * sizeof(Float); + } + + template + void launch_mma(TuneParam &tp, const qudaStream_t &stream) + { + constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); + if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { + constexpr bool to_non_rel = false; + using Arg = ProlongateMmaArg; + Arg arg(out, in, V, fine_to_coarse, parity); + tp.set_max_shared_bytes = true; + launch_cuda(tp, stream, arg); + } else { + errorQuda("Using too many shared memory bytes per block: %d", shared_bytes); + } + } + + void apply(const qudaStream_t &stream) + { + constexpr int block_y = 16; + constexpr int block_z = 8; + constexpr int bN = nVec; + constexpr int bM = fineColor; + constexpr int bK = coarseColor; + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + launch_mma(tp, stream); + } + }; + + template + void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *const *spin_map, int parity) + { + if (in.Nspin() != 2) errorQuda("Coarse spin %d is not supported", in.Nspin()); + constexpr int coarseSpin = 2; + + // first check that the spin_map matches the spin_mapper + spin_mapper mapper; + for (int s = 0; s < fineSpin; s++) + for (int p = 0; p < 2; p++) + if (mapper(s, p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper"); + + if (v.Precision() == QUDA_HALF_PRECISION) { + if constexpr (is_enabled(QUDA_HALF_PRECISION)) { + ProlongateLaunchMma prolongator( + out, in, v, fine_to_coarse, parity); + } else { + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); + } + } else if (v.Precision() == in.Precision()) { + ProlongateLaunchMma prolongator( + out, in, v, fine_to_coarse, parity); + } else { + errorQuda("Unsupported V precision %d", v.Precision()); + } + } + + template + void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *const *spin_map, int parity) + { + if (!is_enabled_spin(out.Nspin())) errorQuda("nSpin %d has not been built", in.Nspin()); + + if (out.Nspin() == 2) { + ProlongateMma(out, in, v, fine_to_coarse, spin_map, parity); + } else if constexpr (fineColor == 3) { + if (out.Nspin() == 4) { + if constexpr (is_enabled_spin(4)) + ProlongateMma(out, in, v, fine_to_coarse, spin_map, parity); + } else if (out.Nspin() == 1) { + if constexpr (is_enabled_spin(1)) + ProlongateMma(out, in, v, fine_to_coarse, spin_map, parity); + } else { + errorQuda("Unsupported nSpin %d", out.Nspin()); + } + } else { + errorQuda("Unexpected spin %d and color %d combination", out.Nspin(), out.Ncolor()); + } + } + + constexpr int fineColor = @QUDA_MULTIGRID_NC_NVEC @; + constexpr int coarseColor = @QUDA_MULTIGRID_NVEC2 @; + constexpr int nVec = @QUDA_MULTIGRID_MRHS @; + + template <> + void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, + const ColorSpinorField &v, const int *fine_to_coarse, + const int *const *spin_map, int parity) + { + if constexpr (is_enabled_multigrid() && fineColor > 3) { + QudaPrecision precision = checkPrecision(out, in); + + if (precision == QUDA_DOUBLE_PRECISION) { + errorQuda("ProlongateMma with double precision has not been enabled"); + } else if (precision == QUDA_SINGLE_PRECISION) { + ProlongateMma(out, in, v, fine_to_coarse, spin_map, parity); + } else { + errorQuda("Unsupported precision %d", out.Precision()); + } + } else { + errorQuda("Multigrid has not been built"); + } + } + +} // end namespace quda From c0cb6e082433e962d7d92e1e9600dc1b68ea20b1 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 27 Aug 2024 15:42:47 -0700 Subject: [PATCH 02/79] Dagger'ed the equation to make better use of MMA. --- include/kernels/prolongator_mma.cuh | 35 ++++++++++++++++------------- lib/prolongator_mma.in.cu | 22 ++++++++++-------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index df16df6a99..b17e973fd2 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -83,13 +83,15 @@ namespace quda int spinor_parity = (arg.nParity == 2) ? parity : 0; int v_parity = (arg.v.Nparity() == 2) ? parity : 0; - constexpr int M = Arg::fineColor; - constexpr int N = Arg::nVec; + // Everything is dagger'ed since coarseColor >= fineColor + + constexpr int M = Arg::nVec; + constexpr int N = Arg::fineColor; constexpr int K = Arg::coarseColor; - constexpr int lda = K; - constexpr int ldb = N; - constexpr int ldc = N; + constexpr int lda = M; + constexpr int ldb = K; + constexpr int ldc = M; using mma_t = typename Arg::mma_t; using Config = mma::MmaConfig; @@ -105,8 +107,8 @@ namespace quda typename Config::SmemObjB smem_obj_b_real(smem_obj_a_imag.ptr + Config::smem_lda * Arg::bK); typename Config::SmemObjB smem_obj_b_imag(smem_obj_b_real.ptr + Config::smem_ldb * Arg::bK); - using store_a_t = complex; - using store_b_t = complex; + using store_a_t = complex; + using store_b_t = complex; store_a_t *smem_tmp_a = reinterpret_cast(smem_obj_b_imag.ptr + Config::smem_ldb * Arg::bK); store_b_t *smem_tmp_b = reinterpret_cast(smem_tmp_a + (Arg::bK + 4) * (Arg::bM + 4)); @@ -118,10 +120,10 @@ namespace quda typename Config::Accumulator accumulator((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x); auto producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { - auto a = arg.v(v_parity, x_cb, spin, 0, 0); - auto b = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); - constexpr bool a_dagger = false; - constexpr bool b_dagger = false; + auto a = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); + auto b = arg.v(v_parity, x_cb, spin, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = true; __syncthreads(); pipe.producer_acquire(); @@ -131,11 +133,11 @@ namespace quda }; auto consumer = [&](float scale_inv_a, float scale_inv_b) { - constexpr bool a_dagger = false; - constexpr bool b_dagger = false; + constexpr bool a_dagger = true; + constexpr bool b_dagger = true; - using a_wrapper_t = decltype(arg.v(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.in(0, 0, 0, 0, 0)); + using a_wrapper_t = decltype(arg.in(0, 0, 0, 0, 0)); + using b_wrapper_t = decltype(arg.v(0, 0, 0, 0, 0)); constexpr bool a_fixed = a_wrapper_t::fixed; constexpr bool b_fixed = b_wrapper_t::fixed; @@ -165,7 +167,8 @@ namespace quda // } auto c = arg.out(spinor_parity, x_cb, spin, 0, 0); - accumulator.template store(c, m_offset, n_offset, assign_t()); + constexpr bool c_dagger = true; + accumulator.template store(c, m_offset, n_offset, assign_t()); } template struct ProlongatorMma { diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index ba9b20cbbf..883613956f 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -64,7 +64,9 @@ namespace quda apply(device::get_default_stream()); } - using mma_t = simt::simt_t; + // using mma_t = simt::simt_t; + // using mma_t = smma::smma_t; // 3xTF32 + using mma_t = typename mma::smma_dispatch::type; static constexpr int n_atom_size = nVec; static constexpr int m_atom_size = fineColor; static constexpr int k_atom_size = coarseColor; @@ -86,11 +88,11 @@ namespace quda tp.block.y = 16; tp.block.z = 8; - int bN = nVec; - int bM = fineColor; + int bN = fineColor; + int bM = nVec; int bK = coarseColor; - tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, fineColor / bM, nVec / bN); + tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, nVec / bM, fineColor / bN); tp.set_max_shared_bytes = true; int shared_bytes = shared_bytes_per_block(bM, bN, bK); @@ -125,8 +127,8 @@ namespace quda { constexpr int block_y = 16; constexpr int block_z = 8; - constexpr int bN = nVec; - constexpr int bM = fineColor; + constexpr int bN = fineColor; + constexpr int bM = nVec; constexpr int bK = coarseColor; TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); launch_mma(tp, stream); @@ -184,9 +186,11 @@ namespace quda } } - constexpr int fineColor = @QUDA_MULTIGRID_NC_NVEC @; - constexpr int coarseColor = @QUDA_MULTIGRID_NVEC2 @; - constexpr int nVec = @QUDA_MULTIGRID_MRHS @; + // clang-format on + constexpr int fineColor = @QUDA_MULTIGRID_NC_NVEC@; + constexpr int coarseColor = @QUDA_MULTIGRID_NVEC2@; + constexpr int nVec = @QUDA_MULTIGRID_MRHS@; + // clang-format off template <> void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, From dd55d92b418cbe295a447d20a627c94456b413bd Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 28 Aug 2024 13:54:03 -0700 Subject: [PATCH 03/79] Make nColor = 3 works: - Still need to add the spin factor of 2; - Still need to cover the to_non_rel = true; --- include/kernels/prolongator_mma.cuh | 52 ++----- .../cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh | 12 +- include/targets/cuda/mma_tensor_op/simt.cuh | 2 +- .../cuda/mma_tensor_op/smma_m16n8_sm80.cuh | 2 +- lib/block_transpose.in.cu | 5 +- lib/prolongator.in.cpp | 6 +- lib/prolongator_mma.in.cu | 140 +++++++++++++++--- 7 files changed, 145 insertions(+), 74 deletions(-) diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index b17e973fd2..397ac1d5ca 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -63,7 +63,7 @@ namespace quda parity(parity), nParity(out.SiteSubset()) { - if (out.Nvec() > get_max_multi_rhs()) + if (out.Nvec() > static_cast(get_max_multi_rhs())) errorQuda("vector set size %d greater than max size %d", out.Nvec(), get_max_multi_rhs()); if (out.Nvec() != nVec) { errorQuda("out.Nvec() (%d) != nVec (%d)", out.Nvec(), nVec); } if (in.Nvec() != nVec) { errorQuda("in.Nvec() (%d) != nVec (%d)", in.Nvec(), nVec); } @@ -97,7 +97,6 @@ namespace quda using Config = mma::MmaConfig; static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); - static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); extern __shared__ typename mma_t::compute_t smem_ptr[]; @@ -119,46 +118,21 @@ namespace quda typename Config::Accumulator accumulator((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x); - auto producer = [&](float &scale_inv_a, float &scale_inv_b, int k_offset) { - auto a = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); - auto b = arg.v(v_parity, x_cb, spin, 0, 0); - constexpr bool a_dagger = true; - constexpr bool b_dagger = true; - - __syncthreads(); - pipe.producer_acquire(); - scale_inv_a = a_loader.template g2tmp(a, m_offset, k_offset, smem_tmp_a, pipe); - scale_inv_b = b_loader.template g2tmp(b, n_offset, k_offset, smem_tmp_b, pipe); - pipe.producer_commit(); - }; - - auto consumer = [&](float scale_inv_a, float scale_inv_b) { - constexpr bool a_dagger = true; - constexpr bool b_dagger = true; - - using a_wrapper_t = decltype(arg.in(0, 0, 0, 0, 0)); - using b_wrapper_t = decltype(arg.v(0, 0, 0, 0, 0)); - constexpr bool a_fixed = a_wrapper_t::fixed; - constexpr bool b_fixed = b_wrapper_t::fixed; - - pipe.consumer_wait(); - __syncthreads(); - a_loader.template tmp2s(smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - b_loader.template tmp2s(smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); - pipe.consumer_release(); - __syncthreads(); - }; - - auto compute = [&]() { accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); }; - accumulator.zero(); - float scale_inv_a; - float scale_inv_b; + auto a = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); + auto b = arg.v(v_parity, x_cb, spin, 0, 0); + constexpr bool a_dagger = true; + constexpr bool b_dagger = true; + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - producer(scale_inv_a, scale_inv_b, k_offset); - consumer(scale_inv_a, scale_inv_b); - compute(); + __syncthreads(); + a_loader.template g2r(a, m_offset, k_offset); + b_loader.template g2r(b, n_offset, k_offset); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + __syncthreads(); + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); } // if constexpr (Arg::fineSpin == 4 && Arg::to_non_rel) { diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh index 39d99bc93a..bc9210b3b6 100644 --- a/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n16k4_sm70.cuh @@ -221,15 +221,19 @@ namespace quda auto scale = cc.get_scale(); s = {f2i_round(op_c_real.reg[i * 2 + 0] * scale), f2i_round(-op_c_imag.reg[i * 2 + 0] * scale)}; - op(&ptr[(n_index + 0) * ldc + m_index], s); + if (!check_bounds || (m_index < M && (n_index + 0) < N)) { op(&ptr[(n_index + 0) * ldc + m_index], s); } + // op(&ptr[(n_index + 0) * ldc + m_index], s); s = {f2i_round(op_c_real.reg[i * 2 + 1] * scale), f2i_round(-op_c_imag.reg[i * 2 + 1] * scale)}; - op(&ptr[(n_index + 1) * ldc + m_index], s); + if (!check_bounds || (m_index < M && (n_index + 1) < N)) { op(&ptr[(n_index + 1) * ldc + m_index], s); } + // op(&ptr[(n_index + 1) * ldc + m_index], s); } else { s = {op_c_real.reg[i * 2 + 0], -op_c_imag.reg[i * 2 + 0]}; - op(&ptr[(n_index + 0) * ldc + m_index], s); + if (!check_bounds || (m_index < M && (n_index + 0) < N)) { op(&ptr[(n_index + 0) * ldc + m_index], s); } + // op(&ptr[(n_index + 0) * ldc + m_index], s); s = {op_c_real.reg[i * 2 + 1], -op_c_imag.reg[i * 2 + 1]}; - op(&ptr[(n_index + 1) * ldc + m_index], s); + if (!check_bounds || (m_index < M && (n_index + 1) < N)) { op(&ptr[(n_index + 1) * ldc + m_index], s); } + // op(&ptr[(n_index + 1) * ldc + m_index], s); } } else { using array_t = typename VectorType::type; // array; diff --git a/include/targets/cuda/mma_tensor_op/simt.cuh b/include/targets/cuda/mma_tensor_op/simt.cuh index 1e216e07cd..7b5b854eaf 100644 --- a/include/targets/cuda/mma_tensor_op/simt.cuh +++ b/include/targets/cuda/mma_tensor_op/simt.cuh @@ -173,7 +173,7 @@ namespace quda int m = m_offset + wrm.idx_m * warp_m + wm; int n = n_offset + wrm.idx_n * warp_n + wn; if constexpr (dagger) { - if (!check_bounds || (m < N && n < M)) { + if (!check_bounds || (m < M && n < N)) { if constexpr (gmem_op_t::fixed) { auto scale = cc.get_scale(); complex_t out = {f2i_round(scale * op_c_real.reg[wn * warp_m + wm]), diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh index b08d5c0b7e..fe3183c128 100644 --- a/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8_sm80.cuh @@ -412,7 +412,7 @@ namespace quda for (int wm = 0; wm < warp_m; wm++) { int m = m_offset + wm * inst_m + (wrm.group_id + tm * 8); int n = n_offset + wn * inst_n + (wrm.thread_id_in_group * 2 + tn); - if (!check_bounds || (m < N && n < M)) { + if (!check_bounds || (m < M && n < N)) { int reg_index = (wn * warp_m + wm) * thread_count + tm * thread_n + tn; if constexpr (gmem_op_t::fixed) { auto scale = cc.get_scale(); diff --git a/lib/block_transpose.in.cu b/lib/block_transpose.in.cu index 505613fdb7..c29227b025 100644 --- a/lib/block_transpose.in.cu +++ b/lib/block_transpose.in.cu @@ -59,7 +59,6 @@ namespace quda { Arg arg(V, B, tp.block.x, tp.block.y); tp.set_max_shared_bytes = true; - resizeVector(tp.block.y * tp.grid.y); // We need a full threadblock launch_device(tp, stream, arg); } @@ -135,7 +134,7 @@ namespace quda if constexpr (sizeof...(N) > 0) { launch_span_nColor(V, B, nVecs); } else { - errorQuda("nColor = %d not instantiated", V.Ncolor()); + errorQuda("nColor = %d not instantiated", B.Ncolor()); } } } @@ -147,7 +146,7 @@ namespace quda errorQuda("V.Ncolor() / V.Nvec() (=%d) != B.Ncolor() (=%d)", V.Ncolor() / V.Nvec(), B[0].Ncolor()); } - IntList<@QUDA_MULTIGRID_NVEC_LIST@> nColors; + IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> nColors; launch_span_nColor(V, B, nColors); } diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 3126ce77c4..0947a72e84 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -61,7 +61,7 @@ namespace quda BlockTransposeBackward(v_out, out); #if 0 std::vector v_cmp(out.size()); - for (int i = 0; i < out.size(); i++) { + for (size_t i = 0; i < out.size(); i++) { ColorSpinorParam param(out[i]); param.create = QUDA_NULL_FIELD_CREATE; v_cmp[i] = ColorSpinorField(param); @@ -72,7 +72,7 @@ namespace quda blas::mxpy(out, v_cmp); auto vn = blas::norm2(vv_cmp); printf("n = "); - for (int i = 0; i < vn.size(); i++) { + for (size_t i = 0; i < vn.size(); i++) { printf("%f ", vn[i]); } printf("\n"); @@ -121,7 +121,7 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - if (out[0].Ncolor() != 3) { + if (1) { // use MMA Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); } else { diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 883613956f..1eba273ed1 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -3,6 +3,7 @@ #include #include #include +#include namespace quda { @@ -22,7 +23,39 @@ namespace quda unsigned int sharedBytesPerThread() const { return 0; } - bool advanceTuneParam(TuneParam ¶m) const { return false; } + bool advanceTuneParam(TuneParam ¶m) const + { + auto advancer = [&](int &i, int limit) -> bool { + if (i < limit) { + i++; + return set_mma_param(param); + } else { + return false; + } + }; + + if (advancer(param.aux.x, 2)) { + return true; + } else { + param.aux.x = 0; + if (advancer(param.aux.y, numFactors(n / n_atom_size) - 1)) { + return true; + } else { + param.aux.y = 0; + if (advancer(param.aux.z, numFactors(m / m_atom_size) - 1)) { + return true; + } else { + param.aux.z = 0; + if (advancer(param.aux.w, numFactors(k / k_atom_size) - 1)) { + return true; + } else { + param.aux.w = 0; + return false; + } + } + } + } + } void initTuneParam(TuneParam ¶m) const { @@ -54,7 +87,6 @@ namespace quda parity(parity), location(checkLocation(out, in, V)) { - printf("out.Location() = %d, parity = %d\n", out.Location(), parity); strcat(vol, ","); strcat(vol, out.VolString().c_str()); strcat(aux, ","); @@ -67,9 +99,14 @@ namespace quda // using mma_t = simt::simt_t; // using mma_t = smma::smma_t; // 3xTF32 using mma_t = typename mma::smma_dispatch::type; - static constexpr int n_atom_size = nVec; - static constexpr int m_atom_size = fineColor; - static constexpr int k_atom_size = coarseColor; + + static constexpr int m = nVec; + static constexpr int n = fineColor; + static constexpr int k = coarseColor; + + static constexpr int n_atom_size = mma_t::MMA_N; + static constexpr int m_atom_size = mma_t::MMA_M; + static constexpr int k_atom_size = k / 2; long long flops() const { @@ -82,31 +119,34 @@ namespace quda return in.Bytes() + out.Bytes() + nVec * (v_bytes + out.SiteSubset() * out.VolumeCB() * sizeof(int)); } + static constexpr int shared_bytes_per_block(int bM, int bN, int bK) + { + return mma::shared_memory_bytes(bM, bN, bK) + (bM + 4) * (bK + 4) * 2 * sizeof(vFloat) + + (bK + 4) * (bN + 4) * 2 * sizeof(Float); + } + bool set_mma_param(TuneParam &tp) const { + static_assert(m % m_atom_size == 0, "m modulo m_atom_size == 0"); + static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); + tp.block.x = 1; - tp.block.y = 16; + tp.block.y = k / (1 << tp.aux.x); tp.block.z = 8; - int bN = fineColor; - int bM = nVec; - int bK = coarseColor; + int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; + int bM = m_atom_size * get_int_factor_array((m + m_atom_size - 1) / m_atom_size)[tp.aux.z]; - tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, nVec / bM, fineColor / bN); + tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; + int bK = k_atom_size * get_int_factor_array(k / k_atom_size)[tp.aux.w]; int shared_bytes = shared_bytes_per_block(bM, bN, bK); tp.shared_bytes = shared_bytes; return shared_bytes <= device::maximum_dynamic_shared_memory(); } - static constexpr int shared_bytes_per_block(int bM, int bN, int bK) - { - return mma::shared_memory_bytes(bM, bN, bK) + (bM + 4) * (bK + 4) * 2 * sizeof(vFloat) - + (bK + 4) * (bN + 4) * 2 * sizeof(Float); - } - template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { @@ -123,15 +163,69 @@ namespace quda } } + template + void launch_mma_span_k(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.w == d) { + constexpr IntFactorArray k_factors; + launch_mma(tp, stream); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_k(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.z."); + } + } + } + + template + void launch_mma_span_m(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.z == d) { + constexpr IntFactorArray<(m + m_atom_size - 1) / m_atom_size> m_factors; + std::make_index_sequence().size()> k_indices; + launch_mma_span_k(tp, stream, k_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_m(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.z."); + } + } + } + + template + void launch_mma_span_n(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.y == d) { + constexpr IntFactorArray<(n + n_atom_size - 1) / n_atom_size> n_factors; + std::make_index_sequence().size()> m_indices; + launch_mma_span_m(tp, stream, m_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_n(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.y."); + } + } + } + + void launch_mma(TuneParam &tp, const qudaStream_t &stream) + { + std::make_index_sequence().size()> n_indices; + + switch (tp.aux.x) { + case 0: launch_mma_span_n(tp, stream, n_indices); break; + case 1: launch_mma_span_n(tp, stream, n_indices); break; + case 2: launch_mma_span_n(tp, stream, n_indices); break; + default: errorQuda("tp.aux.x = %d not supported", tp.aux.x); + } + } + void apply(const qudaStream_t &stream) { - constexpr int block_y = 16; - constexpr int block_z = 8; - constexpr int bN = fineColor; - constexpr int bM = nVec; - constexpr int bK = coarseColor; TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - launch_mma(tp, stream); + launch_mma(tp, stream); } }; @@ -197,7 +291,7 @@ namespace quda const ColorSpinorField &v, const int *fine_to_coarse, const int *const *spin_map, int parity) { - if constexpr (is_enabled_multigrid() && fineColor > 3) { + if constexpr (is_enabled_multigrid()) { QudaPrecision precision = checkPrecision(out, in); if (precision == QUDA_DOUBLE_PRECISION) { From 57b6f953ee0445e1378ad56e5893508b5cd48a81 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Thu, 29 Aug 2024 08:08:49 -0700 Subject: [PATCH 04/79] Add from_to_non_rel to block transpose to complete the circle. --- include/kernels/block_transpose.cuh | 8 ++- include/kernels/prolongator_mma.cuh | 26 +++------ include/multigrid_helper.cuh | 4 ++ include/transfer.h | 6 +- lib/block_transpose.in.cu | 90 ++++++++++++++--------------- lib/prolongator.in.cpp | 3 +- lib/prolongator_mma.in.cu | 12 ++-- 7 files changed, 77 insertions(+), 72 deletions(-) diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 91aaec4d63..1a9d6ae949 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -12,7 +12,7 @@ namespace quda Kernel argument struct */ template + typename bAccessor, int nSpin_, int nColor_, int nVec_, bool from_to_non_rel_> struct BlockTransposeArg : kernel_param<> { using real = typename mapper::type; static constexpr bool is_device = is_device_; @@ -20,6 +20,8 @@ namespace quda static constexpr int nColor = nColor_; static constexpr int nVec = nVec_; + static constexpr int from_to_non_rel = from_to_non_rel_; + using v_t = v_t_; using b_t = b_t_; @@ -94,6 +96,10 @@ namespace quda int x = target::thread_idx().x; if (x_cb < arg.volume_cb && v + v_offset < arg.actual_nvec) { color_spinor_t color_spinor = cache.load(x, v); + if constexpr (Arg::from_to_non_rel && Arg::nSpin == 4) { + color_spinor.toNonRel(); + color_spinor *= rsqrt(static_cast(2.0)); + } #pragma unroll for (int spin = 0; spin < Arg::nSpin; spin++) { arg.B[v + v_offset](parity, x_cb, spin, color) = color_spinor(spin, 0); diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index 397ac1d5ca..7a1b10c5f9 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -44,6 +44,8 @@ namespace quda using v_accessor_t = typename colorspinor::FieldOrderCB; + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + out_accessor_t out; const in_accessor_t in; const v_accessor_t v; @@ -86,7 +88,7 @@ namespace quda // Everything is dagger'ed since coarseColor >= fineColor constexpr int M = Arg::nVec; - constexpr int N = Arg::fineColor; + constexpr int N = Arg::fineColor * Arg::spin_block_factor; constexpr int K = Arg::coarseColor; constexpr int lda = M; @@ -106,13 +108,6 @@ namespace quda typename Config::SmemObjB smem_obj_b_real(smem_obj_a_imag.ptr + Config::smem_lda * Arg::bK); typename Config::SmemObjB smem_obj_b_imag(smem_obj_b_real.ptr + Config::smem_ldb * Arg::bK); - using store_a_t = complex; - using store_b_t = complex; - store_a_t *smem_tmp_a = reinterpret_cast(smem_obj_b_imag.ptr + Config::smem_ldb * Arg::bK); - store_b_t *smem_tmp_b = reinterpret_cast(smem_tmp_a + (Arg::bK + 4) * (Arg::bM + 4)); - - pipeline_t pipe = make_pipeline(); - typename Config::ALoader a_loader; typename Config::BLoader b_loader; @@ -120,8 +115,8 @@ namespace quda accumulator.zero(); - auto a = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin, parity), 0, 0); - auto b = arg.v(v_parity, x_cb, spin, 0, 0); + auto a = arg.in(parity_coarse, x_coarse_cb, arg.spin_map(spin * Arg::spin_block_factor, parity), 0, 0); + auto b = arg.v(v_parity, x_cb, spin * Arg::spin_block_factor, 0, 0); constexpr bool a_dagger = true; constexpr bool b_dagger = true; @@ -135,12 +130,7 @@ namespace quda accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); } - // if constexpr (Arg::fineSpin == 4 && Arg::to_non_rel) { - // out.toNonRel(); - // out *= rsqrt(static_cast(2.0)); - // } - - auto c = arg.out(spinor_parity, x_cb, spin, 0, 0); + auto c = arg.out(spinor_parity, x_cb, spin * Arg::spin_block_factor, 0, 0); constexpr bool c_dagger = true; accumulator.template store(c, m_offset, n_offset, assign_t()); } @@ -156,8 +146,8 @@ namespace quda int m_offset = target::block_idx().y * Arg::bM; int parity_x_cb_spin = target::block_idx().x; - int spin = parity_x_cb_spin % Arg::fineSpin; - int parity_x_cb = parity_x_cb_spin / Arg::fineSpin; + int spin = parity_x_cb_spin % (Arg::fineSpin / Arg::spin_block_factor); + int parity_x_cb = parity_x_cb_spin / (Arg::fineSpin / Arg::spin_block_factor); int parity = (arg.nParity == 2) ? parity_x_cb % 2 : arg.parity; int x_cb = (arg.nParity == 2) ? parity_x_cb / 2 : parity_x_cb; diff --git a/include/multigrid_helper.cuh b/include/multigrid_helper.cuh index aca885e06d..0bc8d9a30a 100644 --- a/include/multigrid_helper.cuh +++ b/include/multigrid_helper.cuh @@ -11,6 +11,10 @@ namespace quda { // fineSpin == 1, coarseSpin == 2 identifies staggered fine -> coarse w/ spin. static constexpr int spin_block_size = (fineSpin == 1 && coarseSpin == 2) ? 0 : fineSpin / coarseSpin; + static constexpr int get_spin_block_factor() { + return (spin_block_size == 0) ? 1 : spin_block_size; + } + /** Return the coarse spin coordinate from the fine spin coordinate @param s Fine spin coordinate diff --git a/include/transfer.h b/include/transfer.h index a64ae0b621..f3fd8fcbe6 100644 --- a/include/transfer.h +++ b/include/transfer.h @@ -295,8 +295,9 @@ namespace quda { - V: spatial -> spin/color -> nVec @param[out] The output V Matrix field @param[in] B input vectors + @param[in] from_non_rel whether or not transform B from non-reletivistic basis */ - void BlockTransposeForward(ColorSpinorField &V, const cvector_ref &B); + void BlockTransposeForward(ColorSpinorField &V, const cvector_ref &B, bool from_non_rel = false); /** @brief Transpose the a composite V field into B vectors: @@ -304,8 +305,9 @@ namespace quda { - V: spatial -> spin/color -> nVec @param[in] The output V Matrix field @param[out] B input vectors + @param[in] from_non_rel whether or not transform B to non-reletivistic basis */ - void BlockTransposeBackward(const ColorSpinorField &V, const cvector_ref &B); + void BlockTransposeBackward(const ColorSpinorField &V, const cvector_ref &B, bool to_non_rel = false); /** @brief Apply the prolongation operator diff --git a/lib/block_transpose.in.cu b/lib/block_transpose.in.cu index c29227b025..65c276d0bb 100644 --- a/lib/block_transpose.in.cu +++ b/lib/block_transpose.in.cu @@ -23,14 +23,15 @@ namespace quda { using real = typename mapper::type; - template - using Arg = BlockTransposeArg; + template + using Arg = BlockTransposeArg; v_t &V; cvector_ref &B; + bool from_to_non_rel; public: - BlockTranspose(v_t &V, cvector_ref &B) : TunableKernel2D(V, B.size()), V(V), B(B) + BlockTranspose(v_t &V, cvector_ref &B, bool from_to_non_rel_) : TunableKernel2D(V, B.size()), V(V), B(B), from_to_non_rel(from_to_non_rel_) { if constexpr (std::is_const_v) { strcat(aux, ",v2b"); @@ -38,6 +39,7 @@ namespace quda strcat(aux, ",b2v"); } setRHSstring(aux, B.size()); + if (from_to_non_rel) { strcat(aux, ",from_to_non_rel"); } resizeStep(1); apply(device::get_default_stream()); } @@ -57,9 +59,19 @@ namespace quda template void launch_device_(TuneParam &tp, const qudaStream_t &stream) { - Arg arg(V, B, tp.block.x, tp.block.y); - tp.set_max_shared_bytes = true; - launch_device(tp, stream, arg); + if (from_to_non_rel) { + if (nSpin == 4) { + Arg arg(V, B, tp.block.x, tp.block.y); + tp.set_max_shared_bytes = true; + launch_device(tp, stream, arg); + } else { + errorQuda("from_to_non_rel is only defined for nSpin(=%d) == 4.", nSpin); + } + } else { + Arg arg(V, B, tp.block.x, tp.block.y); + tp.set_max_shared_bytes = true; + launch_device(tp, stream, arg); + } } void apply(const qudaStream_t &stream) @@ -103,82 +115,70 @@ namespace quda }; template - void launch_span_nVec(v_t &V, cvector_ref &B, IntList) + void launch_span_nVec(v_t &V, cvector_ref &B, bool from_to_non_rel, IntList) { if (V.Nvec() == nVec) { - impl::BlockTranspose transpose(V, B); + impl::BlockTranspose transpose(V, B, from_to_non_rel); } else { - IntList nVecs; + IntList nVecs_remaining; if constexpr (sizeof...(N) > 0) { - launch_span_nVec(V, B, nVecs); + launch_span_nVec(V, B, from_to_non_rel, nVecs_remaining); } else { errorQuda("nVec = %d not instantiated\n", V.Nvec()); } } } - template - void block_transpose(v_t &V, cvector_ref &B) - { - IntList<@QUDA_MULTIGRID_MRHS_LIST@> nVecs; - launch_span_nVec(V, B, nVecs); - } - template - void launch_span_nColor(v_t &V, cvector_ref &B, IntList) + void launch_span_nColor(v_t &V, cvector_ref &B, bool from_to_non_rel, IntList) { if (B[0].Ncolor() == nColor) { - block_transpose(V, B); + IntList<@QUDA_MULTIGRID_MRHS_LIST@> nVecs; + launch_span_nVec(V, B, from_to_non_rel, nVecs); } else { - IntList nVecs; + IntList nColors_remaining; if constexpr (sizeof...(N) > 0) { - launch_span_nColor(V, B, nVecs); + launch_span_nColor(V, B, from_to_non_rel, nColors_remaining); } else { errorQuda("nColor = %d not instantiated", B.Ncolor()); } } } - template - void block_transpose(v_t &V, cvector_ref &B) - { - if (V.Ncolor() / V.Nvec() != B[0].Ncolor()) { - errorQuda("V.Ncolor() / V.Nvec() (=%d) != B.Ncolor() (=%d)", V.Ncolor() / V.Nvec(), B[0].Ncolor()); - } - - IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> nColors; - launch_span_nColor(V, B, nColors); - } - - template void block_transpose(v_t &V, cvector_ref &B) + template + void launch_span_nSpin(v_t &V, cvector_ref &B, bool from_to_non_rel, IntList) { if (V.Nspin() != B[0].Nspin()) { errorQuda("V.Nspin() (=%d) != B.Nspin() (=%d)", V.Nspin(), B[0].Nspin()); } - if (V.Nspin() == 2) { - block_transpose(V, B); - } else if (V.Nspin() == 4) { - block_transpose(V, B); - } else if (V.Nspin() == 1) { - block_transpose(V, B); + if (V.Nspin() == nSpin) { + IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> nColors; + launch_span_nColor(V, B, from_to_non_rel, nColors); } else { - errorQuda("Unexpected nSpin = %d", V.Nspin()); + if constexpr (sizeof...(N) > 0) { + IntList nSpins_remaining; + launch_span_nSpin(V, B, from_to_non_rel, nSpins_remaining); + } else { + errorQuda("Unexpected nSpin = %d", V.Nspin()); + } } } - template void block_transpose(v_t &V, cvector_ref &B) + template void block_transpose(v_t &V, cvector_ref &B, bool from_to_non_rel) { if (!is_enabled(V.Precision()) || !is_enabled(B[0].Precision())) errorQuda("QUDA_PRECISION=%d does not enable required precision combination (V = %d B = %d)", QUDA_PRECISION, V.Precision(), B[0].Precision()); + IntList<1, 2, 4> nSpins; + if constexpr (is_enabled_multigrid()) { if (V.Precision() == QUDA_DOUBLE_PRECISION && B[0].Precision() == QUDA_DOUBLE_PRECISION) { if constexpr (is_enabled_multigrid_double()) - block_transpose(V, B); + launch_span_nSpin(V, B, from_to_non_rel, nSpins); else errorQuda("Double precision multigrid has not been enabled"); } else if (V.Precision() == QUDA_SINGLE_PRECISION && B[0].Precision() == QUDA_SINGLE_PRECISION) { - if constexpr (is_enabled(QUDA_SINGLE_PRECISION)) block_transpose(V, B); + if constexpr (is_enabled(QUDA_SINGLE_PRECISION)) launch_span_nSpin(V, B, from_to_non_rel, nSpins); } else { errorQuda("Unsupported precision combination V=%d B=%d", V.Precision(), B[0].Precision()); } @@ -187,8 +187,8 @@ namespace quda } } - void BlockTransposeForward(ColorSpinorField &V, cvector_ref &B) { block_transpose(V, B); } + void BlockTransposeForward(ColorSpinorField &V, cvector_ref &B, bool from_non_rel) { block_transpose(V, B, from_non_rel); } - void BlockTransposeBackward(const ColorSpinorField &V, cvector_ref &B) { block_transpose(V, B); } + void BlockTransposeBackward(const ColorSpinorField &V, cvector_ref &B, bool to_non_rel) { block_transpose(V, B, to_non_rel); } } // namespace quda diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 0947a72e84..ff96b1da47 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -58,7 +58,8 @@ namespace quda IntList<@QUDA_MULTIGRID_MRHS_LIST@> nvecs; ProlongateMma2(v_out, v_in, V, fine_to_coarse, spin_map, parity, nvecs); - BlockTransposeBackward(v_out, out); + bool to_non_rel = (out.Nspin() == 4) && (out[0].GammaBasis() == QUDA_UKQCD_GAMMA_BASIS); + BlockTransposeBackward(v_out, out, to_non_rel); #if 0 std::vector v_cmp(out.size()); for (size_t i = 0; i < out.size(); i++) { diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 1eba273ed1..4ee78ca55c 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -91,7 +91,8 @@ namespace quda strcat(vol, out.VolString().c_str()); strcat(aux, ","); strcat(aux, out.AuxString().c_str()); - if (out.GammaBasis() == QUDA_UKQCD_GAMMA_BASIS) strcat(aux, ",to_non_rel"); + + strcat(aux, mma_t::get_type_name().c_str()); apply(device::get_default_stream()); } @@ -100,8 +101,10 @@ namespace quda // using mma_t = smma::smma_t; // 3xTF32 using mma_t = typename mma::smma_dispatch::type; + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + static constexpr int m = nVec; - static constexpr int n = fineColor; + static constexpr int n = fineColor * spin_block_factor; static constexpr int k = coarseColor; static constexpr int n_atom_size = mma_t::MMA_N; @@ -121,8 +124,7 @@ namespace quda static constexpr int shared_bytes_per_block(int bM, int bN, int bK) { - return mma::shared_memory_bytes(bM, bN, bK) + (bM + 4) * (bK + 4) * 2 * sizeof(vFloat) - + (bK + 4) * (bN + 4) * 2 * sizeof(Float); + return mma::shared_memory_bytes(bM, bN, bK); } bool set_mma_param(TuneParam &tp) const @@ -137,7 +139,7 @@ namespace quda int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; int bM = m_atom_size * get_int_factor_array((m + m_atom_size - 1) / m_atom_size)[tp.aux.z]; - tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin, (m + bM - 1) / bM, (n + bN - 1) / bN); + tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin / spin_block_factor, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; int bK = k_atom_size * get_int_factor_array(k / k_atom_size)[tp.aux.w]; From efa9c15dc99fad7a4b570b4835fcde82f9f2485f Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Fri, 30 Aug 2024 07:40:16 -0700 Subject: [PATCH 05/79] More cleanup of the MMA code. Apply vector gmem loads when possible. --- include/targets/cuda/mma_tensor_op/gemm.cuh | 436 +----------- .../cuda/mma_tensor_op/gmem_loader.cuh | 621 ++++++++++++++++++ lib/prolongator_mma.in.cu | 39 +- 3 files changed, 647 insertions(+), 449 deletions(-) create mode 100644 include/targets/cuda/mma_tensor_op/gmem_loader.cuh diff --git a/include/targets/cuda/mma_tensor_op/gemm.cuh b/include/targets/cuda/mma_tensor_op/gemm.cuh index 334675d495..4662011576 100644 --- a/include/targets/cuda/mma_tensor_op/gemm.cuh +++ b/include/targets/cuda/mma_tensor_op/gemm.cuh @@ -3,11 +3,9 @@ #include #include #include -#include +#include #include -#include - namespace quda { namespace mma @@ -18,178 +16,6 @@ namespace quda return (bM + mma_t::pad_size(bM) + bN + mma_t::pad_size(bN)) * bK * 2 * sizeof(typename mma_t::compute_t); } - /** - @brief Defining how many elements/atoms are there in type T ... - */ - template struct batch_multiple { - }; - - /** - @brief ... e.g. there are 2 half's in a half2 - */ - template <> struct batch_multiple { - static constexpr int value = 2; - }; - - template <> struct batch_multiple { - static constexpr int value = 1; - }; - - inline __device__ void zero(half2 ®_real, half2 ®_imag) - { - reg_real = __half2half2(0); - reg_imag = __half2half2(0); - } - - inline __device__ void zero(float ®_real, float ®_imag) - { - reg_real = 0; - reg_imag = 0; - } - - inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } - - /** - @brief Load from global memory and store data in registers. - */ - template - inline __device__ void convert_x(half2 ®_real, half2 ®_imag, complex *p, int m_idx, int n_idx, - float scale_inv) - { - if (x) { - auto xx = p[(m_idx + 0) * ld + n_idx]; - auto yy = p[(m_idx + 1) * ld + n_idx]; - - if (fixed) { - reg_real = __floats2half2_rn(scale_inv * xx.real(), scale_inv * yy.real()); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = __floats2half2_rn(scale_inv_conj * xx.imag(), scale_inv_conj * yy.imag()); - } else { - reg_real = __floats2half2_rn(+xx.real(), +yy.real()); - reg_imag = __floats2half2_rn(dagger ? -xx.imag() : +xx.imag(), dagger ? -yy.imag() : +yy.imag()); - } - } else { - using store_type = T; - using store_array = typename VectorType::type; - store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); - - if (fixed) { - reg_real = __floats2half2_rn(scale_inv * v.x, scale_inv * v.z); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = __floats2half2_rn(scale_inv_conj * v.y, scale_inv_conj * v.w); - } else { - reg_real = __floats2half2_rn(+v.x, +v.z); - reg_imag = __floats2half2_rn(dagger ? -v.y : +v.y, dagger ? -v.w : +v.w); - } - } - } - - /** - @brief Load from global memory and store data in registers while also applying a rescaling - */ - template - inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, - float scale_inv, float rescale) - { - if (x) { - auto xx = p[m_idx * ld + n_idx]; - - if (fixed) { - reg_real = scale_inv * xx.real() * rescale; - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag() * rescale; - } else { - reg_real = +xx.real() * rescale; - reg_imag = (dagger ? -xx.imag() : +xx.imag()) * rescale; - } - } else { - auto xx = p[n_idx * ld + m_idx]; - using store_type = T; - using store_array = typename VectorType::type; - store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); - - if (fixed) { - reg_real = scale_inv * xx.real() * rescale; - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag() * rescale; - } else { - reg_real = xx.real() * rescale; - reg_imag = (dagger ? -xx.imag() : xx.imag()) * rescale; - } - } - } - - /** - @brief Load from global memory and store data in registers. - */ - template - inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, float scale_inv) - { - float this_max = 0.0f; - - if (x) { - auto xx = p[m_idx * ld + n_idx]; - - if (fixed) { - this_max = abs_max(scale_inv * xx.real(), this_max); - this_max = abs_max(scale_inv * xx.imag(), this_max); - } else { - this_max = abs_max(xx.real(), this_max); - this_max = abs_max(xx.imag(), this_max); - } - } else { - auto xx = p[n_idx * ld + m_idx]; - using store_type = T; - using store_array = typename VectorType::type; - store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); - - if (fixed) { - this_max = abs_max(scale_inv * xx.real(), this_max); - this_max = abs_max(scale_inv * xx.imag(), this_max); - } else { - this_max = abs_max(xx.real(), this_max); - this_max = abs_max(xx.imag(), this_max); - } - } - - return this_max; - } - - /** - @brief Load from global memory and store data in registers. - */ - template - inline __device__ void convert_x(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, - float scale_inv) - { - if (x) { - auto xx = p[m_idx * ld + n_idx]; - - if (fixed) { - reg_real = scale_inv * xx.real(); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag(); - } else { - reg_real = +xx.real(); - reg_imag = dagger ? -xx.imag() : +xx.imag(); - } - } else { - auto xx = p[n_idx * ld + m_idx]; - using store_type = T; - using store_array = typename VectorType::type; - store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); - - if (fixed) { - reg_real = scale_inv * xx.real(); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag(); - } else { - reg_real = xx.real(); - reg_imag = dagger ? -xx.imag() : xx.imag(); - } - } - } - // A shared memory object that bakes with it a 2-d index access method. template struct SharedMemoryObject { @@ -219,266 +45,6 @@ namespace quda return SharedMemoryObject {ptr_}; } - /** - * A loader object that loads data from global memory to registers (g2r), and then to shared memory (r2s) - * M, N: the global memory matrix size, for bound check only - * bM, bN: the shared memory matrix size - * block_y, block_z: CTA dimension in the y and z directions - * transpose: the global memory always assumes a column-major order, transpose = true if the destination - shared memory is row-major. - */ - template - struct GlobalMemoryLoader { - - static constexpr int batch = batch_multiple::value; - - static constexpr int m_stride_n = block_y * batch; - static constexpr int n_stride_n = block_z * 1; - static constexpr int m_stride_t = block_z * batch; - static constexpr int n_stride_t = block_y * 1; - - static constexpr int register_count - = std::max(((bN + n_stride_n - 1) / n_stride_n) * ((bM + m_stride_n - 1) / m_stride_n), - ((bN + n_stride_t - 1) / n_stride_t) * ((bM + m_stride_t - 1) / m_stride_t)); - - load_t reg_real[register_count]; - load_t reg_imag[register_count]; - - template - __device__ inline float g2tmp(const gmem_accessor_t &gmem, int m_offset, int n_offset, complex *smem_ptr, - pipeline_t &pipe) - { - auto p = gmem.data(); - - int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; - constexpr int element_per_thread = 16 / (sizeof(T) * 2); - while (thread_id * element_per_thread < bM * bN) { - if (transpose != dagger) { - int m = element_per_thread * (thread_id % (bM / element_per_thread)); - int n = thread_id / (bM / element_per_thread); - auto dst_ptr = reinterpret_cast(&smem_ptr[n * (bM + 4) + m]); - auto src_ptr = reinterpret_cast(&p[(n + n_offset) * ld + m + m_offset]); - memcpy_async(dst_ptr, src_ptr, sizeof(float4), pipe); - } else { - int m = thread_id / (bN / element_per_thread); - int n = element_per_thread * (thread_id % (bN / element_per_thread)); - auto dst_ptr = reinterpret_cast(&smem_ptr[m * (bN + 4) + n]); - auto src_ptr = reinterpret_cast(&p[(m + m_offset) * ld + n + n_offset]); - memcpy_async(dst_ptr, src_ptr, sizeof(float4), pipe); - } - thread_id += blockDim.x * blockDim.y * blockDim.z; - } - return gmem.get_scale_inv(); - } - - template - __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, - smem_accessor_t &smem_imag) - { - // for each iteration, each warp loads a tile - int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; - int warp_id = thread_id / 32; - int lane_id = thread_id % 32; - int thread_in_group = lane_id % 4; - int group_id = lane_id / 4; - constexpr int w_m = 8 * batch; - constexpr int w_k = 4; - static_assert(bM % w_m == 0, "bM %% w_m"); - static_assert(bN % w_k == 0, "bN %% w_k"); - - constexpr int tile_dim_m = bM / w_m; - constexpr int tile_dim_k = bN / w_k; - - constexpr int total_tiles = tile_dim_k * tile_dim_m; - constexpr int n_warp = block_y * block_z / 32; - constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; - - float thread_max = 0.0f; - -#pragma unroll - for (int c = 0; c < warp_cycle; c++) { - int logical_warp_index = c * n_warp + warp_id; - if (logical_warp_index < total_tiles) { - int warp_m = (c * n_warp + warp_id) % tile_dim_m; - int warp_k = (c * n_warp + warp_id) / tile_dim_m; - - int smem_m_offset = warp_m * w_m + group_id * batch; - int smem_k_offset = warp_k * w_k + thread_in_group; - - int gmem_m_offset = smem_m_offset; - int gmem_k_offset = smem_k_offset; - - constexpr bool x = (transpose == dagger); - float this_max = find_abs_max < x, fixed, dagger, - x ? bN + 4 : bM + 4 > (smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); - thread_max = fmaxf(this_max, thread_max); - } - } - - // block all-reduce thread_max - using block_reduce_t = cub::BlockReduce; - __shared__ typename block_reduce_t::TempStorage temp_storage; - float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); - - __shared__ float block_max_all; - if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { - if (block_max > 0.0f) { - block_max_all = block_max; - } else { - block_max_all = 1.0f; - } - } - __syncthreads(); - - float block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number - -#pragma unroll - for (int c = 0; c < warp_cycle; c++) { - int logical_warp_index = c * n_warp + warp_id; - if (logical_warp_index < total_tiles) { - int warp_m = (c * n_warp + warp_id) % tile_dim_m; - int warp_k = (c * n_warp + warp_id) / tile_dim_m; - - int smem_m_offset = warp_m * w_m + group_id * batch; - int smem_k_offset = warp_k * w_k + thread_in_group; - - int gmem_m_offset = smem_m_offset; - int gmem_k_offset = smem_k_offset; - - load_t real; - load_t imag; - - constexpr bool x = (transpose == dagger); - convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv, block_rescale_factor); - smem_real.vector_load(smem_m_offset, smem_k_offset, real); - smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); - } - } - - return 1.0f / block_rescale_factor; - } - - template - __device__ inline void tmp2s(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, - smem_accessor_t &smem_imag) - { - // for each iteration, each warp loads a tile - int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; - int warp_id = thread_id / 32; - int lane_id = thread_id % 32; - int thread_in_group = lane_id % 4; - int group_id = lane_id / 4; - constexpr int w_m = 8 * batch; - constexpr int w_k = 4; - static_assert(bM % w_m == 0, "bM %% w_m"); - static_assert(bN % w_k == 0, "bN %% w_k"); - - constexpr int tile_dim_m = bM / w_m; - constexpr int tile_dim_k = bN / w_k; - - constexpr int total_tiles = tile_dim_k * tile_dim_m; - constexpr int n_warp = block_y * block_z / 32; - constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; -#pragma unroll - for (int c = 0; c < warp_cycle; c++) { - int logical_warp_index = c * n_warp + warp_id; - if (logical_warp_index < total_tiles) { - int warp_m = (c * n_warp + warp_id) % tile_dim_m; - int warp_k = (c * n_warp + warp_id) / tile_dim_m; - - int smem_m_offset = warp_m * w_m + group_id * batch; - int smem_k_offset = warp_k * w_k + thread_in_group; - - int gmem_m_offset = smem_m_offset; - int gmem_k_offset = smem_k_offset; - - load_t real; - load_t imag; - - constexpr bool x = (transpose == dagger); - convert_x(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv); - smem_real.vector_load(smem_m_offset, smem_k_offset, real); - smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); - } - } - } - - /** - * ld: leading dimension of global memory - * dagger: if we need to store daggered (tranpose and hermision conjugate) - */ - template - __device__ inline void g2r(const GmemAccessor &gmem, int m_offset, int n_offset) - { - auto p = gmem.data(); - auto scale_inv = gmem.get_scale_inv(); - constexpr bool fixed = GmemAccessor::fixed; - - constexpr bool x = (transpose == dagger); - - constexpr int n_stride = x ? block_y * 1 : block_z * 1; - constexpr int m_stride = x ? block_z * batch : block_y * batch; - int n_thread_offset = x ? threadIdx.y * 1 : threadIdx.z * 1; - int m_thread_offset = x ? threadIdx.z * batch : threadIdx.y * batch; - - constexpr int n_dim = (bN + n_stride - 1) / n_stride; - constexpr int m_dim = (bM + m_stride - 1) / m_stride; - - constexpr bool check_global_bound = !(M % bM == 0 && N % bN == 0); - constexpr bool check_shared_bound = !(bM % m_stride == 0 && bN % n_stride == 0); - -#pragma unroll - for (int n = 0; n < n_dim; n++) { - -#pragma unroll - for (int m = 0; m < m_dim; m++) { - - int n_idx_blk = n * n_stride + n_thread_offset; - int m_idx_blk = m * m_stride + m_thread_offset; - - if (!check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN)) { - - int n_idx = n_idx_blk + n_offset; - int m_idx = m_idx_blk + m_offset; - - if (!check_global_bound || (n_idx < N && m_idx < M)) { - convert_x(reg_real[m * n_dim + n], reg_imag[m * n_dim + n], p, m_idx, n_idx, - scale_inv); - } else { - zero(reg_real[m * n_dim + n], reg_imag[m * n_dim + n]); - } - } - } - } - } - - template __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) - { - constexpr int n_stride = transpose == dagger ? block_y * 1 : block_z * 1; - constexpr int m_stride = transpose == dagger ? block_z * batch : block_y * batch; - int n_thread_offset = transpose == dagger ? threadIdx.y * 1 : threadIdx.z * 1; - int m_thread_offset = transpose == dagger ? threadIdx.z * batch : threadIdx.y * batch; - - constexpr int n_dim = (bN + n_stride - 1) / n_stride; - constexpr int m_dim = (bM + m_stride - 1) / m_stride; - -#pragma unroll - for (int n = 0; n < n_dim; n++) { -#pragma unroll - for (int m = 0; m < m_dim; m++) { - const int n_idx = n * n_stride + n_thread_offset; - const int m_idx = m * m_stride + m_thread_offset; - if (m_idx < bM && n_idx < bN) { - smem_real.vector_load(m_idx, n_idx, reg_real[m * n_dim + n]); - smem_imag.vector_load(m_idx, n_idx, reg_imag[m * n_dim + n]); - } - } - } - } - }; - /** * Perform the complex GEMM * @param m, n, k the corresponding offset in the M, N, and K direction diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh new file mode 100644 index 0000000000..61b9fe93d6 --- /dev/null +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -0,0 +1,621 @@ +#pragma once + +#include +#include + +namespace quda +{ + namespace mma + { + + /** + @brief Defining how many elements/atoms are there in type T ... + */ + template struct batch_multiple { + }; + + /** + @brief ... e.g. there are 2 half's in a half2 + */ + template <> struct batch_multiple { + static constexpr int value = 2; + }; + + template <> struct batch_multiple { + static constexpr int value = 1; + }; + + inline __device__ void zero(half2 ®_real, half2 ®_imag) + { + reg_real = __half2half2(0); + reg_imag = __half2half2(0); + } + + inline __device__ void zero(float ®_real, float ®_imag) + { + reg_real = 0; + reg_imag = 0; + } + + inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } + + template struct batch_load_t { + }; + + template <> struct batch_load_t, 1> { + static void __device__ load(complex v[1], complex *ptr) { v[0] = *ptr; } + }; + + template <> struct batch_load_t, 2> { + static void __device__ load(complex v[2], complex *ptr) + { + float4 l = *reinterpret_cast(ptr); + v[0].real(l.x); + v[0].imag(l.y); + v[1].real(l.z); + v[1].imag(l.w); + } + }; + + template <> struct batch_load_t, 1> { + static void __device__ load(complex v[1], complex *ptr) { v[0] = *ptr; } + }; + + template <> struct batch_load_t, 2> { + static void __device__ load(complex v[2], complex *ptr) + { + short4 l = *reinterpret_cast(ptr); + v[0].real(l.x); + v[0].imag(l.y); + v[1].real(l.z); + v[1].imag(l.w); + } + }; + + template <> struct batch_load_t, 4> { + static void __device__ load(complex v[4], complex *ptr) + { + short8 l = *reinterpret_cast(ptr); + v[0].real(l.x.x); + v[0].imag(l.x.y); + v[1].real(l.x.z); + v[1].imag(l.x.w); + v[2].real(l.y.x); + v[2].imag(l.y.y); + v[3].real(l.y.z); + v[3].imag(l.y.w); + } + }; + + template struct make_vector_t { + }; + + template <> struct make_vector_t { + static auto __device__ get(float v[]) { return v[0]; } + }; + + template <> struct make_vector_t { + static auto __device__ get(float v[]) + { + float2 ret_value; + ret_value.x = v[0]; + ret_value.y = v[1]; + return ret_value; + } + }; + + template <> struct make_vector_t { + static auto __device__ get(float v[]) + { + float4 ret_value; + ret_value.x = v[0]; + ret_value.y = v[1]; + ret_value.z = v[2]; + ret_value.w = v[3]; + return ret_value; + } + }; + + template <> struct make_vector_t { + static auto __device__ get(half2 v[]) { return v[0]; } + }; + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ void convert_x(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, int n_idx, + float scale_inv) + { + static_assert(batch == 1, "for half2, for now, batch needs to be 1"); + if (x) { + auto xx = p[(m_idx + 0) * ld + n_idx]; + auto yy = p[(m_idx + 1) * ld + n_idx]; + + if (fixed) { + reg_real[0] = __floats2half2_rn(scale_inv * xx.real(), scale_inv * yy.real()); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[0] = __floats2half2_rn(scale_inv_conj * xx.imag(), scale_inv_conj * yy.imag()); + } else { + reg_real[0] = __floats2half2_rn(+xx.real(), +yy.real()); + reg_imag[0] = __floats2half2_rn(dagger ? -xx.imag() : +xx.imag(), dagger ? -yy.imag() : +yy.imag()); + } + } else { + using store_type = T; + using store_array = typename VectorType::type; + store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + + if (fixed) { + reg_real[0] = __floats2half2_rn(scale_inv * v.x, scale_inv * v.z); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[0] = __floats2half2_rn(scale_inv_conj * v.y, scale_inv_conj * v.w); + } else { + reg_real[0] = __floats2half2_rn(+v.x, +v.z); + reg_imag[0] = __floats2half2_rn(dagger ? -v.y : +v.y, dagger ? -v.w : +v.w); + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ void convert_x(float reg_real[batch], float reg_imag[batch], complex *p, int m_idx, int n_idx, + float scale_inv) + { + complex v[batch]; + if constexpr (x) { + batch_load_t, batch>::load(v, &p[m_idx * ld + n_idx]); +#pragma unroll + for (int b = 0; b < batch; b++) { + // auto xx = p[m_idx * ld + n_idx]; + if (fixed) { + reg_real[b] = scale_inv * v[b].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = scale_inv_conj * v[b].imag(); + } else { + reg_real[b] = v[b].real(); + reg_imag[b] = dagger ? -v[b].imag() : v[b].imag(); + } + } + } else { + complex v[batch]; + batch_load_t, batch>::load(v, &p[n_idx * ld + m_idx]); +#pragma unroll + for (int b = 0; b < batch; b++) { + // auto xx = p[n_idx * ld + m_idx]; + if (fixed) { + reg_real[b] = scale_inv * v[b].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = scale_inv_conj * v[b].imag(); + } else { + reg_real[b] = v[b].real(); + reg_imag[b] = dagger ? -v[b].imag() : v[b].imag(); + } + } + } + } + + /** + @brief Load from global memory and store data in registers while also applying a rescaling + */ + template + inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, + float scale_inv, float rescale) + { + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = +xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : +xx.imag()) * rescale; + } + } else { + auto xx = p[n_idx * ld + m_idx]; + + if (fixed) { + reg_real = scale_inv * xx.real() * rescale; + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag = scale_inv_conj * xx.imag() * rescale; + } else { + reg_real = xx.real() * rescale; + reg_imag = (dagger ? -xx.imag() : xx.imag()) * rescale; + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, float scale_inv) + { + float this_max = 0.0f; + + if (x) { + auto xx = p[m_idx * ld + n_idx]; + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + } else { + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + } + } else { + auto xx = p[n_idx * ld + m_idx]; + + if (fixed) { + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + } else { + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + } + } + + return this_max; + } + + /** + * A loader object that loads data from global memory to registers (g2r), and then to shared memory (r2s) + * M, N: the global memory matrix size, for bound check only + * bM, bN: the shared memory matrix size + * block_y, block_z: CTA dimension in the y and z directions + * transpose: the global memory always assumes a column-major order, transpose = true if the destination + shared memory is row-major. + */ + template + struct GlobalMemoryLoader { + + static constexpr int batch = batch_multiple::value; + + static constexpr int m_stride_n = block_y * batch; + static constexpr int n_stride_n = block_z * 1; + static constexpr int m_stride_t = block_z * batch; + static constexpr int n_stride_t = block_y * 1; + + static constexpr int register_count + = std::max(((bN + n_stride_n - 1) / n_stride_n) * ((bM + m_stride_n - 1) / m_stride_n), + ((bN + n_stride_t - 1) / n_stride_t) * ((bM + m_stride_t - 1) / m_stride_t)); + + load_t reg_real[register_count]; + load_t reg_imag[register_count]; + + template + __device__ inline float g2tmp(const gmem_accessor_t &gmem, int m_offset, int n_offset, complex *smem_ptr, + pipeline_t &pipe) + { + auto p = gmem.data(); + + constexpr bool check_bounds = !(M % bM == 0 && N % bN == 0); + + int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; + constexpr int element_per_thread = 16 / (sizeof(T) * 2); + while (thread_id * element_per_thread < bM * bN) { + if (transpose != dagger) { + int m = element_per_thread * (thread_id % (bM / element_per_thread)); + int n = thread_id / (bM / element_per_thread); + if (!check_bounds || (n + n_offset < N && m + m_offset < M)) { + auto dst_ptr = reinterpret_cast(&smem_ptr[n * (bM + 4) + m]); + auto src_ptr = reinterpret_cast(&p[(n + n_offset) * ld + m + m_offset]); + memcpy_async(dst_ptr, src_ptr, sizeof(float4), pipe); + } + } else { + int m = thread_id / (bN / element_per_thread); + int n = element_per_thread * (thread_id % (bN / element_per_thread)); + if (!check_bounds || (n + n_offset < N && m + m_offset < M)) { + auto dst_ptr = reinterpret_cast(&smem_ptr[m * (bN + 4) + n]); + auto src_ptr = reinterpret_cast(&p[(m + m_offset) * ld + n + n_offset]); + memcpy_async(dst_ptr, src_ptr, sizeof(float4), pipe); + } + } + thread_id += blockDim.x * blockDim.y * blockDim.z; + } + return gmem.get_scale_inv(); + } + + template + __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, + smem_accessor_t &smem_imag) + { + static_assert(batch == 1, "For now batch needs to be 1 for the rescale kernel."); + + // for each iteration, each warp loads a tile + int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; + int warp_id = thread_id / 32; + int lane_id = thread_id % 32; + int thread_in_group = lane_id % 4; + int group_id = lane_id / 4; + constexpr int w_m = 8 * batch; + constexpr int w_k = 4; + static_assert(bM % w_m == 0, "bM %% w_m"); + static_assert(bN % w_k == 0, "bN %% w_k"); + + constexpr int tile_dim_m = bM / w_m; + constexpr int tile_dim_k = bN / w_k; + + constexpr int total_tiles = tile_dim_k * tile_dim_m; + constexpr int n_warp = block_y * block_z / 32; + constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; + + float thread_max = 0.0f; + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + constexpr bool x = (transpose == dagger); + float this_max + = find_abs_max(smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); + thread_max = fmaxf(this_max, thread_max); + } + } + + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); + + float block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + load_t real; + load_t imag; + + constexpr bool x = (transpose == dagger); + convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv, block_rescale_factor); + smem_real.vector_load(smem_m_offset, smem_k_offset, real); + smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); + } + } + + return 1.0f / block_rescale_factor; + } + + template + __device__ inline void tmp2s(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, + smem_accessor_t &smem_imag) + { + // for each iteration, each warp loads a tile + int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; + int warp_id = thread_id / 32; + int lane_id = thread_id % 32; + int thread_in_group = lane_id % 4; + int group_id = lane_id / 4; + constexpr int w_m = 8 * batch; + constexpr int w_k = 4; + static_assert(bM % w_m == 0, "bM %% w_m"); + static_assert(bN % w_k == 0, "bN %% w_k"); + + constexpr int tile_dim_m = bM / w_m; + constexpr int tile_dim_k = bN / w_k; + + constexpr int total_tiles = tile_dim_k * tile_dim_m; + constexpr int n_warp = block_y * block_z / 32; + constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; +#pragma unroll + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + load_t real; + load_t imag; + + constexpr bool x = (transpose == dagger); + convert_x(&real, &imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv); + smem_real.vector_load(smem_m_offset, smem_k_offset, real); + smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); + } + } + } + + /** + * ld: leading dimension of global memory + * dagger: if we need to store daggered (tranpose and hermision conjugate) + */ + template + __device__ inline void g2r(const GmemAccessor &gmem, int m_offset, int n_offset) + { + auto p = gmem.data(); + auto scale_inv = gmem.get_scale_inv(); + constexpr bool fixed = GmemAccessor::fixed; + + constexpr bool x = (transpose == dagger); + + constexpr int n_stride = x ? block_y : block_z; + constexpr int m_stride = x ? block_z * batch : block_y * batch; + int n_thread_offset = x ? threadIdx.y : threadIdx.z; + int m_thread_offset = x ? threadIdx.z * batch : threadIdx.y * batch; + + constexpr int n_dim = (bN + n_stride - 1) / n_stride; + constexpr int m_dim = (bM + m_stride - 1) / m_stride; + + constexpr bool check_global_bound = !(M % bM == 0 && N % bN == 0); + constexpr bool check_shared_bound = !(bM % m_stride == 0 && bN % n_stride == 0); + + if constexpr (x) { + constexpr int n_batch = (n_dim % 2 == 0 && batch == 1) ? 2 : 1; + static_assert(bN % n_batch == 0, "bN % n_batch == 0"); +#pragma unroll + for (int n = 0; n < n_dim / n_batch; n++) { + +#pragma unroll + for (int m = 0; m < m_dim; m++) { + + int n_idx_blk = (n * n_stride + n_thread_offset) * n_batch; + int m_idx_blk = m * m_stride + m_thread_offset; + + if (!check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN)) { + + int n_idx = n_idx_blk + n_offset; + int m_idx = m_idx_blk + m_offset; + + if (!check_global_bound || (n_idx < N && m_idx < M)) { + convert_x( + ®_real[m * n_dim + n * n_batch], ®_imag[m * n_dim + n * n_batch], p, m_idx, n_idx, scale_inv); + } else { +#pragma unroll + for (int b = 0; b < n_batch; b++) { + zero(reg_real[m * n_dim + n * n_batch + b], reg_imag[m * n_dim + n * n_batch + b]); + } + } + } + } + } + } else { + constexpr int m_batch = (m_dim % 2 == 0 && batch == 1) ? 2 : 1; + static_assert(bM % m_batch == 0, "bN % n_batch == 0"); +#pragma unroll + for (int n = 0; n < n_dim; n++) { + +#pragma unroll + for (int m = 0; m < m_dim / m_batch; m++) { + + int n_idx_blk = n * n_stride + n_thread_offset; + int m_idx_blk = (m * m_stride + m_thread_offset) * m_batch; + + if (!check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN)) { + + int n_idx = n_idx_blk + n_offset; + int m_idx = m_idx_blk + m_offset; + + if (!check_global_bound || (n_idx < N && m_idx < M)) { + load_t v_real[m_batch]; + load_t v_imag[m_batch]; + convert_x(v_real, v_imag, p, m_idx, n_idx, scale_inv); +#pragma unroll + for (int b = 0; b < m_batch; b++) { + reg_real[(m * m_batch + b) * n_dim + n] = v_real[b]; + reg_imag[(m * m_batch + b) * n_dim + n] = v_imag[b]; + } + } else { +#pragma unroll + for (int b = 0; b < m_batch; b++) { + zero(reg_real[(m * m_batch + b) * n_dim + n], reg_imag[(m * m_batch + b) * n_dim + n]); + } + } + } + } + } + } + } + + template __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) + { + constexpr bool x = (transpose == dagger); + + constexpr int n_stride = transpose == dagger ? block_y : block_z; + constexpr int m_stride = transpose == dagger ? block_z * batch : block_y * batch; + int n_thread_offset = transpose == dagger ? threadIdx.y : threadIdx.z; + int m_thread_offset = transpose == dagger ? threadIdx.z * batch : threadIdx.y * batch; + + constexpr int n_dim = (bN + n_stride - 1) / n_stride; + constexpr int m_dim = (bM + m_stride - 1) / m_stride; + + if constexpr (x) { + constexpr int n_batch = (n_dim % 2 == 0 && batch == 1) ? 2 : 1; +#pragma unroll + for (int n = 0; n < n_dim / n_batch; n++) { +#pragma unroll + for (int m = 0; m < m_dim; m++) { + const int n_idx = (n * n_stride + n_thread_offset) * n_batch; + const int m_idx = m * m_stride + m_thread_offset; + if (m_idx < bM && n_idx < bN) { + if constexpr (SmemObj::ldn == 1) { + smem_real.vector_load(m_idx, n_idx, + make_vector_t::get(®_real[m * n_dim + n * n_batch])); + smem_imag.vector_load(m_idx, n_idx, + make_vector_t::get(®_imag[m * n_dim + n * n_batch])); + } else { +#pragma unroll + for (int b = 0; b < n_batch; b++) { + smem_real.vector_load(m_idx, n_idx + b, reg_real[m * n_dim + n * n_batch + b]); + smem_imag.vector_load(m_idx, n_idx + b, reg_imag[m * n_dim + n * n_batch + b]); + } + } + } + } + } + } else { + constexpr int m_batch = (m_dim % 2 == 0 && batch == 1) ? 2 : 1; +#pragma unroll + for (int n = 0; n < n_dim; n++) { +#pragma unroll + for (int m = 0; m < m_dim / m_batch; m++) { + const int n_idx = n * n_stride + n_thread_offset; + const int m_idx = (m * m_stride + m_thread_offset) * m_batch; + if (m_idx < bM && n_idx < bN) { + if constexpr (SmemObj::ldm == 1) { + static_assert(SmemObj::ldm == 1, "SmemObj::ldm == 1"); + load_t v_real[m_batch]; + load_t v_imag[m_batch]; +#pragma unroll + for (int b = 0; b < m_batch; b++) { + v_real[b] = reg_real[(m * m_batch + b) * n_dim + n]; + v_imag[b] = reg_imag[(m * m_batch + b) * n_dim + n]; + } + smem_real.vector_load(m_idx, n_idx, make_vector_t::get(v_real)); + smem_imag.vector_load(m_idx, n_idx, make_vector_t::get(v_imag)); + } else { +#pragma unroll + for (int b = 0; b < m_batch; b++) { + smem_real.vector_load(m_idx + b, n_idx, reg_real[(m * m_batch + b) * n_dim + n]); + smem_imag.vector_load(m_idx + b, n_idx, reg_imag[(m * m_batch + b) * n_dim + n]); + } + } + } + } + } + } + } + }; + + } // namespace mma + +} // namespace quda diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 4ee78ca55c..6d82660257 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -34,19 +34,19 @@ namespace quda } }; - if (advancer(param.aux.x, 2)) { + if (advancer(param.aux.x, numFactors((k + block_atom_size - 1) / block_atom_size) - 1)) { return true; } else { param.aux.x = 0; - if (advancer(param.aux.y, numFactors(n / n_atom_size) - 1)) { + if (advancer(param.aux.y, numFactors((n + n_atom_size - 1) / n_atom_size) - 1)) { return true; } else { param.aux.y = 0; - if (advancer(param.aux.z, numFactors(m / m_atom_size) - 1)) { + if (advancer(param.aux.z, numFactors((m + m_atom_size - 1) / m_atom_size) - 1)) { return true; } else { param.aux.z = 0; - if (advancer(param.aux.w, numFactors(k / k_atom_size) - 1)) { + if (advancer(param.aux.w, numFactors((k + k_atom_size - 1) / k_atom_size) - 1)) { return true; } else { param.aux.w = 0; @@ -109,7 +109,8 @@ namespace quda static constexpr int n_atom_size = mma_t::MMA_N; static constexpr int m_atom_size = mma_t::MMA_M; - static constexpr int k_atom_size = k / 2; + static constexpr int k_atom_size = mma_t::MMA_K; + static constexpr int block_atom_size = 32 / 8; long long flops() const { @@ -133,7 +134,7 @@ namespace quda static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; - tp.block.y = k / (1 << tp.aux.x); + tp.block.y = block_atom_size * get_int_factor_array((k + block_atom_size - 1) / block_atom_size)[tp.aux.x]; tp.block.z = 8; int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; @@ -212,18 +213,28 @@ namespace quda } } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) + template + void launch_mma_span_block(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) { - std::make_index_sequence().size()> n_indices; - - switch (tp.aux.x) { - case 0: launch_mma_span_n(tp, stream, n_indices); break; - case 1: launch_mma_span_n(tp, stream, n_indices); break; - case 2: launch_mma_span_n(tp, stream, n_indices); break; - default: errorQuda("tp.aux.x = %d not supported", tp.aux.x); + if (tp.aux.x == d) { + constexpr IntFactorArray<(k + block_atom_size - 1) / block_atom_size> block_factors; + std::make_index_sequence().size()> n_indices; + launch_mma_span_n(tp, stream, n_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_block(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.x."); + } } } + void launch_mma(TuneParam &tp, const qudaStream_t &stream) + { + std::make_index_sequence().size()> block_indices; + launch_mma_span_block(tp, stream, block_indices); + } + void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); From 82d532b040426fb230ece63572f674ec00774120 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Fri, 30 Aug 2024 08:50:34 -0700 Subject: [PATCH 06/79] Apply more vector gmem loads when possible. --- .../cuda/mma_tensor_op/gmem_loader.cuh | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 61b9fe93d6..3e18e8bdcf 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -261,6 +261,14 @@ namespace quda return this_max; } + template constexpr int get_mn_batch(int internal_batch, int register_dim, int block) + { + return (internal_batch > 1) ? 1 : + ((register_dim % 4 == 0 && block % 4 == 0 && sizeof(T) * 8 <= 16) ? + 4 : + ((register_dim % 2 == 0 && block % 2 == 0 && sizeof(T) * 4 <= 16) ? 2 : 1)); + } + /** * A loader object that loads data from global memory to registers (g2r), and then to shared memory (r2s) * M, N: the global memory matrix size, for bound check only @@ -480,8 +488,7 @@ namespace quda constexpr bool check_shared_bound = !(bM % m_stride == 0 && bN % n_stride == 0); if constexpr (x) { - constexpr int n_batch = (n_dim % 2 == 0 && batch == 1) ? 2 : 1; - static_assert(bN % n_batch == 0, "bN % n_batch == 0"); + constexpr int n_batch = get_mn_batch(batch, n_dim, bN); #pragma unroll for (int n = 0; n < n_dim / n_batch; n++) { @@ -509,8 +516,7 @@ namespace quda } } } else { - constexpr int m_batch = (m_dim % 2 == 0 && batch == 1) ? 2 : 1; - static_assert(bM % m_batch == 0, "bN % n_batch == 0"); + constexpr int m_batch = get_mn_batch(batch, m_dim, bM); #pragma unroll for (int n = 0; n < n_dim; n++) { @@ -559,7 +565,7 @@ namespace quda constexpr int m_dim = (bM + m_stride - 1) / m_stride; if constexpr (x) { - constexpr int n_batch = (n_dim % 2 == 0 && batch == 1) ? 2 : 1; + constexpr int n_batch = get_mn_batch(batch, n_dim, bN); #pragma unroll for (int n = 0; n < n_dim / n_batch; n++) { #pragma unroll @@ -567,7 +573,7 @@ namespace quda const int n_idx = (n * n_stride + n_thread_offset) * n_batch; const int m_idx = m * m_stride + m_thread_offset; if (m_idx < bM && n_idx < bN) { - if constexpr (SmemObj::ldn == 1) { + if constexpr (SmemObj::ldn == 1 && SmemObj::ldm % n_batch == 0) { smem_real.vector_load(m_idx, n_idx, make_vector_t::get(®_real[m * n_dim + n * n_batch])); smem_imag.vector_load(m_idx, n_idx, @@ -583,7 +589,7 @@ namespace quda } } } else { - constexpr int m_batch = (m_dim % 2 == 0 && batch == 1) ? 2 : 1; + constexpr int m_batch = get_mn_batch(batch, m_dim, bM); #pragma unroll for (int n = 0; n < n_dim; n++) { #pragma unroll @@ -591,7 +597,7 @@ namespace quda const int n_idx = n * n_stride + n_thread_offset; const int m_idx = (m * m_stride + m_thread_offset) * m_batch; if (m_idx < bM && n_idx < bN) { - if constexpr (SmemObj::ldm == 1) { + if constexpr (SmemObj::ldm == 1 && SmemObj::ldn % m_batch == 0) { static_assert(SmemObj::ldm == 1, "SmemObj::ldm == 1"); load_t v_real[m_batch]; load_t v_imag[m_batch]; From e6a5d347ded81c138628e65ee3444d697df50aa9 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 4 Sep 2024 08:44:49 -0700 Subject: [PATCH 07/79] Add MMA version for restrictor. --- include/color_spinor_field_order.h | 2 + include/kernels/block_transpose.cuh | 4 + include/kernels/restrictor_mma.cuh | 219 +++++++++++++++++++++++ include/transfer.h | 4 + lib/CMakeLists.txt | 2 + lib/prolongator.in.cpp | 2 +- lib/restrictor.in.cpp | 85 ++++++++- lib/restrictor_mma.in.cu | 264 ++++++++++++++++++++++++++++ 8 files changed, 574 insertions(+), 8 deletions(-) create mode 100644 include/kernels/restrictor_mma.cuh create mode 100644 lib/restrictor_mma.in.cu diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 25f5234390..0764d01650 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -858,6 +858,8 @@ namespace quda static constexpr int nSpin = nSpin_; static constexpr int nColor = nColor_; + using store_type = storeFloat; + field v; unsigned int volumeCB = 0; diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 1a9d6ae949..bad990c80f 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -126,6 +126,10 @@ namespace quda int x_ = thread_idx / arg.block_y; if (x_ + x_offset < arg.volume_cb && v_ + v_offset < arg.actual_nvec) { color_spinor_t color_spinor = cache.load(x_, v_); + if constexpr (Arg::nSpin == 4 && Arg::from_to_non_rel) { + color_spinor.toRel(); + color_spinor *= rsqrt(static_cast(2.0)); + } #pragma unroll for (int spin = 0; spin < Arg::nSpin; spin++) { arg.V(parity, x_ + x_offset, spin, color, v_ + v_offset) = color_spinor(spin, 0); diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh new file mode 100644 index 0000000000..1f37350c48 --- /dev/null +++ b/include/kernels/restrictor_mma.cuh @@ -0,0 +1,219 @@ +#include +#include +#include +#include +#include + +namespace quda +{ + + using namespace quda::colorspinor; + + /** + Kernel argument struct + */ + template + struct RestrictMmaArg : kernel_param<> { + + static constexpr int block_dim = block_z_ * block_y_; + static constexpr int min_blocks = 1; + + using mma_t = mma_t_; + + using out_t = out_t_; + using real = out_t; + using in_t = in_t_; + using v_t = v_t_; + static constexpr int fineSpin = fineSpin_; + static constexpr int fineColor = fineColor_; + static constexpr int coarseSpin = coarseSpin_; + static constexpr int coarseColor = coarseColor_; + static constexpr int nVec = nVec_; + static constexpr int aggregate_size = aggregate_size_; + static constexpr int bN = bN_; + static constexpr int bM = bM_; + static constexpr int bK = bK_; + static constexpr int block_y = block_y_; + static constexpr int block_z = block_z_; + + static constexpr QudaFieldOrder csOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + + // disable ghost to reduce arg size + using in_accessor_t = FieldOrderCB::value>; + using out_accessor_t = FieldOrderCB; + using v_accessor_t = FieldOrderCB; + + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + static_assert(bK % (fineColor * spin_block_factor) == 0, "K %% Arg::bK != 0.\n"); + + static constexpr int aggregate_per_block = bK / (fineColor * spin_block_factor); + static_assert(aggregate_size % aggregate_per_block == 0, "aggregate_size %% aggregate_per_block"); + + out_accessor_t out; + in_accessor_t in; + const v_accessor_t v; + const int_fastdiv aggregate_size_cb; // number of checkerboard sites that form a single aggregate + const int *fine_to_coarse; + const int *coarse_to_fine; + const spin_mapper spin_map; + const int parity; // the parity of the input field (if single parity) + const int nParity; // number of parities of input fine field + + RestrictMmaArg(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, int parity) : + kernel_param(dim3(out.Volume() * coarseSpin, block_y, block_z)), + out(out), + in(in), + v(v), + aggregate_size_cb(in.VolumeCB() / out.Volume()), + fine_to_coarse(fine_to_coarse), + coarse_to_fine(coarse_to_fine), + spin_map(), + parity(parity), + nParity(in.SiteSubset()) + { + if (out.Nvec() > static_cast(get_max_multi_rhs())) + errorQuda("vector set size %d greater than max size %d", out.Nvec(), get_max_multi_rhs()); + if (out.Nvec() != nVec) { errorQuda("out.Nvec() (%d) != nVec (%d)", out.Nvec(), nVec); } + if (in.Nvec() != nVec) { errorQuda("in.Nvec() (%d) != nVec (%d)", in.Nvec(), nVec); } + } + }; + + template + inline void __device__ load_g2s(smem_obj_t &smem_real, smem_obj_t &smem_imag, const gmem_obj_t &gmem, int x_coarse, + int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, const Arg &arg) + { // v as a + constexpr int elements_per_thread = 16 / (sizeof(typename gmem_obj_t::store_type) * 2); + static_assert(contiguous_dim % elements_per_thread == 0, "contiguous_dim %% elements_per_thread == 0"); + int thread = target::thread_idx().y + Arg::block_y * target::thread_idx().z; + while (thread < (contiguous_dim / elements_per_thread) * Arg::spin_block_factor * Arg::fineColor + * Arg::aggregate_per_block) { + int thread_idx = thread; + int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; + thread_idx /= (contiguous_dim / elements_per_thread); + int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin + thread_idx /= Arg::spin_block_factor; + int fine_color = thread_idx % Arg::fineColor; + thread_idx /= Arg::fineColor; + int x_fine_offset = thread_idx + aggregate_k_offset; + + const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; + const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; + const int parity = arg.nParity == 2 ? parity_offset : arg.parity; + + // look-up map is ordered as (coarse-block-id + fine-point-id), + // with fine-point-id parity ordered + const int x_fine_site_id = (x_coarse * 2 + parity) * arg.aggregate_size_cb + x_fine_cb_offset; + const int x_fine = arg.coarse_to_fine[x_fine_site_id]; + const int x_fine_cb = x_fine - parity * arg.in.VolumeCB(); + + const int v_parity = (gmem.Nparity() == 2) ? parity : 0; + + int fine_spin = fine_spin_block + coarse_spin * Arg::spin_block_factor; + auto a_gmem = gmem(v_parity, x_fine_cb, fine_spin, fine_color, contiguous + contiguous_dim_offset); + complex a[elements_per_thread]; + mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); + + int smem_m = contiguous; + int smem_k = (thread_idx * Arg::spin_block_factor + fine_spin_block) * Arg::fineColor + fine_color; + + typename Arg::real a_real[elements_per_thread]; + typename Arg::real a_imag[elements_per_thread]; + if constexpr (decltype(a_gmem)::fixed) { + auto scale_inv = a_gmem.get_scale_inv(); +#pragma unroll + for (int e = 0; e < elements_per_thread; e++) { + a_real[e] = +a[e].real() * scale_inv; + a_imag[e] = dagger ? -a[e].imag() * scale_inv : +a[e].imag() * scale_inv; + } + } else { +#pragma unroll + for (int e = 0; e < elements_per_thread; e++) { + a_real[e] = +a[e].real(); + a_imag[e] = dagger ? -a[e].imag() : +a[e].imag(); + } + } + + static_assert(smem_obj_t::ldm == 1, "smem_obj_t::ldm == 1"); + smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_real)); + smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_imag)); + + thread += Arg::block_y * Arg::block_z; + } + } + + template + void __device__ inline restrict_mma(int x_coarse, int coarse_spin, int m_offset, int n_offset, const Arg &arg) + { + + constexpr int M = Arg::nVec; + constexpr int N = Arg::coarseColor; + constexpr int K = Arg::fineColor * Arg::spin_block_factor * Arg::aggregate_size; + + constexpr int ldc = M; + + using mma_t = typename Arg::mma_t; + // The first two ldc's are dummy + using Config = mma::MmaConfig; + + static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); + static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); + static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); + + extern __shared__ typename mma_t::compute_t smem_ptr[]; + + typename Config::SmemObjA smem_obj_a_real(smem_ptr); + typename Config::SmemObjA smem_obj_a_imag(smem_obj_a_real.ptr + Config::smem_lda * Arg::bK); + typename Config::SmemObjB smem_obj_b_real(smem_obj_a_imag.ptr + Config::smem_lda * Arg::bK); + typename Config::SmemObjB smem_obj_b_imag(smem_obj_b_real.ptr + Config::smem_ldb * Arg::bK); + + typename Config::ALoader a_loader; + typename Config::BLoader b_loader; + + typename Config::Accumulator accumulator((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x); + + accumulator.zero(); + + for (int aggregate_k_offset = 0; aggregate_k_offset < Arg::aggregate_size; + aggregate_k_offset += Arg::aggregate_per_block) { + __syncthreads(); + + constexpr bool a_dagger = true; + load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, m_offset, + aggregate_k_offset, arg); + + constexpr bool b_dagger = false; + load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, n_offset, + aggregate_k_offset, arg); + + __syncthreads(); + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } + + const int parity_coarse = x_coarse >= arg.out.VolumeCB() ? 1 : 0; + const int x_coarse_cb = x_coarse - parity_coarse * arg.out.VolumeCB(); + + auto c_gmem = arg.out(parity_coarse, x_coarse_cb, coarse_spin, 0, 0); + constexpr bool c_dagger = true; + accumulator.template store(c_gmem, m_offset, n_offset, assign_t()); + } + + template struct RestrictorMma { + const Arg &arg; + constexpr RestrictorMma(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ inline void operator()() + { + int coarse_spin = target::block_idx().x % Arg::coarseSpin; + int x_coarse = target::block_idx().x / Arg::coarseSpin; + + int m_offset = Arg::bM * target::block_idx().y; + int n_offset = Arg::bN * target::block_idx().z; + + restrict_mma(x_coarse, coarse_spin, m_offset, n_offset, arg); + } + }; + +} // namespace quda diff --git a/include/transfer.h b/include/transfer.h index f3fd8fcbe6..b816360f13 100644 --- a/include/transfer.h +++ b/include/transfer.h @@ -346,6 +346,10 @@ namespace quda { void Restrict(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity = QUDA_INVALID_PARITY); + template + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity = QUDA_INVALID_PARITY); + /** @brief Apply the unitary "prolongation" operator for Kahler-Dirac preconditioning @param[out] out Resulting fine grid field diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 414198bd78..ca70075d48 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -227,6 +227,8 @@ if(QUDA_MULTIGRID) foreach(QUDA_MULTIGRID_MRHS ${QUDA_MULTIGRID_MRHS_LIST_SEMICOLON}) list(PREPEND QUDA_CU_OBJS "prolongator_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu") configure_file(prolongator_mma.in.cu "prolongator_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu" @ONLY) + list(PREPEND QUDA_CU_OBJS "restrictor_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu") + configure_file(restrictor_mma.in.cu "restrictor_mma_${QUDA_MULTIGRID_NC_NVEC}_${QUDA_MULTIGRID_NVEC2}_nvec${QUDA_MULTIGRID_MRHS}.cu" @ONLY) endforeach() endif() endforeach() diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index ff96b1da47..47ace8a831 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -32,7 +32,7 @@ namespace quda return ColorSpinorField(param); } - auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) + static auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) { ColorSpinorParam param(f); param.create = QUDA_NULL_FIELD_CREATE; diff --git a/lib/restrictor.in.cpp b/lib/restrictor.in.cpp index fd8375ddc8..2d9062398d 100644 --- a/lib/restrictor.in.cpp +++ b/lib/restrictor.in.cpp @@ -6,26 +6,93 @@ namespace quda template struct IntList { }; - template + template + void RestrictMma2(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity, IntList) { + if (out.Nvec() == nVec) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + } else { + if constexpr (sizeof...(N) > 0) { + RestrictMma2(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, IntList()); + } else { + errorQuda("nVec = %d has not been instantiated", out.Nvec()); + } + } + } + + template auto create_color_spinor_copy(cvector_ref &fs, QudaFieldOrder order) + { + ColorSpinorParam param(fs[0]); + int nVec = (fs.size() + 7) / 8 * 8; // Make a multiple of 8 + param.nColor = fs[0].Ncolor() * nVec; + param.nVec = nVec; + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + + static auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) + { + ColorSpinorParam param(f); + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + + template void Restrict2(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity, IntList) { if (out[0].Ncolor() == coarseColor) { if constexpr (coarseColor >= fineColor) { - Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + if constexpr (use_mma) { + constexpr QudaFieldOrder csOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + ColorSpinorField v_in = create_color_spinor_copy(in, csOrder); + ColorSpinorField v_out = create_color_spinor_copy(out, csOrder); + ColorSpinorField V = create_color_spinor_copy(v, csOrder); + + bool from_non_rel = (in.Nspin() == 4) && (in[0].GammaBasis() == QUDA_UKQCD_GAMMA_BASIS); + BlockTransposeForward(v_in, in, from_non_rel); + + V.copy(v); + IntList<@QUDA_MULTIGRID_MRHS_LIST@> nvecs; + RestrictMma2(v_out, v_in, V, fine_to_coarse, coarse_to_fine, spin_map, parity, nvecs); + + BlockTransposeBackward(v_out, out); +#if 1 + std::vector v_cmp(out.size()); + for (size_t i = 0; i < out.size(); i++) { + ColorSpinorParam param(out[i]); + param.create = QUDA_NULL_FIELD_CREATE; + v_cmp[i] = ColorSpinorField(param); + } + auto vv_cmp = make_set(v_cmp); + Restrict(vv_cmp, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + + blas::mxpy(out, v_cmp); + auto vn = blas::norm2(vv_cmp); + printf("n = "); + for (size_t i = 0; i < vn.size(); i++) { + printf("%f ", vn[i]); + } + printf("\n"); +#endif + } else { + Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + } } else { errorQuda("Invalid coarseColor = %d, cannot be less than fineColor = %d", coarseColor, fineColor); } } else { if constexpr (sizeof...(N) > 0) { - Restrict2(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, IntList()); + Restrict2(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, IntList()); } else { errorQuda("Coarse Nc = %d has not been instantiated", out[0].Ncolor()); } } } - template + template void Restrict(cvector_ref &out, cvector_ref &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity, IntList) { @@ -33,10 +100,10 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NVEC_LIST@> coarseColors; // clang-format on - Restrict2(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, coarseColors); + Restrict2(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, coarseColors); } else { if constexpr (sizeof...(N) > 0) { - Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, IntList()); + Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, IntList()); } else { errorQuda("Fine Nc = %d has not been instantiated", in[0].Ncolor()); } @@ -55,7 +122,11 @@ namespace quda IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); + if (in[0].Ncolor() == 3) { + Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); + } else { + Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); + } } else { errorQuda("Multigrid has not been built"); } diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu new file mode 100644 index 0000000000..faa554f666 --- /dev/null +++ b/lib/restrictor_mma.in.cu @@ -0,0 +1,264 @@ +#include +#include +#include +#include +#include +#include + +namespace quda +{ + + template + class RestrictMmaLaunch : public TunableKernel + { + ColorSpinorField &out; + const ColorSpinorField ∈ + const ColorSpinorField &v; + const int *fine_to_coarse; + const int *coarse_to_fine; + const int parity; + + bool checkParam(const TuneParam ¶m) const { return true; } + + unsigned int sharedBytesPerThread() const { return 0; } + + bool advanceTuneParam(TuneParam &) const { return false; } + + void initTuneParam(TuneParam ¶m) const + { + param.aux.x = 0; + param.aux.y = 0; + param.aux.z = 0; + param.aux.w = 0; + set_mma_param(param); + } + + /** sets default values for when tuning is disabled */ + void defaultTuneParam(TuneParam ¶m) const + { + param.aux.x = 0; + param.aux.y = 0; + param.aux.z = 0; + param.aux.w = 0; + set_mma_param(param); + } + + public: + RestrictMmaLaunch(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, int parity) : + TunableKernel(in), + out(out), + in(in), + v(v), + fine_to_coarse(fine_to_coarse), + coarse_to_fine(coarse_to_fine), + parity(parity) + { + strcat(vol, ","); + strcat(vol, out.VolString().c_str()); + strcat(aux, ","); + strcat(aux, out.AuxString().c_str()); + setRHSstring(aux, in.Nvec()); + + apply(device::get_default_stream()); + } + + using mma_t = typename mma::smma_dispatch::type; + // using mma_t = simt::simt_t; + + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + + static constexpr int m = nVec; + static constexpr int n = coarseColor; + static constexpr int k = fineColor * spin_block_factor * aggregate_size; + + static constexpr int n_atom_size = mma_t::MMA_N; + static constexpr int m_atom_size = mma_t::MMA_M; + static constexpr int k_atom_size = fineColor * spin_block_factor * 4; + static constexpr int block_atom_size = 32 / 8; + + long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); } + + long long bytes() const + { + size_t v_bytes = v.Bytes() / (v.SiteSubset() == in.SiteSubset() ? 1 : 2); + return nVec * (in.Bytes() + out.Bytes() + v_bytes + in.SiteSubset() * in.VolumeCB() * sizeof(int)); + } + + static constexpr int shared_bytes_per_block(int bM, int bN, int bK) + { + return mma::shared_memory_bytes(bM, bN, bK); + } + + bool set_mma_param(TuneParam &tp) const + { + tp.block.x = 1; + tp.block.y = 16; + tp.block.z = 8; + + int bN = n; // n_atom_size; + int bM = m; + + tp.grid = dim3(out.Volume() * coarseSpin, (m + bM - 1) / bM, (n + bN - 1) / bN); + tp.set_max_shared_bytes = true; + + int bK = k_atom_size; + int shared_bytes = shared_bytes_per_block(bM, bN, bK); + tp.shared_bytes = shared_bytes; + + return shared_bytes <= device::maximum_dynamic_shared_memory(); + } + + void launch_mma(TuneParam &tp, const qudaStream_t &stream) + { + constexpr int bN = n; // n_atom_size; + constexpr int bM = m; + constexpr int bK = k_atom_size; + constexpr int block_y = 16; + constexpr int block_z = 8; + constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); + if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { + using Arg = RestrictMmaArg; + Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity); + tp.set_max_shared_bytes = true; + launch_cuda(tp, stream, arg); + } else { + errorQuda("Using too many shared memory bytes per block: %d", shared_bytes); + } + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + launch_mma(tp, stream); + } + }; + + template + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) + { + if (out.Nspin() != 2) errorQuda("Unsupported nSpin %d", out.Nspin()); + constexpr int coarseSpin = 2; + + // first check that the spin_map matches the spin_mapper + spin_mapper mapper; + for (int s = 0; s < fineSpin; s++) + for (int p = 0; p < 2; p++) + if (mapper(s, p) != spin_map[s][p]) errorQuda("Spin map does not match spin_mapper"); + + if (v.Precision() == QUDA_HALF_PRECISION) { + if constexpr (is_enabled(QUDA_HALF_PRECISION)) { + RestrictMmaLaunch restrictor( + out, in, v, fine_to_coarse, coarse_to_fine, parity); + } else { + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); + } + } else if (v.Precision() == in.Precision()) { + RestrictMmaLaunch restrictor( + out, in, v, fine_to_coarse, coarse_to_fine, parity); + } else { + errorQuda("Unsupported V precision %d", v.Precision()); + } + } + + template + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) + { + if (!is_enabled_spin(in.Nspin())) errorQuda("nSpin %d has not been built", in.Nspin()); + + if (in.Nspin() == 2) { + RestrictMma(out, in, v, fine_to_coarse, + coarse_to_fine, spin_map, parity); + } else if constexpr (fineColor == 3) { + if (in.Nspin() == 4) { + if constexpr (is_enabled_spin(4)) { + if (in.Precision() == out.Precision()) { + RestrictMma( + out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + } else if (in.Precision() == QUDA_HALF_PRECISION) { + // if constexpr (is_enabled(QUDA_HALF_PRECISION)) { + // RestrictMma(out, in, v, + // fine_to_coarse, coarse_to_fine, spin_map, + // parity); + // } else { + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); + // } + } else { + errorQuda("Unsupported precision %d", in.Precision()); + } + } + } else if (in.Nspin() == 1) { +#if 0 + if constexpr (is_enabled_spin(1)) { + if (in.Precision() == out.Precision()) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else if (in.Precision() == QUDA_HALF_PRECISION) { + if constexpr (is_enabled(QUDA_HALF_PRECISION)) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else { + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); + } + } else { + errorQuda("Unsupported precision %d", in.Precision()); + } + } +#else + errorQuda("Unexpected nSpin = %d", in.Nspin()); +#endif + } else { + errorQuda("Unexpected nSpin = %d", in.Nspin()); + } + } else { + errorQuda("Unexpected spin %d and color %d combination", in.Nspin(), in.Ncolor()); + } + } + + // clang-format off + constexpr int fineColor = @QUDA_MULTIGRID_NC_NVEC@; + constexpr int coarseColor = @QUDA_MULTIGRID_NVEC2@; + constexpr int nVec = @QUDA_MULTIGRID_MRHS@; + // clang-format on + + template + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) + { + int aggregate_size = in.Volume() / out.Volume(); + if (aggregate_size == 128) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + } else { + errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); + } + } + + template <> + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, + const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity) + { + if constexpr (is_enabled_multigrid()) { + + checkLocation(out, in, v); + if (in.Nspin() == 2) checkPrecision(in, out); + QudaPrecision precision = out.Precision(); + + if (precision == QUDA_DOUBLE_PRECISION) { + if constexpr (is_enabled_multigrid_double()) + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + else errorQuda("Double precision multigrid has not been enabled"); + } else if (precision == QUDA_SINGLE_PRECISION) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + } else { + errorQuda("Unsupported precision %d", precision); + } + } else { + errorQuda("Multigrid has not been built"); + } + } + +} // namespace quda From f0f88d5b5787e619db44714a7d6b7da68d071a13 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 4 Sep 2024 09:49:19 -0700 Subject: [PATCH 08/79] Add expands for restrictor with MMA. --- lib/restrictor_mma.in.cu | 139 ++++++++++++++++++++++++++++++++++----- 1 file changed, 124 insertions(+), 15 deletions(-) diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index faa554f666..c680d94d0f 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -4,6 +4,7 @@ #include #include #include +#include namespace quda { @@ -23,7 +24,39 @@ namespace quda unsigned int sharedBytesPerThread() const { return 0; } - bool advanceTuneParam(TuneParam &) const { return false; } + bool advanceTuneParam(TuneParam ¶m) const + { + auto advancer = [&](int &i, int limit) -> bool { + if (i < limit) { + i++; + return set_mma_param(param); + } else { + return false; + } + }; + + if (advancer(param.aux.x, numFactors((block_limit + block_atom_size - 1) / block_atom_size) - 1)) { + return true; + } else { + param.aux.x = 0; + if (advancer(param.aux.y, numFactors((n + n_atom_size - 1) / n_atom_size) - 1)) { + return true; + } else { + param.aux.y = 0; + if (advancer(param.aux.z, numFactors((m + m_atom_size - 1) / m_atom_size) - 1)) { + return true; + } else { + param.aux.z = 0; + if (advancer(param.aux.w, numFactors((k + k_atom_size - 1) / k_atom_size) - 1)) { + return true; + } else { + param.aux.w = 0; + return false; + } + } + } + } + } void initTuneParam(TuneParam ¶m) const { @@ -61,6 +94,8 @@ namespace quda strcat(aux, out.AuxString().c_str()); setRHSstring(aux, in.Nvec()); + strcat(aux, mma_t::get_type_name().c_str()); + apply(device::get_default_stream()); } @@ -77,6 +112,7 @@ namespace quda static constexpr int m_atom_size = mma_t::MMA_M; static constexpr int k_atom_size = fineColor * spin_block_factor * 4; static constexpr int block_atom_size = 32 / 8; + static constexpr int block_limit = 32; long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); } @@ -93,30 +129,31 @@ namespace quda bool set_mma_param(TuneParam &tp) const { + static_assert(m % m_atom_size == 0, "m modulo m_atom_size == 0"); + static_assert(n % n_atom_size == 0, "n modulo n_atom_size == 0"); + static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); + tp.block.x = 1; - tp.block.y = 16; + tp.block.y + = block_atom_size * get_int_factor_array((block_limit + block_atom_size - 1) / block_atom_size)[tp.aux.x]; tp.block.z = 8; - int bN = n; // n_atom_size; - int bM = m; + int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; + int bM = m_atom_size * get_int_factor_array((m + m_atom_size - 1) / m_atom_size)[tp.aux.z]; tp.grid = dim3(out.Volume() * coarseSpin, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; - int bK = k_atom_size; + int bK = k_atom_size * get_int_factor_array(k / k_atom_size)[tp.aux.w]; int shared_bytes = shared_bytes_per_block(bM, bN, bK); tp.shared_bytes = shared_bytes; return shared_bytes <= device::maximum_dynamic_shared_memory(); } + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { - constexpr int bN = n; // n_atom_size; - constexpr int bM = m; - constexpr int bK = k_atom_size; - constexpr int block_y = 16; - constexpr int block_z = 8; constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { using Arg = RestrictMmaArg + void launch_mma_span_k(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.w == d) { + constexpr IntFactorArray k_factors; + launch_mma(tp, stream); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_k(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.z."); + } + } + } + + template + void launch_mma_span_m(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.z == d) { + constexpr IntFactorArray<(m + m_atom_size - 1) / m_atom_size> m_factors; + std::make_index_sequence().size()> k_indices; + launch_mma_span_k(tp, stream, k_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_m(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.z."); + } + } + } + + template + void launch_mma_span_n(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.y == d) { + constexpr IntFactorArray<(n + n_atom_size - 1) / n_atom_size> n_factors; + std::make_index_sequence().size()> m_indices; + launch_mma_span_m(tp, stream, m_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_n(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.y."); + } + } + } + + template + void launch_mma_span_block(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.x == d) { + constexpr IntFactorArray<(k + block_atom_size - 1) / block_atom_size> block_factors; + std::make_index_sequence().size()> n_indices; + launch_mma_span_n(tp, stream, n_indices); + } else { + if constexpr (sizeof...(Ds) > 0) { + launch_mma_span_block(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.x."); + } + } + } + + void launch_mma(TuneParam &tp, const qudaStream_t &stream) + { + std::make_index_sequence().size()> block_indices; + launch_mma_span_block(tp, stream, block_indices); + } + void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); @@ -227,19 +333,21 @@ namespace quda template void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, - const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) + const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) { int aggregate_size = in.Volume() / out.Volume(); if (aggregate_size == 128) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); } else { errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); } } template <> - void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, - const int *fine_to_coarse, const int *coarse_to_fine, const int * const * spin_map, int parity) + void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, + const ColorSpinorField &v, const int *fine_to_coarse, + const int *coarse_to_fine, const int *const *spin_map, int parity) { if constexpr (is_enabled_multigrid()) { @@ -250,7 +358,8 @@ namespace quda if (precision == QUDA_DOUBLE_PRECISION) { if constexpr (is_enabled_multigrid_double()) RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); - else errorQuda("Double precision multigrid has not been enabled"); + else + errorQuda("Double precision multigrid has not been enabled"); } else if (precision == QUDA_SINGLE_PRECISION) { RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); } else { From 92b1c39d855a2ea171db8fa4ecda400a4d3d6186 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 4 Sep 2024 12:07:42 -0700 Subject: [PATCH 09/79] Add shared memory caching for the the restrictor kernel. --- include/kernels/restrictor_mma.cuh | 18 +++++++++++---- lib/prolongator.in.cpp | 2 +- lib/restrictor.in.cpp | 2 +- lib/restrictor_mma.in.cu | 37 +++++++++++++++++++++--------- 4 files changed, 41 insertions(+), 18 deletions(-) diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 1f37350c48..bb17e63c4d 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -82,7 +82,8 @@ namespace quda template inline void __device__ load_g2s(smem_obj_t &smem_real, smem_obj_t &smem_imag, const gmem_obj_t &gmem, int x_coarse, - int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, const Arg &arg) + int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, + int *coarse_to_fine, const Arg &arg) { // v as a constexpr int elements_per_thread = 16 / (sizeof(typename gmem_obj_t::store_type) * 2); static_assert(contiguous_dim % elements_per_thread == 0, "contiguous_dim %% elements_per_thread == 0"); @@ -104,8 +105,7 @@ namespace quda // look-up map is ordered as (coarse-block-id + fine-point-id), // with fine-point-id parity ordered - const int x_fine_site_id = (x_coarse * 2 + parity) * arg.aggregate_size_cb + x_fine_cb_offset; - const int x_fine = arg.coarse_to_fine[x_fine_site_id]; + const int x_fine = coarse_to_fine[parity * arg.aggregate_size_cb + x_fine_cb_offset]; const int x_fine_cb = x_fine - parity * arg.in.VolumeCB(); const int v_parity = (gmem.Nparity() == 2) ? parity : 0; @@ -161,6 +161,14 @@ namespace quda static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); + __shared__ int coarse_to_fine[Arg::aggregate_size]; + int index = target::thread_idx().y + Arg::block_y * target::thread_idx().z; + while (index < Arg::aggregate_size) { + coarse_to_fine[index] = arg.coarse_to_fine[x_coarse * 2 * arg.aggregate_size_cb + index]; + index += Arg::block_y * Arg::block_z; + } + __syncthreads(); + extern __shared__ typename mma_t::compute_t smem_ptr[]; typename Config::SmemObjA smem_obj_a_real(smem_ptr); @@ -181,11 +189,11 @@ namespace quda constexpr bool a_dagger = true; load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, m_offset, - aggregate_k_offset, arg); + aggregate_k_offset, coarse_to_fine, arg); constexpr bool b_dagger = false; load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, n_offset, - aggregate_k_offset, arg); + aggregate_k_offset, coarse_to_fine, arg); __syncthreads(); accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 47ace8a831..812748b419 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -122,7 +122,7 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - if (1) { + if (in.size() % 16 == 0) { // use MMA Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); } else { diff --git a/lib/restrictor.in.cpp b/lib/restrictor.in.cpp index 2d9062398d..9afd03ca4a 100644 --- a/lib/restrictor.in.cpp +++ b/lib/restrictor.in.cpp @@ -122,7 +122,7 @@ namespace quda IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - if (in[0].Ncolor() == 3) { + if (in.size() % 16 == 0) { Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); } else { Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index c680d94d0f..1cd0cd3afa 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -217,7 +217,7 @@ namespace quda void launch_mma_span_block(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) { if (tp.aux.x == d) { - constexpr IntFactorArray<(k + block_atom_size - 1) / block_atom_size> block_factors; + constexpr IntFactorArray<(block_limit + block_atom_size - 1) / block_atom_size> block_factors; std::make_index_sequence().size()> n_indices; launch_mma_span_n(tp, stream, n_indices); } else { @@ -286,13 +286,17 @@ namespace quda RestrictMma( out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); } else if (in.Precision() == QUDA_HALF_PRECISION) { - // if constexpr (is_enabled(QUDA_HALF_PRECISION)) { - // RestrictMma(out, in, v, - // fine_to_coarse, coarse_to_fine, spin_map, - // parity); - // } else { - errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); - // } +#if 0 + if constexpr (is_enabled(QUDA_HALF_PRECISION)) { + RestrictMma(out, in, v, + fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else { +#endif + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); +#if 0 + } +#endif } else { errorQuda("Unsupported precision %d", in.Precision()); } @@ -337,9 +341,20 @@ namespace quda { int aggregate_size = in.Volume() / out.Volume(); if (aggregate_size == 128) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else { + if constexpr (fineColor == 3 && coarseColor == 24) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else { + errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); + } + } else if (aggregate_size == 16) { + if constexpr (fineColor == 24 && coarseColor == 32) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else { + errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); + } + } else{ errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); } } From e7b9361c76739ff0fccff9ff054b76775c020821 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Fri, 6 Sep 2024 11:16:30 -0700 Subject: [PATCH 10/79] Allow restrictor_mma to have N % bN != 0; More generic optimizations for loading from gmem. --- include/kernels/coarse_op_kernel_mma.cuh | 4 +- include/kernels/prolongator_mma.cuh | 4 +- include/kernels/restrictor_mma.cuh | 10 ++- include/targets/cuda/mma_tensor_op/gemm.cuh | 18 ++--- .../cuda/mma_tensor_op/gmem_loader.cuh | 76 ++++++++++--------- include/targets/cuda/mma_tensor_op/simt.cuh | 5 ++ lib/prolongator.in.cpp | 6 +- lib/restrictor.in.cpp | 8 +- lib/restrictor_mma.in.cu | 3 +- 9 files changed, 74 insertions(+), 60 deletions(-) diff --git a/include/kernels/coarse_op_kernel_mma.cuh b/include/kernels/coarse_op_kernel_mma.cuh index 75c819fdb9..3aca52cdd5 100644 --- a/include/kernels/coarse_op_kernel_mma.cuh +++ b/include/kernels/coarse_op_kernel_mma.cuh @@ -246,7 +246,7 @@ namespace quda __syncthreads(); a_loader.template g2r(a, m_offset, 0); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); __syncthreads(); for (int s_col = 0; s_col < fineSpin; s_col++) { // which chiral block @@ -255,7 +255,7 @@ namespace quda __syncthreads(); b_loader.template g2r(b, n_offset, 0); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); #pragma unroll 1 diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index 7a1b10c5f9..06f4998ccd 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -124,8 +124,8 @@ namespace quda __syncthreads(); a_loader.template g2r(a, m_offset, k_offset); b_loader.template g2r(b, n_offset, k_offset); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); } diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index bb17e63c4d..c05222fa30 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -80,7 +80,7 @@ namespace quda } }; - template + template inline void __device__ load_g2s(smem_obj_t &smem_real, smem_obj_t &smem_imag, const gmem_obj_t &gmem, int x_coarse, int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, int *coarse_to_fine, const Arg &arg) @@ -92,6 +92,8 @@ namespace quda * Arg::aggregate_per_block) { int thread_idx = thread; int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; + constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); + if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { thread_idx /= (contiguous_dim / elements_per_thread); int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin thread_idx /= Arg::spin_block_factor; @@ -138,6 +140,7 @@ namespace quda static_assert(smem_obj_t::ldm == 1, "smem_obj_t::ldm == 1"); smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_real)); smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_imag)); + } thread += Arg::block_y * Arg::block_z; } @@ -158,7 +161,6 @@ namespace quda using Config = mma::MmaConfig; static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); - static_assert(N % Arg::bN == 0, "N %% Arg::bN != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); __shared__ int coarse_to_fine[Arg::aggregate_size]; @@ -188,11 +190,11 @@ namespace quda __syncthreads(); constexpr bool a_dagger = true; - load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, m_offset, + load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, m_offset, aggregate_k_offset, coarse_to_fine, arg); constexpr bool b_dagger = false; - load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, n_offset, + load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, n_offset, aggregate_k_offset, coarse_to_fine, arg); __syncthreads(); diff --git a/include/targets/cuda/mma_tensor_op/gemm.cuh b/include/targets/cuda/mma_tensor_op/gemm.cuh index 4662011576..eb0a0993aa 100644 --- a/include/targets/cuda/mma_tensor_op/gemm.cuh +++ b/include/targets/cuda/mma_tensor_op/gemm.cuh @@ -259,10 +259,10 @@ namespace quda __syncthreads(); a_loader.template g2r(a, m_offset, 0); // bk = 0 - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); b_loader.template g2r(b, n_offset, 0); // bk = 0 - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); #pragma unroll 1 @@ -279,8 +279,8 @@ namespace quda // to smem. __syncthreads(); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); } @@ -322,10 +322,10 @@ namespace quda __syncthreads(); a_loader.template g2r(a, m_offset, 0); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); b_loader.template g2r(b, n_offset, 0); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); #pragma unroll 1 @@ -390,10 +390,10 @@ namespace quda __syncthreads(); a_loader.template g2r(a, 0, 0); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); b_loader.template g2r(b, 0, 0); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); #pragma unroll 1 @@ -435,7 +435,7 @@ namespace quda if (a_m + bM < M) { __syncthreads(); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); __syncthreads(); } } diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 3e18e8bdcf..4288bcea2a 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -127,31 +127,37 @@ namespace quda inline __device__ void convert_x(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, int n_idx, float scale_inv) { - static_assert(batch == 1, "for half2, for now, batch needs to be 1"); - if (x) { - auto xx = p[(m_idx + 0) * ld + n_idx]; - auto yy = p[(m_idx + 1) * ld + n_idx]; + if constexpr (x) { + complex vx[batch]; + complex vy[batch]; + batch_load_t, batch>::load(vx, &p[(m_idx + 0) * ld + n_idx]); + batch_load_t, batch>::load(vy, &p[(m_idx + 1) * ld + n_idx]); - if (fixed) { - reg_real[0] = __floats2half2_rn(scale_inv * xx.real(), scale_inv * yy.real()); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[0] = __floats2half2_rn(scale_inv_conj * xx.imag(), scale_inv_conj * yy.imag()); - } else { - reg_real[0] = __floats2half2_rn(+xx.real(), +yy.real()); - reg_imag[0] = __floats2half2_rn(dagger ? -xx.imag() : +xx.imag(), dagger ? -yy.imag() : +yy.imag()); +#pragma unroll + for (int b = 0; b < batch; b++) { + if (fixed) { + reg_real[b] = __floats2half2_rn(scale_inv * vx[b].real(), scale_inv * vy[b].real()); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = __floats2half2_rn(scale_inv_conj * vx[b].imag(), scale_inv_conj * vy[b].imag()); + } else { + reg_real[b] = __floats2half2_rn(+vx[b].real(), +vy[b].real()); + reg_imag[b] = __floats2half2_rn(dagger ? -vx[b].imag() : +vx[b].imag(), dagger ? -vy[b].imag() : +vy[b].imag()); + } } } else { - using store_type = T; - using store_array = typename VectorType::type; - store_array v = *reinterpret_cast(&p[n_idx * ld + m_idx]); + complex v[batch * 2]; + batch_load_t, batch * 2>::load(v, &p[n_idx * ld + m_idx]); - if (fixed) { - reg_real[0] = __floats2half2_rn(scale_inv * v.x, scale_inv * v.z); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[0] = __floats2half2_rn(scale_inv_conj * v.y, scale_inv_conj * v.w); - } else { - reg_real[0] = __floats2half2_rn(+v.x, +v.z); - reg_imag[0] = __floats2half2_rn(dagger ? -v.y : +v.y, dagger ? -v.w : +v.w); +#pragma unroll + for (int b = 0; b < batch; b++) { + if (fixed) { + reg_real[b] = __floats2half2_rn(scale_inv * v[b * 2].real(), scale_inv * v[b * 2 + 1].real()); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = __floats2half2_rn(scale_inv_conj * v[b * 2].imag(), scale_inv_conj * v[b * 2 + 1].imag()); + } else { + reg_real[b] = __floats2half2_rn(+v[b * 2].real(), +v[b * 2 + 1].real()); + reg_imag[b] = __floats2half2_rn(dagger ? -v[b * 2].imag() : +v[b * 2].imag(), dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag()); + } } } } @@ -263,10 +269,9 @@ namespace quda template constexpr int get_mn_batch(int internal_batch, int register_dim, int block) { - return (internal_batch > 1) ? 1 : - ((register_dim % 4 == 0 && block % 4 == 0 && sizeof(T) * 8 <= 16) ? - 4 : - ((register_dim % 2 == 0 && block % 2 == 0 && sizeof(T) * 4 <= 16) ? 2 : 1)); + return ((register_dim % (internal_batch * 4) == 0 && block % (internal_batch * 4) == 0 && sizeof(T) * internal_batch * 8 <= 16) ? + 4 : + ((register_dim % (internal_batch * 2) == 0 && block % (internal_batch * 2) == 0 && sizeof(T) * internal_batch * 4 <= 16) ? 2 : 1)); } /** @@ -474,6 +479,8 @@ namespace quda auto scale_inv = gmem.get_scale_inv(); constexpr bool fixed = GmemAccessor::fixed; + using store_t = typename GmemAccessor::store_type; + constexpr bool x = (transpose == dagger); constexpr int n_stride = x ? block_y : block_z; @@ -488,7 +495,7 @@ namespace quda constexpr bool check_shared_bound = !(bM % m_stride == 0 && bN % n_stride == 0); if constexpr (x) { - constexpr int n_batch = get_mn_batch(batch, n_dim, bN); + constexpr int n_batch = get_mn_batch(1, n_dim, bN); #pragma unroll for (int n = 0; n < n_dim / n_batch; n++) { @@ -516,7 +523,7 @@ namespace quda } } } else { - constexpr int m_batch = get_mn_batch(batch, m_dim, bM); + constexpr int m_batch = get_mn_batch(batch, m_dim, bM); #pragma unroll for (int n = 0; n < n_dim; n++) { @@ -552,10 +559,12 @@ namespace quda } } - template __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) + template __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) { constexpr bool x = (transpose == dagger); + using store_t = typename GmemAccessor::store_type; + constexpr int n_stride = transpose == dagger ? block_y : block_z; constexpr int m_stride = transpose == dagger ? block_z * batch : block_y * batch; int n_thread_offset = transpose == dagger ? threadIdx.y : threadIdx.z; @@ -565,7 +574,7 @@ namespace quda constexpr int m_dim = (bM + m_stride - 1) / m_stride; if constexpr (x) { - constexpr int n_batch = get_mn_batch(batch, n_dim, bN); + constexpr int n_batch = get_mn_batch(1, n_dim, bN); #pragma unroll for (int n = 0; n < n_dim / n_batch; n++) { #pragma unroll @@ -589,7 +598,7 @@ namespace quda } } } else { - constexpr int m_batch = get_mn_batch(batch, m_dim, bM); + constexpr int m_batch = get_mn_batch(batch, m_dim, bM); #pragma unroll for (int n = 0; n < n_dim; n++) { #pragma unroll @@ -597,8 +606,7 @@ namespace quda const int n_idx = n * n_stride + n_thread_offset; const int m_idx = (m * m_stride + m_thread_offset) * m_batch; if (m_idx < bM && n_idx < bN) { - if constexpr (SmemObj::ldm == 1 && SmemObj::ldn % m_batch == 0) { - static_assert(SmemObj::ldm == 1, "SmemObj::ldm == 1"); + if constexpr (SmemObj::ldm == 1 && SmemObj::ldn % (batch * m_batch) == 0) { load_t v_real[m_batch]; load_t v_imag[m_batch]; #pragma unroll @@ -611,8 +619,8 @@ namespace quda } else { #pragma unroll for (int b = 0; b < m_batch; b++) { - smem_real.vector_load(m_idx + b, n_idx, reg_real[(m * m_batch + b) * n_dim + n]); - smem_imag.vector_load(m_idx + b, n_idx, reg_imag[(m * m_batch + b) * n_dim + n]); + smem_real.vector_load(m_idx + b * batch, n_idx, reg_real[(m * m_batch + b) * n_dim + n]); + smem_imag.vector_load(m_idx + b * batch, n_idx, reg_imag[(m * m_batch + b) * n_dim + n]); } } } diff --git a/include/targets/cuda/mma_tensor_op/simt.cuh b/include/targets/cuda/mma_tensor_op/simt.cuh index 7b5b854eaf..2a3fd766d5 100644 --- a/include/targets/cuda/mma_tensor_op/simt.cuh +++ b/include/targets/cuda/mma_tensor_op/simt.cuh @@ -36,6 +36,11 @@ namespace quda using compute_t = T; using load_t = T; + static constexpr bool do_rescale() + { + return false; + } + static std::string get_type_name() { char s[TuneKey::aux_n] = ",simt,m"; diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 812748b419..75afd1002d 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -71,10 +71,10 @@ namespace quda Prolongate(vv_cmp, in, v, fine_to_coarse, spin_map, parity); blas::mxpy(out, v_cmp); - auto vn = blas::norm2(vv_cmp); - printf("n = "); + auto vn = blas::max(vv_cmp); + printf("prolongator %d->%d = ", coarseColor, fineColor); for (size_t i = 0; i < vn.size(); i++) { - printf("%f ", vn[i]); + printf("%4.2e ", vn[i]); } printf("\n"); #endif diff --git a/lib/restrictor.in.cpp b/lib/restrictor.in.cpp index 9afd03ca4a..9c3315add2 100644 --- a/lib/restrictor.in.cpp +++ b/lib/restrictor.in.cpp @@ -59,7 +59,7 @@ namespace quda RestrictMma2(v_out, v_in, V, fine_to_coarse, coarse_to_fine, spin_map, parity, nvecs); BlockTransposeBackward(v_out, out); -#if 1 +#if 0 std::vector v_cmp(out.size()); for (size_t i = 0; i < out.size(); i++) { ColorSpinorParam param(out[i]); @@ -70,10 +70,10 @@ namespace quda Restrict(vv_cmp, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); blas::mxpy(out, v_cmp); - auto vn = blas::norm2(vv_cmp); - printf("n = "); + auto vn = blas::max(vv_cmp); + printf("restrictor %d->%d = ", fineColor, coarseColor); for (size_t i = 0; i < vn.size(); i++) { - printf("%f ", vn[i]); + printf("%4.2e ", vn[i]); } printf("\n"); #endif diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 1cd0cd3afa..405b71b660 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -110,7 +110,7 @@ namespace quda static constexpr int n_atom_size = mma_t::MMA_N; static constexpr int m_atom_size = mma_t::MMA_M; - static constexpr int k_atom_size = fineColor * spin_block_factor * 4; + static constexpr int k_atom_size = fineColor * spin_block_factor * mma_t::MMA_K; static constexpr int block_atom_size = 32 / 8; static constexpr int block_limit = 32; @@ -130,7 +130,6 @@ namespace quda bool set_mma_param(TuneParam &tp) const { static_assert(m % m_atom_size == 0, "m modulo m_atom_size == 0"); - static_assert(n % n_atom_size == 0, "n modulo n_atom_size == 0"); static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; From 6bb02f6ae6a94ebdbdea9973125cc539c000bf96 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Mon, 9 Sep 2024 21:43:29 -0700 Subject: [PATCH 11/79] Add rescaling to prolongator. --- include/kernels/prolongator_mma.cuh | 29 +- .../cuda/mma_tensor_op/gmem_loader.cuh | 267 +++++++++++++++--- 2 files changed, 246 insertions(+), 50 deletions(-) diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index 06f4998ccd..a3f884298f 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -120,14 +120,27 @@ namespace quda constexpr bool a_dagger = true; constexpr bool b_dagger = true; - for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { - __syncthreads(); - a_loader.template g2r(a, m_offset, k_offset); - b_loader.template g2r(b, n_offset, k_offset); - a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); - b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); - __syncthreads(); - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + if constexpr (Arg::mma_t::do_rescale()) { + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + __syncthreads(); + constexpr bool rescale = true; + float a_rescale = a_loader.template g2r_rescale(a, m_offset, k_offset); + float b_rescale = b_loader.template g2r_rescale(b, n_offset, k_offset); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + __syncthreads(); + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, a_rescale * b_rescale); + } + } else { + for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { + __syncthreads(); + a_loader.template g2r(a, m_offset, k_offset); + b_loader.template g2r(b, n_offset, k_offset); + a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); + b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); + __syncthreads(); + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } } auto c = arg.out(spinor_parity, x_cb, spin * Arg::spin_block_factor, 0, 0); diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 4288bcea2a..1cb7acc7c1 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -25,6 +25,14 @@ namespace quda static constexpr int value = 1; }; + inline __device__ void zero(float2 ®_real, float2 ®_imag) + { + reg_real.x = 0; + reg_real.y = 0; + reg_imag.x = 0; + reg_imag.y = 0; + } + inline __device__ void zero(half2 ®_real, half2 ®_imag) { reg_real = __half2half2(0); @@ -39,6 +47,8 @@ namespace quda inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } + inline __device__ float abs_max(float2 a, float max) { return fmaxf(fabsf(a.y), fmaxf(fabsf(a.x), max)); } + template struct batch_load_t { }; @@ -120,6 +130,56 @@ namespace quda static auto __device__ get(half2 v[]) { return v[0]; } }; + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ void convert_x(float2 reg_real[batch], float2 reg_imag[batch], complex *p, int m_idx, + int n_idx, float scale_inv) + { + if constexpr (x) { + complex vx[batch]; + complex vy[batch]; + batch_load_t, batch>::load(vx, &p[(m_idx + 0) * ld + n_idx]); + batch_load_t, batch>::load(vy, &p[(m_idx + 1) * ld + n_idx]); + +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + reg_real[b].x = scale_inv * vx[b].real(); + reg_real[b].y = scale_inv * vy[b].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b].x = scale_inv_conj * vx[b].imag(); + reg_imag[b].y = scale_inv_conj * vy[b].imag(); + } else { + reg_real[b].x = +vx[b].real(); + reg_real[b].y = +vy[b].real(); + reg_imag[b].x = dagger ? -vx[b].imag() : +vx[b].imag(); + reg_imag[b].y = dagger ? -vy[b].imag() : +vy[b].imag(); + } + } + } else { + complex v[batch * 2]; + batch_load_t, batch * 2>::load(v, &p[n_idx * ld + m_idx]); + +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + reg_real[b].x = scale_inv * v[b * 2 + 0].real(); + reg_real[b].y = scale_inv * v[b * 2 + 1].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b].x = scale_inv_conj * v[b * 2 + 0].imag(); + reg_imag[b].y = scale_inv_conj * v[b * 2 + 1].imag(); + } else { + reg_real[b].x = +v[b * 2 + 0].real(); + reg_real[b].y = +v[b * 2 + 1].real(); + reg_imag[b].x = dagger ? -v[b * 2 + 0].imag() : +v[b * 2 + 0].imag(); + reg_imag[b].y = dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag(); + } + } + } + } + /** @brief Load from global memory and store data in registers. */ @@ -135,13 +195,14 @@ namespace quda #pragma unroll for (int b = 0; b < batch; b++) { - if (fixed) { + if constexpr (fixed) { reg_real[b] = __floats2half2_rn(scale_inv * vx[b].real(), scale_inv * vy[b].real()); auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag[b] = __floats2half2_rn(scale_inv_conj * vx[b].imag(), scale_inv_conj * vy[b].imag()); } else { reg_real[b] = __floats2half2_rn(+vx[b].real(), +vy[b].real()); - reg_imag[b] = __floats2half2_rn(dagger ? -vx[b].imag() : +vx[b].imag(), dagger ? -vy[b].imag() : +vy[b].imag()); + reg_imag[b] + = __floats2half2_rn(dagger ? -vx[b].imag() : +vx[b].imag(), dagger ? -vy[b].imag() : +vy[b].imag()); } } } else { @@ -150,13 +211,60 @@ namespace quda #pragma unroll for (int b = 0; b < batch; b++) { - if (fixed) { + if constexpr (fixed) { reg_real[b] = __floats2half2_rn(scale_inv * v[b * 2].real(), scale_inv * v[b * 2 + 1].real()); auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag[b] = __floats2half2_rn(scale_inv_conj * v[b * 2].imag(), scale_inv_conj * v[b * 2 + 1].imag()); } else { reg_real[b] = __floats2half2_rn(+v[b * 2].real(), +v[b * 2 + 1].real()); - reg_imag[b] = __floats2half2_rn(dagger ? -v[b * 2].imag() : +v[b * 2].imag(), dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag()); + reg_imag[b] = __floats2half2_rn(dagger ? -v[b * 2].imag() : +v[b * 2].imag(), + dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag()); + } + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ void convert_x_rescale(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, + int n_idx, float scale_inv, float rescale) + { + if constexpr (x) { + complex vx[batch]; + complex vy[batch]; + batch_load_t, batch>::load(vx, &p[(m_idx + 0) * ld + n_idx]); + batch_load_t, batch>::load(vy, &p[(m_idx + 1) * ld + n_idx]); + +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + float scale_inv_rescale = scale_inv * rescale; + reg_real[b] = __floats2half2_rn(scale_inv_rescale * vx[b].real(), scale_inv_rescale * vy[b].real()); + auto scale_inv_conj = dagger ? -scale_inv_rescale : scale_inv_rescale; + reg_imag[b] = __floats2half2_rn(scale_inv_conj * vx[b].imag(), scale_inv_conj * vy[b].imag()); + } else { + reg_real[b] = __floats2half2_rn(+vx[b].real() * rescale, +vy[b].real() * rescale); + reg_imag[b] = __floats2half2_rn((dagger ? -vx[b].imag() : +vx[b].imag()) * rescale, + (dagger ? -vy[b].imag() : +vy[b].imag()) * rescale); + } + } + } else { + complex v[batch * 2]; + batch_load_t, batch * 2>::load(v, &p[n_idx * ld + m_idx]); + +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + float scale_inv_rescale = scale_inv * rescale; + reg_real[b] = __floats2half2_rn(scale_inv_rescale * v[b * 2].real(), scale_inv_rescale * v[b * 2 + 1].real()); + auto scale_inv_conj = dagger ? -scale_inv_rescale : scale_inv_rescale; + reg_imag[b] = __floats2half2_rn(scale_inv_conj * v[b * 2].imag(), scale_inv_conj * v[b * 2 + 1].imag()); + } else { + reg_real[b] = __floats2half2_rn(+v[b * 2].real() * rescale, +v[b * 2 + 1].real() * rescale); + reg_imag[b] = __floats2half2_rn((dagger ? -v[b * 2].imag() : +v[b * 2].imag()) * rescale, + (dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag()) * rescale); } } } @@ -175,7 +283,7 @@ namespace quda #pragma unroll for (int b = 0; b < batch; b++) { // auto xx = p[m_idx * ld + n_idx]; - if (fixed) { + if constexpr (fixed) { reg_real[b] = scale_inv * v[b].real(); auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag[b] = scale_inv_conj * v[b].imag(); @@ -190,7 +298,7 @@ namespace quda #pragma unroll for (int b = 0; b < batch; b++) { // auto xx = p[n_idx * ld + m_idx]; - if (fixed) { + if constexpr (fixed) { reg_real[b] = scale_inv * v[b].real(); auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag[b] = scale_inv_conj * v[b].imag(); @@ -212,7 +320,7 @@ namespace quda if (x) { auto xx = p[m_idx * ld + n_idx]; - if (fixed) { + if constexpr (fixed) { reg_real = scale_inv * xx.real() * rescale; auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag = scale_inv_conj * xx.imag() * rescale; @@ -223,7 +331,7 @@ namespace quda } else { auto xx = p[n_idx * ld + m_idx]; - if (fixed) { + if constexpr (fixed) { reg_real = scale_inv * xx.real() * rescale; auto scale_inv_conj = dagger ? -scale_inv : scale_inv; reg_imag = scale_inv_conj * xx.imag() * rescale; @@ -242,10 +350,10 @@ namespace quda { float this_max = 0.0f; - if (x) { + if constexpr (x) { auto xx = p[m_idx * ld + n_idx]; - if (fixed) { + if constexpr (fixed) { this_max = abs_max(scale_inv * xx.real(), this_max); this_max = abs_max(scale_inv * xx.imag(), this_max); } else { @@ -255,7 +363,7 @@ namespace quda } else { auto xx = p[n_idx * ld + m_idx]; - if (fixed) { + if constexpr (fixed) { this_max = abs_max(scale_inv * xx.real(), this_max); this_max = abs_max(scale_inv * xx.imag(), this_max); } else { @@ -269,9 +377,13 @@ namespace quda template constexpr int get_mn_batch(int internal_batch, int register_dim, int block) { - return ((register_dim % (internal_batch * 4) == 0 && block % (internal_batch * 4) == 0 && sizeof(T) * internal_batch * 8 <= 16) ? - 4 : - ((register_dim % (internal_batch * 2) == 0 && block % (internal_batch * 2) == 0 && sizeof(T) * internal_batch * 4 <= 16) ? 2 : 1)); + return ((register_dim % (internal_batch * 4) == 0 && block % (internal_batch * 4) == 0 + && sizeof(T) * internal_batch * 8 <= 16) ? + 4 : + ((register_dim % (internal_batch * 2) == 0 && block % (internal_batch * 2) == 0 + && sizeof(T) * internal_batch * 4 <= 16) ? + 2 : + 1)); } /** @@ -472,8 +584,8 @@ namespace quda * ld: leading dimension of global memory * dagger: if we need to store daggered (tranpose and hermision conjugate) */ - template - __device__ inline void g2r(const GmemAccessor &gmem, int m_offset, int n_offset) + template + __device__ inline float g2r_rescale(const GmemAccessor &gmem, int m_offset, int n_offset) { auto p = gmem.data(); auto scale_inv = gmem.get_scale_inv(); @@ -494,6 +606,11 @@ namespace quda constexpr bool check_global_bound = !(M % bM == 0 && N % bN == 0); constexpr bool check_shared_bound = !(bM % m_stride == 0 && bN % n_stride == 0); + using store_array_t = typename VectorType::type; + + store_array_t f_real[register_count]; + store_array_t f_imag[register_count]; + if constexpr (x) { constexpr int n_batch = get_mn_batch(1, n_dim, bN); #pragma unroll @@ -505,19 +622,19 @@ namespace quda int n_idx_blk = (n * n_stride + n_thread_offset) * n_batch; int m_idx_blk = m * m_stride + m_thread_offset; - if (!check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN)) { + int n_idx = n_idx_blk + n_offset; + int m_idx = m_idx_blk + m_offset; - int n_idx = n_idx_blk + n_offset; - int m_idx = m_idx_blk + m_offset; + bool b1 = !check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN); + bool b2 = !check_global_bound || (n_idx < N && m_idx < M); - if (!check_global_bound || (n_idx < N && m_idx < M)) { - convert_x( - ®_real[m * n_dim + n * n_batch], ®_imag[m * n_dim + n * n_batch], p, m_idx, n_idx, scale_inv); - } else { + if (b1 && b2) { + convert_x(&f_real[m * n_dim + n * n_batch], + &f_imag[m * n_dim + n * n_batch], p, m_idx, n_idx, scale_inv); + } else { #pragma unroll - for (int b = 0; b < n_batch; b++) { - zero(reg_real[m * n_dim + n * n_batch + b], reg_imag[m * n_dim + n * n_batch + b]); - } + for (int b = 0; b < n_batch; b++) { + zero(f_real[m * n_dim + n * n_batch + b], f_imag[m * n_dim + n * n_batch + b]); } } } @@ -533,33 +650,99 @@ namespace quda int n_idx_blk = n * n_stride + n_thread_offset; int m_idx_blk = (m * m_stride + m_thread_offset) * m_batch; - if (!check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN)) { + int n_idx = n_idx_blk + n_offset; + int m_idx = m_idx_blk + m_offset; - int n_idx = n_idx_blk + n_offset; - int m_idx = m_idx_blk + m_offset; + bool b1 = !check_shared_bound || (m_idx_blk < bM && n_idx_blk < bN); + bool b2 = !check_global_bound || (n_idx < N && m_idx < M); - if (!check_global_bound || (n_idx < N && m_idx < M)) { - load_t v_real[m_batch]; - load_t v_imag[m_batch]; - convert_x(v_real, v_imag, p, m_idx, n_idx, scale_inv); + if (b1 && b2) { + store_array_t v_real[m_batch]; + store_array_t v_imag[m_batch]; + convert_x(v_real, v_imag, p, m_idx, n_idx, scale_inv); #pragma unroll - for (int b = 0; b < m_batch; b++) { - reg_real[(m * m_batch + b) * n_dim + n] = v_real[b]; - reg_imag[(m * m_batch + b) * n_dim + n] = v_imag[b]; - } - } else { + for (int b = 0; b < m_batch; b++) { + f_real[(m * m_batch + b) * n_dim + n] = v_real[b]; + f_imag[(m * m_batch + b) * n_dim + n] = v_imag[b]; + } + } else { #pragma unroll - for (int b = 0; b < m_batch; b++) { - zero(reg_real[(m * m_batch + b) * n_dim + n], reg_imag[(m * m_batch + b) * n_dim + n]); - } + for (int b = 0; b < m_batch; b++) { + zero(f_real[(m * m_batch + b) * n_dim + n], f_imag[(m * m_batch + b) * n_dim + n]); } } } } } + + float block_rescale_factor = 1.0f; + if constexpr (rescale) { + float thread_max = 0; +#pragma unroll + for (int n = 0; n < n_dim; n++) { +#pragma unroll + for (int m = 0; m < m_dim; m++) { + thread_max = abs_max(f_real[m * n_dim + n], thread_max); + thread_max = abs_max(f_imag[m * n_dim + n], thread_max); + } + } + + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); + + block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + } + + if constexpr (std::is_same_v) { +#pragma unroll + for (int n = 0; n < n_dim; n++) { +#pragma unroll + for (int m = 0; m < m_dim; m++) { + reg_real[m * n_dim + n] = __floats2half2_rn(f_real[m * n_dim + n].x * block_rescale_factor, + f_real[m * n_dim + n].y * block_rescale_factor); + reg_imag[m * n_dim + n] = __floats2half2_rn(f_imag[m * n_dim + n].x * block_rescale_factor, + f_imag[m * n_dim + n].y * block_rescale_factor); + } + } + } else { +#pragma unroll + for (int n = 0; n < n_dim; n++) { +#pragma unroll + for (int m = 0; m < m_dim; m++) { + reg_real[m * n_dim + n] = f_real[m * n_dim + n] * block_rescale_factor; + reg_imag[m * n_dim + n] = f_imag[m * n_dim + n] * block_rescale_factor; + } + } + } + + return 1.0f / block_rescale_factor; + } + + /** + * ld: leading dimension of global memory + * dagger: if we need to store daggered (tranpose and hermision conjugate) + */ + template + __device__ inline void g2r(const GmemAccessor &gmem, int m_offset, int n_offset) + { + constexpr bool rescale = false; + g2r_rescale(gmem, m_offset, n_offset); } - template __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) + template + __device__ inline void r2s(SmemObj &smem_real, SmemObj &smem_imag) { constexpr bool x = (transpose == dagger); From 2d8867e267eab7559018e3d758b02ae5eb525991 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 10 Sep 2024 13:02:52 -0700 Subject: [PATCH 12/79] Modify the MMA types. --- .../cuda/mma_tensor_op/hmma_m16n8k8_sm70.cuh | 277 ++++++++++++++++++ .../cuda/mma_tensor_op/smma_m16n8k8_sm70.cuh | 150 ++++++++++ lib/dslash_coarse_mma.in.hpp | 6 +- lib/prolongator_mma.in.cu | 7 +- lib/restrictor_mma.in.cu | 7 +- 5 files changed, 441 insertions(+), 6 deletions(-) create mode 100644 include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm70.cuh create mode 100644 include/targets/cuda/mma_tensor_op/smma_m16n8k8_sm70.cuh diff --git a/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm70.cuh b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm70.cuh new file mode 100644 index 0000000000..6ddb7def90 --- /dev/null +++ b/include/targets/cuda/mma_tensor_op/hmma_m16n8k8_sm70.cuh @@ -0,0 +1,277 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +// Here we implement the architecture dependent part of MMA for Volta (sm70, the mma.sync.m8n8k4 instruction). + +namespace quda +{ + namespace hmma + { + + using half = mma::half; + using half2 = mma::half2; + + template struct hmma_x_t { + }; + + template <> struct hmma_x_t<16, 8, 8, half, half2> { + + static __device__ __host__ constexpr int inline pad_size(int m) { return m == 48 ? 2 : 10; } + + static constexpr bool do_rescale() + { + return true; // true because we use FP16 + } + + static constexpr int MMA_M = 16; + static constexpr int MMA_N = 8; + static constexpr int MMA_K = 8; + + static constexpr int warp_size = 32; + + using compute_t = half; + using load_t = half2; + + struct WarpRegisterMapping { + + int warp_id; + int row_offset; // quad_row * 8 + quad_hilo * 4 + int quad_hilo; + int quad_col; + int quad_thread; // 0,1,2,3 + + __device__ inline WarpRegisterMapping(int thread_id) + { + warp_id = thread_id >> 5; + int lane_id = thread_id & 31; + int octl_id = lane_id >> 2; + int quad_id = octl_id & 3; + int quad_row = quad_id & 1; + quad_hilo = (octl_id >> 2) & 1; + quad_col = quad_id >> 1; + quad_thread = lane_id & 3; + row_offset = quad_row * 8 + quad_hilo * 4; + } + }; + + static std::string get_type_name() + { + char s[TuneKey::aux_n] = ",1xfp16,m"; + i32toa(s + strlen(s), MMA_M); + strcat(s, "n"); + i32toa(s + strlen(s), MMA_N); + strcat(s, "k"); + i32toa(s + strlen(s), MMA_K); + return s; + } + + struct OperandA { + + unsigned reg[2]; + + template + __device__ inline void load(const void *smem, int k, int warp_row, const WarpRegisterMapping &wrm) + { + const unsigned *A = reinterpret_cast(smem); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = (warp_row * MMA_M + wrm.row_offset) / 2; + int thread_offset_a = idx_strided * (lda / 2) + idx_contiguous; + reg[0] = A[thread_offset_a + 0]; + reg[1] = A[thread_offset_a + 1]; + } + + template + __device__ inline void load(const SmemObj &smem_obj, int k, int warp_row, const WarpRegisterMapping &wrm) + { + const unsigned *A = reinterpret_cast(smem_obj.ptr); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = (warp_row * MMA_M + wrm.row_offset) / 2; + const int thread_offset_a = idx_strided * (SmemObj::ldn / 2) + idx_contiguous; + reg[0] = A[thread_offset_a]; + reg[1] = A[thread_offset_a + 1]; + } + + __device__ inline void negate() + { + asm volatile("neg.f16x2 %0, %0;" : "+r"(reg[0])); + asm volatile("neg.f16x2 %0, %0;" : "+r"(reg[1])); + } + }; + + struct OperandB { + + unsigned reg[2]; + + template + __device__ inline void load(const void *smem, int k, int warp_col, const WarpRegisterMapping &wrm) + { + const unsigned *B = reinterpret_cast(smem); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = (warp_col * MMA_N + wrm.quad_hilo * 4) / 2; + int thread_offset_b = idx_strided * (ldb / 2) + idx_contiguous; + reg[0] = B[thread_offset_b + 0]; + reg[1] = B[thread_offset_b + 1]; + } + + template + __device__ inline void load(const SmemObj &smem_obj, int k, int warp_col, const WarpRegisterMapping &wrm) + { + const unsigned *B = reinterpret_cast(smem_obj.ptr); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = (warp_col * MMA_N + wrm.quad_hilo * 4) / 2; + const int thread_offset_b = idx_strided * (SmemObj::ldn / 2) + idx_contiguous; + reg[0] = B[thread_offset_b]; + reg[1] = B[thread_offset_b + 1]; + } + }; + + template struct Structure { + real v[length]; + __device__ inline const real &operator[](int i) const { return v[i]; } + __device__ inline real &operator[](int i) { return v[i]; } + }; + + struct OperandC { + + using reg_type = float; + reg_type reg[4]; + + __device__ inline OperandC() { zero(); } + + __device__ inline void zero() + { +#pragma unroll + for (int i = 0; i < 4; i++) { reg[i] = 0; } + } + + __device__ inline void ax(float alpha) + { +#pragma unroll + for (int i = 0; i < 4; i++) { reg[i] *= alpha; } + } + + __device__ inline void axpy(float alpha, OperandC x) + { +#pragma unroll + for (int i = 0; i < 4; i++) { reg[i] += alpha * x.reg[i]; } + } + + template + __device__ inline void store(void *smem, int warp_row, int warp_col, const WarpRegisterMapping &wrm) + { + half2 *C = reinterpret_cast(smem); + + const int idx_strided = warp_row * 16 + wrm.row_offset + (wrm.quad_thread % 2); + const int idx_contiguous = warp_col * 4 + wrm.quad_col * 2 + (wrm.quad_thread / 2); + + int thread_offset_c = idx_strided * (ldc / 2) + idx_contiguous; + C[thread_offset_c] = __floats2half2_rn(reg[0], reg[1]); + + thread_offset_c = (idx_strided + 2) * (ldc / 2) + idx_contiguous; + C[thread_offset_c] = __floats2half2_rn(reg[2], reg[3]); + } + + template __device__ inline void abs_max(F &max) + { +#pragma unroll + for (int i = 0; i < 4; i++) { max = fmax(max, fabsf(reg[i])); } + } + }; + + static __device__ inline void mma(const OperandA &op_a, const OperandB &op_b, OperandC &op_c) + { + float op_c_tmp[8]; +#pragma unroll + for (int i = 0; i < 8; i++) { op_c_tmp[i] = 0; } + mma::mma_instruction_t mma_instruction; + mma_instruction(op_c_tmp, op_a.reg, op_b.reg); + float other_op_c_tmp[8]; +#pragma unroll + for (int x = 0; x < 8; x++) { + other_op_c_tmp[x] = __shfl_xor_sync(0xffffffff, op_c_tmp[x], 0x8); + } + int lane_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) % 32; + if (lane_id / 8 % 2 == 0) { +#pragma unroll + for (int x = 0; x < 4; x++) { + op_c.reg[x] += op_c_tmp[x] + other_op_c_tmp[x]; + } + } else { +#pragma unroll + for (int x = 0; x < 4; x++) { + op_c.reg[x] += op_c_tmp[x + 4] + other_op_c_tmp[x + 4]; + } + } + } + + template + static inline __device__ void store_complex(int warp_row, int warp_col, const WarpRegisterMapping &wrm, + GmemOperandC &cc, const OperandC &op_c_real, + const OperandC &op_c_imag, op_t op) + { + using store_t = typename GmemOperandC::store_type; + + const int row = warp_row + wrm.row_offset + (wrm.quad_thread % 2); + const int col = warp_col + wrm.quad_col * 4 + (wrm.quad_thread / 2) * 2; + + constexpr bool fixed = GmemOperandC::fixed; + constexpr bool check_bounds = !((M % MMA_M == 0) && (N % MMA_N == 0)); + +#pragma unroll + for (int i = 0; i < 2; i++) { + int m_index = row + (i % 2) * 2; + int n_index = col; + + if constexpr (dagger) { + using complex_t = complex; + auto ptr = reinterpret_cast(cc.data()); + complex_t s; + if constexpr (fixed) { + auto scale = cc.get_scale(); + s = {f2i_round(op_c_real.reg[i * 2 + 0] * scale), + f2i_round(-op_c_imag.reg[i * 2 + 0] * scale)}; + if (!check_bounds || (m_index < M && (n_index + 0) < N)) { op(&ptr[(n_index + 0) * ldc + m_index], s); } + // op(&ptr[(n_index + 0) * ldc + m_index], s); + s = {f2i_round(op_c_real.reg[i * 2 + 1] * scale), + f2i_round(-op_c_imag.reg[i * 2 + 1] * scale)}; + if (!check_bounds || (m_index < M && (n_index + 1) < N)) { op(&ptr[(n_index + 1) * ldc + m_index], s); } + // op(&ptr[(n_index + 1) * ldc + m_index], s); + } else { + s = {op_c_real.reg[i * 2 + 0], -op_c_imag.reg[i * 2 + 0]}; + if (!check_bounds || (m_index < M && (n_index + 0) < N)) { op(&ptr[(n_index + 0) * ldc + m_index], s); } + // op(&ptr[(n_index + 0) * ldc + m_index], s); + s = {op_c_real.reg[i * 2 + 1], -op_c_imag.reg[i * 2 + 1]}; + if (!check_bounds || (m_index < M && (n_index + 1) < N)) { op(&ptr[(n_index + 1) * ldc + m_index], s); } + // op(&ptr[(n_index + 1) * ldc + m_index], s); + } + } else { + using array_t = typename VectorType::type; // array; + array_t *ptr = reinterpret_cast(cc.data()); + array_t s; + if constexpr (fixed) { + auto scale = cc.get_scale(); + s.x = f2i_round(op_c_real.reg[i * 2 + 0] * scale); + s.y = f2i_round(op_c_imag.reg[i * 2 + 0] * scale); + s.z = f2i_round(op_c_real.reg[i * 2 + 1] * scale); + s.w = f2i_round(op_c_imag.reg[i * 2 + 1] * scale); + } else { + s.x = op_c_real.reg[i * 2 + 0]; + s.y = op_c_imag.reg[i * 2 + 0]; + s.z = op_c_real.reg[i * 2 + 1]; + s.w = op_c_imag.reg[i * 2 + 1]; + } + if (!check_bounds || (m_index < M && n_index < N)) { op(&ptr[(m_index * ldc + n_index) / 2], s); } + } + } + } + }; + + } // namespace hmma +} // namespace quda diff --git a/include/targets/cuda/mma_tensor_op/smma_m16n8k8_sm70.cuh b/include/targets/cuda/mma_tensor_op/smma_m16n8k8_sm70.cuh new file mode 100644 index 0000000000..f59354649d --- /dev/null +++ b/include/targets/cuda/mma_tensor_op/smma_m16n8k8_sm70.cuh @@ -0,0 +1,150 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace quda +{ + namespace smma + { + + template struct smma_x_t { + }; + + template <> struct smma_x_t { + + static constexpr bool use_intermediate_accumulator() { return true; }; + + static __device__ __host__ constexpr int inline pad_size(int) { return 0; } + + static constexpr bool do_rescale() + { + return true; // false because we use FP16 + } + + static constexpr int MMA_M = 16; + static constexpr int MMA_N = 8; + static constexpr int MMA_K = 8; + + static constexpr int warp_size = 32; + + using compute_t = float; + using load_t = float; + + using base_t = hmma::hmma_x_t<16, 8, 8, half, half2>; + + using WarpRegisterMapping = typename base_t::WarpRegisterMapping; + + static std::string get_type_name() + { + char s[TuneKey::aux_n] = ",3xfp16,m"; + i32toa(s + strlen(s), MMA_M); + strcat(s, "n"); + i32toa(s + strlen(s), MMA_N); + strcat(s, "k"); + i32toa(s + strlen(s), MMA_K); + return s; + } + + struct OperandA { + + unsigned big[2]; + unsigned small[2]; + + template + __device__ inline void load(const SmemObj &smem_obj, int k, int warp_row, const WarpRegisterMapping &wrm) + { + const float *A = reinterpret_cast(smem_obj.ptr); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = warp_row * MMA_M + wrm.row_offset; + const int thread_offset_a = idx_strided * SmemObj::ldn + idx_contiguous; + +#pragma unroll + for (int v = 0; v < 2; v++) { + float f[2]; + f[0] = A[thread_offset_a + 2 * v + 0]; + f[1] = A[thread_offset_a + 2 * v + 1]; + Shuffle s; + s(big[v], small[v], f); + } + } + + __device__ inline void negate() + { +#pragma unroll + for (int v = 0; v < 2; v++) { + asm volatile("neg.f16x2 %0, %0;" : "+r"(big[v])); + asm volatile("neg.f16x2 %0, %0;" : "+r"(small[v])); + } + } + }; + + struct OperandB { + + unsigned big[2]; + unsigned small[2]; + + template + __device__ inline void load(const SmemObj &smem_obj, int k, int warp_col, const WarpRegisterMapping &wrm) + { + const float *B = reinterpret_cast(smem_obj.ptr); + int idx_strided = k * MMA_K + wrm.quad_thread + wrm.quad_col * 4; + int idx_contiguous = warp_col * MMA_N + wrm.quad_hilo * 4; + const int thread_offset_b = idx_strided * SmemObj::ldn + idx_contiguous; + +#pragma unroll + for (int v = 0; v < 2; v++) { + float f[2]; + f[0] = B[thread_offset_b + 2 * v + 0]; + f[1] = B[thread_offset_b + 2 * v + 1]; + Shuffle s; + s(big[v], small[v], f); + } + } + }; + + using OperandC = typename base_t::OperandC; + + static __device__ void mma(const OperandA &op_a, const OperandB &op_b, OperandC &op_c) + { + mma::mma_instruction_t mma_instruction; + float acc[8]; +#pragma unroll + for (int c = 0; c < 8; c++) { acc[c] = 0; } + + mma_instruction(acc, op_a.big, op_b.big); + mma_instruction(acc, op_a.big, op_b.small); + mma_instruction(acc, op_a.small, op_b.big); + float other_acc[8]; +#pragma unroll + for (int x = 0; x < 8; x++) { + other_acc[x] = __shfl_xor_sync(0xffffffff, acc[x], 0x8); + } + int lane_id = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) % 32; + if (lane_id / 8 % 2 == 0) { +#pragma unroll + for (int x = 0; x < 4; x++) { + op_c.reg[x] += acc[x] + other_acc[x]; + } + } else { +#pragma unroll + for (int x = 0; x < 4; x++) { + op_c.reg[x] += acc[x + 4] + other_acc[x + 4]; + } + } + } + + template + static inline __device__ void store_complex(int warp_row, int warp_col, const WarpRegisterMapping &wrm, + GmemOperandC &cc, const OperandC &op_c_real, + const OperandC &op_c_imag, op_t op) + { + base_t::template store_complex(warp_row, warp_col, wrm, cc, op_c_real, op_c_imag, op); + } + }; + + } // namespace smma +} // namespace quda diff --git a/lib/dslash_coarse_mma.in.hpp b/lib/dslash_coarse_mma.in.hpp index 56668e4f1b..564e7d2f9a 100644 --- a/lib/dslash_coarse_mma.in.hpp +++ b/lib/dslash_coarse_mma.in.hpp @@ -15,6 +15,7 @@ #include #include +#include namespace quda { @@ -175,11 +176,12 @@ namespace quda // using mma_t = smma::smma_t; // 3xBF16 // using mma_t = smma::smma_t; // 3xTF32 - // using mma_t = simt::simt_t; // SIMT + using mma_t = simt::simt_t; // SIMT + // using mma_t = smma::smma_x_t; // 1xFP16 - m16n8k8 variant for sm70 // using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; // 1xTF32 // using mma_t = mma::smma_half_t; // 3xFP16 // using mma_t = mma::hmma_t; // 1xFP16 - using mma_t = typename mma::smma_dispatch::type; + // using mma_t = typename mma::smma_dispatch::type; static constexpr int n_atom_size = mma_t::MMA_N; static constexpr int m_atom_size = mma_t::MMA_M; static constexpr int k_atom_size = Ns * Nc / 2; diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 6d82660257..c9ad874f04 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -4,6 +4,7 @@ #include #include #include +#include namespace quda { @@ -97,9 +98,11 @@ namespace quda apply(device::get_default_stream()); } + // using mma_t = typename mma::smma_dispatch::type; // using mma_t = simt::simt_t; - // using mma_t = smma::smma_t; // 3xTF32 - using mma_t = typename mma::smma_dispatch::type; + // using mma_t = smma::smma_x_t; + using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; + // using mma_t = hmma::hmma_t<16, 16, 4, mma::half, mma::half2>; static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 405b71b660..6047753488 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -5,6 +5,7 @@ #include #include #include +#include namespace quda { @@ -99,8 +100,10 @@ namespace quda apply(device::get_default_stream()); } - using mma_t = typename mma::smma_dispatch::type; - // using mma_t = simt::simt_t; + // using mma_t = typename mma::smma_dispatch::type; + using mma_t = simt::simt_t; + // using mma_t = smma::smma_x_t; + // using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); From db6c03016dc13e7437295ebbf84b6c03b1edee3a Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 10 Sep 2024 22:19:15 -0700 Subject: [PATCH 13/79] Add rescale for restrictor. --- include/kernels/restrictor_mma.cuh | 212 +++++++++++++----- .../cuda/mma_tensor_op/gmem_loader.cuh | 10 + lib/restrictor_mma.in.cu | 4 +- 3 files changed, 170 insertions(+), 56 deletions(-) diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index c05222fa30..1d7f8979cc 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -80,70 +80,164 @@ namespace quda } }; - template - inline void __device__ load_g2s(smem_obj_t &smem_real, smem_obj_t &smem_imag, const gmem_obj_t &gmem, int x_coarse, - int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, - int *coarse_to_fine, const Arg &arg) - { // v as a + template + inline float __device__ load_g2s(smem_obj_t &smem_real, smem_obj_t &smem_imag, const gmem_obj_t &gmem, int x_coarse, + int coarse_spin, int contiguous_dim_offset, int aggregate_k_offset, + int *coarse_to_fine, const Arg &arg) + { constexpr int elements_per_thread = 16 / (sizeof(typename gmem_obj_t::store_type) * 2); static_assert(contiguous_dim % elements_per_thread == 0, "contiguous_dim %% elements_per_thread == 0"); + float block_rescale_factor = 1.0f; + + if constexpr (rescale) { + float thread_max = 0; + int thread = target::thread_idx().y + Arg::block_y * target::thread_idx().z; + while (thread < (contiguous_dim / elements_per_thread) * Arg::spin_block_factor * Arg::fineColor + * Arg::aggregate_per_block) { + int thread_idx = thread; + int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; + constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); + if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { + thread_idx /= (contiguous_dim / elements_per_thread); + int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin + thread_idx /= Arg::spin_block_factor; + int fine_color = thread_idx % Arg::fineColor; + thread_idx /= Arg::fineColor; + int x_fine_offset = thread_idx + aggregate_k_offset; + + const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; + const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; + const int parity = arg.nParity == 2 ? parity_offset : arg.parity; + + // look-up map is ordered as (coarse-block-id + fine-point-id), + // with fine-point-id parity ordered + const int x_fine = coarse_to_fine[parity * arg.aggregate_size_cb + x_fine_cb_offset]; + const int x_fine_cb = x_fine - parity * arg.in.VolumeCB(); + + const int v_parity = (gmem.Nparity() == 2) ? parity : 0; + + int fine_spin = fine_spin_block + coarse_spin * Arg::spin_block_factor; + auto a_gmem = gmem(v_parity, x_fine_cb, fine_spin, fine_color, contiguous + contiguous_dim_offset); + complex a[elements_per_thread]; + mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); + + if constexpr (decltype(a_gmem)::fixed) { + auto scale_inv = a_gmem.get_scale_inv(); +#pragma unroll + for (int e = 0; e < elements_per_thread; e++) { + thread_max = mma::abs_max(a[e].real() * scale_inv, thread_max); + thread_max = mma::abs_max(a[e].imag() * scale_inv, thread_max); + } + } else { +#pragma unroll + for (int e = 0; e < elements_per_thread; e++) { + thread_max = mma::abs_max(a[e].real(), thread_max); + thread_max = mma::abs_max(a[e].imag(), thread_max); + } + } + } + + thread += Arg::block_y * Arg::block_z; + } + + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); + + block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + } + int thread = target::thread_idx().y + Arg::block_y * target::thread_idx().z; while (thread < (contiguous_dim / elements_per_thread) * Arg::spin_block_factor * Arg::fineColor * Arg::aggregate_per_block) { int thread_idx = thread; int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; - constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); - if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { - thread_idx /= (contiguous_dim / elements_per_thread); - int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin - thread_idx /= Arg::spin_block_factor; - int fine_color = thread_idx % Arg::fineColor; - thread_idx /= Arg::fineColor; - int x_fine_offset = thread_idx + aggregate_k_offset; - - const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; - const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; - const int parity = arg.nParity == 2 ? parity_offset : arg.parity; - - // look-up map is ordered as (coarse-block-id + fine-point-id), - // with fine-point-id parity ordered - const int x_fine = coarse_to_fine[parity * arg.aggregate_size_cb + x_fine_cb_offset]; - const int x_fine_cb = x_fine - parity * arg.in.VolumeCB(); - - const int v_parity = (gmem.Nparity() == 2) ? parity : 0; - - int fine_spin = fine_spin_block + coarse_spin * Arg::spin_block_factor; - auto a_gmem = gmem(v_parity, x_fine_cb, fine_spin, fine_color, contiguous + contiguous_dim_offset); - complex a[elements_per_thread]; - mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); - - int smem_m = contiguous; - int smem_k = (thread_idx * Arg::spin_block_factor + fine_spin_block) * Arg::fineColor + fine_color; - - typename Arg::real a_real[elements_per_thread]; - typename Arg::real a_imag[elements_per_thread]; - if constexpr (decltype(a_gmem)::fixed) { - auto scale_inv = a_gmem.get_scale_inv(); + constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); + if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { + thread_idx /= (contiguous_dim / elements_per_thread); + int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin + thread_idx /= Arg::spin_block_factor; + int fine_color = thread_idx % Arg::fineColor; + thread_idx /= Arg::fineColor; + int x_fine_offset = thread_idx + aggregate_k_offset; + + const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; + const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; + const int parity = arg.nParity == 2 ? parity_offset : arg.parity; + + // look-up map is ordered as (coarse-block-id + fine-point-id), + // with fine-point-id parity ordered + const int x_fine = coarse_to_fine[parity * arg.aggregate_size_cb + x_fine_cb_offset]; + const int x_fine_cb = x_fine - parity * arg.in.VolumeCB(); + + const int v_parity = (gmem.Nparity() == 2) ? parity : 0; + + int fine_spin = fine_spin_block + coarse_spin * Arg::spin_block_factor; + auto a_gmem = gmem(v_parity, x_fine_cb, fine_spin, fine_color, contiguous + contiguous_dim_offset); + complex a[elements_per_thread]; + mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); + + int smem_m = contiguous; + int smem_k = (thread_idx * Arg::spin_block_factor + fine_spin_block) * Arg::fineColor + fine_color; + + typename Arg::real a_real[elements_per_thread]; + typename Arg::real a_imag[elements_per_thread]; + if constexpr (decltype(a_gmem)::fixed) { + auto scale_inv = a_gmem.get_scale_inv() * block_rescale_factor; +#pragma unroll + for (int e = 0; e < elements_per_thread; e++) { + a_real[e] = +a[e].real() * scale_inv; + a_imag[e] = dagger ? -a[e].imag() * scale_inv : +a[e].imag() * scale_inv; + } + } else { #pragma unroll - for (int e = 0; e < elements_per_thread; e++) { - a_real[e] = +a[e].real() * scale_inv; - a_imag[e] = dagger ? -a[e].imag() * scale_inv : +a[e].imag() * scale_inv; + for (int e = 0; e < elements_per_thread; e++) { + a_real[e] = +a[e].real() * block_rescale_factor; + a_imag[e] = (dagger ? -a[e].imag() : +a[e].imag()) * block_rescale_factor; + } } - } else { + + static_assert(smem_obj_t::ldm == 1, "smem_obj_t::ldm == 1"); + if constexpr (std::is_same_v) { + static_assert(elements_per_thread % 2 == 0, "elements_per_thread %% 2 == 0"); + typename Arg::mma_t::load_t h2_real[elements_per_thread / 2]; + typename Arg::mma_t::load_t h2_imag[elements_per_thread / 2]; +#pragma unroll + for (int b = 0; b < elements_per_thread / 2; b++) { + h2_real[b] = __floats2half2_rn(a_real[2 * b + 0], a_real[2 * b + 1]); + h2_imag[b] = __floats2half2_rn(a_imag[2 * b + 0], a_imag[2 * b + 1]); + } + if constexpr (smem_obj_t::ldn % elements_per_thread == 0) { + smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(h2_real)); + smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(h2_imag)); + } else { #pragma unroll - for (int e = 0; e < elements_per_thread; e++) { - a_real[e] = +a[e].real(); - a_imag[e] = dagger ? -a[e].imag() : +a[e].imag(); + for (int b = 0; b < elements_per_thread / 2; b++) { + smem_real.vector_load(smem_m + b * 2, smem_k, h2_real[b]); + smem_imag.vector_load(smem_m + b * 2, smem_k, h2_imag[b]); + } + } + } else { + smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_real)); + smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_imag)); } } - static_assert(smem_obj_t::ldm == 1, "smem_obj_t::ldm == 1"); - smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_real)); - smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_imag)); - } - thread += Arg::block_y * Arg::block_z; } + + return 1.0f / block_rescale_factor; } template @@ -185,20 +279,30 @@ namespace quda accumulator.zero(); + constexpr bool rescale = mma_t::do_rescale(); + for (int aggregate_k_offset = 0; aggregate_k_offset < Arg::aggregate_size; aggregate_k_offset += Arg::aggregate_per_block) { __syncthreads(); constexpr bool a_dagger = true; - load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, m_offset, - aggregate_k_offset, coarse_to_fine, arg); + float a_rescale + = load_g2s(smem_obj_a_real, smem_obj_a_imag, arg.in, x_coarse, coarse_spin, + m_offset, aggregate_k_offset, coarse_to_fine, arg); constexpr bool b_dagger = false; - load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, n_offset, - aggregate_k_offset, coarse_to_fine, arg); + float b_rescale + = load_g2s(smem_obj_b_real, smem_obj_b_imag, arg.v, x_coarse, coarse_spin, + n_offset, aggregate_k_offset, coarse_to_fine, arg); __syncthreads(); - accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + + if constexpr (rescale) { + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, + a_rescale * b_rescale); + } else { + accumulator.mma(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag); + } } const int parity_coarse = x_coarse >= arg.out.VolumeCB() ? 1 : 0; diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 1cb7acc7c1..f330d66cb2 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -130,6 +130,16 @@ namespace quda static auto __device__ get(half2 v[]) { return v[0]; } }; + template <> struct make_vector_t { + static auto __device__ get(half2 v[]) + { + uint2 ret_value; + ret_value.x = reinterpret_cast(v)[0]; + ret_value.y = reinterpret_cast(v)[1]; + return ret_value; + } + }; + /** @brief Load from global memory and store data in registers. */ diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 6047753488..6544a17e54 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -101,9 +101,9 @@ namespace quda } // using mma_t = typename mma::smma_dispatch::type; - using mma_t = simt::simt_t; + // using mma_t = simt::simt_t; // using mma_t = smma::smma_x_t; - // using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; + using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); From ce257ef0a54a29d045b7049ed6651189ebabaece Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Mon, 16 Sep 2024 10:24:10 -0700 Subject: [PATCH 14/79] Abstract the MMA expansions into a class. --- include/expand_list.hpp | 155 ++++++++++++++++++++++++++++++++++++ lib/prolongator_mma.in.cu | 160 ++++++++------------------------------ 2 files changed, 188 insertions(+), 127 deletions(-) create mode 100644 include/expand_list.hpp diff --git a/include/expand_list.hpp b/include/expand_list.hpp new file mode 100644 index 0000000000..7181fb6a74 --- /dev/null +++ b/include/expand_list.hpp @@ -0,0 +1,155 @@ +#include +#include + +namespace quda { + +template +class expand_aux_t { + + Callable &_callable; + + template + void span_w(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.w == W) { + constexpr IntFactorArray<(w + w_atom_size - 1) / w_atom_size> w_factors; + _callable.template launch_mma(tp, stream); + } else { + if constexpr (sizeof...(Ws) > 0) { + span_w(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.w"); + } + } + } + + template + void span_z(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.z == Z) { + constexpr IntFactorArray<(z + z_atom_size - 1) / z_atom_size> z_factors; + std::make_index_sequence().size()> w_indices; + span_w(tp, stream, w_indices); + } else { + if constexpr (sizeof...(Zs) > 0) { + span_z(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.z"); + } + } + } + + template + void span_y(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.y == Y) { + constexpr IntFactorArray<(y + y_atom_size - 1) / y_atom_size> y_factors; + std::make_index_sequence().size()> z_indices; + span_z(tp, stream, z_indices); + } else { + if constexpr (sizeof...(Ys) > 0) { + span_y(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.y"); + } + } + } + + template + void span_x(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) + { + if (tp.aux.x == X) { + constexpr IntFactorArray<(x + x_atom_size - 1) / x_atom_size> x_factors; + std::make_index_sequence().size()> y_indices; + span_y(tp, stream, y_indices); + } else { + if constexpr (sizeof...(Xs) > 0) { + span_x(tp, stream, std::index_sequence()); + } else { + errorQuda("Invalid tp.aux.x"); + } + } + } + + public: + + void expand(TuneParam &tp, const qudaStream_t &stream) + { + std::make_index_sequence().size()> x_indices; + span_x(tp, stream, x_indices); + } + + expand_aux_t(Callable &callable): _callable(callable) { } + + int get_x(const TuneParam &tp) const { + if (static_cast(tp.aux.x) >= IntFactorArray<(x + x_atom_size - 1) / x_atom_size>().size()) { + errorQuda("Invalid tp.aux.x = %d\n", tp.aux.x); + } + return x_atom_size * get_int_factor_array((x + x_atom_size - 1) / x_atom_size)[tp.aux.x]; + } + + int get_y(const TuneParam &tp) const { + if (static_cast(tp.aux.y) >= IntFactorArray<(y + y_atom_size - 1) / y_atom_size>().size()) { + errorQuda("Invalid tp.aux.y = %d\n", tp.aux.y); + } + return y_atom_size * get_int_factor_array((y + y_atom_size - 1) / y_atom_size)[tp.aux.y]; + } + + int get_z(const TuneParam &tp) const { + if (static_cast(tp.aux.z) >= IntFactorArray<(z + z_atom_size - 1) / z_atom_size>().size()) { + errorQuda("Invalid tp.aux.z = %d\n", tp.aux.z); + } + return z_atom_size * get_int_factor_array((z + z_atom_size - 1) / z_atom_size)[tp.aux.z]; + } + + int get_w(const TuneParam &tp) const { + if (static_cast(tp.aux.w) >= IntFactorArray<(w + w_atom_size - 1) / w_atom_size>().size()) { + errorQuda("Invalid tp.aux.w = %d\n", tp.aux.w); + } + return w_atom_size * get_int_factor_array((w + w_atom_size - 1) / w_atom_size)[tp.aux.w]; + } + + bool advance_aux(TuneParam ¶m) const + { + auto advancer = [&](int &i, int limit) -> bool { + if (i < limit) { + i++; + return _callable.set_mma_param(param); + } else { + return false; + } + }; + + if (advancer(param.aux.x, numFactors((x + x_atom_size - 1) / x_atom_size) - 1)) { + return true; + } else { + param.aux.x = 0; + if (advancer(param.aux.y, numFactors((y + y_atom_size - 1) / y_atom_size) - 1)) { + return true; + } else { + param.aux.y = 0; + if (advancer(param.aux.z, numFactors((z + z_atom_size - 1) / z_atom_size) - 1)) { + return true; + } else { + param.aux.z = 0; + if (advancer(param.aux.w, numFactors((w + w_atom_size - 1) / w_atom_size) - 1)) { + return true; + } else { + param.aux.w = 0; + return false; + } + } + } + } + } + + void init_aux(TuneParam ¶m) const { + param.aux.x = 0; + param.aux.y = 0; + param.aux.z = 0; + param.aux.w = 0; + } + +}; + +} diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index c9ad874f04..4440ab4c25 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include namespace quda @@ -20,60 +20,45 @@ namespace quda int parity; QudaFieldLocation location; + // using mma_t = typename mma::smma_dispatch::type; + // using mma_t = simt::simt_t; + // using mma_t = smma::smma_x_t; + using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; + // using mma_t = hmma::hmma_t<16, 16, 4, mma::half, mma::half2>; + + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + + static constexpr int m = nVec; + static constexpr int n = fineColor * spin_block_factor; + static constexpr int k = coarseColor; + + static constexpr int n_atom_size = mma_t::MMA_N; + static constexpr int m_atom_size = mma_t::MMA_M; + static constexpr int k_atom_size = mma_t::MMA_K; + static constexpr int block_atom_size = 32 / 8; + + using this_t = ProlongateLaunchMma; + expand_aux_t expand; + bool checkParam(const TuneParam ¶m) const { return true; } unsigned int sharedBytesPerThread() const { return 0; } bool advanceTuneParam(TuneParam ¶m) const { - auto advancer = [&](int &i, int limit) -> bool { - if (i < limit) { - i++; - return set_mma_param(param); - } else { - return false; - } - }; - - if (advancer(param.aux.x, numFactors((k + block_atom_size - 1) / block_atom_size) - 1)) { - return true; - } else { - param.aux.x = 0; - if (advancer(param.aux.y, numFactors((n + n_atom_size - 1) / n_atom_size) - 1)) { - return true; - } else { - param.aux.y = 0; - if (advancer(param.aux.z, numFactors((m + m_atom_size - 1) / m_atom_size) - 1)) { - return true; - } else { - param.aux.z = 0; - if (advancer(param.aux.w, numFactors((k + k_atom_size - 1) / k_atom_size) - 1)) { - return true; - } else { - param.aux.w = 0; - return false; - } - } - } - } + return expand.advance_aux(param); } void initTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } /** sets default values for when tuning is disabled */ void defaultTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } @@ -86,7 +71,8 @@ namespace quda V(V), fine_to_coarse(fine_to_coarse), parity(parity), - location(checkLocation(out, in, V)) + location(checkLocation(out, in, V)), + expand(*this) { strcat(vol, ","); strcat(vol, out.VolString().c_str()); @@ -98,23 +84,6 @@ namespace quda apply(device::get_default_stream()); } - // using mma_t = typename mma::smma_dispatch::type; - // using mma_t = simt::simt_t; - // using mma_t = smma::smma_x_t; - using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; - // using mma_t = hmma::hmma_t<16, 16, 4, mma::half, mma::half2>; - - static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); - - static constexpr int m = nVec; - static constexpr int n = fineColor * spin_block_factor; - static constexpr int k = coarseColor; - - static constexpr int n_atom_size = mma_t::MMA_N; - static constexpr int m_atom_size = mma_t::MMA_M; - static constexpr int k_atom_size = mma_t::MMA_K; - static constexpr int block_atom_size = 32 / 8; - long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * out.SiteSubset() * out.VolumeCB(); @@ -137,28 +106,29 @@ namespace quda static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; - tp.block.y = block_atom_size * get_int_factor_array((k + block_atom_size - 1) / block_atom_size)[tp.aux.x]; + tp.block.y = expand.get_x(tp); tp.block.z = 8; - int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; - int bM = m_atom_size * get_int_factor_array((m + m_atom_size - 1) / m_atom_size)[tp.aux.z]; + int bN = expand.get_y(tp); + int bM = expand.get_z(tp); tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin / spin_block_factor, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; - int bK = k_atom_size * get_int_factor_array(k / k_atom_size)[tp.aux.w]; + int bK = expand.get_w(tp); int shared_bytes = shared_bytes_per_block(bM, bN, bK); tp.shared_bytes = shared_bytes; return shared_bytes <= device::maximum_dynamic_shared_memory(); } - template + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { constexpr bool to_non_rel = false; + constexpr int block_z = 8; using Arg = ProlongateMmaArg; Arg arg(out, in, V, fine_to_coarse, parity); @@ -169,73 +139,9 @@ namespace quda } } - template - void launch_mma_span_k(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.w == d) { - constexpr IntFactorArray k_factors; - launch_mma(tp, stream); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_k(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_m(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.z == d) { - constexpr IntFactorArray<(m + m_atom_size - 1) / m_atom_size> m_factors; - std::make_index_sequence().size()> k_indices; - launch_mma_span_k(tp, stream, k_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_m(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_n(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.y == d) { - constexpr IntFactorArray<(n + n_atom_size - 1) / n_atom_size> n_factors; - std::make_index_sequence().size()> m_indices; - launch_mma_span_m(tp, stream, m_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_n(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.y."); - } - } - } - - template - void launch_mma_span_block(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.x == d) { - constexpr IntFactorArray<(k + block_atom_size - 1) / block_atom_size> block_factors; - std::make_index_sequence().size()> n_indices; - launch_mma_span_n(tp, stream, n_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_block(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.x."); - } - } - } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) { - std::make_index_sequence().size()> block_indices; - launch_mma_span_block(tp, stream, block_indices); + expand.expand(tp, stream); } void apply(const qudaStream_t &stream) From f38882a5e0cee8f0bb54ec2ca8ffd723e8868199 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Mon, 16 Sep 2024 13:42:46 -0700 Subject: [PATCH 15/79] Add more abstraction; make TF32 the default for SM80 and later. --- lib/dslash_coarse_mma.in.hpp | 150 ++++++++----------------------- lib/prolongator_mma.in.cu | 6 +- lib/restrictor_mma.in.cu | 165 ++++++++--------------------------- 3 files changed, 79 insertions(+), 242 deletions(-) diff --git a/lib/dslash_coarse_mma.in.hpp b/lib/dslash_coarse_mma.in.hpp index 564e7d2f9a..b0f8a43545 100644 --- a/lib/dslash_coarse_mma.in.hpp +++ b/lib/dslash_coarse_mma.in.hpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include @@ -40,6 +40,30 @@ namespace quda mutable int color_col_stride; mutable int dim_threads; + // using mma_t = smma::smma_t; // 3xBF16 + // using mma_t = smma::smma_t; // 3xTF32 + // using mma_t = smma::smma_x_t; // 1xFP16 - m16n8k8 variant for sm70 + // using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; // 1xTF32 + // using mma_t = mma::smma_half_t; // 3xFP16 + // using mma_t = mma::hmma_t; // 1xFP16 +#if (__COMPUTE_CAPABILITY__ >= 800) + using mma_t = typename mma::smma_dispatch::type; +#else + using mma_t = typename simt::simt_t; +#endif + static constexpr int n_atom_size = mma_t::MMA_N; + static constexpr int m_atom_size = mma_t::MMA_M; + static constexpr int k_atom_size = Ns * Nc / 2; + + static constexpr int n = nVec; + static constexpr int m = Ns * Nc; + static constexpr int k = Ns * Nc; + static constexpr int block_atom_size = Ns * Nc / (Nc > 64 ? 8 : 4); + static constexpr int block_limit = Ns * Nc / (Nc > 64 ? 2 : 1); + + using this_t = DslashCoarseMma; + expand_aux_t expand; + long long flops() const { return ((dslash * 2 * nDim + clover * 1) * (8 * Ns * Nc * Ns * Nc) - 2 * Ns * Nc) * nParity @@ -64,55 +88,19 @@ namespace quda bool advanceTuneParam(TuneParam ¶m) const { - - auto advancer = [&](int &i, int limit) -> bool { - if (i < limit) { - i++; - return set_mma_param(param); - } else { - return false; - } - }; - - if (advancer(param.aux.x, 2)) { - return true; - } else { - param.aux.x = 0; - if (advancer(param.aux.y, numFactors(out[0].Nvec() / n_atom_size) - 1)) { - return true; - } else { - param.aux.y = 0; - if (advancer(param.aux.z, numFactors((Ns * Nc) / m_atom_size) - 1)) { - return true; - } else { - param.aux.z = 0; - if (advancer(param.aux.w, numFactors((Ns * Nc) / k_atom_size) - 1)) { - return true; - } else { - param.aux.w = 0; - return false; - } - } - } - } + return expand.advance_aux(param); } void initTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } /** sets default values for when tuning is disabled */ void defaultTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } @@ -130,7 +118,8 @@ namespace quda parity(parity), nParity(out.SiteSubset()), halo(halo), - color_col_stride(-1) + color_col_stride(-1), + expand(*this) { strcpy(vol, out.VolString().c_str()); strcpy(aux, (std::string("policy_kernel,") + vol).c_str()); @@ -174,18 +163,6 @@ namespace quda apply(device::get_default_stream()); } - // using mma_t = smma::smma_t; // 3xBF16 - // using mma_t = smma::smma_t; // 3xTF32 - using mma_t = simt::simt_t; // SIMT - // using mma_t = smma::smma_x_t; // 1xFP16 - m16n8k8 variant for sm70 - // using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; // 1xTF32 - // using mma_t = mma::smma_half_t; // 3xFP16 - // using mma_t = mma::hmma_t; // 1xFP16 - // using mma_t = typename mma::smma_dispatch::type; - static constexpr int n_atom_size = mma_t::MMA_N; - static constexpr int m_atom_size = mma_t::MMA_M; - static constexpr int k_atom_size = Ns * Nc / 2; - static constexpr int shared_bytes_per_block(int bM, int bN, int bK) { return mma::shared_memory_bytes(bM, bN, bK) + (bM + 4) * (bK + 4) * 2 * sizeof(yFloat) @@ -195,22 +172,22 @@ namespace quda bool set_mma_param(TuneParam &tp) const { tp.block.x = 1; - tp.block.y = Ns * Nc / ((Nc > 64 ? 2 : 1) << tp.aux.x); + tp.block.y = expand.get_x(tp); tp.block.z = 8; if (out[0].Nvec() % n_atom_size != 0) { errorQuda("out[0].Nvec() %% n_atom_size != 0"); } - int bN = n_atom_size * get_int_factor_array(out[0].Nvec() / n_atom_size)[tp.aux.y]; + int bN = expand.get_y(tp); if (out[0].Nvec() % bN != 0) { errorQuda("Invalid bN."); } if ((Ns * Nc) % m_atom_size != 0) { errorQuda("(Ns * Nc) %% m_atom_size != 0"); } - int bM = m_atom_size * get_int_factor_array((Ns * Nc) / m_atom_size)[tp.aux.z]; + int bM = expand.get_z(tp); if ((Ns * Nc) % bM != 0) { errorQuda("Invalid bM"); } tp.grid = dim3(out.SiteSubset() * out.VolumeCB(), (Ns * Nc) / bM, out[0].Nvec() / bN); tp.set_max_shared_bytes = true; if ((Ns * Nc) % k_atom_size != 0) { errorQuda("(Ns * Nc) %% k_atom_size != 0"); } - int bK = k_atom_size * get_int_factor_array((Ns * Nc) / k_atom_size)[tp.aux.w]; + int bK = expand.get_w(tp); if ((Ns * Nc) % bK != 0) { errorQuda("Invalid bK"); } int shared_bytes = shared_bytes_per_block(bM, bN, bK); tp.shared_bytes = shared_bytes; @@ -218,11 +195,12 @@ namespace quda return shared_bytes <= device::maximum_dynamic_shared_memory(); } - template + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { + constexpr int block_z = 8; using Arg = DslashCoarseMmaArg; Arg arg(out[0], inA[0], inB[0], Y, X, (Float)kappa, parity, halo); @@ -233,63 +211,9 @@ namespace quda } } - template - void launch_mma_span_k(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.w == d) { - constexpr IntFactorArray<(Ns * Nc) / k_atom_size> a; - launch_mma(tp, stream); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_k(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_m(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.z == d) { - constexpr IntFactorArray<(Ns * Nc) / m_atom_size> a; - std::make_index_sequence().size()> xt; - launch_mma_span_k(tp, stream, xt); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_m(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_n(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.y == d) { - constexpr IntFactorArray a; - std::make_index_sequence().size()> xt; - launch_mma_span_m(tp, stream, xt); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_n(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.y."); - } - } - } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) { - std::make_index_sequence().size()> xt; - - switch (tp.aux.x) { - case 0: launch_mma_span_n 64 ? 2 : 1), 8>(tp, stream, xt); break; - case 1: launch_mma_span_n 64 ? 4 : 2), 8>(tp, stream, xt); break; - case 2: launch_mma_span_n 64 ? 8 : 4), 8>(tp, stream, xt); break; - default: errorQuda("tp.aux.x = %d not supported", tp.aux.x); - } + expand.expand(tp, stream); } void apply(const qudaStream_t &stream) diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 4440ab4c25..9bf770b9c8 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -23,8 +23,12 @@ namespace quda // using mma_t = typename mma::smma_dispatch::type; // using mma_t = simt::simt_t; // using mma_t = smma::smma_x_t; - using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; // using mma_t = hmma::hmma_t<16, 16, 4, mma::half, mma::half2>; +#if (__COMPUTE_CAPABILITY__ >= 800) + using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; +#else + using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; +#endif static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 6544a17e54..d84346a7ec 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include namespace quda @@ -21,60 +21,49 @@ namespace quda const int *coarse_to_fine; const int parity; + // using mma_t = typename mma::smma_dispatch::type; + // using mma_t = simt::simt_t; + // using mma_t = smma::smma_x_t; +#if (__COMPUTE_CAPABILITY__ >= 800) + using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; +#else + using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; +#endif + + static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + + static constexpr int m = nVec; + static constexpr int n = coarseColor; + static constexpr int k = fineColor * spin_block_factor * aggregate_size; + + static constexpr int n_atom_size = mma_t::MMA_N; + static constexpr int m_atom_size = mma_t::MMA_M; + static constexpr int k_atom_size = fineColor * spin_block_factor * mma_t::MMA_K; + static constexpr int block_atom_size = 32 / 8; + static constexpr int block_limit = 32; + + using this_t = RestrictMmaLaunch; + expand_aux_t expand; + bool checkParam(const TuneParam ¶m) const { return true; } unsigned int sharedBytesPerThread() const { return 0; } bool advanceTuneParam(TuneParam ¶m) const { - auto advancer = [&](int &i, int limit) -> bool { - if (i < limit) { - i++; - return set_mma_param(param); - } else { - return false; - } - }; - - if (advancer(param.aux.x, numFactors((block_limit + block_atom_size - 1) / block_atom_size) - 1)) { - return true; - } else { - param.aux.x = 0; - if (advancer(param.aux.y, numFactors((n + n_atom_size - 1) / n_atom_size) - 1)) { - return true; - } else { - param.aux.y = 0; - if (advancer(param.aux.z, numFactors((m + m_atom_size - 1) / m_atom_size) - 1)) { - return true; - } else { - param.aux.z = 0; - if (advancer(param.aux.w, numFactors((k + k_atom_size - 1) / k_atom_size) - 1)) { - return true; - } else { - param.aux.w = 0; - return false; - } - } - } - } + return expand.advance_aux(param); } void initTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } /** sets default values for when tuning is disabled */ void defaultTuneParam(TuneParam ¶m) const { - param.aux.x = 0; - param.aux.y = 0; - param.aux.z = 0; - param.aux.w = 0; + expand.init_aux(param); set_mma_param(param); } @@ -87,7 +76,8 @@ namespace quda v(v), fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine), - parity(parity) + parity(parity), + expand(*this) { strcat(vol, ","); strcat(vol, out.VolString().c_str()); @@ -100,23 +90,6 @@ namespace quda apply(device::get_default_stream()); } - // using mma_t = typename mma::smma_dispatch::type; - // using mma_t = simt::simt_t; - // using mma_t = smma::smma_x_t; - using mma_t = hmma::hmma_x_t<16, 8, 8, mma::half, mma::half2>; - - static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); - - static constexpr int m = nVec; - static constexpr int n = coarseColor; - static constexpr int k = fineColor * spin_block_factor * aggregate_size; - - static constexpr int n_atom_size = mma_t::MMA_N; - static constexpr int m_atom_size = mma_t::MMA_M; - static constexpr int k_atom_size = fineColor * spin_block_factor * mma_t::MMA_K; - static constexpr int block_atom_size = 32 / 8; - static constexpr int block_limit = 32; - long long flops() const { return nVec * 8 * fineSpin * fineColor * coarseColor * in.SiteSubset() * in.VolumeCB(); } long long bytes() const @@ -136,28 +109,28 @@ namespace quda static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; - tp.block.y - = block_atom_size * get_int_factor_array((block_limit + block_atom_size - 1) / block_atom_size)[tp.aux.x]; + tp.block.y = expand.get_x(tp); tp.block.z = 8; - int bN = n_atom_size * get_int_factor_array((n + n_atom_size - 1) / n_atom_size)[tp.aux.y]; - int bM = m_atom_size * get_int_factor_array((m + m_atom_size - 1) / m_atom_size)[tp.aux.z]; + int bN = expand.get_y(tp); + int bM = expand.get_z(tp); tp.grid = dim3(out.Volume() * coarseSpin, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; - int bK = k_atom_size * get_int_factor_array(k / k_atom_size)[tp.aux.w]; + int bK = expand.get_w(tp); int shared_bytes = shared_bytes_per_block(bM, bN, bK); tp.shared_bytes = shared_bytes; return shared_bytes <= device::maximum_dynamic_shared_memory(); } - template + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { + constexpr int block_z = 8; using Arg = RestrictMmaArg; Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity); @@ -168,73 +141,9 @@ namespace quda } } - template - void launch_mma_span_k(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.w == d) { - constexpr IntFactorArray k_factors; - launch_mma(tp, stream); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_k(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_m(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.z == d) { - constexpr IntFactorArray<(m + m_atom_size - 1) / m_atom_size> m_factors; - std::make_index_sequence().size()> k_indices; - launch_mma_span_k(tp, stream, k_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_m(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.z."); - } - } - } - - template - void launch_mma_span_n(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.y == d) { - constexpr IntFactorArray<(n + n_atom_size - 1) / n_atom_size> n_factors; - std::make_index_sequence().size()> m_indices; - launch_mma_span_m(tp, stream, m_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_n(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.y."); - } - } - } - - template - void launch_mma_span_block(TuneParam &tp, const qudaStream_t &stream, std::index_sequence) - { - if (tp.aux.x == d) { - constexpr IntFactorArray<(block_limit + block_atom_size - 1) / block_atom_size> block_factors; - std::make_index_sequence().size()> n_indices; - launch_mma_span_n(tp, stream, n_indices); - } else { - if constexpr (sizeof...(Ds) > 0) { - launch_mma_span_block(tp, stream, std::index_sequence()); - } else { - errorQuda("Invalid tp.aux.x."); - } - } - } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) { - std::make_index_sequence().size()> block_indices; - launch_mma_span_block(tp, stream, block_indices); + expand.expand(tp, stream); } void apply(const qudaStream_t &stream) From 0a2bd2a626bcc3217d34bc40c1883d6f332bbb82 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 17 Sep 2024 12:24:11 -0700 Subject: [PATCH 16/79] Fix block transpose by having no bound checks on the kernel level; in the process add an additional template to `kernel_param`. --- include/kernel_helper.h | 4 +++- include/kernels/block_transpose.cuh | 2 +- include/targets/cuda/kernel.h | 17 +++++++++++------ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/include/kernel_helper.h b/include/kernel_helper.h index dcb33baba0..884ae6ff73 100644 --- a/include/kernel_helper.h +++ b/include/kernel_helper.h @@ -14,8 +14,10 @@ namespace quda enum class use_kernel_arg_p { FALSE, TRUE, ALWAYS }; - template struct kernel_param { + template + struct kernel_param { static constexpr use_kernel_arg_p use_kernel_arg = use_kernel_arg_; + static constexpr bool check_bounds = check_bounds_; dim3 threads; /** number of active threads required */ int comms_rank; /** per process value of comm_rank() */ int comms_rank_global; /** per process value comm_rank_global() */ diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index bad990c80f..d4325bdf3b 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -13,7 +13,7 @@ namespace quda */ template - struct BlockTransposeArg : kernel_param<> { + struct BlockTransposeArg : kernel_param { // no bound checks using real = typename mapper::type; static constexpr bool is_device = is_device_; static constexpr int nSpin = nSpin_; diff --git a/include/targets/cuda/kernel.h b/include/targets/cuda/kernel.h index 313457d1d4..b85e6bcc72 100644 --- a/include/targets/cuda/kernel.h +++ b/include/targets/cuda/kernel.h @@ -89,14 +89,19 @@ namespace quda auto i = threadIdx.x + blockIdx.x * blockDim.x; auto j = threadIdx.y + blockIdx.y * blockDim.y; - if (j >= arg.threads.y) return; - while (i < arg.threads.x) { + if constexpr (Arg::check_bounds) { + if (j >= arg.threads.y) return; + + while (i < arg.threads.x) { + f(i, j); + if (grid_stride) + i += gridDim.x * blockDim.x; + else + break; + } + } else { f(i, j); - if (grid_stride) - i += gridDim.x * blockDim.x; - else - break; } } From de05ebf7725fd1bc14019619cf8cfdf1135a4c27 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 17 Sep 2024 12:33:23 -0700 Subject: [PATCH 17/79] Add some aggregate sizes to MMA restrictor --- lib/restrictor_mma.in.cu | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index d84346a7ec..d80751db15 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -251,21 +251,16 @@ namespace quda const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) { int aggregate_size = in.Volume() / out.Volume(); - if (aggregate_size == 128) { - if constexpr (fineColor == 3 && coarseColor == 24) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else { - errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); - } - } else if (aggregate_size == 16) { - if constexpr (fineColor == 24 && coarseColor == 32) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else { - errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); - } - } else{ + if (aggregate_size == 16) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else if (aggregate_size == 128) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else if (aggregate_size == 512) { + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, + parity); + } else { errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); } } From bcdfb86e47540fefd0e58fbeaef1c7d530aa05bf Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 17 Sep 2024 14:11:41 -0700 Subject: [PATCH 18/79] Make `aggregate_size` a runtime variable. --- include/kernels/restrictor_mma.cuh | 35 ++++++++++---------- lib/restrictor_mma.in.cu | 52 +++++++++++------------------- 2 files changed, 37 insertions(+), 50 deletions(-) diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 1d7f8979cc..35bcd27923 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -13,7 +13,7 @@ namespace quda Kernel argument struct */ template + int coarseColor_, int nVec_, int bN_, int bM_, int bK_, int block_y_, int block_z_> struct RestrictMmaArg : kernel_param<> { static constexpr int block_dim = block_z_ * block_y_; @@ -30,7 +30,7 @@ namespace quda static constexpr int coarseSpin = coarseSpin_; static constexpr int coarseColor = coarseColor_; static constexpr int nVec = nVec_; - static constexpr int aggregate_size = aggregate_size_; + // static constexpr int aggregate_size = aggregate_size_; static constexpr int bN = bN_; static constexpr int bM = bM_; static constexpr int bK = bK_; @@ -48,11 +48,11 @@ namespace quda static_assert(bK % (fineColor * spin_block_factor) == 0, "K %% Arg::bK != 0.\n"); static constexpr int aggregate_per_block = bK / (fineColor * spin_block_factor); - static_assert(aggregate_size % aggregate_per_block == 0, "aggregate_size %% aggregate_per_block"); out_accessor_t out; in_accessor_t in; const v_accessor_t v; + const int aggregate_size; const int_fastdiv aggregate_size_cb; // number of checkerboard sites that form a single aggregate const int *fine_to_coarse; const int *coarse_to_fine; @@ -66,6 +66,7 @@ namespace quda out(out), in(in), v(v), + aggregate_size(in.Volume() / out.Volume()), aggregate_size_cb(in.VolumeCB() / out.Volume()), fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine), @@ -97,13 +98,14 @@ namespace quda int thread_idx = thread; int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); - if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { + bool b = !check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit; thread_idx /= (contiguous_dim / elements_per_thread); int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin thread_idx /= Arg::spin_block_factor; int fine_color = thread_idx % Arg::fineColor; thread_idx /= Arg::fineColor; int x_fine_offset = thread_idx + aggregate_k_offset; + if (x_fine_offset < arg.aggregate_size && b) { const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; @@ -164,13 +166,14 @@ namespace quda int thread_idx = thread; int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); - if (!check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit) { + bool b = !check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit; thread_idx /= (contiguous_dim / elements_per_thread); int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin thread_idx /= Arg::spin_block_factor; int fine_color = thread_idx % Arg::fineColor; thread_idx /= Arg::fineColor; int x_fine_offset = thread_idx + aggregate_k_offset; + if (x_fine_offset < arg.aggregate_size && b) { const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; const int x_fine_cb_offset = x_fine_offset % arg.aggregate_size_cb; @@ -229,6 +232,7 @@ namespace quda } } } else { + static_assert(smem_obj_t::ldn % elements_per_thread == 0); smem_real.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_real)); smem_imag.vector_load(smem_m, smem_k, mma::make_vector_t::get(a_imag)); } @@ -246,7 +250,7 @@ namespace quda constexpr int M = Arg::nVec; constexpr int N = Arg::coarseColor; - constexpr int K = Arg::fineColor * Arg::spin_block_factor * Arg::aggregate_size; + constexpr int K = 0; // K is dummy here since it is a runtime variable; constexpr int ldc = M; @@ -257,14 +261,6 @@ namespace quda static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); - __shared__ int coarse_to_fine[Arg::aggregate_size]; - int index = target::thread_idx().y + Arg::block_y * target::thread_idx().z; - while (index < Arg::aggregate_size) { - coarse_to_fine[index] = arg.coarse_to_fine[x_coarse * 2 * arg.aggregate_size_cb + index]; - index += Arg::block_y * Arg::block_z; - } - __syncthreads(); - extern __shared__ typename mma_t::compute_t smem_ptr[]; typename Config::SmemObjA smem_obj_a_real(smem_ptr); @@ -272,8 +268,13 @@ namespace quda typename Config::SmemObjB smem_obj_b_real(smem_obj_a_imag.ptr + Config::smem_lda * Arg::bK); typename Config::SmemObjB smem_obj_b_imag(smem_obj_b_real.ptr + Config::smem_ldb * Arg::bK); - typename Config::ALoader a_loader; - typename Config::BLoader b_loader; + int *coarse_to_fine = reinterpret_cast(smem_obj_b_imag.ptr + Config::smem_ldb * Arg::bK); + int index = target::thread_idx().y + Arg::block_y * target::thread_idx().z; + while (index < arg.aggregate_size) { + coarse_to_fine[index] = arg.coarse_to_fine[x_coarse * 2 * arg.aggregate_size_cb + index]; + index += Arg::block_y * Arg::block_z; + } + __syncthreads(); typename Config::Accumulator accumulator((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x); @@ -281,7 +282,7 @@ namespace quda constexpr bool rescale = mma_t::do_rescale(); - for (int aggregate_k_offset = 0; aggregate_k_offset < Arg::aggregate_size; + for (int aggregate_k_offset = 0; aggregate_k_offset < arg.aggregate_size; aggregate_k_offset += Arg::aggregate_per_block) { __syncthreads(); diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index d80751db15..2f51a8d9e5 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -11,7 +11,7 @@ namespace quda { template + int nVec> class RestrictMmaLaunch : public TunableKernel { ColorSpinorField &out; @@ -20,6 +20,7 @@ namespace quda const int *fine_to_coarse; const int *coarse_to_fine; const int parity; + int aggregate_size; // using mma_t = typename mma::smma_dispatch::type; // using mma_t = simt::simt_t; @@ -31,10 +32,11 @@ namespace quda #endif static constexpr int spin_block_factor = spin_mapper::get_spin_block_factor(); + static constexpr int aggregate_size_block_max = 16; static constexpr int m = nVec; static constexpr int n = coarseColor; - static constexpr int k = fineColor * spin_block_factor * aggregate_size; + static constexpr int k = fineColor * spin_block_factor * aggregate_size_block_max; static constexpr int n_atom_size = mma_t::MMA_N; static constexpr int m_atom_size = mma_t::MMA_M; @@ -42,7 +44,7 @@ namespace quda static constexpr int block_atom_size = 32 / 8; static constexpr int block_limit = 32; - using this_t = RestrictMmaLaunch; + using this_t = RestrictMmaLaunch; expand_aux_t expand; bool checkParam(const TuneParam ¶m) const { return true; } @@ -77,6 +79,7 @@ namespace quda fine_to_coarse(fine_to_coarse), coarse_to_fine(coarse_to_fine), parity(parity), + aggregate_size(in.Volume() / out.Volume()), expand(*this) { strcat(vol, ","); @@ -86,6 +89,8 @@ namespace quda setRHSstring(aux, in.Nvec()); strcat(aux, mma_t::get_type_name().c_str()); + strcat(aux, ",aggregate_size_block_max="); + i32toa(aux + strlen(aux), aggregate_size_block_max); apply(device::get_default_stream()); } @@ -98,9 +103,9 @@ namespace quda return nVec * (in.Bytes() + out.Bytes() + v_bytes + in.SiteSubset() * in.VolumeCB() * sizeof(int)); } - static constexpr int shared_bytes_per_block(int bM, int bN, int bK) + int shared_bytes_per_block(int bM, int bN, int bK) const { - return mma::shared_memory_bytes(bM, bN, bK); + return mma::shared_memory_bytes(bM, bN, bK) + aggregate_size * sizeof(int); } bool set_mma_param(TuneParam &tp) const @@ -128,11 +133,11 @@ namespace quda template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { - constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); + constexpr int shared_bytes = mma::shared_memory_bytes(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { constexpr int block_z = 8; using Arg = RestrictMmaArg; + bN, bM, bK, block_y, block_z>; Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity); tp.set_max_shared_bytes = true; launch_cuda(tp, stream, arg); @@ -153,7 +158,7 @@ namespace quda } }; - template + template void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) { @@ -168,38 +173,38 @@ namespace quda if (v.Precision() == QUDA_HALF_PRECISION) { if constexpr (is_enabled(QUDA_HALF_PRECISION)) { - RestrictMmaLaunch restrictor( + RestrictMmaLaunch restrictor( out, in, v, fine_to_coarse, coarse_to_fine, parity); } else { errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); } } else if (v.Precision() == in.Precision()) { - RestrictMmaLaunch restrictor( + RestrictMmaLaunch restrictor( out, in, v, fine_to_coarse, coarse_to_fine, parity); } else { errorQuda("Unsupported V precision %d", v.Precision()); } } - template + template void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) { if (!is_enabled_spin(in.Nspin())) errorQuda("nSpin %d has not been built", in.Nspin()); if (in.Nspin() == 2) { - RestrictMma(out, in, v, fine_to_coarse, + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); } else if constexpr (fineColor == 3) { if (in.Nspin() == 4) { if constexpr (is_enabled_spin(4)) { if (in.Precision() == out.Precision()) { - RestrictMma( + RestrictMma( out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); } else if (in.Precision() == QUDA_HALF_PRECISION) { #if 0 if constexpr (is_enabled(QUDA_HALF_PRECISION)) { - RestrictMma(out, in, v, + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); } else { @@ -246,25 +251,6 @@ namespace quda constexpr int nVec = @QUDA_MULTIGRID_MRHS@; // clang-format on - template - void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, - const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity) - { - int aggregate_size = in.Volume() / out.Volume(); - if (aggregate_size == 16) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else if (aggregate_size == 128) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else if (aggregate_size == 512) { - RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, - parity); - } else { - errorQuda("Unexpected aggregate_size = %d\n", aggregate_size); - } - } - template <> void RestrictMma(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v, const int *fine_to_coarse, From 189f78e2c528fda799eeb82a02cf3d24cf1e2de8 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 17 Sep 2024 14:14:46 -0700 Subject: [PATCH 19/79] Apply clang-format. --- include/kernels/prolongator_mma.cuh | 5 +++-- include/kernels/restrictor_mma.cuh | 28 +++++++++++++-------------- lib/prolongator_mma.in.cu | 20 +++++++------------ lib/restrictor_mma.in.cu | 30 +++++++++++------------------ 4 files changed, 35 insertions(+), 48 deletions(-) diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index a3f884298f..9a22cbaebf 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -98,7 +98,7 @@ namespace quda using mma_t = typename Arg::mma_t; using Config = mma::MmaConfig; - static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); + // static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); extern __shared__ typename mma_t::compute_t smem_ptr[]; @@ -129,7 +129,8 @@ namespace quda a_loader.template r2s(smem_obj_a_real, smem_obj_a_imag); b_loader.template r2s(smem_obj_b_real, smem_obj_b_imag); __syncthreads(); - accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, a_rescale * b_rescale); + accumulator.mma_rescale(smem_obj_a_real, smem_obj_a_imag, smem_obj_b_real, smem_obj_b_imag, + a_rescale * b_rescale); } } else { for (int k_offset = 0; k_offset < K; k_offset += Arg::bK) { diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 35bcd27923..1b1da7f1e6 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -12,8 +12,8 @@ namespace quda /** Kernel argument struct */ - template + template struct RestrictMmaArg : kernel_param<> { static constexpr int block_dim = block_z_ * block_y_; @@ -99,12 +99,12 @@ namespace quda int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); bool b = !check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit; - thread_idx /= (contiguous_dim / elements_per_thread); - int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin - thread_idx /= Arg::spin_block_factor; - int fine_color = thread_idx % Arg::fineColor; - thread_idx /= Arg::fineColor; - int x_fine_offset = thread_idx + aggregate_k_offset; + thread_idx /= (contiguous_dim / elements_per_thread); + int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin + thread_idx /= Arg::spin_block_factor; + int fine_color = thread_idx % Arg::fineColor; + thread_idx /= Arg::fineColor; + int x_fine_offset = thread_idx + aggregate_k_offset; if (x_fine_offset < arg.aggregate_size && b) { const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; @@ -167,12 +167,12 @@ namespace quda int contiguous = thread_idx % (contiguous_dim / elements_per_thread) * elements_per_thread; constexpr bool check_contiguous_bound = !(contiguous_limit % contiguous_dim == 0); bool b = !check_contiguous_bound || contiguous + contiguous_dim_offset < contiguous_limit; - thread_idx /= (contiguous_dim / elements_per_thread); - int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin - thread_idx /= Arg::spin_block_factor; - int fine_color = thread_idx % Arg::fineColor; - thread_idx /= Arg::fineColor; - int x_fine_offset = thread_idx + aggregate_k_offset; + thread_idx /= (contiguous_dim / elements_per_thread); + int fine_spin_block = thread_idx % Arg::spin_block_factor; // fineSpin / coarseSpin + thread_idx /= Arg::spin_block_factor; + int fine_color = thread_idx % Arg::fineColor; + thread_idx /= Arg::fineColor; + int x_fine_offset = thread_idx + aggregate_k_offset; if (x_fine_offset < arg.aggregate_size && b) { const int parity_offset = x_fine_offset >= arg.aggregate_size_cb ? 1 : 0; diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index 9bf770b9c8..d21c5c3995 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -48,10 +48,7 @@ namespace quda unsigned int sharedBytesPerThread() const { return 0; } - bool advanceTuneParam(TuneParam ¶m) const - { - return expand.advance_aux(param); - } + bool advanceTuneParam(TuneParam ¶m) const { return expand.advance_aux(param); } void initTuneParam(TuneParam ¶m) const { @@ -116,7 +113,8 @@ namespace quda int bN = expand.get_y(tp); int bM = expand.get_z(tp); - tp.grid = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin / spin_block_factor, (m + bM - 1) / bM, (n + bN - 1) / bN); + tp.grid + = dim3(out.SiteSubset() * out.VolumeCB() * fineSpin / spin_block_factor, (m + bM - 1) / bM, (n + bN - 1) / bN); tp.set_max_shared_bytes = true; int bK = expand.get_w(tp); @@ -126,8 +124,7 @@ namespace quda return shared_bytes <= device::maximum_dynamic_shared_memory(); } - template - void launch_mma(TuneParam &tp, const qudaStream_t &stream) + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { constexpr int shared_bytes = shared_bytes_per_block(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { @@ -143,10 +140,7 @@ namespace quda } } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) - { - expand.expand(tp, stream); - } + void launch_mma(TuneParam &tp, const qudaStream_t &stream) { expand.expand(tp, stream); } void apply(const qudaStream_t &stream) { @@ -206,11 +200,11 @@ namespace quda } } - // clang-format on + // clang-format off constexpr int fineColor = @QUDA_MULTIGRID_NC_NVEC@; constexpr int coarseColor = @QUDA_MULTIGRID_NVEC2@; constexpr int nVec = @QUDA_MULTIGRID_MRHS@; - // clang-format off + // clang-format on template <> void ProlongateMma(ColorSpinorField &out, const ColorSpinorField &in, diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 2f51a8d9e5..6b54226408 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -10,8 +10,7 @@ namespace quda { - template + template class RestrictMmaLaunch : public TunableKernel { ColorSpinorField &out; @@ -51,10 +50,7 @@ namespace quda unsigned int sharedBytesPerThread() const { return 0; } - bool advanceTuneParam(TuneParam ¶m) const - { - return expand.advance_aux(param); - } + bool advanceTuneParam(TuneParam ¶m) const { return expand.advance_aux(param); } void initTuneParam(TuneParam ¶m) const { @@ -130,14 +126,13 @@ namespace quda return shared_bytes <= device::maximum_dynamic_shared_memory(); } - template - void launch_mma(TuneParam &tp, const qudaStream_t &stream) + template void launch_mma(TuneParam &tp, const qudaStream_t &stream) { constexpr int shared_bytes = mma::shared_memory_bytes(bM, bN, bK); if constexpr (shared_bytes <= device::maximum_dynamic_shared_memory()) { constexpr int block_z = 8; - using Arg = RestrictMmaArg; + using Arg = RestrictMmaArg; Arg arg(out, in, v, fine_to_coarse, coarse_to_fine, parity); tp.set_max_shared_bytes = true; launch_cuda(tp, stream, arg); @@ -146,10 +141,7 @@ namespace quda } } - void launch_mma(TuneParam &tp, const qudaStream_t &stream) - { - expand.expand(tp, stream); - } + void launch_mma(TuneParam &tp, const qudaStream_t &stream) { expand.expand(tp, stream); } void apply(const qudaStream_t &stream) { @@ -193,14 +185,14 @@ namespace quda if (!is_enabled_spin(in.Nspin())) errorQuda("nSpin %d has not been built", in.Nspin()); if (in.Nspin() == 2) { - RestrictMma(out, in, v, fine_to_coarse, - coarse_to_fine, spin_map, parity); + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, + spin_map, parity); } else if constexpr (fineColor == 3) { if (in.Nspin() == 4) { if constexpr (is_enabled_spin(4)) { if (in.Precision() == out.Precision()) { - RestrictMma( - out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity); + RestrictMma(out, in, v, fine_to_coarse, coarse_to_fine, + spin_map, parity); } else if (in.Precision() == QUDA_HALF_PRECISION) { #if 0 if constexpr (is_enabled(QUDA_HALF_PRECISION)) { @@ -209,7 +201,7 @@ namespace quda parity); } else { #endif - errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); + errorQuda("QUDA_PRECISION=%d does not enable half precision", QUDA_PRECISION); #if 0 } #endif From 86361602774933e269cda655b1b05ed8b9bb360c Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 17 Sep 2024 14:50:47 -0700 Subject: [PATCH 20/79] Soften the restriction for nrhs from multiple of 16 to multiple of 8. --- include/kernels/prolongator_mma.cuh | 1 - include/kernels/restrictor_mma.cuh | 2 -- lib/prolongator.in.cpp | 2 +- lib/prolongator_mma.in.cu | 1 - lib/restrictor.in.cpp | 2 +- lib/restrictor_mma.in.cu | 1 - 6 files changed, 2 insertions(+), 7 deletions(-) diff --git a/include/kernels/prolongator_mma.cuh b/include/kernels/prolongator_mma.cuh index 9a22cbaebf..a129e00b98 100644 --- a/include/kernels/prolongator_mma.cuh +++ b/include/kernels/prolongator_mma.cuh @@ -98,7 +98,6 @@ namespace quda using mma_t = typename Arg::mma_t; using Config = mma::MmaConfig; - // static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); extern __shared__ typename mma_t::compute_t smem_ptr[]; diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 1b1da7f1e6..4b075b31ea 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -30,7 +30,6 @@ namespace quda static constexpr int coarseSpin = coarseSpin_; static constexpr int coarseColor = coarseColor_; static constexpr int nVec = nVec_; - // static constexpr int aggregate_size = aggregate_size_; static constexpr int bN = bN_; static constexpr int bM = bM_; static constexpr int bK = bK_; @@ -258,7 +257,6 @@ namespace quda // The first two ldc's are dummy using Config = mma::MmaConfig; - static_assert(M % Arg::bM == 0, "M %% Arg::bM != 0.\n"); static_assert(K % Arg::bK == 0, "K %% Arg::bK != 0.\n"); extern __shared__ typename mma_t::compute_t smem_ptr[]; diff --git a/lib/prolongator.in.cpp b/lib/prolongator.in.cpp index 75afd1002d..0ed9142400 100644 --- a/lib/prolongator.in.cpp +++ b/lib/prolongator.in.cpp @@ -122,7 +122,7 @@ namespace quda // clang-format off IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - if (in.size() % 16 == 0) { + if (in.size() % 8 == 0) { // use MMA Prolongate(out, in, v, fine_to_coarse, spin_map, parity, fineColors); } else { diff --git a/lib/prolongator_mma.in.cu b/lib/prolongator_mma.in.cu index d21c5c3995..925f73fbd5 100644 --- a/lib/prolongator_mma.in.cu +++ b/lib/prolongator_mma.in.cu @@ -103,7 +103,6 @@ namespace quda bool set_mma_param(TuneParam &tp) const { - static_assert(m % m_atom_size == 0, "m modulo m_atom_size == 0"); static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; diff --git a/lib/restrictor.in.cpp b/lib/restrictor.in.cpp index 9c3315add2..4a053a4fac 100644 --- a/lib/restrictor.in.cpp +++ b/lib/restrictor.in.cpp @@ -122,7 +122,7 @@ namespace quda IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors; // clang-format on - if (in.size() % 16 == 0) { + if (in.size() % 8 == 0) { Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); } else { Restrict(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors); diff --git a/lib/restrictor_mma.in.cu b/lib/restrictor_mma.in.cu index 6b54226408..57853d570d 100644 --- a/lib/restrictor_mma.in.cu +++ b/lib/restrictor_mma.in.cu @@ -106,7 +106,6 @@ namespace quda bool set_mma_param(TuneParam &tp) const { - static_assert(m % m_atom_size == 0, "m modulo m_atom_size == 0"); static_assert(k % k_atom_size == 0, "k modulo k_atom_size == 0"); tp.block.x = 1; From f3a42bd14e430a96b93d5a6921cc666650bcc28a Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 18 Sep 2024 09:48:56 -0700 Subject: [PATCH 21/79] Set the default precision in coarse dslash mma to TF32/FP16; Fix the underlying code such that FP16 works with rescaling. --- .../cuda/mma_tensor_op/gmem_loader.cuh | 93 +++++++++++++++---- lib/dslash_coarse_mma.in.hpp | 6 +- 2 files changed, 76 insertions(+), 23 deletions(-) diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index f330d66cb2..3e034f7a64 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -321,42 +321,92 @@ namespace quda } /** - @brief Load from global memory and store data in registers while also applying a rescaling + @brief Load from global memory and store data in registers. */ - template - inline __device__ void convert_x_rescale(float ®_real, float ®_imag, complex *p, int m_idx, int n_idx, + template + inline __device__ void convert_x_rescale(float reg_real[batch], float reg_imag[batch], complex *p, int m_idx, int n_idx, float scale_inv, float rescale) { - if (x) { + complex v[batch]; + scale_inv *= rescale; + if constexpr (x) { + batch_load_t, batch>::load(v, &p[m_idx * ld + n_idx]); +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + reg_real[b] = scale_inv * v[b].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = scale_inv_conj * v[b].imag(); + } else { + reg_real[b] = v[b].real() * rescale; + reg_imag[b] = (dagger ? -v[b].imag() : v[b].imag()) * rescale; + } + } + } else { + complex v[batch]; + batch_load_t, batch>::load(v, &p[n_idx * ld + m_idx]); +#pragma unroll + for (int b = 0; b < batch; b++) { + if constexpr (fixed) { + reg_real[b] = scale_inv * v[b].real(); + auto scale_inv_conj = dagger ? -scale_inv : scale_inv; + reg_imag[b] = scale_inv_conj * v[b].imag(); + } else { + reg_real[b] = v[b].real() * rescale; + reg_imag[b] = (dagger ? -v[b].imag() : v[b].imag()) * rescale; + } + } + } + } + + /** + @brief Load from global memory and store data in registers. + */ + template + inline __device__ float find_abs_max(half2, complex *p, int m_idx, int n_idx, float scale_inv) + { + float this_max = 0.0f; + + if constexpr (x) { auto xx = p[m_idx * ld + n_idx]; + auto yy = p[(m_idx + 1) * ld + n_idx]; if constexpr (fixed) { - reg_real = scale_inv * xx.real() * rescale; - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag() * rescale; + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + this_max = abs_max(scale_inv * yy.real(), this_max); + this_max = abs_max(scale_inv * yy.imag(), this_max); } else { - reg_real = +xx.real() * rescale; - reg_imag = (dagger ? -xx.imag() : +xx.imag()) * rescale; + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + this_max = abs_max(yy.real(), this_max); + this_max = abs_max(yy.imag(), this_max); } } else { auto xx = p[n_idx * ld + m_idx]; + auto yy = p[n_idx * ld + m_idx + 1]; if constexpr (fixed) { - reg_real = scale_inv * xx.real() * rescale; - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag = scale_inv_conj * xx.imag() * rescale; + this_max = abs_max(scale_inv * xx.real(), this_max); + this_max = abs_max(scale_inv * xx.imag(), this_max); + this_max = abs_max(scale_inv * yy.real(), this_max); + this_max = abs_max(scale_inv * yy.imag(), this_max); } else { - reg_real = xx.real() * rescale; - reg_imag = (dagger ? -xx.imag() : xx.imag()) * rescale; + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + this_max = abs_max(yy.real(), this_max); + this_max = abs_max(yy.imag(), this_max); } } + + return this_max; } /** @brief Load from global memory and store data in registers. */ template - inline __device__ float find_abs_max(complex *p, int m_idx, int n_idx, float scale_inv) + inline __device__ float find_abs_max(float, complex *p, int m_idx, int n_idx, float scale_inv) { float this_max = 0.0f; @@ -458,8 +508,6 @@ namespace quda __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, smem_accessor_t &smem_imag) { - static_assert(batch == 1, "For now batch needs to be 1 for the rescale kernel."); - // for each iteration, each warp loads a tile int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; int warp_id = thread_id / 32; @@ -495,7 +543,7 @@ namespace quda constexpr bool x = (transpose == dagger); float this_max - = find_abs_max(smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); + = find_abs_max(load_t{}, smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); thread_max = fmaxf(this_max, thread_max); } } @@ -534,8 +582,13 @@ namespace quda load_t imag; constexpr bool x = (transpose == dagger); - convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv, block_rescale_factor); + // if constexpr (std::is_same_v) { + // convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, + // scale_inv, block_rescale_factor); + // } else { + convert_x_rescale(&real, &imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv, block_rescale_factor); + // } smem_real.vector_load(smem_m_offset, smem_k_offset, real); smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); } diff --git a/lib/dslash_coarse_mma.in.hpp b/lib/dslash_coarse_mma.in.hpp index b0f8a43545..6fa3b58a2f 100644 --- a/lib/dslash_coarse_mma.in.hpp +++ b/lib/dslash_coarse_mma.in.hpp @@ -42,14 +42,14 @@ namespace quda // using mma_t = smma::smma_t; // 3xBF16 // using mma_t = smma::smma_t; // 3xTF32 - // using mma_t = smma::smma_x_t; // 1xFP16 - m16n8k8 variant for sm70 + // using mma_t = smma::smma_x_t; // 3xFP16 - m16n8k8 variant for sm70 // using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; // 1xTF32 // using mma_t = mma::smma_half_t; // 3xFP16 // using mma_t = mma::hmma_t; // 1xFP16 #if (__COMPUTE_CAPABILITY__ >= 800) - using mma_t = typename mma::smma_dispatch::type; + using mma_t = hmma::hmma_tfloat32_t<4, 1, 1>; #else - using mma_t = typename simt::simt_t; + using mma_t = mma::hmma_t; #endif static constexpr int n_atom_size = mma_t::MMA_N; static constexpr int m_atom_size = mma_t::MMA_M; From 82a61c67c3f838ded5bfef0bf96a08b6ea6d1795 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 18 Sep 2024 10:13:43 -0700 Subject: [PATCH 22/79] Clean up the MMA code. --- .../cuda/mma_tensor_op/gmem_loader.cuh | 99 ++++--------------- 1 file changed, 19 insertions(+), 80 deletions(-) diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 3e034f7a64..f1851b175d 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -141,7 +141,7 @@ namespace quda }; /** - @brief Load from global memory and store data in registers. + @brief Load from global memory and store data in registers: specialized for float2. */ template inline __device__ void convert_x(float2 reg_real[batch], float2 reg_imag[batch], complex *p, int m_idx, @@ -191,51 +191,7 @@ namespace quda } /** - @brief Load from global memory and store data in registers. - */ - template - inline __device__ void convert_x(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, int n_idx, - float scale_inv) - { - if constexpr (x) { - complex vx[batch]; - complex vy[batch]; - batch_load_t, batch>::load(vx, &p[(m_idx + 0) * ld + n_idx]); - batch_load_t, batch>::load(vy, &p[(m_idx + 1) * ld + n_idx]); - -#pragma unroll - for (int b = 0; b < batch; b++) { - if constexpr (fixed) { - reg_real[b] = __floats2half2_rn(scale_inv * vx[b].real(), scale_inv * vy[b].real()); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[b] = __floats2half2_rn(scale_inv_conj * vx[b].imag(), scale_inv_conj * vy[b].imag()); - } else { - reg_real[b] = __floats2half2_rn(+vx[b].real(), +vy[b].real()); - reg_imag[b] - = __floats2half2_rn(dagger ? -vx[b].imag() : +vx[b].imag(), dagger ? -vy[b].imag() : +vy[b].imag()); - } - } - } else { - complex v[batch * 2]; - batch_load_t, batch * 2>::load(v, &p[n_idx * ld + m_idx]); - -#pragma unroll - for (int b = 0; b < batch; b++) { - if constexpr (fixed) { - reg_real[b] = __floats2half2_rn(scale_inv * v[b * 2].real(), scale_inv * v[b * 2 + 1].real()); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[b] = __floats2half2_rn(scale_inv_conj * v[b * 2].imag(), scale_inv_conj * v[b * 2 + 1].imag()); - } else { - reg_real[b] = __floats2half2_rn(+v[b * 2].real(), +v[b * 2 + 1].real()); - reg_imag[b] = __floats2half2_rn(dagger ? -v[b * 2].imag() : +v[b * 2].imag(), - dagger ? -v[b * 2 + 1].imag() : +v[b * 2 + 1].imag()); - } - } - } - } - - /** - @brief Load from global memory and store data in registers. + @brief Load from global memory and store data in registers: specialized for half2. */ template inline __device__ void convert_x_rescale(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, @@ -284,44 +240,15 @@ namespace quda @brief Load from global memory and store data in registers. */ template - inline __device__ void convert_x(float reg_real[batch], float reg_imag[batch], complex *p, int m_idx, int n_idx, + inline __device__ void convert_x(half2 reg_real[batch], half2 reg_imag[batch], complex *p, int m_idx, int n_idx, float scale_inv) { - complex v[batch]; - if constexpr (x) { - batch_load_t, batch>::load(v, &p[m_idx * ld + n_idx]); -#pragma unroll - for (int b = 0; b < batch; b++) { - // auto xx = p[m_idx * ld + n_idx]; - if constexpr (fixed) { - reg_real[b] = scale_inv * v[b].real(); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[b] = scale_inv_conj * v[b].imag(); - } else { - reg_real[b] = v[b].real(); - reg_imag[b] = dagger ? -v[b].imag() : v[b].imag(); - } - } - } else { - complex v[batch]; - batch_load_t, batch>::load(v, &p[n_idx * ld + m_idx]); -#pragma unroll - for (int b = 0; b < batch; b++) { - // auto xx = p[n_idx * ld + m_idx]; - if constexpr (fixed) { - reg_real[b] = scale_inv * v[b].real(); - auto scale_inv_conj = dagger ? -scale_inv : scale_inv; - reg_imag[b] = scale_inv_conj * v[b].imag(); - } else { - reg_real[b] = v[b].real(); - reg_imag[b] = dagger ? -v[b].imag() : v[b].imag(); - } - } - } + constexpr float rescale = 1.0f; + convert_x_rescale(reg_real, reg_imag, p, m_idx, n_idx, scale_inv, rescale); } /** - @brief Load from global memory and store data in registers. + @brief Load from global memory and store data in registers: specialized for float. */ template inline __device__ void convert_x_rescale(float reg_real[batch], float reg_imag[batch], complex *p, int m_idx, int n_idx, @@ -362,6 +289,18 @@ namespace quda /** @brief Load from global memory and store data in registers. */ + template + inline __device__ void convert_x(float reg_real[batch], float reg_imag[batch], complex *p, int m_idx, int n_idx, + float scale_inv) + { + constexpr float rescale = 1.0f; + convert_x_rescale(reg_real, reg_imag, p, m_idx, n_idx, scale_inv, rescale); + } + + + /** + @brief Load from global memory and find the absolute maximum value: specialized for half2. + */ template inline __device__ float find_abs_max(half2, complex *p, int m_idx, int n_idx, float scale_inv) { @@ -403,7 +342,7 @@ namespace quda } /** - @brief Load from global memory and store data in registers. + @brief Load from global memory and find the absolute maximum value: specialized for float. */ template inline __device__ float find_abs_max(float, complex *p, int m_idx, int n_idx, float scale_inv) From 651517749a77569e48a98ef58c6f2f234e69e123 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 18 Sep 2024 11:33:08 -0700 Subject: [PATCH 23/79] Short cut the rescaling code to use scale_inv for fixed point format fields. --- include/kernels/restrictor_mma.cuh | 12 +- .../cuda/mma_tensor_op/gmem_loader.cuh | 125 +++++++++--------- 2 files changed, 69 insertions(+), 68 deletions(-) diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 4b075b31ea..e28b1ef506 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -119,17 +119,13 @@ namespace quda int fine_spin = fine_spin_block + coarse_spin * Arg::spin_block_factor; auto a_gmem = gmem(v_parity, x_fine_cb, fine_spin, fine_color, contiguous + contiguous_dim_offset); - complex a[elements_per_thread]; - mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); + using store_t = typename gmem_obj_t::store_type; if constexpr (decltype(a_gmem)::fixed) { - auto scale_inv = a_gmem.get_scale_inv(); -#pragma unroll - for (int e = 0; e < elements_per_thread; e++) { - thread_max = mma::abs_max(a[e].real() * scale_inv, thread_max); - thread_max = mma::abs_max(a[e].imag() * scale_inv, thread_max); - } + thread_max = fmaxf(fixedMaxValue::value * a_gmem.get_scale_inv(), thread_max); } else { + complex a[elements_per_thread]; + mma::batch_load_t, elements_per_thread>::load(a, a_gmem.data()); #pragma unroll for (int e = 0; e < elements_per_thread; e++) { thread_max = mma::abs_max(a[e].real(), thread_max); diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index f1851b175d..7c14a89171 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -443,7 +443,7 @@ namespace quda return gmem.get_scale_inv(); } - template + template __device__ inline float tmp2s_rescale(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, smem_accessor_t &smem_imag) { @@ -465,44 +465,50 @@ namespace quda constexpr int n_warp = block_y * block_z / 32; constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; - float thread_max = 0.0f; - + float block_rescale_factor = 1.0f; + if constexpr (rescale) { + if constexpr (fixed) { + block_rescale_factor = scale_inv > 0 ? 65504.0f / (scale_inv * fixedMaxValue::value) : 1.0f; + } else { + float thread_max = 0.0f; #pragma unroll - for (int c = 0; c < warp_cycle; c++) { - int logical_warp_index = c * n_warp + warp_id; - if (logical_warp_index < total_tiles) { - int warp_m = (c * n_warp + warp_id) % tile_dim_m; - int warp_k = (c * n_warp + warp_id) / tile_dim_m; - - int smem_m_offset = warp_m * w_m + group_id * batch; - int smem_k_offset = warp_k * w_k + thread_in_group; - - int gmem_m_offset = smem_m_offset; - int gmem_k_offset = smem_k_offset; + for (int c = 0; c < warp_cycle; c++) { + int logical_warp_index = c * n_warp + warp_id; + if (logical_warp_index < total_tiles) { + int warp_m = (c * n_warp + warp_id) % tile_dim_m; + int warp_k = (c * n_warp + warp_id) / tile_dim_m; + + int smem_m_offset = warp_m * w_m + group_id * batch; + int smem_k_offset = warp_k * w_k + thread_in_group; + + int gmem_m_offset = smem_m_offset; + int gmem_k_offset = smem_k_offset; + + constexpr bool x = (transpose == dagger); + float this_max + = find_abs_max(load_t{}, smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); + thread_max = fmaxf(this_max, thread_max); + } + } - constexpr bool x = (transpose == dagger); - float this_max - = find_abs_max(load_t{}, smem_ptr, gmem_m_offset, gmem_k_offset, scale_inv); - thread_max = fmaxf(this_max, thread_max); - } - } + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); - // block all-reduce thread_max - using block_reduce_t = cub::BlockReduce; - __shared__ typename block_reduce_t::TempStorage temp_storage; - float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } + } + __syncthreads(); - __shared__ float block_max_all; - if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { - if (block_max > 0.0f) { - block_max_all = block_max; - } else { - block_max_all = 1.0f; + block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number } } - __syncthreads(); - - float block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number #pragma unroll for (int c = 0; c < warp_cycle; c++) { @@ -521,13 +527,8 @@ namespace quda load_t imag; constexpr bool x = (transpose == dagger); - // if constexpr (std::is_same_v) { - // convert_x_rescale(real, imag, smem_ptr, gmem_m_offset, gmem_k_offset, - // scale_inv, block_rescale_factor); - // } else { - convert_x_rescale(&real, &imag, smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv, block_rescale_factor); - // } + convert_x_rescale(&real, &imag, smem_ptr, gmem_m_offset, gmem_k_offset, + scale_inv, block_rescale_factor); smem_real.vector_load(smem_m_offset, smem_k_offset, real); smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); } @@ -679,32 +680,36 @@ namespace quda float block_rescale_factor = 1.0f; if constexpr (rescale) { - float thread_max = 0; + if constexpr (fixed) { + block_rescale_factor = scale_inv > 0 ? 65504.0f / (scale_inv * fixedMaxValue::value) : 1.0f; + } else { + float thread_max = 0; #pragma unroll - for (int n = 0; n < n_dim; n++) { + for (int n = 0; n < n_dim; n++) { #pragma unroll - for (int m = 0; m < m_dim; m++) { - thread_max = abs_max(f_real[m * n_dim + n], thread_max); - thread_max = abs_max(f_imag[m * n_dim + n], thread_max); + for (int m = 0; m < m_dim; m++) { + thread_max = abs_max(f_real[m * n_dim + n], thread_max); + thread_max = abs_max(f_imag[m * n_dim + n], thread_max); + } } - } - // block all-reduce thread_max - using block_reduce_t = cub::BlockReduce; - __shared__ typename block_reduce_t::TempStorage temp_storage; - float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); - - __shared__ float block_max_all; - if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { - if (block_max > 0.0f) { - block_max_all = block_max; - } else { - block_max_all = 1.0f; + // block all-reduce thread_max + using block_reduce_t = cub::BlockReduce; + __shared__ typename block_reduce_t::TempStorage temp_storage; + float block_max = block_reduce_t(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float block_max_all; + if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { + if (block_max > 0.0f) { + block_max_all = block_max; + } else { + block_max_all = 1.0f; + } } - } - __syncthreads(); + __syncthreads(); - block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + block_rescale_factor = 65504.0f / block_max_all; // 65504 = the maximum FP16 number + } } if constexpr (std::is_same_v) { From a9f6d7ca7bd6f11b77a126e5739c32976393bdc1 Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Wed, 18 Sep 2024 14:27:35 -0700 Subject: [PATCH 24/79] Clean up code. --- include/kernels/dslash_coarse_mma.cuh | 20 ++++----- .../cuda/mma_tensor_op/gmem_loader.cuh | 42 +------------------ 2 files changed, 12 insertions(+), 50 deletions(-) diff --git a/include/kernels/dslash_coarse_mma.cuh b/include/kernels/dslash_coarse_mma.cuh index 3c36631842..f834fa16f9 100644 --- a/include/kernels/dslash_coarse_mma.cuh +++ b/include/kernels/dslash_coarse_mma.cuh @@ -217,9 +217,9 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale( + float rescale_factor_a = a_loader.template tmp2s_rescale( smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale( + float rescale_factor_b = b_loader.template tmp2s_rescale( smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); @@ -237,9 +237,9 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale( + float rescale_factor_a = a_loader.template tmp2s_rescale( smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale( + float rescale_factor_b = b_loader.template tmp2s_rescale( smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); @@ -307,9 +307,9 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale( + float rescale_factor_a = a_loader.template tmp2s_rescale( smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale( + float rescale_factor_b = b_loader.template tmp2s_rescale( smem_tmp_b_ghost, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); @@ -326,9 +326,9 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale( + float rescale_factor_a = a_loader.template tmp2s_rescale( smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale( + float rescale_factor_b = b_loader.template tmp2s_rescale( smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); @@ -369,9 +369,9 @@ namespace quda pipe.consumer_wait(); __syncthreads(); - float rescale_factor_a = a_loader.template tmp2s_rescale( + float rescale_factor_a = a_loader.template tmp2s_rescale( smem_tmp_a, scale_inv_a, smem_obj_a_real, smem_obj_a_imag); - float rescale_factor_b = b_loader.template tmp2s_rescale( + float rescale_factor_b = b_loader.template tmp2s_rescale( smem_tmp_b, scale_inv_b, smem_obj_b_real, smem_obj_b_imag); pipe.consumer_release(); __syncthreads(); diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index 7c14a89171..0b4f6d5dfe 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -541,46 +541,8 @@ namespace quda __device__ inline void tmp2s(complex *smem_ptr, float scale_inv, smem_accessor_t &smem_real, smem_accessor_t &smem_imag) { - // for each iteration, each warp loads a tile - int thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x; - int warp_id = thread_id / 32; - int lane_id = thread_id % 32; - int thread_in_group = lane_id % 4; - int group_id = lane_id / 4; - constexpr int w_m = 8 * batch; - constexpr int w_k = 4; - static_assert(bM % w_m == 0, "bM %% w_m"); - static_assert(bN % w_k == 0, "bN %% w_k"); - - constexpr int tile_dim_m = bM / w_m; - constexpr int tile_dim_k = bN / w_k; - - constexpr int total_tiles = tile_dim_k * tile_dim_m; - constexpr int n_warp = block_y * block_z / 32; - constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; -#pragma unroll - for (int c = 0; c < warp_cycle; c++) { - int logical_warp_index = c * n_warp + warp_id; - if (logical_warp_index < total_tiles) { - int warp_m = (c * n_warp + warp_id) % tile_dim_m; - int warp_k = (c * n_warp + warp_id) / tile_dim_m; - - int smem_m_offset = warp_m * w_m + group_id * batch; - int smem_k_offset = warp_k * w_k + thread_in_group; - - int gmem_m_offset = smem_m_offset; - int gmem_k_offset = smem_k_offset; - - load_t real; - load_t imag; - - constexpr bool x = (transpose == dagger); - convert_x(&real, &imag, smem_ptr, gmem_m_offset, gmem_k_offset, - scale_inv); - smem_real.vector_load(smem_m_offset, smem_k_offset, real); - smem_imag.vector_load(smem_m_offset, smem_k_offset, imag); - } - } + constexpr bool rescale = false; + tmp2s_rescale(smem_ptr, scale_inv, smem_real, smem_imag); } /** From 16deef203236f458234d68cd24400045ec3e729f Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Thu, 3 Oct 2024 11:21:20 -0700 Subject: [PATCH 25/79] Add const and constexpr; add nrhs to prolongator tuning string. --- include/kernels/coarse_op_kernel_mma.cuh | 12 ++++++------ include/kernels/coarse_op_preconditioned_mma.cuh | 6 +++--- include/targets/cuda/kernel.h | 2 +- include/targets/cuda/mma_tensor_op/gemm.cuh | 8 ++++---- lib/prolongator_mma.in.cu | 1 + 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/kernels/coarse_op_kernel_mma.cuh b/include/kernels/coarse_op_kernel_mma.cuh index 3aca52cdd5..151db16768 100644 --- a/include/kernels/coarse_op_kernel_mma.cuh +++ b/include/kernels/coarse_op_kernel_mma.cuh @@ -47,7 +47,7 @@ namespace quda Where: mu = dir, s = fine spin, c' = coarse color, c = fine color */ template - __device__ __host__ inline auto computeUV(Arg &arg, const Wtype &Wacc, int parity, int x_cb, int m_offset, + __device__ __host__ inline auto computeUV(const Arg &arg, const Wtype &Wacc, int parity, int x_cb, int m_offset, int n_offset) { using real = typename Arg::Float; @@ -126,8 +126,8 @@ namespace quda } // namespace impl template struct ComputeUVMMA { - Arg &arg; - constexpr ComputeUVMMA(Arg &arg) : arg(arg) {} + const Arg &arg; + constexpr ComputeUVMMA(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __forceinline__ void operator()() @@ -170,7 +170,7 @@ namespace quda namespace impl { - template __device__ void computeVUV(Arg &arg, int parity, int x_cb, int m_offset, int n_offset) + template __device__ void computeVUV(const Arg &arg, int parity, int x_cb, int m_offset, int n_offset) { constexpr int fineSpin = Arg::fineSpin; constexpr int coarseSpin = Arg::coarseSpin; @@ -331,8 +331,8 @@ namespace quda } // namespace impl template struct ComputeVUVMMA { - Arg &arg; - constexpr ComputeVUVMMA(Arg &arg) : arg(arg) {} + const Arg &arg; + constexpr ComputeVUVMMA(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __forceinline__ void operator()() diff --git a/include/kernels/coarse_op_preconditioned_mma.cuh b/include/kernels/coarse_op_preconditioned_mma.cuh index 98753718ec..624d337cce 100644 --- a/include/kernels/coarse_op_preconditioned_mma.cuh +++ b/include/kernels/coarse_op_preconditioned_mma.cuh @@ -29,7 +29,7 @@ namespace quda }; template - inline __device__ auto computeYhatMMA(Arg &arg, int d, int x_cb, int parity, int m, int n) + inline __device__ auto computeYhatMMA(const Arg &arg, int d, int x_cb, int parity, int m, int n) { using real = typename Arg::Float; constexpr int nDim = 4; @@ -84,8 +84,8 @@ namespace quda } template struct CalculateYhatMMA { - Arg &arg; - constexpr CalculateYhatMMA(Arg &arg) : arg(arg) {} + const Arg &arg; + constexpr CalculateYhatMMA(const Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } __device__ __forceinline__ void operator()() diff --git a/include/targets/cuda/kernel.h b/include/targets/cuda/kernel.h index b85e6bcc72..d9903caa58 100644 --- a/include/targets/cuda/kernel.h +++ b/include/targets/cuda/kernel.h @@ -224,7 +224,7 @@ namespace quda @param[in] arg Kernel argument */ template