Skip to content

Commit

Permalink
fixes to CUDA target
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Sep 15, 2023
1 parent 93a03f3 commit 08af61e
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 12 deletions.
2 changes: 1 addition & 1 deletion include/kernels/block_orthogonalize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ namespace quda {
}

template <bool allthreads = false>
__device__ __host__ inline void operator()(dim3 block, dim3 thread, bool active = true)
__device__ __host__ inline void operator()(dim3 block, dim3 thread, bool /*active*/ = true)
{
int x_coarse = block.x;
int x_fine_offset = thread.x;
Expand Down
4 changes: 3 additions & 1 deletion include/kernels/coarse_op_kernel_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <block_reduce_helper.h>
#include <kernel.h>
#include <kernels/coarse_op_kernel.cuh>
#include <special_ops_target.h>

namespace quda
{
Expand Down Expand Up @@ -161,7 +162,8 @@ namespace quda

if (Arg::compute_max) {
constexpr int block_dim = 3;
unsigned aggregate = BlockReduce<unsigned, block_dim>().Max(__float_as_uint(max));
SpecialOps<BlockReduce<unsigned, block_dim>> ops{};
unsigned aggregate = BlockReduce<unsigned, block_dim>{ops}.Max(__float_as_uint(max));
if (threadIdx.y == 0 && threadIdx.z == 0) atomic_fetch_abs_max(arg.max_d, __uint_as_float(aggregate));
}
}
Expand Down
3 changes: 2 additions & 1 deletion include/kernels/coarse_op_preconditioned_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ namespace quda
}
if (Arg::compute_max) {
constexpr int block_dim = 3;
unsigned aggregate = BlockReduce<unsigned, block_dim>().Max(__float_as_uint(max));
SpecialOps<BlockReduce<unsigned, block_dim>> ops{};
unsigned aggregate = BlockReduce<unsigned, block_dim>{ops}.Max(__float_as_uint(max));
if (threadIdx.y == 0 && threadIdx.z == 0) atomic_fetch_abs_max(arg.max_d, __uint_as_float(aggregate));
}
}
Expand Down
17 changes: 16 additions & 1 deletion include/targets/cuda/block_reduce_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <target_device.h>
#include <reducer.h>
#include <helpers.h>

/**
@file block_reduce_helper.h
Expand Down Expand Up @@ -213,6 +214,17 @@ namespace quda
// pre-declaration of block_reduce that we wish to specialize
template <bool> struct block_reduce_impl;

/**
@brief Dummy generic implementation of block_reduce
*/
template <bool is_device> struct block_reduce_impl {
template <typename T, typename reducer_t, typename param_t>
T operator()(const T &value, bool, int, bool, reducer_t, param_t)
{
return value;
}
};

/**
@brief CUDA specialization of block_reduce, building on the warp_reduce
*/
Expand Down Expand Up @@ -290,7 +302,10 @@ namespace quda

template <typename T, int block_dim, int batch_size>
struct block_reduce {
template <typename S> HostDevice inline block_reduce(S &ops) {};
template <typename Ops> HostDevice inline block_reduce(Ops &) {};
template <typename ...Arg> static constexpr size_t shared_mem_size(dim3 block) {
return SizeBlockDivWarp::size(block) * sizeof(T);
}
template <typename reducer_t>
HostDevice inline T apply(const T &value, bool async, int batch, bool all, const reducer_t &r)
{
Expand Down
3 changes: 2 additions & 1 deletion include/targets/cuda/mdw_dslash5_tensor_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ namespace quda
warp_max[0] = fmaxf(warp_max[0], warp_max[1]);

constexpr int block_dim = 2;
return BlockReduce<float, block_dim>().AllMax(warp_max[0]) / target_scale;
SpecialOps<BlockReduce<float, block_dim>> ops{};
return BlockReduce<float, block_dim>{ops}.AllMax(warp_max[0]) / target_scale;
}

// Actually does more than the function name suggests.
Expand Down
1 change: 1 addition & 0 deletions include/targets/cuda/shared_memory_helper.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <target_device.h>
#include <special_ops.h>

/**
@file shared_memory_helper.h
Expand Down
30 changes: 23 additions & 7 deletions include/targets/cuda/target_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,40 @@ namespace quda
#ifdef _NVHPC_CUDA

// nvc++: run-time dispatch using if target
template <template <bool, typename...> class f, typename... Args> __host__ __device__ auto dispatch(Args &&...args)
template <template <bool, typename ...> class f, auto ...Params, typename ...Args> __host__ __device__ auto dispatch(Args &&...args)
{
if target (nv::target::is_device) {
return f<true>()(args...);
} else {
return f<false>()(args...);
if constexpr (sizeof...(Params) == 0) {
return f<true>()(args...);
} else {
return f<true>().template operator()<Params...>(args...);
}
} else {
if constexpr (sizeof...(Params) == 0) {
return f<false>()(args...);
} else {
return f<false>().template operator()<Params...>(args...);
}
}
}

#else

// nvcc or clang: compile-time dispatch
template <template <bool, typename...> class f, typename... Args> __host__ __device__ auto dispatch(Args &&...args)
template <template <bool, typename ...> class f, auto ...Params, typename ...Args> __host__ __device__ auto dispatch(Args &&...args)
{
#ifdef __CUDA_ARCH__
return f<true>()(args...);
if constexpr (sizeof...(Params) == 0) {
return f<true>()(args...);
} else {
return f<true>().template operator()<Params...>(args...);
}
#else
return f<false>()(args...);
if constexpr (sizeof...(Params) == 0) {
return f<false>()(args...);
} else {
return f<false>().template operator()<Params...>(args...);
}
#endif
}

Expand Down
6 changes: 6 additions & 0 deletions include/targets/generic/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ namespace quda
}
};

struct SizeBlockDivWarp {
static constexpr unsigned int size(dim3 b) {
return (b.x * b.y * b.z + device::warp_size() - 1)/device::warp_size();
}
};

template <typename D, int N = 1> struct SizeDims {
static constexpr unsigned int size(dim3 block) {
dim3 dims = D::dims(block);
Expand Down

0 comments on commit 08af61e

Please sign in to comment.