Skip to content

Commit

Permalink
add missing file
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Sep 18, 2023
1 parent 2cd916a commit af8d555
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 229 deletions.
229 changes: 0 additions & 229 deletions include/targets/cuda/special_ops_target.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,137 +11,6 @@ namespace quda {
}
};

#if 0
// blockSync
template <typename ...T>
inline void blockSync(SpecialOps<T...> *ops) {
static_assert(hasBlockSync<T...>);
//if (ops->ndi == nullptr) {
// errorQuda("SpecialOps not set");
//}
#ifdef __SYCL_DEVICE_ONLY__
sycl::group_barrier(ops->ndi->get_group());
#endif
}
template <typename ...T> inline void blockSync(SpecialOps<T...> ops) { blockSync(&ops); }

template <typename ...T> static constexpr bool isOpConcurrent = false;
template <typename ...T> static constexpr bool isOpConcurrent<op_Concurrent<T...>> = true;

template <typename T, typename ...U> static constexpr int getOpIndex = 0;
template <typename T, typename ...U> static constexpr int getOpIndex<T,op_Concurrent<U...>> = getOpIndex<T,U...>;
template <typename T, typename U, typename ...V> static constexpr int getOpIndex<T, U, V...> =
std::is_same_v<T,U> ? 0 : (1 + getOpIndex<T,V...>);

// getSpecialOp
template <typename U, int n = 0, typename ...T>
inline SpecialOpsType<U,n> getSpecialOp(const SpecialOps<T...> &ops) {
if constexpr (!isOpConcurrent<U> && sizeof...(T) == 1 && isOpConcurrent<T...>) {
static constexpr int i = getOpIndex<U, T...>;
return getSpecialOp<T...,i>(ops);
} else {
static_assert(hasSpecialOpType<U,T...>);
//if (ops->ndi == nullptr || ops->smem == nullptr) {
// errorQuda("SpecialOps not set");
//}
SpecialOpsType<U,n> s;
s.ndi = ops.ndi;
//s.smem = ops->smem + sharedMemOffset<U,n>()(ops->ndi->get_local_range()); // FIXME: need to pass arg
s.smem = ops.smem + sharedMemOffset<U,n>()(getBlockDim()); // FIXME: need to pass arg
return s;
}
}
template <typename U, int n = 0, typename ...T>
inline SpecialOpsType<U,n> getSpecialOp(const SpecialOps<T...> *ops) { return getSpecialOp<U,n>(*ops); }
template <typename U, int n = 0> struct getSpecialOpF {
template <typename T> inline SpecialOpsType<U,n> operator()(const T &ops) { return getSpecialOp<U,n>(ops); }
};

// getDependentOps
template <typename U, int n = 0, typename ...T>
inline SpecialOpDependencies<SpecialOpsType<U,n>> getDependentOps(const SpecialOps<T...> &ops) {
static_assert(hasSpecialOpType<U,T...>);
//if (ops->ndi == nullptr || ops->smem == nullptr) {
//errorQuda("SpecialOps not set");
//}
SpecialOpDependencies<SpecialOpsType<U,n>> s;
s.ndi = ops.ndi;
//s.smem = ops->smem + sharedMemOffset<U,n>()(ops->ndi->get_local_range()); // FIXME: need to pass arg
s.smem = ops.smem + sharedMemOffset<U,n>()(getBlockDim()); // FIXME: need to pass arg
return s;
}

// getSharedMemPtr
#if 0
template <typename ...T>
//SpecialOpsElemType<T...> *getSharedMemPtr(SpecialOps<T...> *ops) {
sycl::local_ptr<SpecialOpsElemType<T...>> getSharedMemPtr(SpecialOps<T...> *ops) {
static_assert(!std::is_same_v<SpecialOpsElemType<T...>,void>);
//return reinterpret_cast<SpecialOpsElemType<T...>*>(ops->smem);
//return reinterpret_cast<SpecialOpsElemType<T...>*>(ops->smem.get());
//sycl::local_ptr<SpecialOpsElemType<T...>> smem = ops->smem.get();
//return smem.get();
//auto p = ops->smem.get();
sycl::local_ptr<void> v(ops->smem);
sycl::local_ptr<SpecialOpsElemType<T...>> p(v);
return p;
//sycl::local_ptr<SpecialOpsElemType<T...>> smem;
//using LT = decltype(smem.get());
//LT pt = reinterpret_cast<LT>(p);
//sycl::local_ptr<SpecialOpsElemType<T...>> smem2(pt);
//return smem2;
//return reinterpret_cast<SpecialOpsElemType<T...>*>(0);
}
template <typename ...T>
inline SpecialOpsElemType<T...> *getSharedMemPtr(SpecialOps<T...> ops) { return getSharedMemPtr(&ops); }
#endif

