Skip to content

Commit

Permalink
update CUDA target
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Sep 14, 2023
1 parent 206c4c6 commit 93a03f3
Show file tree
Hide file tree
Showing 24 changed files with 460 additions and 385 deletions.
1 change: 1 addition & 0 deletions include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <shmem_pack_helper.cuh>
#include <kernel_helper.h>
#include <tune_quda.h>
#include <special_ops.h>

#if defined(_NVHPC_CUDA)
#include <constant_kernel_arg.h>
Expand Down
25 changes: 14 additions & 11 deletions include/kernels/field_strength_tensor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ namespace quda
}
};

template <typename Arg>
__device__ __host__ inline void computeFmunuCore(const Arg &arg, int idx, int parity, int mu, int nu)
using computeFmunuCoreOps = SpecialOps<thread_array<int,4>>;
template <typename Ftor>
__device__ __host__ inline void computeFmunuCore(const Ftor &ftor, int idx, int parity, int mu, int nu)
{
using Link = Matrix<complex<typename Arg::Float>, 3>;
using Link = Matrix<complex<typename Ftor::Arg::Float>, 3>;
auto &arg = ftor.arg;

int x[4];
int X[4];
Expand All @@ -53,7 +55,7 @@ namespace quda
{ // U(x,mu) U(x+mu,nu) U[dagger](x+nu,mu) U[dagger](x,nu)

// load U(x)_(+mu)
thread_array<int, 4> dx = {};
thread_array<int, 4> dx{ftor};
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), parity);

// load U(x+mu)_(+nu)
Expand All @@ -76,7 +78,7 @@ namespace quda
{ // U(x,nu) U[dagger](x+nu-mu,mu) U[dagger](x-mu,nu) U(x-mu, mu)

// load U(x)_(+nu)
thread_array<int, 4> dx = {};
thread_array<int, 4> dx{ftor};
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), parity);

// load U(x+nu)_(-mu) = U(x+nu-mu)_(+mu)
Expand All @@ -103,7 +105,7 @@ namespace quda
{ // U[dagger](x-nu,nu) U(x-nu,mu) U(x+mu-nu,nu) U[dagger](x,mu)

// load U(x)_(-nu)
thread_array<int, 4> dx = {};
thread_array<int, 4> dx{ftor};
dx[nu]--;
Link U1 = arg.u(nu, linkIndexShift(x, dx, X), 1 - parity);
dx[nu]++;
Expand All @@ -130,7 +132,7 @@ namespace quda
{ // U[dagger](x-mu,mu) U[dagger](x-mu-nu,nu) U(x-mu-nu,mu) U(x-nu,nu)

// load U(x)_(-mu)
thread_array<int, 4> dx = {};
thread_array<int, 4> dx{ftor};
dx[mu]--;
Link U1 = arg.u(mu, linkIndexShift(x, dx, X), 1 - parity);
dx[mu]++;
Expand Down Expand Up @@ -166,15 +168,16 @@ namespace quda
// 3*18 + 12*198 = 54 + 2376 = 2430
{
F -= conj(F); // 18 real subtractions + one matrix conjugation
F *= static_cast<typename Arg::Float>(0.125); // 18 real multiplications
F *= static_cast<typename Ftor::Arg::Float>(0.125); // 18 real multiplications
// 36 floating point operations here
}

int munu_idx = (mu * (mu - 1)) / 2 + nu; // lower-triangular indexing
arg.f(munu_idx, idx, parity) = F;
}

template <typename Arg> struct ComputeFmunu {
template <typename Arg_> struct ComputeFmunu : computeFmunuCoreOps {
using Arg = Arg_;
const Arg &arg;
constexpr ComputeFmunu(const Arg &arg) : arg(arg) {}
static constexpr const char* filename() { return KERNEL_FILE; }
Expand All @@ -190,7 +193,7 @@ namespace quda
case 4: mu = 3, nu = 1; break;
case 5: mu = 3, nu = 2; break;
}
computeFmunuCore(arg, x_cb, parity, mu, nu);
computeFmunuCore(*this, x_cb, parity, mu, nu);
}
};

Expand Down
4 changes: 2 additions & 2 deletions include/kernels/gauge_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ namespace quda {
}
};

