From f9ea205ed8f0c08279a1a2bafff3c76b292e0c05 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 19 Aug 2020 16:58:28 +1200 Subject: [PATCH] Updated shuffle implementation to use better bijective function --- testing/shuffle.cu | 459 +++++++++++++++++++++-- thrust/system/detail/generic/shuffle.inl | 165 ++++---- 2 files changed, 516 insertions(+), 108 deletions(-) diff --git a/testing/shuffle.cu b/testing/shuffle.cu index 2d9094b421..41cfde44dd 100644 --- a/testing/shuffle.cu +++ b/testing/shuffle.cu @@ -1,15 +1,309 @@ #include #if THRUST_CPP_DIALECT >= 2011 +#include +#include #include #include #include -#include #include -#include -template -void TestShuffleSimple() { +class CephesFunctions { +public: + static double cephes_igamc(double a, double x) { + double ans, ax, c, yc, r, t, y, z; + double pk, pkm1, pkm2, qk, qkm1, qkm2; + + if ((x <= 0) || (a <= 0)) + return (1.0); + + if ((x < 1.0) || (x < a)) + return (1.e0 - cephes_igam(a, x)); + + ax = a * log(x) - x - cephes_lgam(a); + + if (ax < -MAXLOG) { + printf("igamc: UNDERFLOW\n"); + return 0.0; + } + ax = exp(ax); + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + do { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = fabs((ans - r) / r); + ans = r; + } else + t = 1.0; + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (fabs(pk) > big) { + pkm2 *= biginv; + pkm1 *= biginv; + qkm2 *= biginv; + qkm1 *= biginv; + } + } while (t > MACHEP); + + return ans * ax; + } + +private: + static constexpr double rel_error = 1E-12; + + static constexpr double MACHEP = 1.11022302462515654042E-16; // 2**-53 + static constexpr double MAXLOG = 7.09782712893383996732224E2; // log(MAXNUM) + static constexpr double MAXNUM = 1.7976931348623158E308; // 2**1024*(1-MACHEP) + static constexpr double PI = 3.14159265358979323846; + + static constexpr double big = 4.503599627370496e15; + static constexpr double biginv = 2.22044604925031308085e-16; + + static int sgngam; + + static double cephes_igam(double a, double x) { + double ans, ax, c, r; + + if ((x <= 0) || (a <= 0)) + return 0.0; + + if ((x > 1.0) && (x > a)) + return 1.e0 - cephes_igamc(a, x); + + /* Compute x**a * exp(-x) / gamma(a) */ + ax = a * log(x) - x - cephes_lgam(a); + if (ax < -MAXLOG) { + printf("igam: UNDERFLOW\n"); + return 0.0; + } + ax = exp(ax); + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + do { + r += 1.0; + c *= x / r; + ans += c; + } while (c / ans > MACHEP); + + return ans * ax / a; + } + + /* A[]: Stirling's formula expansion of log gamma + * B[], C[]: log gamma function between 2 and 3 + */ + static constexpr unsigned short A[] = { + 0x6661, 0x2733, 0x9850, 0x3f4a, 0xe943, 0xb580, 0x7fbd, + 0xbf43, 0x5ebb, 0x20dc, 0x019f, 0x3f4a, 0xa5a1, 0x16b0, + 0xc16c, 0xbf66, 0x554b, 0x5555, 0x5555, 0x3fb5}; + static constexpr unsigned short B[] = { + 0x6761, 0x8ff3, 0x8901, 0xc095, 0xb93e, 0x355b, 0xf234, 0xc0e2, + 0x89e5, 0xf890, 0x3d73, 0xc114, 0xdb51, 0xf994, 0xbc82, 0xc131, + 0xf20b, 0x0219, 0x4589, 0xc13a, 0x055e, 0x5418, 0x0c67, 0xc12a}; + static constexpr unsigned short C[] = { + /*0x0000,0x0000,0x0000,0x3ff0,*/ + 0x12b2, 0x1cf3, 0xfd0d, 0xc075, 0xd757, 0x7b89, 0xaa0d, 0xc0d0, + 0x4c9b, 0xb974, 0xeb84, 0xc10a, 0x0043, 0x7195, 0x6286, 0xc131, + 0xf34c, 0x892f, 0x5255, 0xc143, 0xe14a, 0x6a11, 0xce4b, 0xc13e}; + + static constexpr double MAXLGM = 2.556348e305; + + /* Logarithm of gamma function */ + static double cephes_lgam(double x) { + double p, q, u, w, z; + int i; + + sgngam = 1; + + if (x < -34.0) { + q = -x; + w = cephes_lgam(q); /* note this modifies sgngam! */ + p = floor(q); + if (p == q) { + lgsing: + goto loverf; + } + i = (int)p; + if ((i & 1) == 0) + sgngam = -1; + else + sgngam = 1; + z = q - p; + if (z > 0.5) { + p += 1.0; + z = p - q; + } + z = q * sin(PI * z); + if (z == 0.0) + goto lgsing; + /* z = log(PI) - log( z ) - w;*/ + z = log(PI) - log(z) - w; + return z; + } + + if (x < 13.0) { + z = 1.0; + p = 0.0; + u = x; + while (u >= 3.0) { + p -= 1.0; + u = x + p; + z *= u; + } + while (u < 2.0) { + if (u == 0.0) + goto lgsing; + z /= u; + p += 1.0; + u = x + p; + } + if (z < 0.0) { + sgngam = -1; + z = -z; + } else + sgngam = 1; + if (u == 2.0) + return (log(z)); + p -= 2.0; + x = x + p; + p = x * cephes_polevl(x, (const double *)B, 5) / + cephes_p1evl(x, (const double *)C, 6); + + return log(z) + p; + } + + if (x > MAXLGM) { + loverf: + printf("lgam: OVERFLOW\n"); + + return sgngam * MAXNUM; + } + + q = (x - 0.5) * log(x) - x + log(sqrt(2 * PI)); + if (x > 1.0e8) + return q; + + p = 1.0 / (x * x); + if (x >= 1000.0) + q += + ((7.9365079365079365079365e-4 * p - 2.7777777777777777777778e-3) * p + + 0.0833333333333333333333) / + x; + else + q += cephes_polevl(p, (const double *)A, 4) / x; + + return q; + } + + static double cephes_polevl(double x, const double *coef, int N) { + const double *p = coef; + double ans = *p++; + int i = N; + do + ans = ans * x + *p++; + while (--i); + + return ans; + } + + static double cephes_p1evl(double x, const double *coef, int N) { + const double *p = coef; + double ans = x + *p++; + int i = N - 1; + + do + ans = ans * x + *p++; + while (--i); + + return ans; + } + + static double cephes_erf(double x) { + static const double two_sqrtpi = 1.128379167095512574; + double sum = x, term = x, xsqr = x * x; + int j = 1; + + if (fabs(x) > 2.2) + return 1.0 - cephes_erfc(x); + + do { + term *= xsqr / j; + sum -= term / (2 * j + 1); + j++; + term *= xsqr / j; + sum += term / (2 * j + 1); + j++; + } while (fabs(term) / sum > rel_error); + + return two_sqrtpi * sum; + } + + static double cephes_erfc(double x) { + static const double one_sqrtpi = 0.564189583547756287; + double a = 1, b = x, c = x, d = x * x + 0.5; + double q1, q2 = b / d, n = 1.0, t; + + if (fabs(x) < 2.2) + return 1.0 - cephes_erf(x); + if (x < 0) + return 2.0 - cephes_erfc(-x); + + do { + t = a * n + b * x; + a = b; + b = t; + t = c * n + d * x; + c = d; + d = t; + n += 0.5; + q1 = q2; + q2 = b / d; + } while (fabs(q1 - q2) / q2 > rel_error); + + return one_sqrtpi * exp(-x * x) * q2; + } + + static double cephes_normal(double x) { + double arg, result, sqrt2 = 1.414213562373095048801688724209698078569672; + + if (x > 0) { + arg = x / sqrt2; + result = 0.5 * (1 + erf(arg)); + } else { + arg = -x / sqrt2; + result = 0.5 * (1 - erf(arg)); + } + + return (result); + } +}; +int CephesFunctions::sgngam = 0; +constexpr unsigned short CephesFunctions::A[]; +constexpr unsigned short CephesFunctions::B[]; +constexpr unsigned short CephesFunctions::C[]; + +template void TestShuffleSimple() { Vector data(5); data[0] = 0; data[1] = 1; @@ -26,8 +320,7 @@ void TestShuffleSimple() { } DECLARE_VECTOR_UNITTEST(TestShuffleSimple); -template -void TestShuffleCopySimple() { +template void TestShuffleCopySimple() { Vector data(5); data[0] = 0; data[1] = 1; @@ -43,8 +336,7 @@ void TestShuffleCopySimple() { } DECLARE_VECTOR_UNITTEST(TestShuffleCopySimple); -template -void TestHostDeviceIdentical(size_t m) { +template void TestHostDeviceIdentical(size_t m) { thrust::host_vector host_result(m); thrust::host_vector device_result(m); thrust::sequence(host_result.begin(), host_result.end(), 0llu); @@ -60,10 +352,40 @@ void TestHostDeviceIdentical(size_t m) { } DECLARE_VARIABLE_UNITTEST(TestHostDeviceIdentical); +template void TestFunctionIsBijection(size_t m) { + thrust::default_random_engine host_g(0xD5); + thrust::default_random_engine device_g(0xD5); + + thrust::system::detail::generic::rc5_bijection<> host_f(m, host_g); + thrust::system::detail::generic::rc5_bijection<> device_f(m, device_g); + + if (host_f.bijection_width() >= std::numeric_limits::max() || m == 0) { + return; + } + + thrust::host_vector host_result(host_f.bijection_width()); + thrust::host_vector device_result(device_f.bijection_width()); + thrust::sequence(host_result.begin(), host_result.end(), 0llu); + thrust::sequence(device_result.begin(), device_result.end(), 0llu); + + thrust::transform(host_result.begin(), host_result.end(), host_result.begin(), + host_f); + thrust::transform(device_result.begin(), device_result.end(), + device_result.begin(), device_f); + + ASSERT_EQUAL(host_result, device_result); + + thrust::sort(host_result.begin(), host_result.end()); + // Assert all values were generated exactly once + for (uint64_t i = 0; i < m; i++) { + ASSERT_EQUAL((uint64_t)host_result[i], i); + } +} +DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection); + // Individual input keys should be permuted to output locations with uniform // probability. Perform chi-squared test with confidence 99.9%. -template -void TestShuffleKeyPosition() { +template void TestShuffleKeyPosition() { typedef typename Vector::value_type T; size_t m = 20; size_t num_samples = 100; @@ -71,9 +393,9 @@ void TestShuffleKeyPosition() { thrust::host_vector sequence(m); thrust::sequence(sequence.begin(), sequence.end(), T(0)); + thrust::default_random_engine g(0xD5); for (size_t i = 0; i < num_samples; i++) { Vector shuffled(sequence.begin(), sequence.end()); - thrust::default_random_engine g(i); thrust::shuffle(shuffled.begin(), shuffled.end(), g); thrust::host_vector tmp(shuffled.begin(), shuffled.end()); @@ -81,6 +403,7 @@ void TestShuffleKeyPosition() { index_sum[tmp[j]] += j; } } + double expected_average_position = static_cast(m - 1) / 2; double chi_squared = 0.0; for (auto j = 0ull; j < m; j++) { @@ -97,10 +420,12 @@ DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleKeyPosition); struct vector_compare { template - bool operator()(const VectorT& a, const VectorT& b) const { + bool operator()(const VectorT &a, const VectorT &b) const { for (auto i = 0ull; i < a.size(); i++) { - if (a[i] < b[i]) return true; - if (a[i] > b[i]) return false; + if (a[i] < b[i]) + return true; + if (a[i] > b[i]) + return false; } return false; } @@ -109,8 +434,7 @@ struct vector_compare { // Brute force check permutations are uniformly distributed on small input // Uses a chi-squared test indicating 99% confidence the output is uniformly // random -template -void TestShuffleUniformPermutation() { +template void TestShuffleUniformPermutation() { typedef typename Vector::value_type T; size_t m = 5; @@ -119,7 +443,7 @@ void TestShuffleUniformPermutation() { std::map, size_t, vector_compare> permutation_counts; Vector sequence(m); thrust::sequence(sequence.begin(), sequence.end(), T(0)); - thrust::default_random_engine g(17); + thrust::default_random_engine g(0xD5); for (auto i = 0ull; i < num_samples; i++) { thrust::shuffle(sequence.begin(), sequence.end(), g); thrust::host_vector tmp(sequence.begin(), sequence.end()); @@ -133,10 +457,103 @@ void TestShuffleUniformPermutation() { for (auto kv : permutation_counts) { chi_squared += std::pow(expected_count - kv.second, 2) / expected_count; } - // Tabulated chi-squared critical value for 119 degrees of freedom (5! - 1) - // and 99% confidence - double confidence_threshold = 157.8; - ASSERT_LESS(chi_squared, confidence_threshold); + double p_score = CephesFunctions::cephes_igamc( + (double)(total_permutations - 1) / 2.0, chi_squared / 2.0); + ASSERT_GREATER(p_score, 0.01); } DECLARE_VECTOR_UNITTEST(TestShuffleUniformPermutation); + +template void TestShuffleEvenSpacingBetweenOccurances() { + typedef typename Vector::value_type T; + const uint64_t shuffle_size = 10; + const uint64_t num_samples = 1000; + + thrust::host_vector h_results; + Vector sequence(shuffle_size); + thrust::sequence(sequence.begin(), sequence.end(), 0); + thrust::default_random_engine g(0xD5); + for (auto i = 0ull; i < num_samples; i++) { + thrust::shuffle(sequence.begin(), sequence.end(), g); + thrust::host_vector tmp(sequence.begin(), sequence.end()); + h_results.insert(h_results.end(), sequence.begin(), sequence.end()); + } + + std::vector>> distance_between( + num_samples, std::vector>( + num_samples, std::vector(shuffle_size, 0))); + + for (uint64_t sample = 0; sample < num_samples; sample++) { + for (uint64_t i = 0; i < shuffle_size - 1; i++) { + for (uint64_t j = 1; j < shuffle_size - i; j++) { + T val_1 = h_results[sample * shuffle_size + i]; + T val_2 = h_results[sample * shuffle_size + i + j]; + distance_between[val_1][val_2][j]++; + distance_between[val_2][val_1][shuffle_size - j]++; + } + } + } + + const double expected_occurances = (double)num_samples / (shuffle_size - 1); + for (uint64_t val_1 = 0; val_1 < shuffle_size; val_1++) { + for (uint64_t val_2 = val_1 + 1; val_2 < shuffle_size; val_2++) { + double chi_squared = 0.0; + auto &distances = distance_between[val_1][val_2]; + for (uint64_t i = 1; i < shuffle_size; i++) { + chi_squared += std::pow((double)distances[i] - expected_occurances, 2) / + expected_occurances; + } + + double p_score = CephesFunctions::cephes_igamc( + (double)(shuffle_size - 2) / 2.0, chi_squared / 2.0); + ASSERT_GREATER(p_score, 0.01); + } + } +} +DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenSpacingBetweenOccurances); + +template void TestShuffleEvenDistribution() { + typedef typename Vector::value_type T; + const uint64_t shuffle_sizes[] = {10, 100, 500}; + thrust::default_random_engine g(0xD5); + for (auto shuffle_size : shuffle_sizes) { + if(shuffle_size > std::numeric_limits::max()) + continue; + const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200; + + std::vector counts(shuffle_size * shuffle_size, 0); + Vector sequence(shuffle_size); + for (auto i = 0ull; i < num_samples; i++) { + thrust::sequence(sequence.begin(), sequence.end(), 0); + thrust::shuffle(sequence.begin(), sequence.end(), g); + thrust::host_vector tmp(sequence.begin(), sequence.end()); + for (uint64_t j = 0; j < shuffle_size; j++) { + assert(j < tmp.size()); + counts.at(j * shuffle_size + tmp[j])++; + } + } + + const double expected_occurances = (double)num_samples / shuffle_size; + for (uint64_t i = 0; i < shuffle_size; i++) { + double chi_squared_pos = 0.0; + double chi_squared_num = 0.0; + for (uint64_t j = 0; j < shuffle_size; j++) { + auto count_pos = counts.at(i * shuffle_size + j); + auto count_num = counts.at(j * shuffle_size + i); + chi_squared_pos += + pow((double)count_pos - expected_occurances, 2) / expected_occurances; + chi_squared_num += + pow((double)count_num - expected_occurances, 2) / expected_occurances; + } + + double p_score_pos = CephesFunctions::cephes_igamc( + (double)(shuffle_size - 1) / 2.0, chi_squared_pos / 2.0); + ASSERT_GREATER(p_score_pos, 0.001 / (double)shuffle_size); + + double p_score_num = CephesFunctions::cephes_igamc( + (double)(shuffle_size - 1) / 2.0, chi_squared_num / 2.0); + ASSERT_GREATER(p_score_num, 0.001 / (double)shuffle_size); + } + } +} +DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution); #endif diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 80b45dc024..d9a7906be8 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -32,94 +32,85 @@ namespace system { namespace detail { namespace generic { -// An implementation of a Feistel cipher for operating on 64 bit keys -class feistel_bijection { - private: +// An implementation of RC5 +template class rc5_bijection { +private: struct round_state { - uint32_t left; - uint32_t right; + uint32_t A; + uint32_t B; }; - public: +public: template - __host__ __device__ feistel_bijection(uint64_t m, URBG&& g) { - uint64_t total_bits = get_cipher_bits(m); - // Half bits rounded down - left_side_bits = total_bits / 2; - left_side_mask = (1ull << left_side_bits) - 1; - // Half the bits rounded up - right_side_bits = total_bits - left_side_bits; - right_side_mask = (1ull << right_side_bits) - 1; - - for (uint64_t i = 0; i < num_rounds; i++) { - key[i] = g(); - } + __host__ __device__ rc5_bijection(uint64_t m, URBG &&g) + : w(get_cipher_bits(m)) + { + init_state(std::forward(g)); } - __host__ __device__ uint64_t nearest_power_of_two() const { - return 1ull << (left_side_bits + right_side_bits); + __host__ __device__ uint64_t bijection_width() const { + return 1ull << (2*w); } - __host__ __device__ uint64_t operator()(const uint64_t val) const { - // Extract the right and left sides of the input - uint32_t left = (uint32_t)(val >> right_side_bits); - uint32_t right = (uint32_t)(val & right_side_mask); - round_state state = {left, right}; - for (uint64_t i = 0; i < num_rounds; i++) { + __host__ __device__ uint64_t operator()(const uint64_t val) const { + if(w == 0) + return val; + round_state state = { (uint32_t)val & get_mask(), (uint32_t)(val >> w) }; + state.A = (state.A + S[0]) & get_mask(); + state.B = (state.B + S[1]) & get_mask(); + for(uint32_t i = 0; i < num_rounds; i++) state = do_round(state, i); - } - - // Check we have the correct number of bits on each side - assert((state.left >> left_side_bits) == 0); - assert((state.right >> right_side_bits) == 0); + uint64_t res = state.B << w | state.A; + return res; + } - // Combine the left and right sides together to get result - return state.left << right_side_bits | state.right; +private: + template + __host__ __device__ void init_state(URBG&& g) + { + thrust::uniform_int_distribution dist(0, get_mask()); + for( uint32_t i = 0; i < state_size; i++ ) + S[i] = dist(g); } - private: // Find the nearest power of two __host__ __device__ uint64_t get_cipher_bits(uint64_t m) { + if(m == 0) + return 0; uint64_t i = 0; + m--; while (m != 0) { i++; - m >>= 1; + m >>= 2u; } return i; } - // Round function, a 'pseudorandom function' whos output is indistinguishable - // from random for each key value input. This is not cryptographically secure - // but sufficient for generating permutations. We hash the value with the - // tau88 engine and combine it with the random bits of the key (provided by - // the user-defined engine). - __host__ __device__ uint32_t round_function(uint64_t value, - const uint64_t key) const { - uint64_t value_hash = thrust::random::taus88(value)(); - return (value_hash ^ key) & left_side_mask; + __host__ __device__ uint32_t get_mask() const + { + return (uint32_t)((1ull << (uint64_t)w) - 1ull); + } + + __host__ __device__ uint32_t rotl( uint32_t val, uint32_t amount ) const + { + const uint32_t amount_mod = amount % w; + return val << amount_mod | val >> (w-amount_mod); } __host__ __device__ round_state do_round(const round_state state, const uint64_t round) const { - const uint32_t new_left = state.right & left_side_mask; - const uint32_t round_function_res = - state.left ^ round_function(state.right, key[round]); - if (right_side_bits != left_side_bits) { - // Upper bit of the old right becomes lower bit of new right if we have - // odd length feistel - const uint32_t new_right = - (round_function_res << 1ull) | state.right >> left_side_bits; - return {new_left, new_right}; - } - return {new_left, round_function_res}; + uint32_t A = state.A; + uint32_t B = state.B; + + A = ( rotl( A ^ B, B ) + S[ 2 * round + 2 ] ) & get_mask(); + B = ( rotl( A ^ B, A ) + S[ 2 * round + 3 ] ) & get_mask(); + + return { A, B }; } - static const uint64_t num_rounds = 8; - uint64_t right_side_bits; - uint64_t left_side_bits; - uint64_t right_side_mask; - uint64_t left_side_mask; - uint64_t key[num_rounds]; + static constexpr uint64_t state_size = 2 * num_rounds + 3; + const uint32_t w = 0; + uint32_t S[state_size]; }; struct key_flag_tuple { @@ -129,17 +120,16 @@ struct key_flag_tuple { // scan only flags struct key_flag_scan_op { - __host__ __device__ key_flag_tuple operator()(const key_flag_tuple& a, - const key_flag_tuple& b) { + __host__ __device__ key_flag_tuple operator()(const key_flag_tuple &a, + const key_flag_tuple &b) { return {b.key, a.flag + b.flag}; } }; -struct construct_key_flag_op { +template struct construct_key_flag_op { uint64_t m; - feistel_bijection bijection; - __host__ __device__ construct_key_flag_op(uint64_t m, - feistel_bijection bijection) + bijection_op bijection; + __host__ __device__ construct_key_flag_op(uint64_t m, bijection_op bijection) : m(m), bijection(bijection) {} __host__ __device__ key_flag_tuple operator()(uint64_t idx) { auto gather_key = bijection(idx); @@ -147,27 +137,26 @@ struct construct_key_flag_op { } }; -template -struct write_output_op { +template struct write_output_op { uint64_t m; InputIterT in; OutputIterT out; // flag contains inclusive scan of valid keys // perform gather using valid keys - __thrust_exec_check_disable__ - __host__ __device__ size_t operator()(key_flag_tuple x) { + __thrust_exec_check_disable__ __host__ __device__ size_t + operator()(key_flag_tuple x) { if (x.key < m) { // -1 because inclusive scan out[x.flag - 1] = in[x.key]; } - return 0; // Discarded + return 0; // Discarded } }; template -__host__ __device__ void shuffle( - thrust::execution_policy& exec, RandomIterator first, - RandomIterator last, URBG&& g) { +__host__ __device__ void +shuffle(thrust::execution_policy &exec, RandomIterator first, + RandomIterator last, URBG &&g) { typedef typename thrust::iterator_traits::value_type InputType; @@ -179,21 +168,23 @@ __host__ __device__ void shuffle( template -__host__ __device__ void shuffle_copy( - thrust::execution_policy& exec, RandomIterator first, - RandomIterator last, OutputIterator result, URBG&& g) { +__host__ __device__ void +shuffle_copy(thrust::execution_policy &exec, + RandomIterator first, RandomIterator last, OutputIterator result, + URBG &&g) { // m is the length of the input // we have an available bijection of length n via a feistel cipher size_t m = last - first; - feistel_bijection bijection(m, g); - uint64_t n = bijection.nearest_power_of_two(); + using bijection_op = rc5_bijection<>; + bijection_op bijection(m, g); + uint64_t n = bijection.bijection_width(); // perform stream compaction over length n bijection to get length m // pseudorandom bijection over the original input thrust::counting_iterator indices(0); - thrust::transform_iterator - key_flag_it(indices, construct_key_flag_op(m, bijection)); + thrust::transform_iterator, + decltype(indices), key_flag_tuple> + key_flag_it(indices, construct_key_flag_op(m, bijection)); write_output_op write_functor{m, first, result}; auto gather_output_it = thrust::make_transform_output_iterator( @@ -206,8 +197,8 @@ __host__ __device__ void shuffle_copy( key_flag_scan_op()); } -} // end namespace generic -} // end namespace detail -} // end namespace system -} // end namespace thrust +} // end namespace generic +} // end namespace detail +} // end namespace system +} // end namespace thrust #endif