Skip to content

Commit

Permalink
Add generic reduction functions and separate reductions/warp_primitiv…
Browse files Browse the repository at this point in the history
…es (#1470)

This PR adds bunch of new device reduction functions such as: 

- Generic device reductions that takes reduction operator as argument.
- Ranked reductions to return the index/rank of the reduced value.
- Weighted random reduction to have probabilistic reduction using conditional probability.
- Binary reduction to reduce binary values more efficiently.

There are tests implemented for all device reduction operations.

This PR also separates warp primitives to the `warp_primitives.cuh`.
All reduction functions are moved to `reduction.cuh`

Authors:
  - Akif ÇÖRDÜK (https://github.com/akifcorduk)
  - Corey J. Nolet (https://github.com/cjnolet)
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #1470
  • Loading branch information
akifcorduk authored May 18, 2023
1 parent 6fdb041 commit 8e412b4
Show file tree
Hide file tree
Showing 8 changed files with 841 additions and 312 deletions.
104 changes: 104 additions & 0 deletions cpp/include/raft/random/device/sample.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <stdint.h>

#include <raft/core/cudart_utils.hpp>
#include <raft/core/operators.hpp>
#include <raft/util/cuda_dev_essentials.cuh>
#include <raft/util/warp_primitives.cuh>

namespace raft::random::device {

/**
* @brief warp-level random sampling of an index.
* It selects an index with the given discrete probability
* distribution(represented by weights of each index)
* @param rng random number generator, must have next_u32() function
* @param weight weight of the rank/index.
* @param idx index to be used as rank
* @return only the thread0 will contain valid reduced result
*/
template <typename T, typename rng_t, typename i_t = int>
DI T warp_random_sample(rng_t& rng, T& weight, i_t& idx)
{
// Todo(#1491): benchmark whether a scan and then selecting within the ranges is more efficient.
static_assert(std::is_integral<T>::value, "The type T must be an integral type.");
#pragma unroll
for (i_t offset = raft::WarpSize / 2; offset > 0; offset /= 2) {
T tmp_weight = shfl(weight, laneId() + offset);
i_t tmp_idx = shfl(idx, laneId() + offset);
T sum = (tmp_weight + weight);
weight = sum;
if (sum != 0) {
i_t rnd_number = (rng.next_u32() % sum);
if (rnd_number < tmp_weight) { idx = tmp_idx; }
}
}
}

/**
* @brief 1-D block-level random sampling of an index.
* It selects an index with the given discrete probability
* distribution(represented by weights of each index)
*
* Let w_i be the weight stored on thread i. We calculate the cumulative distribution function
* F_i = sum_{k=0..i} weight_i.
* Sequentially, we could select one of the elements with with the desired probability using the
* following method. We can consider that each element has a subinterval assigned: [F_{i-1}, F_i).
* We generate a uniform random number in the [0, F_i) range, and check which subinterval it falls.
* We return idx corresponding to the selected subinterval.
* In parallel, we do a tree reduction and make a selection at every step when we combine two
* values.
* @param rng random number generator, must have next_u32() function
* @param shbuf shared memory region needed for storing intermediate results. It
* must alteast be of size: `(sizeof(T) + sizeof(i_t)) * WarpSize`
* @param weight weight of the rank/index.
* @param idx index to be used as rank
* @return only the thread0 will contain valid reduced result
*/
template <typename T, typename rng_t, typename i_t = int>
DI i_t block_random_sample(rng_t rng, T* shbuf, T weight = 1, i_t idx = threadIdx.x)
{
T* values = shbuf;
i_t* indices = (i_t*)&shbuf[WarpSize];
i_t wid = threadIdx.x / WarpSize;
i_t nWarps = (blockDim.x + WarpSize - 1) / WarpSize;
warp_random_sample(rng, weight, idx); // Each warp performs partial reduction
i_t lane = laneId();
if (lane == 0) {
values[wid] = weight; // Write reduced value to shared memory
indices[wid] = idx; // Write reduced value to shared memory
}

__syncthreads(); // Wait for all partial reductions

// read from shared memory only if that warp existed
if (lane < nWarps) {
weight = values[lane];
idx = indices[lane];
} else {
weight = 0;
idx = -1;
}
__syncthreads();
if (wid == 0) warp_random_sample(rng, weight, idx);
return idx;
}

} // namespace raft::random::device
Loading

0 comments on commit 8e412b4

Please sign in to comment.