Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMA-izing the prolongator and restrictor kernels #1497

Merged
merged 88 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
88a108e
Add MMA version of prolongator.
hummingtree Aug 27, 2024
c0cb6e0
Dagger'ed the equation to make better use of MMA.
hummingtree Aug 27, 2024
dd55d92
Make nColor = 3 works:
hummingtree Aug 28, 2024
57b6f95
Add from_to_non_rel to block transpose to complete the circle.
hummingtree Aug 29, 2024
efa9c15
More cleanup of the MMA code. Apply vector gmem loads when possible.
hummingtree Aug 30, 2024
82d532b
Apply more vector gmem loads when possible.
hummingtree Aug 30, 2024
e6a5d34
Add MMA version for restrictor.
hummingtree Sep 4, 2024
f0f88d5
Add expands for restrictor with MMA.
hummingtree Sep 4, 2024
92b1c39
Add shared memory caching for the the restrictor kernel.
hummingtree Sep 4, 2024
e7b9361
Allow restrictor_mma to have N % bN != 0; More generic optimizations …
hummingtree Sep 6, 2024
6bb02f6
Add rescaling to prolongator.
hummingtree Sep 10, 2024
2d8867e
Modify the MMA types.
hummingtree Sep 10, 2024
db6c030
Add rescale for restrictor.
hummingtree Sep 11, 2024
ce257ef
Abstract the MMA expansions into a class.
hummingtree Sep 16, 2024
f38882a
Add more abstraction; make TF32 the default for SM80 and later.
hummingtree Sep 16, 2024
69e15d6
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
maddyscientist Sep 17, 2024
e03ca23
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
hummingtree Sep 17, 2024
0a2bd2a
Fix block transpose by having no bound checks on the kernel level; in…
hummingtree Sep 17, 2024
e68dfa9
Merge branch 'feature/prolongator-mma' of github.com:lattice/quda int…
maddyscientist Sep 17, 2024
de05ebf
Add some aggregate sizes to MMA restrictor
maddyscientist Sep 17, 2024
bcdfb86
Make `aggregate_size` a runtime variable.
hummingtree Sep 17, 2024
189f78e
Apply clang-format.
hummingtree Sep 17, 2024
8636160
Soften the restriction for nrhs from multiple of 16 to multiple of 8.
hummingtree Sep 17, 2024
f3a42bd
Set the default precision in coarse dslash mma to TF32/FP16; Fix the …
hummingtree Sep 18, 2024
82a61c6
Clean up the MMA code.
hummingtree Sep 18, 2024
6515177
Short cut the rescaling code to use scale_inv for fixed point format …
hummingtree Sep 18, 2024
a9f6d7c
Clean up code.
hummingtree Sep 18, 2024
29593b1
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
hummingtree Sep 18, 2024
1bc4df7
Merge branch 'feature/mrhs-solvers' of github.com:lattice/quda into f…
hummingtree Sep 30, 2024
16deef2
Add const and constexpr; add nrhs to prolongator tuning string.
hummingtree Oct 3, 2024
e1d8963
Add initial support for TMA in coarse dslash; optimize the shared mem…
hummingtree Oct 15, 2024
caea8f3
Add more shared memory patterns.
hummingtree Oct 24, 2024
40e3cc3
Apply TMA to the clover term; clean up code and remove duplications.
hummingtree Oct 31, 2024
9e59692
Relax the requirements in multigrid_benchmark_test.
hummingtree Oct 31, 2024
0f6f71a
Merge remote-tracking branch 'origin/develop' into feature/prolongato…
hummingtree Oct 31, 2024
c3a155d
Put tma_wait under the TMA macro.
hummingtree Oct 31, 2024
88bedc9
Add const to `include/kernels/dslash_mdw_fused.cuh`; Put `lib/prolong…
hummingtree Oct 31, 2024
013aa50
Add more const to Arg &arg; Fix compile when MMA is not available.
hummingtree Oct 31, 2024
48af0e2
Fix CI:
hummingtree Nov 1, 2024
123aa04
Remove the unwanted warning macro.
hummingtree Nov 1, 2024
c6e833f
Address the unused parameter warnings from CI.
hummingtree Nov 1, 2024
bb83b5f
Add support for staggered nSpin = 1 case for restrictor_mma.
hummingtree Nov 4, 2024
7f31418
Break TMA box sizes into smaller pieces when the sizes are larger tha…
hummingtree Nov 4, 2024
3a0f205
Restore the SMMA precision for coarse dslash MMA.
hummingtree Nov 4, 2024
6d6beb5
Use SIMT for coarse dslash MMA on SM70.
hummingtree Nov 4, 2024
43e7cf7
Add param and command line option for running transfer with MMA.
hummingtree Nov 5, 2024
53eb2ae
Add doxygen to `include/expand_list.hpp` and `include/targets/cuda/tm…
hummingtree Nov 7, 2024
d0da823
Fix typo; replace `store_type` with `store_t`.
hummingtree Nov 12, 2024
cd3bbad
Only instantiate enabled spins in `lib/block_transpose.in.cu`.
hummingtree Nov 12, 2024
1be04f5
Remove test code in `lib/restrictor.in.cpp` and `lib/prolongator.in.c…
hummingtree Nov 12, 2024
c84c376
Store the aux values as the shape values directly (instead of as the …
hummingtree Nov 12, 2024
6451e7f
Use mma::numeric_limits<mma::half>::max() in MMA code.
hummingtree Nov 12, 2024
eacc613
Add initializer to *_factors.
hummingtree Nov 12, 2024
24c925f
Use curly brace initializers.
hummingtree Nov 12, 2024
fd4a930
Initialize looping variable.
hummingtree Nov 12, 2024
abe7854
Specify MMA types in CMake.
hummingtree Dec 2, 2024
a3e2f4a
The MMA kernels will now choose smallest instantiated nVec that is la…
hummingtree Dec 3, 2024
979ba1a
Add missing file; add checks for MRHS.
hummingtree Dec 3, 2024
be3d180
Add comments to the newly added MMA macros.
hummingtree Dec 4, 2024
d94156a
Merge remote-tracking branch 'origin/develop' into feature/prolongato…
hummingtree Dec 5, 2024
52bbc43
Add nVec_actual to the color_spinor_fields/params - now the MMA kenre…
hummingtree Dec 11, 2024
e6ee340
Put the IntList into a header file.
hummingtree Dec 11, 2024
47d6f00
Some improvements to handle nVec and nVec_actual: nVec should be stor…
maddyscientist Dec 20, 2024
bdf2aba
Fix bytes comutation for MRHS clover solo operator
maddyscientist Dec 20, 2024
87d332f
Add comments; remove dead code; address some of the review comments.
hummingtree Jan 2, 2025
48655db
Merge branch 'feature/prolongator-mma' of github.com:lattice/quda int…
hummingtree Jan 2, 2025
68ab140
Move the conditional macro to tma_helper.hpp
hummingtree Jan 2, 2025
85da7a3
Fix the plumbing for prolongator/restrictor-MMA for staggered.
hummingtree Jan 3, 2025
ea867a3
Use divide and conquer logic when choosing the instantiated nVec to use.
hummingtree Jan 3, 2025
53c8e31
Optimize the rescaling code:
hummingtree Jan 3, 2025
df97e1a
When using TF32 and BF16 check if the CC is less than 80.
hummingtree Jan 3, 2025
b0b4355
Apply the Arg::check_bounds to the other kernels.
hummingtree Jan 3, 2025
75b24c2
Further reduece the number of float division by using scale.
hummingtree Jan 6, 2025
254a79b
Allow specify MMA types per precision.
hummingtree Jan 6, 2025
af4c9a0
Break the MMA type into half/single ones.
hummingtree Jan 7, 2025
49c0a58
Print the nVec used only in debug verbosity.
hummingtree Jan 7, 2025
e8ca869
Remove the divide and conquer code for larger than limit TMA box size…
hummingtree Jan 16, 2025
a3e3384
Change the default setup MMA types to 3xfp16.
hummingtree Jan 16, 2025
f823560
Use logQuda.
hummingtree Jan 16, 2025
30d69f5
Change from - uses: actions/checkout@v3 to - uses: actions/checkout@v4.
hummingtree Jan 16, 2025
d62842c
Update cuda_githubactions_build.yml
mathiaswagner Jan 16, 2025
6f98322
Merge remote-tracking branch 'origin/develop' into feature/prolongato…
hummingtree Jan 17, 2025
cd67249
Not generating the Nc=6 files for MMA transfers.
hummingtree Jan 17, 2025
d91a823
Disable compiling for fineColor=6 and coarseColor=6 for the transfer …
hummingtree Jan 17, 2025
dcad5c8
Apply clang-format.
hummingtree Jan 17, 2025
d8adfa2
Resolve the cmake/clang-format conflicts, hopefully.
hummingtree Jan 17, 2025
0231629
Whitelist instead of blacklisting the fineColor and coarseColor for t…
hummingtree Jan 21, 2025
fadb49a
Added apt-get install clang-14 to github actions
weinbe2 Jan 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/cuda_githubactions_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ jobs:
packages: cuda-compiler-12-1 cuda-libraries-dev-12-1 cuda-nvml-dev-12-1
execute_install_scripts: true

- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Ccache for gh actions
uses: hendrikmuhs/[email protected].9
uses: hendrikmuhs/[email protected].16
with:
key: ${{ github.job }}-${{ matrix.compiler }}
max-size: 2000M
Expand Down
5 changes: 5 additions & 0 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ namespace quda
int nColor = 0; // Number of colors of the field
int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor
int nVec = 1; // number of packed vectors (for multigrid transfer operator)
int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded)

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID; // used by twisted mass
QudaSiteOrder siteOrder = QUDA_INVALID_SITE_ORDER; // defined for full fields
Expand Down Expand Up @@ -241,6 +242,7 @@ namespace quda
nColor(cpuParam.nColor),
nSpin(cpuParam.nSpin),
nVec(cpuParam.nVec),
nVec_actual(cpuParam.nVec_actual),
twistFlavor(cpuParam.twistFlavor),
siteOrder(QUDA_EVEN_ODD_SITE_ORDER),
fieldOrder(QUDA_INVALID_FIELD_ORDER),
Expand Down Expand Up @@ -318,6 +320,7 @@ namespace quda
int nColor = 0;
int nSpin = 0;
int nVec = 0;
mutable int nVec_actual = 0;

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID;

Expand Down Expand Up @@ -455,6 +458,8 @@ namespace quda
int Ncolor() const { return nColor; }
int Nspin() const { return nSpin; }
int Nvec() const { return nVec; }
int Nvec_actual() const { return nVec_actual; }
void Nvec_actual(int nVec_actual) const { this->nVec_actual = nVec_actual; }
QudaTwistFlavorType TwistFlavor() const { return twistFlavor; }
int Ndim() const { return nDim; }
const int *X() const { return x.data; }
Expand Down
5 changes: 3 additions & 2 deletions include/color_spinor_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ namespace quda
*/
template <typename Float, typename storeFloat, bool block_float_, typename norm_t> struct fieldorder_wrapper {
using value_type = Float; /**< Compute type */
using store_type = storeFloat; /**< Storage type */
using store_t = storeFloat; /**< Storage type */
complex<storeFloat> *v; /**< Field memory address this wrapper encompasses */
const int idx; /**< Index into field */
private:
Expand Down Expand Up @@ -581,7 +581,6 @@ namespace quda
*/
__device__ __host__ inline auto get_scale() const
{
static_assert(block_float == false, "Orders with block_float == true should not call the get_scale method.");
return block_float ? static_cast<Float>(1) / norm[norm_idx] : scale;
}

Expand Down Expand Up @@ -858,6 +857,8 @@ namespace quda
static constexpr int nSpin = nSpin_;
static constexpr int nColor = nColor_;

using store_t = storeFloat;

field<Float, storeFloat, fixed, block_float> v;
unsigned int volumeCB = 0;

Expand Down
206 changes: 206 additions & 0 deletions include/expand_list.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include <tune_quda.h>
#include <int_factor_array.hpp>

namespace quda {

/**
@brief This helper class instantiates the following mapping:
hummingtree marked this conversation as resolved.
Show resolved Hide resolved
tp.aux.x -> Bx in x_atom_size * [factors of (x + x_atom_size - 1) / x_atom_size];
tp.aux.y -> By in y_atom_size * [factors of (y + y_atom_size - 1) / y_atom_size];
tp.aux.z -> Bz in z_atom_size * [factors of (z + z_atom_size - 1) / z_atom_size];
tp.aux.w -> Bw in w_atom_size * [factors of (w + w_atom_size - 1) / w_atom_size].
See `void expand(TuneParam &tp, const qudaStream_t &stream)`
*/
template <class Callable, int x, int x_atom_size, int y, int y_atom_size, int z, int z_atom_size, int w, int w_atom_size>
class expand_aux_t {

Callable &_callable;

static constexpr IntFactorArray<(x + x_atom_size - 1) / x_atom_size, x_atom_size> x_factors{};
static constexpr IntFactorArray<(y + y_atom_size - 1) / y_atom_size, y_atom_size> y_factors{};
static constexpr IntFactorArray<(z + z_atom_size - 1) / z_atom_size, z_atom_size> z_factors{};
static constexpr IntFactorArray<(w + w_atom_size - 1) / w_atom_size, w_atom_size> w_factors{};

template <int Bx, int By, int Bz, size_t W, size_t... Ws>
void span_w(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<W, Ws...>)
{
constexpr int Bw = w_factors[W];
if (tp.aux.w == Bw) {
_callable.template launch_mma<Bx, By, Bz, Bw>(tp, stream);
} else {
if constexpr (sizeof...(Ws) > 0) {
span_w<Bx, By, Bz>(tp, stream, std::index_sequence<Ws...>());
} else {
errorQuda("Invalid tp.aux.w(=%d)", tp.aux.w);
}
}
}

template <int Bx, int By, size_t Z, size_t... Zs>
void span_z(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<Z, Zs...>)
{
constexpr int Bz = z_factors[Z];
if (tp.aux.z == Bz) {
std::make_index_sequence<w_factors.size()> w_indices;
span_w<Bx, By, Bz>(tp, stream, w_indices);
} else {
if constexpr (sizeof...(Zs) > 0) {
span_z<Bx, By>(tp, stream, std::index_sequence<Zs...>());
} else {
errorQuda("Invalid tp.aux.z(=%d)", tp.aux.z);
}
}
}

template <int Bx, size_t Y, size_t... Ys>
void span_y(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<Y, Ys...>)
{
constexpr int By = y_factors[Y];
if (tp.aux.y == By) {
std::make_index_sequence<z_factors.size()> z_indices;
span_z<Bx, By>(tp, stream, z_indices);
} else {
if constexpr (sizeof...(Ys) > 0) {
span_y<Bx>(tp, stream, std::index_sequence<Ys...>());
} else {
errorQuda("Invalid tp.aux.y(=%d)", tp.aux.y);
}
}
}

template <size_t X, size_t... Xs>
void span_x(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<X, Xs...>)
{
constexpr int Bx = x_factors[X];
if (tp.aux.x == Bx) {
std::make_index_sequence<y_factors.size()> y_indices;
span_y<Bx>(tp, stream, y_indices);
} else {
if constexpr (sizeof...(Xs) > 0) {
span_x(tp, stream, std::index_sequence<Xs...>());
} else {
errorQuda("Invalid tp.aux.x(=%d)", tp.aux.x);
}
}
}

public:

/**
@brief invoke `_callable.template launch_mma<Bx, By, Bz, Bw>(tp, stream);` based on the tp.aux values
tp.aux.x -> Bx in x_atom_size * [factors of (x + x_atom_size - 1) / x_atom_size];
tp.aux.y -> By in y_atom_size * [factors of (y + y_atom_size - 1) / y_atom_size];
tp.aux.z -> Bz in z_atom_size * [factors of (z + z_atom_size - 1) / z_atom_size];
tp.aux.w -> Bw in w_atom_size * [factors of (w + w_atom_size - 1) / w_atom_size].
For example, if x_atom_size = 8, x = 48, then Bx can take values in [8, 16, 24, 48]; when tp.aux.x == 0,
Bx = 8; when tp.aux.x == 1, Bx = 16; when tp.aux.x == 2, Bx = 24; when tp.aux.x == 3, Bx = 48.
@param tp The TuneParam parameter
@param stream The stream parameter
*/
void expand(TuneParam &tp, const qudaStream_t &stream)
{
std::make_index_sequence<x_factors.size()> x_indices;
span_x(tp, stream, x_indices);
}

expand_aux_t(Callable &callable): _callable(callable) { }

/**
@brief Get the Bx value
@param tp The TuneParam parameter
*/
int get_x(const TuneParam &tp) const {
if (x_factors.get_index(tp.aux.x) >= x_factors.size()) {
errorQuda("Invalid tp.aux.x = %d\n", tp.aux.x);
}
return tp.aux.x;
}

/**
@brief Get the By value
@param tp The TuneParam parameter
*/
int get_y(const TuneParam &tp) const {
if (y_factors.get_index(tp.aux.y) >= y_factors.size()) {
errorQuda("Invalid tp.aux.y = %d\n", tp.aux.y);
}
return tp.aux.y;
}

/**
@brief Get the Bz value
@param tp The TuneParam parameter
*/
int get_z(const TuneParam &tp) const {
if (z_factors.get_index(tp.aux.z) >= z_factors.size()) {
errorQuda("Invalid tp.aux.z = %d\n", tp.aux.z);
}
return tp.aux.z;
}

/**
@brief Get the Bw value
@param tp The TuneParam parameter
*/
int get_w(const TuneParam &tp) const {
if (w_factors.get_index(tp.aux.w) >= w_factors.size()) {
errorQuda("Invalid tp.aux.w = %d\n", tp.aux.w);
}
return tp.aux.w;
}

template <unsigned int Int, unsigned int Multiple>
bool advancer(int &aux, TuneParam &param, const IntFactorArray<Int, Multiple> &factors) const {
if (factors.get_index(aux) < factors.size() - 1) {
aux = factors[factors.get_index(aux) + 1];
return _callable.set_mma_param(param);
} else {
return false;
}
}

/**
@brief Advance to the next possible aux value and return true; return false we have gone to the last
possible value
@return whether or not an advance is performed
@param tp The TuneParam parameter
*/
bool advance_aux(TuneParam &param) const
{
if (advancer(param.aux.x, param, x_factors)) {
return true;
} else {
param.aux.x = x_atom_size;
if (advancer(param.aux.y, param, y_factors)) {
return true;
} else {
param.aux.y = y_atom_size;
if (advancer(param.aux.z, param, z_factors)) {
return true;
} else {
param.aux.z = z_atom_size;
if (advancer(param.aux.w, param, w_factors)) {
return true;
} else {
param.aux.w = w_atom_size;
return false;
}
}
}
}
}

/**
@brief Initialize aux
@param tp The TuneParam parameter
*/
void init_aux(TuneParam &param) const {
param.aux.x = x_atom_size;
param.aux.y = y_atom_size;
param.aux.z = z_atom_size;
param.aux.w = w_atom_size;
}

};

}
2 changes: 1 addition & 1 deletion include/gauge_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ namespace quda {
template <typename Float, typename storeFloat>
struct fieldorder_wrapper {
using value_type = Float;
using store_type = storeFloat;
using store_t = storeFloat;
complex<storeFloat> *v;
const unsigned int idx;

Expand Down
42 changes: 14 additions & 28 deletions include/int_factor_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,6 @@
namespace quda
{

inline unsigned int numFactors(unsigned int Int)
{
unsigned int i = 0;
for (unsigned int j = 1u; j <= Int; j++) {
if (Int % j == 0) { i++; }
}
return i;
}

/**
* @brief A struct containing a compile time generated array
* containing factors of an integer.
*/
inline auto get_int_factor_array(unsigned int Int)
{
std::vector<unsigned int> _out(numFactors(Int));
unsigned int i = 0;
for (unsigned int j = 1u; j <= Int; j++) {
if (Int % j == 0) {
_out[i] = j;
i++;
}
}
return _out;
}

/**
* @brief compute number of factors of an integer
*
Expand All @@ -48,7 +22,7 @@ namespace quda
* @brief A struct containing a compile time generated array
* containing factors of an integer.
*/
template <unsigned int Int> struct IntFactorArray {
template <unsigned int Int, unsigned int Multiple> struct IntFactorArray {

array<unsigned int, numFactors<Int>()> data_;

Expand All @@ -72,7 +46,19 @@ namespace quda
* @brief read only constant index operator[]
* @param i the index to look up
*/
constexpr unsigned int operator[](int i) const noexcept { return data_[i]; }
constexpr unsigned int operator[](int i) const noexcept { return Multiple * data_[i]; }

constexpr unsigned int get_index(unsigned int value) const noexcept
{
unsigned int i = 0;
for (; i < numFactors<Int>(); i++) {
if (Multiple * data_[i] == static_cast<unsigned int>(value)) {
return i;
}
}
return i;
}

}; // end struct

} // namespace quda
10 changes: 10 additions & 0 deletions include/int_list.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once

namespace quda {

/**
@brief This is a dummy struct that wraps around a list of integers
*/
template <int... Ints> struct IntList { };

}
4 changes: 3 additions & 1 deletion include/kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ namespace quda

enum class use_kernel_arg_p { FALSE, TRUE, ALWAYS };

template <use_kernel_arg_p use_kernel_arg_ = use_kernel_arg_p::TRUE> struct kernel_param {
template <use_kernel_arg_p use_kernel_arg_ = use_kernel_arg_p::TRUE, bool check_bounds_ = true>
struct kernel_param {
static constexpr use_kernel_arg_p use_kernel_arg = use_kernel_arg_;
static constexpr bool check_bounds = check_bounds_;
dim3 threads; /** number of active threads required */
int comms_rank; /** per process value of comm_rank() */
int comms_rank_global; /** per process value comm_rank_global() */
Expand Down
Loading
Loading