Skip to content

Commit

Permalink
Merge pull request #1380 from lattice/feature/sycl-merge
Browse files Browse the repository at this point in the history
add ThreadLocalCache
  • Loading branch information
weinbe2 authored Dec 20, 2023
2 parents a24bcfa + 418824e commit 1914dc3
Show file tree
Hide file tree
Showing 43 changed files with 918 additions and 795 deletions.
11 changes: 10 additions & 1 deletion include/kernels/block_transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ namespace quda
constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

struct CacheDims {
static constexpr dim3 dims(dim3 block)
{
block.x += 1;
block.z = 1;
return block;
}
};

/**
@brief Transpose between the two different orders of batched colorspinor fields:
- B: nVec -> spatial/N -> spin/color -> N, where N is for that in floatN
Expand All @@ -60,7 +69,7 @@ namespace quda
int parity = parity_color / Arg::nColor;
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;

SharedMemoryCache<color_spinor_t> cache({target::block_dim().x + 1, target::block_dim().y, 1});
SharedMemoryCache<color_spinor_t, CacheDims> cache;

int x_offset = target::block_dim().x * target::block_idx().x;
int v_offset = target::block_dim().y * target::block_idx().y;
Expand Down
15 changes: 11 additions & 4 deletions include/kernels/coarse_op_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <matrix_tile.cuh>
#include <target_device.h>
#include <kernel.h>
#include <shared_memory_cache_helper.h>

namespace quda {

Expand Down Expand Up @@ -1387,14 +1388,20 @@ namespace quda {
};

template <> struct storeCoarseSharedAtomic_impl<true> {
template <typename Arg>
using CacheT = complex<storeType>[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4]
[Arg::coarseSpin][Arg::coarseSpin];
template <typename Arg> using Cache = SharedMemoryCache<CacheT<Arg>, DimsStatic<2, 1, 1>>;

template <typename VUV, typename Pack, typename Arg>
inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, int parity, const Pack &pack, const Arg &arg)
{
using real = typename Arg::Float;
using TileType = typename Arg::vuvTileType;
const int dim_index = arg.dim_index % arg.Y_atomic.geometry;
__shared__ complex<storeType> X[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
__shared__ complex<storeType> Y[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4][Arg::coarseSpin][Arg::coarseSpin];
Cache<Arg> cache;
auto &X = cache.data()[0];
auto &Y = cache.data()[1];

int x_ = coarse_x_cb % arg.aggregates_per_block;
int tx = virtualThreadIdx(arg);
Expand All @@ -1416,7 +1423,7 @@ namespace quda {
}
}

__syncthreads();
cache.sync();

#pragma unroll
for (int i = 0; i < TileType::M; i++) {
Expand Down Expand Up @@ -1445,7 +1452,7 @@ namespace quda {
}
}

__syncthreads();
cache.sync();

if (tx < Arg::coarseSpin*Arg::coarseSpin && (parity == 0 || arg.parity_flip == 1) ) {

Expand Down
23 changes: 15 additions & 8 deletions include/kernels/color_spinor_pack.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,24 @@ namespace quda {
};

template <> struct site_max<true> {
template <typename Arg> struct CacheDims {
static constexpr int Ms = spins_per_thread<true>(Arg::nSpin);
static constexpr int Mc = colors_per_thread<true>(Arg::nColor);
static constexpr int color_spin_threads = (Arg::nSpin / Ms) * (Arg::nColor / Mc);
static constexpr dim3 dims(dim3 block)
{
// pad the shared block size to avoid bank conflicts for native ordering
if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size();
block.y = color_spin_threads; // state the y block since we know it at compile time
return block;
}
};

template <typename Arg> __device__ inline auto operator()(typename Arg::real thread_max, Arg &)
{
using real = typename Arg::real;
constexpr int Ms = spins_per_thread<true>(Arg::nSpin);
constexpr int Mc = colors_per_thread<true>(Arg::nColor);
constexpr int color_spin_threads = (Arg::nSpin/Ms) * (Arg::nColor/Mc);
auto block = target::block_dim();
// pad the shared block size to avoid bank conflicts for native ordering
if (Arg::is_native) block.x = ((block.x + device::warp_size() - 1) / device::warp_size()) * device::warp_size();
block.y = color_spin_threads; // state the y block since we know it at compile time
SharedMemoryCache<real> cache(block);
constexpr int color_spin_threads = CacheDims<Arg>::color_spin_threads;
SharedMemoryCache<real, CacheDims<Arg>> cache;
cache.save(thread_max);
cache.sync();
real this_site_max = static_cast<real>(0);
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/dslash_clover_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ namespace quda {

Mat A = arg.clover(x_cb, clover_parity, chirality);

SharedMemoryCache<half_fermion> cache(target::block_dim());
SharedMemoryCache<half_fermion> cache;

half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion
#pragma unroll
Expand Down
2 changes: 1 addition & 1 deletion include/kernels/dslash_coarse.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ namespace quda {
template <> struct dim_collapse<true> {
template <typename T, typename Arg> __device__ __host__ inline void operator()(T &out, int dir, int dim, const Arg &arg)
{
SharedMemoryCache<T> cache(target::block_dim());
SharedMemoryCache<T> cache;
// only need to write to shared memory if not master thread
if (dim > 0 || dir) cache.save(out);

Expand Down
10 changes: 5 additions & 5 deletions include/kernels/dslash_domain_wall_m5.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ namespace quda
if (mobius_m5::use_half_vector()) {
// if using shared-memory caching then load spinor field for my site into cache
typedef ColorSpinor<real, Arg::nColor, 4 / 2> HalfVector;
SharedMemoryCache<HalfVector> cache(target::block_dim());
SharedMemoryCache<HalfVector> cache;

{ // forwards direction
constexpr int proj_dir = dagger ? +1 : -1;
Expand Down Expand Up @@ -271,7 +271,7 @@ namespace quda
} else { // use_half_vector

// if using shared-memory caching then load spinor field for my site into cache
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
if (shared) {
if (sync) { cache.sync(); }
cache.save(in);
Expand Down Expand Up @@ -377,7 +377,7 @@ namespace quda
const auto inv = arg.inv;

// if using shared-memory caching then load spinor field for my site into cache
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
if (shared) {
// cache.save(arg.in(s_ * arg.volume_4d_cb + x_cb, parity));
if (sync) { cache.sync(); }
Expand Down Expand Up @@ -436,7 +436,7 @@ namespace quda
Vector out;

if (mobius_m5::use_half_vector()) {
SharedMemoryCache<HalfVector> cache(target::block_dim());
SharedMemoryCache<HalfVector> cache;

{ // first do R
constexpr int proj_dir = dagger ? -1 : +1;
Expand Down Expand Up @@ -495,7 +495,7 @@ namespace quda
out += l.reconstruct(4, proj_dir);
}
} else { // use_half_vector
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
if (shared) {
if (sync) { cache.sync(); }
cache.save(in);
Expand Down
4 changes: 2 additions & 2 deletions include/kernels/dslash_mobius_eofa.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace quda
using real = typename Arg::real;
typedef ColorSpinor<real, Arg::nColor, 4> Vector;

SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;

Vector out;
cache.save(arg.in(s * arg.volume_4d_cb + x_cb, parity));
Expand Down Expand Up @@ -185,7 +185,7 @@ namespace quda
typedef ColorSpinor<real, Arg::nColor, 4> Vector;

const auto sherman_morrison = arg.sherman_morrison;
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
cache.save(arg.in(s * arg.volume_4d_cb + x_cb, parity));
cache.sync();

Expand Down
2 changes: 1 addition & 1 deletion include/kernels/dslash_ndeg_twisted_clover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ namespace quda
// apply the chiral and flavor twists
// use consistent load order across s to ensure better cache locality
Vector x = arg.x(my_flavor_idx, my_spinor_parity);
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
cache.save(x);

x.toRel(); // switch to chiral basis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace quda

int chirality = flavor; // relabel flavor as chirality

SharedMemoryCache<HalfVector> cache(target::block_dim());
SharedMemoryCache<HalfVector> cache;

enum swizzle_direction {
FORWARDS = 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ namespace quda
}

if (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in
SharedMemoryCache<Vector> cache(target::block_dim());
SharedMemoryCache<Vector> cache;
if (isComplete<mykernel_type>(arg, coord) && active) {
// to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it
cache.save(out);
Expand Down
5 changes: 3 additions & 2 deletions include/kernels/gauge_stout.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <su3_project.cuh>
#include <kernels/gauge_utils.cuh>
#include <kernel.h>
#include <thread_local_cache.h>

namespace quda
{
Expand Down Expand Up @@ -134,8 +135,8 @@ namespace quda
}

Link U, Q;
SharedMemoryCache<Link> Stap(target::block_dim());
SharedMemoryCache<Link> Rect(target::block_dim(), sizeof(Link));
ThreadLocalCache<Link, 0, computeStapleRectangleOps> Stap;
ThreadLocalCache<Link, 0, decltype(Stap)> Rect; // offset by Stap type to ensure non-overlapping allocations

// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
// and the 1x2 and 2x1 rectangles of length 5. From the following paper:
Expand Down
2 changes: 2 additions & 0 deletions include/kernels/gauge_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (2*18 + 4*198) = dims*828
using computeStapleOps = thread_array<int, 4>;
template <typename Arg, typename Staple, typename Int>
__host__ __device__ inline void computeStaple(const Arg &arg, const int *x, const Int *X, const int parity, const int nu, Staple &staple, const int dir_ignore)
{
Expand Down Expand Up @@ -94,6 +95,7 @@ namespace quda
// matrix+matrix = 18 floating-point ops
// => Total number of floating point ops per function call
// dims * (8*18 + 28*198) = dims*5688
using computeStapleRectangleOps = thread_array<int, 4>;
template <typename Arg, typename Staple, typename Rectangle, typename Int>
__host__ __device__ inline void computeStapleRectangle(const Arg &arg, const int *x, const Int *X, const int parity, const int nu,
Staple &staple, Rectangle &rectangle, const int dir_ignore)
Expand Down
5 changes: 3 additions & 2 deletions include/kernels/gauge_wilson_flow.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <kernels/gauge_utils.cuh>
#include <su3_project.cuh>
#include <kernel.h>
#include <thread_local_cache.h>

namespace quda
{
Expand Down Expand Up @@ -71,8 +72,8 @@ namespace quda
// This function gets stap = S_{mu,nu} i.e., the staple of length 3,
// and the 1x2 and 2x1 rectangles of length 5. From the following paper:
// https://arxiv.org/abs/0801.1165
SharedMemoryCache<Link> Stap(target::block_dim());
SharedMemoryCache<Link> Rect(target::block_dim(), sizeof(Link)); // offset to ensure non-overlapping allocations
ThreadLocalCache<Link, 0, computeStapleRectangleOps> Stap;
ThreadLocalCache<Link, 0, decltype(Stap)> Rect; // offset by Stap type to ensure non-overlapping allocations
computeStapleRectangle(arg, x, arg.E, parity, dir, Stap, Rect, Arg::wflow_dim);
Z = arg.coeff1x1 * static_cast<const Link &>(Stap) + arg.coeff2x1 * static_cast<const Link &>(Rect);
break;
Expand Down
Loading

0 comments on commit 1914dc3

Please sign in to comment.