Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use thrust functions instead of custom functions #5544

Merged
merged 1 commit into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 3 additions & 158 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
#include <thrust/system_error.h>
#include <thrust/logical.h>
#include <thrust/gather.h>
#include <thrust/binary_search.h>

#include <omp.h>
#include <rabit/rabit.h>
#include <cub/cub.cuh>
#include <cub/util_allocator.cuh>

#include <algorithm>
#include <chrono>
#include <ctime>
#include <numeric>
#include <sstream>
#include <string>
Expand All @@ -28,7 +27,6 @@
#include "xgboost/span.h"

#include "common.h"
#include "timer.h"

#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
Expand Down Expand Up @@ -132,94 +130,6 @@ DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, un
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
}

namespace internal {

// Items of size 'n' are sorted in an order determined by the Comparator
// If left is true, find the number of elements where 'comp(item, v)' returns true;
// 0 if nothing is true
// If left is false, find the number of elements where '!comp(item, v)' returns true;
// 0 if nothing is true
template <typename T, typename Comparator = thrust::greater<T>>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsImpl(bool left, const T * __restrict__ items, uint32_t n, T v,
const Comparator &comp = Comparator()) {
const T *items_begin = items;
uint32_t num_remaining = n;
const T *middle_item = nullptr;
uint32_t middle;
while (num_remaining > 0) {
middle_item = items_begin;
middle = num_remaining / 2;
middle_item += middle;
if ((left && comp(*middle_item, v)) || (!left && !comp(v, *middle_item))) {
items_begin = ++middle_item;
num_remaining -= middle + 1;
} else {
num_remaining = middle;
}
}

return left ? items_begin - items : items + n - items_begin;
}

} // namespace internal

/*!
* \brief Find the strict upper bound for an element in a sorted array
* using binary search.
* \param items pointer to the first element of the sorted array
* \param n length of the sorted array
* \param v value for which to find the upper bound
* \param comp determines how the items are sorted ascending/descending order - should conform
* to ordering semantics
* \return the smallest index i that has a value > v, or n if none is larger when sorted ascendingly
* or, an index i with a value < v, or 0 if none is smaller when sorted descendingly
*/
// Preserve existing default behavior of upper bound
template <typename T, typename Comp = thrust::less<T>>
XGBOOST_DEVICE __forceinline__ uint32_t UpperBound(const T *__restrict__ items,
uint32_t n,
T v,
const Comp &comp = Comp()) {
if (std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value) {
return n - internal::CountNumItemsImpl(false, items, n, v, comp);
} else {
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value,
"Invalid comparator used in Upperbound - can only be thrust::greater/less");
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
}
}

/*!
* \brief Find the strict lower bound for an element in a sorted array
* using binary search.
* \param items pointer to the first element of the sorted array
* \param n length of the sorted array
* \param v value for which to find the upper bound
* \param comp determines how the items are sorted ascending/descending order - should conform
* to ordering semantics
* \return the smallest index i that has a value >= v, or n if none is larger
* when sorted ascendingly
* or, an index i with a value <= v, or 0 if none is smaller when sorted descendingly
*/
template <typename T, typename Comp = thrust::less<T>>
XGBOOST_DEVICE __forceinline__ uint32_t LowerBound(const T *__restrict__ items,
uint32_t n,
T v,
const Comp &comp = Comp()) {
if (std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value) {
return internal::CountNumItemsImpl(true, items, n, v, comp);
} else {
static_assert(std::is_same<Comp, thrust::less<T>>::value ||
std::is_same<Comp, thrust::greater<T>>::value,
"Invalid comparator used in LowerBound - can only be thrust::greater/less");
return std::numeric_limits<uint32_t>::max(); // Simply to quiesce the compiler
}
}

template <typename T>
__device__ xgboost::common::Range GridStrideRange(T begin, T end) {
begin += blockDim.x * blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -878,7 +788,8 @@ class SegmentSorter {
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return dh::UpperBound(dgroups, ngroups, idx) - 1;
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
dgroups - 1;
}; // NOLINT

thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
Expand Down Expand Up @@ -1018,70 +929,4 @@ thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}

template <typename FunctionT>
class LauncherItr {
public:
int idx { 0 };
FunctionT f;
XGBOOST_DEVICE LauncherItr() : idx(0) {} // NOLINT
XGBOOST_DEVICE LauncherItr(int idx, FunctionT f) : idx(idx), f(f) {}
XGBOOST_DEVICE LauncherItr &operator=(int output) {
f(idx, output);
return *this;
}
};

