diff --git a/internal/benchmark/bench.cu b/internal/benchmark/bench.cu index e73a0d5bd..38d1d647a 100644 --- a/internal/benchmark/bench.cu +++ b/internal/benchmark/bench.cu @@ -992,7 +992,6 @@ void run_core_primitives_experiments_for_type() , RegularTrials >::run_experiment(); -#if THRUST_CPP_DIALECT >= 2011 experiment_driver< shuffle_tester , ElementMetaType @@ -1000,7 +999,6 @@ void run_core_primitives_experiments_for_type() , BaselineTrials , RegularTrials >::run_experiment(); -#endif } /////////////////////////////////////////////////////////////////////////////// diff --git a/testing/shuffle.cu b/testing/shuffle.cu index a5b1c6f29..77e660c00 100644 --- a/testing/shuffle.cu +++ b/testing/shuffle.cu @@ -1,6 +1,5 @@ #include -#if THRUST_CPP_DIALECT >= 2011 #include #include #include @@ -383,7 +382,7 @@ void TestFunctionIsBijection(size_t m) { thrust::system::detail::generic::feistel_bijection host_f(m, host_g); thrust::system::detail::generic::feistel_bijection device_f(m, device_g); - if (host_f.nearest_power_of_two() >= std::numeric_limits::max() || m == 0) { + if (static_cast(host_f.nearest_power_of_two()) >= static_cast(std::numeric_limits::max()) || m == 0) { return; } @@ -410,17 +409,17 @@ DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection); void TestBijectionLength() { thrust::default_random_engine g(0xD5); - uint64_t m = 3; + uint64_t m = 31; thrust::system::detail::generic::feistel_bijection f(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(4)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32)); - m = 2; + m = 32; f = thrust::system::detail::generic::feistel_bijection(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(2)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32)); - m = 0; + m = 1; f = thrust::system::detail::generic::feistel_bijection(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(1)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(16)); } DECLARE_UNITTEST(TestBijectionLength); @@ -515,7 +514,7 @@ void TestShuffleEvenSpacingBetweenOccurances() { thrust::host_vector h_results; Vector sequence(shuffle_size); thrust::sequence(sequence.begin(), sequence.end(), 0); - thrust::default_random_engine g(0xD5); + thrust::default_random_engine g(0xD6); for (auto i = 0ull; i < num_samples; i++) { thrust::shuffle(sequence.begin(), sequence.end(), g); thrust::host_vector tmp(sequence.begin(), sequence.end()); @@ -561,7 +560,7 @@ void TestShuffleEvenDistribution() { const uint64_t shuffle_sizes[] = {10, 100, 500}; thrust::default_random_engine g(0xD5); for (auto shuffle_size : shuffle_sizes) { - if(shuffle_size > std::numeric_limits::max()) + if(shuffle_size > (uint64_t)std::numeric_limits::max()) continue; const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200; @@ -601,4 +600,3 @@ void TestShuffleEvenDistribution() { } } DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution); -#endif diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 91b77351d..baece51be 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -48,7 +48,7 @@ class feistel_bijection { right_side_bits = total_bits - left_side_bits; right_side_mask = (1ull << right_side_bits) - 1; - for (std::uint64_t i = 0; i < num_rounds; i++) { + for (std::uint32_t i = 0; i < num_rounds; i++) { key[i] = g(); } } @@ -56,28 +56,34 @@ class feistel_bijection { __host__ __device__ std::uint64_t nearest_power_of_two() const { return 1ull << (left_side_bits + right_side_bits); } - __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { - // Extract the right and left sides of the input - auto left = static_cast(val >> right_side_bits); - auto right = static_cast(val & right_side_mask); - round_state state = {left, right}; - for (std::uint64_t i = 0; i < num_rounds; i++) { - state = do_round(state, i); + __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { + std::uint32_t state[2] = { static_cast( val >> right_side_bits ), static_cast( val & right_side_mask ) }; + for( std::uint32_t i = 0; i < num_rounds; i++ ) + { + std::uint32_t hi, lo; + constexpr std::uint64_t M0 = UINT64_C( 0xD2B74407B1CE6E93 ); + mulhilo( M0, state[0], hi, lo ); + lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1] >> left_side_bits; + state[0] = ( ( hi ^ key[i] ) ^ state[1] ) & left_side_mask; + state[1] = lo & right_side_mask; } - - // Check we have the correct number of bits on each side - assert((state.left >> left_side_bits) == 0); - assert((state.right >> right_side_bits) == 0); - // Combine the left and right sides together to get result - return state.left << right_side_bits | state.right; + return static_cast(state[0] << right_side_bits) | static_cast(state[1]); } private: + // Perform 64 bit multiplication and save result in two 32 bit int + static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo ) + { + std::uint64_t product = a * b; + hi = static_cast( product >> 32 ); + lo = static_cast( product ); + } + // Find the nearest power of two - __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { - if (m == 0) return 0; + static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { + if (m <= 16) return 4; std::uint64_t i = 0; m--; while (m != 0) { @@ -87,45 +93,12 @@ class feistel_bijection { return i; } - // Equivalent to boost::hash_combine - __host__ __device__ - std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const { - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } - - // Round function, a 'pseudorandom function' who's output is indistinguishable - // from random for each key value input. This is not cryptographically secure - // but sufficient for generating permutations. - __host__ __device__ std::uint32_t round_function(std::uint64_t value, - const std::uint64_t key_) const { - std::uint64_t hash0 = thrust::random::taus88(static_cast(value))(); - std::uint64_t hash1 = thrust::random::ranlux48(value)(); - return static_cast( - hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask); - } - - __host__ __device__ round_state do_round(const round_state state, - const std::uint64_t round) const { - const std::uint32_t new_left = state.right & left_side_mask; - const std::uint32_t round_function_res = - state.left ^ round_function(state.right, key[round]); - if (right_side_bits != left_side_bits) { - // Upper bit of the old right becomes lower bit of new right if we have - // odd length feistel - const std::uint32_t new_right = - (round_function_res << 1ull) | state.right >> left_side_bits; - return {new_left, new_right}; - } - return {new_left, round_function_res}; - } - - static constexpr std::uint64_t num_rounds = 16; + static constexpr std::uint32_t num_rounds = 24; std::uint64_t right_side_bits; std::uint64_t left_side_bits; std::uint64_t right_side_mask; std::uint64_t left_side_mask; - std::uint64_t key[num_rounds]; + std::uint32_t key[num_rounds]; }; struct key_flag_tuple {