-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moving remaining stats prims from cuml (#507)
Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Divye Gala (https://github.com/divyegala) URL: #507
- Loading branch information
Showing
23 changed files
with
3,027 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright (c) 2019-2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <raft/cuda_utils.cuh> | ||
#include <vector> | ||
|
||
// Taken from: | ||
// https://github.com/teju85/programming/blob/master/euler/include/seive.h | ||
|
||
namespace raft { | ||
namespace common { | ||
|
||
/** | ||
* @brief Implementation of 'Seive of Eratosthenes' | ||
*/ | ||
class Seive { | ||
public: | ||
/** | ||
* @param _num number of integers for which seive is needed | ||
*/ | ||
Seive(unsigned _num) | ||
{ | ||
N = _num; | ||
generateSeive(); | ||
} | ||
|
||
/** | ||
* @brief Check whether a number is prime or not | ||
* @param num number to be checked | ||
* @return true if the 'num' is prime, else false | ||
*/ | ||
bool isPrime(unsigned num) const | ||
{ | ||
unsigned mask, pos; | ||
if (num <= 1) { return false; } | ||
if (num == 2) { return true; } | ||
if (!(num & 1)) { return false; } | ||
getMaskPos(num, mask, pos); | ||
return (seive[pos] & mask); | ||
} | ||
|
||
private: | ||
void generateSeive() | ||
{ | ||
auto sqN = fastIntSqrt(N); | ||
auto size = raft::ceildiv<unsigned>(N, sizeof(unsigned) * 8); | ||
seive.resize(size); | ||
// assume all to be primes initially | ||
for (auto& itr : seive) { | ||
itr = 0xffffffffu; | ||
} | ||
unsigned cid = 0; | ||
unsigned cnum = getNum(cid); | ||
while (cnum <= sqN) { | ||
do { | ||
++cid; | ||
cnum = getNum(cid); | ||
if (isPrime(cnum)) { break; } | ||
} while (cnum <= sqN); | ||
auto cnum2 = cnum << 1; | ||
// 'unmark' all the 'odd' multiples of the current prime | ||
for (unsigned i = 3, num = i * cnum; num <= N; i += 2, num += cnum2) { | ||
unmark(num); | ||
} | ||
} | ||
} | ||
|
||
unsigned getId(unsigned num) const { return (num >> 1); } | ||
|
||
unsigned getNum(unsigned id) const | ||
{ | ||
if (id == 0) { return 2; } | ||
return ((id << 1) + 1); | ||
} | ||
|
||
void getMaskPos(unsigned num, unsigned& mask, unsigned& pos) const | ||
{ | ||
pos = getId(num); | ||
mask = 1 << (pos & 0x1f); | ||
pos >>= 5; | ||
} | ||
|
||
void unmark(unsigned num) | ||
{ | ||
unsigned mask, pos; | ||
getMaskPos(num, mask, pos); | ||
seive[pos] &= ~mask; | ||
} | ||
|
||
// REF: http://www.azillionmonkeys.com/qed/ulerysqroot.pdf | ||
unsigned fastIntSqrt(unsigned val) | ||
{ | ||
unsigned g = 0; | ||
auto bshft = 15u, b = 1u << bshft; | ||
do { | ||
unsigned temp = ((g << 1) + b) << bshft--; | ||
if (val >= temp) { | ||
g += b; | ||
val -= temp; | ||
} | ||
} while (b >>= 1); | ||
return g; | ||
} | ||
|
||
/** find all primes till this number */ | ||
unsigned N; | ||
/** the seive */ | ||
std::vector<unsigned> seive; | ||
}; | ||
}; // namespace common | ||
}; // namespace raft |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Copyright (c) 2021-2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <raft/cuda_utils.cuh> | ||
#include <utility> // pair | ||
|
||
namespace raft { | ||
|
||
// TODO move to raft https://github.com/rapidsai/raft/issues/90 | ||
/** helper method to get the compute capability version numbers */ | ||
inline std::pair<int, int> getDeviceCapability() | ||
{ | ||
int devId; | ||
RAFT_CUDA_TRY(cudaGetDevice(&devId)); | ||
int major, minor; | ||
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devId)); | ||
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devId)); | ||
return std::make_pair(major, minor); | ||
} | ||
|
||
/** | ||
* @brief Batched warp-level sum reduction | ||
* | ||
* @tparam T data type | ||
* @tparam NThreads Number of threads in the warp doing independent reductions | ||
* | ||
* @param[in] val input value | ||
* @return for the first "group" of threads, the reduced value. All | ||
* others will contain unusable values! | ||
* | ||
* @note Why not cub? Because cub doesn't seem to allow working with arbitrary | ||
* number of warps in a block and also doesn't support this kind of | ||
* batched reduction operation | ||
* @note All threads in the warp must enter this function together | ||
* | ||
* @todo Expand this to support arbitrary reduction ops | ||
*/ | ||
template <typename T, int NThreads> | ||
DI T batchedWarpReduce(T val) | ||
{ | ||
#pragma unroll | ||
for (int i = NThreads; i < raft::WarpSize; i <<= 1) { | ||
val += raft::shfl(val, raft::laneId() + i); | ||
} | ||
return val; | ||
} | ||
|
||
/** | ||
* @brief 1-D block-level batched sum reduction | ||
* | ||
* @tparam T data type | ||
* @tparam NThreads Number of threads in the warp doing independent reductions | ||
* | ||
* @param val input value | ||
* @param smem shared memory region needed for storing intermediate results. It | ||
* must alteast be of size: `sizeof(T) * nWarps * NThreads` | ||
* @return for the first "group" of threads in the block, the reduced value. | ||
* All others will contain unusable values! | ||
* | ||
* @note Why not cub? Because cub doesn't seem to allow working with arbitrary | ||
* number of warps in a block and also doesn't support this kind of | ||
* batched reduction operation | ||
* @note All threads in the block must enter this function together | ||
* | ||
* @todo Expand this to support arbitrary reduction ops | ||
*/ | ||
template <typename T, int NThreads> | ||
DI T batchedBlockReduce(T val, char* smem) | ||
{ | ||
auto* sTemp = reinterpret_cast<T*>(smem); | ||
constexpr int nGroupsPerWarp = raft::WarpSize / NThreads; | ||
static_assert(raft::isPo2(nGroupsPerWarp), "nGroupsPerWarp must be a PO2!"); | ||
const int nGroups = (blockDim.x + NThreads - 1) / NThreads; | ||
const int lid = raft::laneId(); | ||
const int lgid = lid % NThreads; | ||
const int gid = threadIdx.x / NThreads; | ||
const auto wrIdx = (gid / nGroupsPerWarp) * NThreads + lgid; | ||
const auto rdIdx = gid * NThreads + lgid; | ||
for (int i = nGroups; i > 0;) { | ||
auto iAligned = ((i + nGroupsPerWarp - 1) / nGroupsPerWarp) * nGroupsPerWarp; | ||
if (gid < iAligned) { | ||
val = batchedWarpReduce<T, NThreads>(val); | ||
if (lid < NThreads) sTemp[wrIdx] = val; | ||
} | ||
__syncthreads(); | ||
i /= nGroupsPerWarp; | ||
if (i > 0) { val = gid < i ? sTemp[rdIdx] : T(0); } | ||
__syncthreads(); | ||
} | ||
return val; | ||
} | ||
|
||
} // namespace raft |
Oops, something went wrong.