Skip to content

Commit

Permalink
Try to also guard against underflow
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Aug 28, 2024
1 parent e0c98be commit 2af243f
Showing 1 changed file with 45 additions and 25 deletions.
70 changes: 45 additions & 25 deletions thrust/thrust/system/cuda/detail/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

#include <cstdint>
#include <stdexcept>
#include <string>

#if defined(THRUST_FORCE_32BIT_OFFSET_TYPE) && defined(THRUST_FORCE_64BIT_OFFSET_TYPE)
# error "Only THRUST_FORCE_32BIT_OFFSET_TYPE or THRUST_FORCE_64BIT_OFFSET_TYPE may be defined!"
Expand All @@ -51,38 +52,55 @@
status = call arguments; \
}

#define _THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(_CUDA_VSTD::is_signed, decltype(count))) \
{ \
if (count < 0) \
{ \
throw ::std::runtime_error("Invalid input range. Passed size is " + std::to_string(count)); \
} \
}

#define _THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(_CUDA_VSTD::is_signed, decltype(count1))) \
{ \
if (count1 < 0 || count2 < 0) \
{ \
throw ::std::runtime_error( \
"Invalid input ranges. Passed sizes are " + std::to_string(count1) + " and " + std::to_string(count2)); \
} \
}

#if defined(THRUST_FORCE_64BIT_OFFSET_TYPE)
//! @brief Always dispatches to 64 bit offset version of an algorithm
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two different call implementations
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call_64, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int64_t, status, call_64, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but always dispatching to uint64_t. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint64_t, status, call_64, count, arguments)

#elif defined(THRUST_FORCE_32BIT_OFFSET_TYPE)

//! @brief Ensures that the size of the input does not overflow the offset type
# define _THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(index_type, count) \
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(_CUDA_VSTD::is_signed, index_type)) \
{ \
if (count < 0) \
{ \
throw ::std::runtime_error("Invalid input range. Passed size is " + std::to_string(count)); \
} \
} \
if (static_cast<std::uint64_t>(count) \
> static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max)) \
{ \
Expand All @@ -95,48 +113,45 @@
}

//! @brief Ensures that the sizes of the inputs do not overflow the offset type, but two counts
# define _THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(index_type, count1, count2) \
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(_CUDA_VSTD::is_signed, index_type)) \
{ \
if (count1 < 0 || count2 < 0) \
{ \
throw ::std::runtime_error( \
"Invalid input ranges. Passed size are " + std::to_string(count1) + " and " + std::to_string(count2)); \
} \
} \
if (static_cast<std::uint64_t>(count1) + static_cast<std::uint64_t>(count2) \
> static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max)) \
{ \
throw ::std::runtime_error( \
"Combined input sizes exceed the maximum allowable value for " #index_type " (" \
+ std::to_string(thrust::detail::integer_traits<index_type>::const_max) \
+ "). " #index_type " was used because the macro THRUST_FORCE_32BIT_OFFSET_TYPE was defined. " \
"To handle larger input sizes, either remove this macro to dynamically dispatch " \
"between 32-bit and 64-bit index types, or define THRUST_FORCE_64BIT_OFFSET_TYPE."); \
# define _THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(index_type, count1, count2) \
if (static_cast<std::uint64_t>(count1) + static_cast<std::uint64_t>(count2) \
> static_cast<std::uint64_t>(thrust::detail::integer_traits<index_type>::const_max)) \
{ \
throw ::std::runtime_error( \
"Combined input sizes exceed the maximum allowable value for " #index_type " (" \
+ std::to_string(thrust::detail::integer_traits<index_type>::const_max) \
+ "). " #index_type " was used because the macro THRUST_FORCE_32BIT_OFFSET_TYPE was defined. " \
"To handle larger input sizes, either remove this macro to dynamically dispatch " \
"between 32-bit and 64-bit index types, or define THRUST_FORCE_64BIT_OFFSET_TYPE."); \
}

//! @brief Always dispatches to 32 bit offset version of an algorithm but throws if count would overflow
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two counts
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but with two different call implementations
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call_32, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW2(std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call_32, count1, count2, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH but always dispatching to uint64_t. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_OVERFLOW(std::uint32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call_32, count, arguments)

Expand All @@ -154,6 +169,7 @@
//! assumes that callables for both branches consist of the same tokens, and is intended to be used with Thrust-style
//! dispatch interfaces, that always deduce the size type from the arguments.
# define THRUST_INDEX_TYPE_DISPATCH(status, call, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call, count, arguments) \
else \
Expand All @@ -165,6 +181,7 @@
//!
//! This version of the macro supports providing two count variables, which is necessary for set algorithms.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT2 (std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call, count1, count2, arguments) \
else \
Expand All @@ -178,13 +195,15 @@
//!
//! See reduce_n_impl to see an example of how this is meant to be used.
# define THRUST_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::int32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::int32_t, status, call_32, count, arguments) \
else \
_THRUST_INDEX_TYPE_DISPATCH(std::int64_t, status, call_64, count, arguments)

//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but uses two counts.
# define THRUST_DOUBLE_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count1, count2, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW2(count1, count2) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT2 (std::int32_t, count1, count2) \
_THRUST_INDEX_TYPE_DISPATCH2(std::int32_t, status, call_32, count1, count2, arguments) \
else \
Expand All @@ -193,6 +212,7 @@
//! Like \ref THRUST_INDEX_TYPE_DISPATCH2 but dispatching to uint32_t and uint64_t, respectively, depending on the
//! `count` argument. `count` must not be negative.
# define THRUST_UNSIGNED_INDEX_TYPE_DISPATCH2(status, call_32, call_64, count, arguments) \
_THRUST_INDEX_TYPE_DISPATCH_GUARD_UNDERFLOW(count) \
if _THRUST_INDEX_TYPE_DISPATCH_SELECT (std::uint32_t, count) \
_THRUST_INDEX_TYPE_DISPATCH(std::uint32_t, status, call_32, count, arguments) \
else \
Expand Down

0 comments on commit 2af243f

Please sign in to comment.