/**
* \brief Thrust compatible iterator type - discards algorithm output and launches device lambda
* with the index of the output and the algorithm output as arguments.
*
* \author Rory
* \date 7/9/2017
*
* \tparam FunctionT Type of the function t.
*/
template <typename FunctionT>
class DiscardLambdaItr {
public:
// Required iterator traits
using self_type = DiscardLambdaItr; // NOLINT
using difference_type = ptrdiff_t; // NOLINT
using value_type = void; // NOLINT
using pointer = value_type *; // NOLINT
using reference = LauncherItr<FunctionT>; // NOLINT
using iterator_category = typename thrust::detail::iterator_facade_category< // NOLINT
thrust::any_system_tag, thrust::random_access_traversal_tag, value_type,
reference>::type; // NOLINT
private:
difference_type offset_;
FunctionT f_;
public:
XGBOOST_DEVICE explicit DiscardLambdaItr(FunctionT f) : offset_(0), f_(f) {}
XGBOOST_DEVICE DiscardLambdaItr(difference_type offset, FunctionT f)
: offset_(offset), f_(f) {}
XGBOOST_DEVICE self_type operator+(const int &b) const {
return DiscardLambdaItr(offset_ + b, f_);
}
XGBOOST_DEVICE self_type operator++() {
offset_++;
return *this;
}
XGBOOST_DEVICE self_type operator++(int) {
self_type retval = *this;
offset_++;
return retval;
}
XGBOOST_DEVICE self_type &operator+=(const int &b) {
offset_ += b;
return *this;
}
XGBOOST_DEVICE reference operator*() const {
return LauncherItr<FunctionT>(offset_, f_);
}
XGBOOST_DEVICE reference operator[](int idx) {
self_type offset = (*this) + idx;
return *offset;
}
};

} // namespace dh
2 changes: 1 addition & 1 deletion src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo {
size_t Size() const { return num_elements_; }
__device__ COOTuple GetElement(size_t idx) const {
size_t column_idx =
dh::UpperBound(column_ptr_.data(), column_ptr_.size(), idx) - 1;
thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1;
auto& column = columns_[column_idx];
size_t row_idx = idx - column_ptr_[column_idx];
float value = column.valid.Data() == nullptr || column.valid.Check(row_idx)
Expand Down
4 changes: 3 additions & 1 deletion src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ __global__ void CompressBinEllpackKernel(
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
// Assigning the bin in current entry.
// S.t.: fvalue < feature_cuts[bin]
bin = dh::UpperBound(feature_cuts, ncuts, fvalue);
bin = thrust::upper_bound(thrust::seq, feature_cuts, feature_cuts + ncuts,
fvalue) -
feature_cuts;
if (bin >= ncuts) {
bin = ncuts - 1;
}
Expand Down
17 changes: 12 additions & 5 deletions src/objective/rank_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,18 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {

template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheLeftOf(const T * __restrict__ items, uint32_t n, T v) {
return dh::LowerBound(items, n, v, thrust::greater<T>());
CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) {
return thrust::lower_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items;
}

template <typename T>
XGBOOST_DEVICE __forceinline__ uint32_t
CountNumItemsToTheRightOf(const T * __restrict__ items, uint32_t n, T v) {
return n - dh::UpperBound(items, n, v, thrust::greater<T>());
CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) {
return n - (thrust::upper_bound(thrust::seq, items, items + n, v,
thrust::greater<T>()) -
items);
}
#endif

Expand Down Expand Up @@ -671,7 +675,10 @@ class SortedLabelList : dh::SegmentSorter<float> {
dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to
uint32_t item_idx = idx % total_items;
uint32_t group_idx = dh::UpperBound(dgroups.data(), ngroups, item_idx);
uint32_t group_idx =
thrust::upper_bound(thrust::seq, dgroups.begin(),
dgroups.begin() + ngroups, item_idx) -
dgroups.begin();
// Span of this group within the larger labels/predictions sorted tuple
uint32_t group_begin = dgroups[group_idx - 1];
uint32_t group_end = dgroups[group_idx];
Expand Down
108 changes: 63 additions & 45 deletions src/tree/gpu_hist/row_partitioner.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*/
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/sequence.h>
#include <vector>
#include "../../common/device_helpers.cuh"
Expand All @@ -11,58 +13,74 @@ namespace tree {

struct IndicateLeftTransform {
bst_node_t left_nidx;
explicit IndicateLeftTransform(bst_node_t left_nidx)
: left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const {
explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ size_t
operator()(const bst_node_t& x) const {
return x == left_nidx ? 1 : 0;
}
};
/*
* position: Position of rows belonged to current split node.
*/
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx,
bst_node_t right_nidx,
int64_t* d_left_count, cudaStream_t stream) {
// radix sort over 1 bit, see:
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
// the ex_scan_result represents how many rows have been assigned to left node so far
// during scan.

struct IndexFlagTuple {
size_t idx;
size_t flag;
};

struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
const IndexFlagTuple& b) const {
return {b.idx, a.flag + b.flag};
}
};

struct WriteResultsFunctor {
bst_node_t left_nidx;
common::Span<bst_node_t> position_in;
common::Span<bst_node_t> position_out;
common::Span<RowPartitioner::RowIndexT> ridx_in;
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;

__device__ int operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
if (position_in[x.idx] == left_nidx) {
scatter_address = x.flag - 1; // -1 because inclusive scan
} else {
// current number of rows belong to right node + total number of rows belong to left
// node
scatter_address = (idx - ex_scan_result) + *d_left_count;
// current number of rows belong to right node + total number of rows
// belong to left node
scatter_address = (x.idx - x.flag) + *d_left_count;
}
// copy the node id to output
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
position_out[scatter_address] = position_in[x.idx];
ridx_out[scatter_address] = ridx_in[x.idx];

// Discard
return 0;
}
};

IndicateLeftTransform is_left(left_nidx);
// an iterator that given a old position returns whether it belongs to left or right
// node.
cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
bst_node_t*>
in_itr(d_position_in, is_left);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
// position is of the same size with current split node's row segment
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
in_itr, out_itr, position.size(), stream);
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx, bst_node_t right_nidx,
int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator = thrust::make_transform_output_iterator(
thrust::discard_iterator<int>(), write_results);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return IndexFlagTuple{idx, position[idx] == left_nidx};
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
input_iterator + position.size(),
discard_write_iterator,
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
return IndexFlagTuple{b.idx, a.flag + b.flag};
});
}

RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
Expand Down Expand Up @@ -137,7 +155,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
SortPosition(
// position_in
common::Span<bst_node_t>(position_.Current() + segment.begin,
segment.Size()),
segment.Size()),
// position_out
common::Span<bst_node_t>(position_.Other() + segment.begin,
segment.Size()),
Expand Down
Loading