template <typename T, typename S, typename O = op_SharedMemory<T,S>>
inline sycl::local_ptr<T> getSharedMemPtr(const only_SharedMemory<T,S> &ops) {
//if (ops->ndi == nullptr || ops->smem == nullptr) {
//errorQuda("SpecialOps not set");
//}
sycl::local_ptr<void> v(ops.smem);
sycl::local_ptr<T> p(v);
return p;
}
//template <typename T, typename S>
//inline sycl::local_ptr<T> getSharedMemPtr(only_SharedMemory<T,S> ops) { return getSharedMemPtr(&ops); }
template <typename O, typename T, typename U, typename ...V>
inline auto getSharedMemPtr(const SpecialOps<T,U,V...> &ops) {
SpecialOps<O> op = getSpecialOp<O>(ops);
return getSharedMemPtr(op);
}

template <typename T, typename O>
inline auto getSharedMemory(O *ops)
{
auto s = getSpecialOp<T>(ops);
return getSharedMemPtr(s);
}

// base operation dependencies
struct depNone {};
template <> struct sharedMemSizeS<depNone> {
template <typename ...Arg>
static constexpr unsigned int size(dim3 block, Arg &...arg) { return 0; }
};

struct depFullBlock {};
template <> struct sharedMemSizeS<depFullBlock> {
template <typename ...Arg>
static constexpr unsigned int size(dim3 block, Arg &...arg) { return 0; }
};

template <typename T, typename S>
struct depSharedMem {};
template <typename T, typename S> struct sharedMemSizeS<depSharedMem<T,S>> {
template <typename ...Arg>
static constexpr unsigned int size(dim3 block, Arg &...arg) { return S().template size<T>(block, arg...); }
};

#endif

// op implementations
struct op_blockSync : op_BaseT<void> {
//using dependencies = depFullBlock;
Expand All @@ -157,102 +26,4 @@ namespace quda {
static constexpr unsigned int shared_mem_size(dim3, Arg &...arg) { return 0; }
};

#if 0

template <typename T, int N>
struct op_thread_array : op_BaseT<T,N> {
//using dependencies = depNone;
using dependencies = op_SharedMemory<array<T,N>,opSizeBlock>;
};

template <typename T>
struct op_BlockReduce : op_BaseT<T> {
using concurrentOps = op_Concurrent<op_blockSync,op_SharedMemory<T,opSizeBlockDivWarp>>;
using opBlockSync = getSpecialOpF<concurrentOps,0>;
using opSharedMem = getSpecialOpF<concurrentOps,1>;
//using specialOps = SpecialOps<concurrentOps>;
using dependencies = concurrentOps;
};

template <typename T, typename D>
struct op_SharedMemoryCache : op_BaseT<T> {
template <typename ...Arg> static constexpr dim3 dims(dim3 block, Arg &...arg) { return D::dims(block, arg...); }
using dependencies = op_Sequential<op_blockSync,op_SharedMemory<T,opSizeDims<D>>>;
};

template <typename T, typename S>
struct op_SharedMemory : op_BaseT<T> {
using dependencies = depSharedMem<T,S>;
template <typename ...Arg>
static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) { return S::template size<T>(block, arg...); }
};

// needsFullWarp?