template <typename Arg> struct GaugeForce
template <typename Arg> struct GaugeForce : SpecialOps<thread_array<int,4>>
{
const Arg &arg;
constexpr GaugeForce(const Arg &arg) : arg(arg) {}
Expand All @@ -62,7 +62,7 @@ namespace quda {
// prod: current matrix product
// accum: accumulator matrix
Link link_prod, accum;
thread_array<int, 4> dx{};
thread_array<int, 4> dx{*this};

for (int i=0; i<arg.p.num_paths; i++) {
real coeff = arg.p.path_coeff[i];
Expand Down
4 changes: 2 additions & 2 deletions include/kernels/gauge_loop_trace.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace quda {
}
};

template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>
template <typename Arg> struct GaugeLoop : plus<typename Arg::reduce_t>, SpecialOps<thread_array<int,4>>
{
using reduce_t = typename Arg::reduce_t;
using plus<reduce_t>::operator();
Expand All @@ -71,7 +71,7 @@ namespace quda {
getCoords(x, x_cb, arg.X, parity);
for (int dr=0; dr<4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates

thread_array<int, 4> dx{};
thread_array<int, 4> dx{*this};

double coeff_loop = arg.factor * arg.p.path_coeff[path_id];
if (coeff_loop == 0) return operator()(loop_trace, value);
Expand Down
4 changes: 2 additions & 2 deletions include/kernels/multi_blas_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ namespace quda
#ifndef QUDA_FAST_COMPILE_REDUCE
constexpr bool enable_warp_split() { return false; }
#else
//constexpr bool enable_warp_split() { return true; }
constexpr bool enable_warp_split() { return false; }
constexpr bool enable_warp_split() { return true; }
//constexpr bool enable_warp_split() { return false; }
#endif

/**
Expand Down
6 changes: 6 additions & 0 deletions include/targets/cuda/atomic_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ namespace quda
atomic_fetch_add(reinterpret_cast<int *>(addr) + 3, val.w);
}

template <typename T>
__device__ __host__ inline void atomic_add_local(T *addr, T val)
{
atomic_fetch_add(addr, val);
}

template <bool is_device> struct atomic_fetch_abs_max_impl {
template <typename T> inline void operator()(T *addr, T val)
{
Expand Down
29 changes: 27 additions & 2 deletions include/targets/cuda/block_reduce_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,28 @@ namespace quda

#endif

/**
@brief block_reduce_param is used as a container for passing
non-type parameters to specialize block_reduce through the
target::dispatch
@tparam block_dim Block dimension of the reduction (1, 2 or 3)
@tparam batch_size Batch size of the reduction. Threads will be
ordered such that batch size is the slowest running index. Note
that batch_size > 1 requires block_dim <= 2.
*/
template <int block_dim_, int batch_size_ = 1> struct block_reduce_param {
static constexpr int block_dim = block_dim_;
static constexpr int batch_size = batch_size_;
static_assert(batch_size == 1 || block_dim <= 2, "Batching not possible with 3-d block reduction");
};

// pre-declaration of block_reduce that we wish to specialize
template <bool> struct block_reduce;
template <bool> struct block_reduce_impl;

/**
@brief CUDA specialization of block_reduce, building on the warp_reduce
*/
template <> struct block_reduce<true> {
template <> struct block_reduce_impl<true> {

template <int width_> struct warp_reduce_param {
static constexpr int width = width_;
Expand Down Expand Up @@ -273,6 +288,16 @@ namespace quda
}
};

template <typename T, int block_dim, int batch_size>
struct block_reduce {
template <typename S> HostDevice inline block_reduce(S &ops) {};
template <typename reducer_t>
HostDevice inline T apply(const T &value, bool async, int batch, bool all, const reducer_t &r)
{
return target::dispatch<block_reduce_impl>(value, async, batch, all, r, block_reduce_param<block_dim, batch_size>());
}
};

} // namespace quda

#include "../generic/block_reduce_helper.h"
6 changes: 4 additions & 2 deletions include/targets/cuda/reduce_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ namespace quda
constexpr auto n_batch_block = std::min(Arg::max_n_batch_block, device::max_block_size());
using BlockReduce = BlockReduce<T, Reducer::reduce_block_dim, n_batch_block>;

T aggregate = BlockReduce(target::thread_idx().z).Reduce(in, r);
//T aggregate = BlockReduce(target::thread_idx().z).Reduce(in, r);
SpecialOps<BlockReduce> ops{};
T aggregate = BlockReduce(ops, target::thread_idx().z).Reduce(in, r);

if (target::grid_dim().x == 1) { // short circuit where we have a single CTA - no need to do final reduction
write_result(arg, aggregate, idx);
Expand Down Expand Up @@ -276,7 +278,7 @@ namespace quda
i += target::block_size<2>();
}

sum = BlockReduce(target::thread_idx().z).Reduce(sum, r);
sum = BlockReduce(ops, target::thread_idx().z).Reduce(sum, r);

write_result(arg, sum, idx);
}
Expand Down
Loading

0 comments on commit 93a03f3

Please sign in to comment.