Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Allow segmented problems to have different types for offset iterators. #291

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 58 additions & 42 deletions cub/device/device_segmented_radix_sort.cuh

Large diffs are not rendered by default.

180 changes: 96 additions & 84 deletions cub/device/device_segmented_reduce.cuh

Large diffs are not rendered by default.

26 changes: 14 additions & 12 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ template <
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
typename KeyT, ///< Key type
typename ValueT, ///< Value type
typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator
typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator
typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
ChainedPolicyT::ActivePolicy::AltSegmentedPolicy::BLOCK_THREADS :
Expand All @@ -366,8 +367,8 @@ __global__ void DeviceSegmentedRadixSortKernel(
KeyT *d_keys_out, ///< [in] Output keys buffer
const ValueT *d_values_in, ///< [in] Input values buffer
ValueT *d_values_out, ///< [in] Output values buffer
OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
int /*num_segments*/, ///< [in] The number of segments that comprise the sorting data
int current_bit, ///< [in] Bit position of current radix digit
int pass_bits) ///< [in] Number of bits of current radix digit
Expand Down Expand Up @@ -1627,7 +1628,8 @@ template <
bool IS_DESCENDING, ///< Whether or not the sorted-order is high-to-low
typename KeyT, ///< Key type
typename ValueT, ///< Value type
typename OffsetIteratorT, ///< Random-access input iterator type for reading segment offsets \iterator
typename BeginOffsetIteratorT, ///< Random-access input iterator type for reading segment beginning offsets \iterator
typename EndOffsetIteratorT, ///< Random-access input iterator type for reading segment ending offsets \iterator
typename OffsetT, ///< Signed integer type for global offsets
typename SelectedPolicy = DeviceRadixSortPolicy<KeyT, ValueT, OffsetT> >
struct DispatchSegmentedRadixSort :
Expand All @@ -1654,8 +1656,8 @@ struct DispatchSegmentedRadixSort :
DoubleBuffer<ValueT> &d_values; ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
OffsetT num_items; ///< [in] Number of items to sort
OffsetT num_segments; ///< [in] The number of segments that comprise the sorting data
OffsetIteratorT d_begin_offsets; ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
OffsetIteratorT d_end_offsets; ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
BeginOffsetIteratorT d_begin_offsets; ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
EndOffsetIteratorT d_end_offsets; ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison
int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison
cudaStream_t stream; ///< [in] CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
Expand All @@ -1677,8 +1679,8 @@ struct DispatchSegmentedRadixSort :
DoubleBuffer<ValueT> &d_values,
OffsetT num_items,
OffsetT num_segments,
OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets,
int begin_bit,
int end_bit,
bool is_overwrite_okay,
Expand Down Expand Up @@ -1898,8 +1900,8 @@ struct DispatchSegmentedRadixSort :

// Force kernel code-generation in all compiler passes
return InvokePasses<ActivePolicyT>(
DeviceSegmentedRadixSortKernel<MaxPolicyT, false, IS_DESCENDING, KeyT, ValueT, OffsetIteratorT, OffsetT>,
DeviceSegmentedRadixSortKernel<MaxPolicyT, true, IS_DESCENDING, KeyT, ValueT, OffsetIteratorT, OffsetT>);
DeviceSegmentedRadixSortKernel<MaxPolicyT, false, IS_DESCENDING, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>,
DeviceSegmentedRadixSortKernel<MaxPolicyT, true, IS_DESCENDING, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>);
}


Expand All @@ -1917,8 +1919,8 @@ struct DispatchSegmentedRadixSort :
DoubleBuffer<ValueT> &d_values, ///< [in,out] Double-buffer whose current buffer contains the unsorted input values and, upon return, is updated to point to the sorted output values
int num_items, ///< [in] Number of items to sort
int num_segments, ///< [in] The number of segments that comprise the sorting data
OffsetIteratorT d_begin_offsets, ///< [in] Pointer to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
OffsetIteratorT d_end_offsets, ///< [in] Pointer to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
BeginOffsetIteratorT d_begin_offsets, ///< [in] Random-access input iterator to the sequence of beginning offsets of length \p num_segments, such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
EndOffsetIteratorT d_end_offsets, ///< [in] Random-access input iterator to the sequence of ending offsets of length \p num_segments, such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
int begin_bit, ///< [in] The beginning (least-significant) bit index needed for key comparison
int end_bit, ///< [in] The past-the-end (most-significant) bit index needed for key comparison
bool is_overwrite_okay, ///< [in] Whether is okay to overwrite source buffers
Expand Down
Loading