// needsFullBlock
#if 0
template <typename T> static constexpr bool needsFullBlock = needsFullBlock<getSpecialOps<T>>;
template <typename ...T> static constexpr bool needsFullBlockImpl = (needsFullBlockImpl<T> || ...);
template <> static constexpr bool needsFullBlockImpl<depNone> = false;
template <> static constexpr bool needsFullBlockImpl<depFullBlock> = true;
template <typename T, typename S> static constexpr bool needsFullBlockImpl<depSharedMem<T,S>> = false;
template <typename ...T> static constexpr bool needsFullBlockImpl<op_Concurrent<T...>> = needsFullBlockImpl<T...>;
template <typename ...T> static constexpr bool needsFullBlockImpl<op_Sequential<T...>> = needsFullBlockImpl<T...>;
template <typename T> static constexpr bool needsFullBlockF() {
if constexpr (std::is_base_of<op_Base,T>::value) {
return needsFullBlockImpl<typename T::dependencies>;
} else {
//if constexpr (hasSpecialOps<T>) {
//return needsFullBlock<getSpecialOps<T>>;
//} else {
//return false;
return needsFullBlock<typename T::dependentOps>;
//}
}
}
template <typename T> static constexpr bool needsFullBlockImpl<T> = needsFullBlockF<T>();
template <> static constexpr bool needsFullBlock<NoSpecialOps> = false;
template <typename ...T> static constexpr bool needsFullBlock<SpecialOps<T...>> = needsFullBlockImpl<T...>;
#else
template <typename T> static constexpr bool needsFullBlock = needsFullBlock<getSpecialOps<T>>;
template <typename ...T> static constexpr bool needsFullBlock<SpecialOps<T...>> = (needsFullBlock<T> || ...);
template <> static constexpr bool needsFullBlock<NoSpecialOps> = false;
#endif


// needsSharedMem
#if 0
template <typename T> static constexpr bool needsSharedMem = needsSharedMem<getSpecialOps<T>>;
template <typename ...T> static constexpr bool needsSharedMemImpl = (needsSharedMemImpl<T> || ...);
template <> static constexpr bool needsSharedMemImpl<depNone> = false;
template <> static constexpr bool needsSharedMemImpl<depFullBlock> = false;
template <typename T, typename S> static constexpr bool needsSharedMemImpl<depSharedMem<T,S>> = true;
template <typename ...T> static constexpr bool needsSharedMemImpl<op_Concurrent<T...>> = needsSharedMemImpl<T...>;
template <typename ...T> static constexpr bool needsSharedMemImpl<op_Sequential<T...>> = needsSharedMemImpl<T...>;
template <typename T> static constexpr bool needsSharedMemF() {
if constexpr (std::is_base_of<op_Base,T>::value) {
//if constexpr (is_instance<T,op_Base>) {
return needsSharedMemImpl<typename T::dependencies>;
} else {
//if constexpr (hasSpecialOps<T>) {
//return needsSharedMem<getSpecialOps<T>>;
//} else {
//return false;
return needsSharedMem<typename T::dependentOps>;
//}
}
}
template <typename T> static constexpr bool needsSharedMemImpl<T> = needsSharedMemF<T>();
template <> static constexpr bool needsSharedMem<NoSpecialOps> = false;
template <typename ...T> static constexpr bool needsSharedMem<SpecialOps<T...>> = needsSharedMemImpl<T...>;
#else
//template <typename ...T> static constexpr bool needsSharedMemImpl = (needsSharedMemImpl<T> || ...);
template <typename T> static constexpr bool needsSharedMemImpl = (T::shared_mem_size(dim3{8,8,8}) > 0);
template <typename... T> static constexpr bool needsSharedMemImpl<SpecialOps<T...>> = (needsSharedMemImpl<T> || ...);
template <typename T> static constexpr bool needsSharedMem = needsSharedMem<getSpecialOps<T>>;
template <typename... T> static constexpr bool needsSharedMem<SpecialOps<T...>> = (needsSharedMemImpl<T> || ...);
template <> static constexpr bool needsSharedMem<NoSpecialOps> = false;
#endif

#endif
}
29 changes: 29 additions & 0 deletions include/targets/hip/special_ops_target.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
#include <special_ops.h>

namespace quda {

// SpecialOps
template <typename ...T>
struct SpecialOps : SpecialOps_Base<T...> {
template <typename ...U> constexpr void setSpecialOps(const SpecialOps<U...> &ops) {
static_assert(std::is_same_v<SpecialOps<T...>,SpecialOps<U...>>);
}
};

// op implementations
struct op_blockSync : op_BaseT<void> {
//using dependencies = depFullBlock;
template <typename ...Arg>
static constexpr unsigned int shared_mem_size(dim3 block, Arg &...arg) { return 0; }
};

template <typename T>
struct op_warp_combine : op_BaseT<T> {
//using dependencies = depNone;
//using dependencies = depFullBlock;
template <typename ...Arg>
static constexpr unsigned int shared_mem_size(dim3, Arg &...arg) { return 0; }
};

}

0 comments on commit af8d555

Please sign in to comment.