Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #1507 from bjude/sequence_no_size_t_conversion
Browse files Browse the repository at this point in the history
Sequence: Specialise compute_sequence_values for builtin arithmetic types
  • Loading branch information
alliepiper authored Oct 6, 2021
2 parents 04f3895 + 1782f28 commit e99e10b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
33 changes: 33 additions & 0 deletions testing/sequence.cu
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,36 @@ void TestSequenceComplex()
thrust::sequence(m.begin(), m.end());
}
DECLARE_UNITTEST(TestSequenceComplex);

// A class that doesnt accept conversion from size_t but can be multiplied by a scalar
struct Vector
{
Vector() = default;
// Explicitly disable construction from size_t
Vector(std::size_t) = delete;
__host__ __device__ Vector(int x_, int y_) : x{x_}, y{y_} {}
Vector(const Vector&) = default;
Vector &operator=(const Vector&) = default;

int x, y;
};

// Vector-Vector addition
__host__ __device__ Vector operator+(const Vector a, const Vector b) { return Vector{a.x + b.x, a.y + b.y}; }
// Vector-Scalar Multiplication
__host__ __device__ Vector operator*(const int a, const Vector b) { return Vector{a * b.x, a * b.y}; }
__host__ __device__ Vector operator*(const Vector b, const int a) { return Vector{a * b.x, a * b.y}; }

void TestSequenceNoSizeTConversion()
{
thrust::device_vector<Vector> m(64);
thrust::sequence(m.begin(), m.end(), ::Vector{0, 0}, ::Vector{1, 2});

for (std::size_t i = 0; i < m.size(); ++i)
{
const ::Vector v = m[i];
ASSERT_EQUAL(static_cast<std::size_t>(v.x), i);
ASSERT_EQUAL(static_cast<std::size_t>(v.y), 2 * i);
}
}
DECLARE_UNITTEST(TestSequenceNoSizeTConversion);
15 changes: 14 additions & 1 deletion thrust/system/detail/generic/sequence.inl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,25 @@ __host__ __device__

namespace detail
{
template <typename T>
template <typename T, typename = void>
struct compute_sequence_value
{
T init;
T step;

__thrust_exec_check_disable__
__host__ __device__
T operator()(std::size_t i) const
{
return init + step * i;
}
};
template <typename T>
struct compute_sequence_value<T, typename std::enable_if<std::is_arithmetic<T>::value>::type>
{
T init;
T step;

__thrust_exec_check_disable__
__host__ __device__
T operator()(std::size_t i) const
Expand Down

0 comments on commit e99e10b

Please sign in to comment.