Skip to content

Commit

Permalink
Remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 11, 2022
1 parent 1d2863d commit 9861601
Showing 1 changed file with 0 additions and 147 deletions.
147 changes: 0 additions & 147 deletions cpp/include/raft/spatial/knn/detail/ivf_flat/radix_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,6 @@

namespace raft::spatial::knn::detail::ivf_flat {

inline size_t calc_aligned_size(const std::vector<size_t>& sizes)
{
const size_t ALIGN_BYTES = 256;
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
size_t total = 0;
for (auto sz : sizes) {
total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}
return total + ALIGN_BYTES - 1;
}

inline std::vector<void*> calc_aligned_pointers(const void* p, const std::vector<size_t>& sizes)
{
const size_t ALIGN_BYTES = 256;
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);

char* ptr = reinterpret_cast<char*>((reinterpret_cast<size_t>(p) + ALIGN_BYTES - 1) & ALIGN_MASK);

std::vector<void*> aligned_pointers;
aligned_pointers.reserve(sizes.size());
for (auto sz : sizes) {
aligned_pointers.push_back(ptr);
ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK;
}

return aligned_pointers;
}

constexpr int BLOCK_DIM = 512;
constexpr int ITEM_PER_THREAD = 32;

Expand Down Expand Up @@ -438,125 +410,6 @@ __global__ void radix_kernel(const T* in_buf,
}
}

template <typename T, typename IdxT>
__global__ void final_filter(const T* in,
const IdxT len,
const IdxT k,
Counter<T, IdxT>* counters,
T* out,
IdxT* out_idx,
bool greater)
{
const int batch_id = blockIdx.y;
const T kth_value = counters[batch_id].kth_value;
const IdxT needed_num_of_kth = counters[batch_id].k;
IdxT& out_cnt = counters[batch_id].out_cnt;
IdxT& out_back_cnt = counters[batch_id].out_back_cnt;

in = in + batch_id * len;
out = out + batch_id * k;
out_idx = out_idx + batch_id * k;

auto f = [k, greater, kth_value, needed_num_of_kth, &out_cnt, &out_back_cnt, out, out_idx](
T val, IdxT i) {
if ((greater && val > kth_value) || (!greater && val < kth_value)) {
IdxT pos = atomicAdd(&out_cnt, 1);
out[pos] = val;
out_idx[pos] = i;
} else if (val == kth_value) {
IdxT back_pos = atomicAdd(&out_back_cnt, 1);
if (back_pos < needed_num_of_kth) {
IdxT pos = k - 1 - back_pos;
out[pos] = val;
out_idx[pos] = i;
}
}
};
vectorized_process(in, len, f);
}

template <typename T, typename IdxT, int BITS_PER_PASS, int NUM_THREAD>
void radix_select_topk(void* buf,
size_t& buf_size,
const T* in,
IdxT batch_size,
IdxT len,
IdxT k,
T* out,
IdxT* out_idx,
bool greater,
cudaStream_t stream)
{
// TODO: is it possible to relax this restriction?
static_assert(calc_num_passes<T, BITS_PER_PASS>() > 1);
constexpr int num_buckets = calc_num_buckets<BITS_PER_PASS>();

Counter<T, IdxT>* counters = nullptr;
IdxT* histograms = nullptr;
T* buf1 = nullptr;
T* buf2 = nullptr;
{
std::vector<size_t> sizes = {sizeof(*counters) * batch_size,
sizeof(*histograms) * num_buckets * batch_size,
sizeof(*buf1) * len * batch_size,
sizeof(*buf2) * len * batch_size};
size_t total_size = calc_aligned_size(sizes);
if (!buf) {
buf_size = total_size;
return;
}

std::vector<void*> aligned_pointers = calc_aligned_pointers(buf, sizes);
counters = static_cast<decltype(counters)>(aligned_pointers[0]);
histograms = static_cast<decltype(histograms)>(aligned_pointers[1]);
buf1 = static_cast<decltype(buf1)>(aligned_pointers[2]);
buf2 = static_cast<decltype(buf2)>(aligned_pointers[3]);

RAFT_CUDA_TRY(cudaMemsetAsync(
buf,
0,
static_cast<char*>(aligned_pointers[2]) - static_cast<char*>(aligned_pointers[0]),
stream));
}

const T* in_buf = nullptr;
T* out_buf = nullptr;

dim3 blocks((len - 1) / (NUM_THREAD * ITEM_PER_THREAD) + 1, batch_size);

constexpr int num_passes = calc_num_passes<T, BITS_PER_PASS>();
for (int pass = 0; pass < num_passes; ++pass) {
if (pass == 0) {
in_buf = in;
out_buf = nullptr;
} else if (pass == 1) {
in_buf = in;
out_buf = buf1;
} else {
in_buf = (pass % 2 == 0) ? buf1 : buf2;
out_buf = (pass % 2 == 0) ? buf2 : buf1;
}
radix_kernel<T, IdxT, BITS_PER_PASS, NUM_THREAD><<<blocks, NUM_THREAD, 0, stream>>>(in_buf,
nullptr,
out_buf,
nullptr,
nullptr,
nullptr,
counters,
histograms,
len,
k,
greater,
pass);
}

constexpr int FILTER_BLOCK_DIM = 256;
constexpr int FILTER_ITEM_PER_THREAD = 32;
dim3 filter_blocks((len - 1) / (FILTER_BLOCK_DIM * FILTER_ITEM_PER_THREAD) + 1, batch_size);
final_filter<<<filter_blocks, FILTER_BLOCK_DIM, 0, stream>>>(
in, len, k, counters, out, out_idx, greater);
}

template <typename T, typename IdxT, int BITS_PER_PASS, int NUM_THREAD>
void radix_topk(const T* in,
const IdxT* in_idx,
Expand Down

0 comments on commit 9861601

Please sign in to comment.