From b4ae3ec62731c50a7bbd0b13af6dc6cafd4c9abe Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 10 Nov 2021 22:07:31 +1300 Subject: [PATCH 1/4] Updated thrust shuffle to use improved bijective function Updates the thrust shuffle to use the Variable Philox bijective function with 24 rounds. Updates the test suite to include new test statistic based on maximum mean discrepency to enable more thorough testing of larger permutations. --- testing/shuffle.cu | 2 +- testing/shuffle_mmd.cu | 250 +++++++++++++++++++++++ thrust/system/detail/generic/shuffle.inl | 73 +++---- 3 files changed, 274 insertions(+), 51 deletions(-) create mode 100644 testing/shuffle_mmd.cu diff --git a/testing/shuffle.cu b/testing/shuffle.cu index a5b1c6f29..345cc22ca 100644 --- a/testing/shuffle.cu +++ b/testing/shuffle.cu @@ -515,7 +515,7 @@ void TestShuffleEvenSpacingBetweenOccurances() { thrust::host_vector h_results; Vector sequence(shuffle_size); thrust::sequence(sequence.begin(), sequence.end(), 0); - thrust::default_random_engine g(0xD5); + thrust::default_random_engine g(0xD6); for (auto i = 0ull; i < num_samples; i++) { thrust::shuffle(sequence.begin(), sequence.end(), g); thrust::host_vector tmp(sequence.begin(), sequence.end()); diff --git a/testing/shuffle_mmd.cu b/testing/shuffle_mmd.cu new file mode 100644 index 000000000..74a773269 --- /dev/null +++ b/testing/shuffle_mmd.cu @@ -0,0 +1,250 @@ +#include + +#if THRUST_CPP_DIALECT >= 2011 +#include +#include +#include +#include +#include +#include +#include + +// Inverse error function +// https://github.com/lakshayg/erfinv +/* +MIT License +Copyright (c) 2017-2019 Lakshay Garg +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +long double erfinv( long double x ) +{ + + if( x < -1 || x > 1 ) + { + return NAN; + } + else if( x == 1.0 ) + { + return INFINITY; + } + else if( x == -1.0 ) + { + return -INFINITY; + } + + const long double LN2 = 6.931471805599453094172321214581e-1L; + + const long double A0 = 1.1975323115670912564578e0L; + const long double A1 = 4.7072688112383978012285e1L; + const long double A2 = 6.9706266534389598238465e2L; + const long double A3 = 4.8548868893843886794648e3L; + const long double A4 = 1.6235862515167575384252e4L; + const long double A5 = 2.3782041382114385731252e4L; + const long double A6 = 1.1819493347062294404278e4L; + const long double A7 = 8.8709406962545514830200e2L; + + const long double B0 = 1.0000000000000000000e0L; + const long double B1 = 4.2313330701600911252e1L; + const long double B2 = 6.8718700749205790830e2L; + const long double B3 = 5.3941960214247511077e3L; + const long double B4 = 2.1213794301586595867e4L; + const long double B5 = 3.9307895800092710610e4L; + const long double B6 = 2.8729085735721942674e4L; + const long double B7 = 5.2264952788528545610e3L; + + const long double C0 = 1.42343711074968357734e0L; + const long double C1 = 4.63033784615654529590e0L; + const long double C2 = 5.76949722146069140550e0L; + const long double C3 = 3.64784832476320460504e0L; + const long double C4 = 1.27045825245236838258e0L; + const long double C5 = 2.41780725177450611770e-1L; + const long double C6 = 2.27238449892691845833e-2L; + const long double C7 = 7.74545014278341407640e-4L; + + const long double D0 = 1.4142135623730950488016887e0L; + const long double D1 = 2.9036514445419946173133295e0L; + const long double D2 = 2.3707661626024532365971225e0L; + const long double D3 = 9.7547832001787427186894837e-1L; + const long double D4 = 2.0945065210512749128288442e-1L; + const long double D5 = 2.1494160384252876777097297e-2L; + const long double D6 = 7.7441459065157709165577218e-4L; + const long double D7 = 1.4859850019840355905497876e-9L; + + const long double E0 = 6.65790464350110377720e0L; + const long double E1 = 5.46378491116411436990e0L; + const long double E2 = 1.78482653991729133580e0L; + const long double E3 = 2.96560571828504891230e-1L; + const long double E4 = 2.65321895265761230930e-2L; + const long double E5 = 1.24266094738807843860e-3L; + const long double E6 = 2.71155556874348757815e-5L; + const long double E7 = 2.01033439929228813265e-7L; + + const long double F0 = 1.414213562373095048801689e0L; + const long double F1 = 8.482908416595164588112026e-1L; + const long double F2 = 1.936480946950659106176712e-1L; + const long double F3 = 2.103693768272068968719679e-2L; + const long double F4 = 1.112800997078859844711555e-3L; + const long double F5 = 2.611088405080593625138020e-5L; + const long double F6 = 2.010321207683943062279931e-7L; + const long double F7 = 2.891024605872965461538222e-15L; + + long double abs_x = fabsl( x ); + + if( abs_x <= 0.85L ) + { + long double r = 0.180625L - 0.25L * x * x; + long double num = + ( ( ( ( ( ( ( A7 * r + A6 ) * r + A5 ) * r + A4 ) * r + A3 ) * r + A2 ) * r + A1 ) * r + A0 ); + long double den = + ( ( ( ( ( ( ( B7 * r + B6 ) * r + B5 ) * r + B4 ) * r + B3 ) * r + B2 ) * r + B1 ) * r + B0 ); + return x * num / den; + } + + long double r = sqrtl( LN2 - logl( 1.0L - abs_x ) ); + + long double num, den; + if( r <= 5.0L ) + { + r = r - 1.6L; + num = ( ( ( ( ( ( ( C7 * r + C6 ) * r + C5 ) * r + C4 ) * r + C3 ) * r + C2 ) * r + C1 ) * r + C0 ); + den = ( ( ( ( ( ( ( D7 * r + D6 ) * r + D5 ) * r + D4 ) * r + D3 ) * r + D2 ) * r + D1 ) * r + D0 ); + } + else + { + r = r - 5.0L; + num = ( ( ( ( ( ( ( E7 * r + E6 ) * r + E5 ) * r + E4 ) * r + E3 ) * r + E2 ) * r + E1 ) * r + E0 ); + den = ( ( ( ( ( ( ( F7 * r + F6 ) * r + F5 ) * r + F4 ) * r + F3 ) * r + F2 ) * r + F1 ) * r + F0 ); + } + + return copysignl( num / den, x ); +} + +long double erfinv_refine( long double x, int nr_iter ) +{ + const long double k = 0.8862269254527580136490837416706L; // 0.5 * sqrt(pi) + long double y = erfinv( x ); + while( nr_iter-- > 0 ) + { + y -= k * ( erfl( y ) - x ) / expl( -y * y ); + } + return y; +} + +#define LSBIT( i ) ( ( i ) & -( i ) ) + +class FenwickTree +{ + std::vector data; + +public: + FenwickTree( size_t n ) : data( n ) + { + } + void Add( size_t i ) + { + for( ; i < data.size(); i += LSBIT( i + 1 ) ) + { + data[i]++; + } + } + int GetCount( size_t i ) + { + int sum = 0; + for( ; i > 0; i -= LSBIT( i ) ) + sum += data[i - 1]; + return sum; + } +}; + +template +size_t ConcordantPairs( const Vector& x ) +{ + size_t count = 0; + FenwickTree tree( x.size() ); + for( auto x_i : x ) + { + count += tree.GetCount( x_i ); + tree.Add( x_i ); + } + return count; +} + +template +double MallowsKernelIdentity( const Vector& x, double lambda ) +{ + auto con = ConcordantPairs( x ); + auto norm = x.size() * ( x.size() - 1 ) / 2; + double y = 1 - ( double( con ) / norm ); + return exp( -lambda * y ); +} + +double MallowsExpectedValue( size_t n, double lambda ) +{ + double norm = n * ( n - 1 ) / 2.0; + double product = 1.0; + for( size_t j = 1; j <= n; j++ ) + { + product *= ( 1.0 - exp( -lambda * j / norm ) ) / ( j * ( 1.0 - exp( -lambda / norm ) ) ); + } + return product; +} + +double HoeffdingAcceptanceThreshold( double alpha, size_t num_samples ) +{ + double w = log( 2 / alpha ) / ( 2 * num_samples ); + return sqrt( w ); +} + +double NormalAcceptanceThreshold( double alpha, size_t num_samples, size_t n, double lambda ) +{ + double var = (MallowsExpectedValue( n, 2 * lambda ) - pow( MallowsExpectedValue( n, lambda ), 2.0 )) / num_samples; + return sqrt( 2 * var ) * erfinv_refine( 1 - alpha, 10 ); +} + +template +void TestShuffleMallows() { + typedef typename Vector::value_type T; + + const uint32_t shuffle_size = std::min((uint32_t)(1u << 13) + 1, (uint32_t)std::numeric_limits::max()); + const uint32_t num_samples = 1000; + const double lambda = 5; + + thrust::default_random_engine g(0xD5); + Vector sequence(shuffle_size); + double mallows_expected = 0; + for( uint32_t i = 0; i < num_samples; i++ ) + { + thrust::sequence(sequence.begin(), sequence.end(), 0); + thrust::shuffle(sequence.begin(), sequence.end(), g); + + thrust::host_vector tmp(sequence.begin(), sequence.end()); + mallows_expected += MallowsKernelIdentity( tmp, lambda ); + } + + mallows_expected /= num_samples; + double mmd = abs( mallows_expected - MallowsExpectedValue( shuffle_size, lambda ) ); + + const double alpha = 0.01; + ASSERT_LESS(mmd, HoeffdingAcceptanceThreshold( alpha, num_samples )); + ASSERT_LESS(mmd, NormalAcceptanceThreshold( alpha, num_samples, shuffle_size, lambda )); + +} +DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleMallows); + + +#endif diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 91b77351d..39556371a 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -48,7 +48,7 @@ class feistel_bijection { right_side_bits = total_bits - left_side_bits; right_side_mask = (1ull << right_side_bits) - 1; - for (std::uint64_t i = 0; i < num_rounds; i++) { + for (std::uint32_t i = 0; i < num_rounds; i++) { key[i] = g(); } } @@ -56,27 +56,33 @@ class feistel_bijection { __host__ __device__ std::uint64_t nearest_power_of_two() const { return 1ull << (left_side_bits + right_side_bits); } - __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { - // Extract the right and left sides of the input - auto left = static_cast(val >> right_side_bits); - auto right = static_cast(val & right_side_mask); - round_state state = {left, right}; - for (std::uint64_t i = 0; i < num_rounds; i++) { - state = do_round(state, i); + __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { + std::uint32_t state[2] = { uint32_t( val >> right_side_bits ), uint32_t( val & right_side_mask ) }; + for( std::uint32_t i = 0; i < num_rounds; i++ ) + { + std::uint32_t hi, lo; + constexpr std::uint64_t M0 = UINT64_C( 0xD2B74407B1CE6E93 ); + mulhilo( M0, state[0], hi, lo ); + lo = ( lo << ( right_side_bits - left_side_bits ) ) | state[1] >> left_side_bits; + state[0] = ( ( hi ^ key[i] ) ^ state[1] ) & left_side_mask; + state[1] = lo & right_side_mask; } - - // Check we have the correct number of bits on each side - assert((state.left >> left_side_bits) == 0); - assert((state.right >> right_side_bits) == 0); - // Combine the left and right sides together to get result - return state.left << right_side_bits | state.right; + return (std::uint64_t)state[0] << right_side_bits | (std::uint64_t)state[1]; } private: + // Perform 64 bit multiplication and save result in two 32 bit int + constexpr static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo ) + { + std::uint64_t product = a * b; + hi = std::uint32_t( product >> 32 ); + lo = std::uint32_t( product ); + } + // Find the nearest power of two - __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { + constexpr static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { if (m == 0) return 0; std::uint64_t i = 0; m--; @@ -87,45 +93,12 @@ class feistel_bijection { return i; } - // Equivalent to boost::hash_combine - __host__ __device__ - std::size_t hash_combine(std::uint64_t lhs, std::uint64_t rhs) const { - lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); - return lhs; - } - - // Round function, a 'pseudorandom function' who's output is indistinguishable - // from random for each key value input. This is not cryptographically secure - // but sufficient for generating permutations. - __host__ __device__ std::uint32_t round_function(std::uint64_t value, - const std::uint64_t key_) const { - std::uint64_t hash0 = thrust::random::taus88(static_cast(value))(); - std::uint64_t hash1 = thrust::random::ranlux48(value)(); - return static_cast( - hash_combine(hash_combine(hash0, key_), hash1) & left_side_mask); - } - - __host__ __device__ round_state do_round(const round_state state, - const std::uint64_t round) const { - const std::uint32_t new_left = state.right & left_side_mask; - const std::uint32_t round_function_res = - state.left ^ round_function(state.right, key[round]); - if (right_side_bits != left_side_bits) { - // Upper bit of the old right becomes lower bit of new right if we have - // odd length feistel - const std::uint32_t new_right = - (round_function_res << 1ull) | state.right >> left_side_bits; - return {new_left, new_right}; - } - return {new_left, round_function_res}; - } - - static constexpr std::uint64_t num_rounds = 16; + static constexpr std::uint32_t num_rounds = 24; std::uint64_t right_side_bits; std::uint64_t left_side_bits; std::uint64_t right_side_mask; std::uint64_t left_side_mask; - std::uint64_t key[num_rounds]; + std::uint32_t key[num_rounds]; }; struct key_flag_tuple { From ca86e2ebc4d7bfe5f4f2707ab478a9db8c2bfc21 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Sun, 14 Nov 2021 15:52:22 +1300 Subject: [PATCH 2/4] Addressed feedback on review for improved shuffle --- internal/benchmark/bench.cu | 2 - testing/shuffle.cu | 2 - testing/shuffle_mmd.cu | 250 ----------------------- thrust/system/detail/generic/shuffle.inl | 2 +- 4 files changed, 1 insertion(+), 255 deletions(-) delete mode 100644 testing/shuffle_mmd.cu diff --git a/internal/benchmark/bench.cu b/internal/benchmark/bench.cu index e73a0d5bd..38d1d647a 100644 --- a/internal/benchmark/bench.cu +++ b/internal/benchmark/bench.cu @@ -992,7 +992,6 @@ void run_core_primitives_experiments_for_type() , RegularTrials >::run_experiment(); -#if THRUST_CPP_DIALECT >= 2011 experiment_driver< shuffle_tester , ElementMetaType @@ -1000,7 +999,6 @@ void run_core_primitives_experiments_for_type() , BaselineTrials , RegularTrials >::run_experiment(); -#endif } /////////////////////////////////////////////////////////////////////////////// diff --git a/testing/shuffle.cu b/testing/shuffle.cu index 345cc22ca..5d2997319 100644 --- a/testing/shuffle.cu +++ b/testing/shuffle.cu @@ -1,6 +1,5 @@ #include -#if THRUST_CPP_DIALECT >= 2011 #include #include #include @@ -601,4 +600,3 @@ void TestShuffleEvenDistribution() { } } DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleEvenDistribution); -#endif diff --git a/testing/shuffle_mmd.cu b/testing/shuffle_mmd.cu deleted file mode 100644 index 74a773269..000000000 --- a/testing/shuffle_mmd.cu +++ /dev/null @@ -1,250 +0,0 @@ -#include - -#if THRUST_CPP_DIALECT >= 2011 -#include -#include -#include -#include -#include -#include -#include - -// Inverse error function -// https://github.com/lakshayg/erfinv -/* -MIT License -Copyright (c) 2017-2019 Lakshay Garg -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ - -long double erfinv( long double x ) -{ - - if( x < -1 || x > 1 ) - { - return NAN; - } - else if( x == 1.0 ) - { - return INFINITY; - } - else if( x == -1.0 ) - { - return -INFINITY; - } - - const long double LN2 = 6.931471805599453094172321214581e-1L; - - const long double A0 = 1.1975323115670912564578e0L; - const long double A1 = 4.7072688112383978012285e1L; - const long double A2 = 6.9706266534389598238465e2L; - const long double A3 = 4.8548868893843886794648e3L; - const long double A4 = 1.6235862515167575384252e4L; - const long double A5 = 2.3782041382114385731252e4L; - const long double A6 = 1.1819493347062294404278e4L; - const long double A7 = 8.8709406962545514830200e2L; - - const long double B0 = 1.0000000000000000000e0L; - const long double B1 = 4.2313330701600911252e1L; - const long double B2 = 6.8718700749205790830e2L; - const long double B3 = 5.3941960214247511077e3L; - const long double B4 = 2.1213794301586595867e4L; - const long double B5 = 3.9307895800092710610e4L; - const long double B6 = 2.8729085735721942674e4L; - const long double B7 = 5.2264952788528545610e3L; - - const long double C0 = 1.42343711074968357734e0L; - const long double C1 = 4.63033784615654529590e0L; - const long double C2 = 5.76949722146069140550e0L; - const long double C3 = 3.64784832476320460504e0L; - const long double C4 = 1.27045825245236838258e0L; - const long double C5 = 2.41780725177450611770e-1L; - const long double C6 = 2.27238449892691845833e-2L; - const long double C7 = 7.74545014278341407640e-4L; - - const long double D0 = 1.4142135623730950488016887e0L; - const long double D1 = 2.9036514445419946173133295e0L; - const long double D2 = 2.3707661626024532365971225e0L; - const long double D3 = 9.7547832001787427186894837e-1L; - const long double D4 = 2.0945065210512749128288442e-1L; - const long double D5 = 2.1494160384252876777097297e-2L; - const long double D6 = 7.7441459065157709165577218e-4L; - const long double D7 = 1.4859850019840355905497876e-9L; - - const long double E0 = 6.65790464350110377720e0L; - const long double E1 = 5.46378491116411436990e0L; - const long double E2 = 1.78482653991729133580e0L; - const long double E3 = 2.96560571828504891230e-1L; - const long double E4 = 2.65321895265761230930e-2L; - const long double E5 = 1.24266094738807843860e-3L; - const long double E6 = 2.71155556874348757815e-5L; - const long double E7 = 2.01033439929228813265e-7L; - - const long double F0 = 1.414213562373095048801689e0L; - const long double F1 = 8.482908416595164588112026e-1L; - const long double F2 = 1.936480946950659106176712e-1L; - const long double F3 = 2.103693768272068968719679e-2L; - const long double F4 = 1.112800997078859844711555e-3L; - const long double F5 = 2.611088405080593625138020e-5L; - const long double F6 = 2.010321207683943062279931e-7L; - const long double F7 = 2.891024605872965461538222e-15L; - - long double abs_x = fabsl( x ); - - if( abs_x <= 0.85L ) - { - long double r = 0.180625L - 0.25L * x * x; - long double num = - ( ( ( ( ( ( ( A7 * r + A6 ) * r + A5 ) * r + A4 ) * r + A3 ) * r + A2 ) * r + A1 ) * r + A0 ); - long double den = - ( ( ( ( ( ( ( B7 * r + B6 ) * r + B5 ) * r + B4 ) * r + B3 ) * r + B2 ) * r + B1 ) * r + B0 ); - return x * num / den; - } - - long double r = sqrtl( LN2 - logl( 1.0L - abs_x ) ); - - long double num, den; - if( r <= 5.0L ) - { - r = r - 1.6L; - num = ( ( ( ( ( ( ( C7 * r + C6 ) * r + C5 ) * r + C4 ) * r + C3 ) * r + C2 ) * r + C1 ) * r + C0 ); - den = ( ( ( ( ( ( ( D7 * r + D6 ) * r + D5 ) * r + D4 ) * r + D3 ) * r + D2 ) * r + D1 ) * r + D0 ); - } - else - { - r = r - 5.0L; - num = ( ( ( ( ( ( ( E7 * r + E6 ) * r + E5 ) * r + E4 ) * r + E3 ) * r + E2 ) * r + E1 ) * r + E0 ); - den = ( ( ( ( ( ( ( F7 * r + F6 ) * r + F5 ) * r + F4 ) * r + F3 ) * r + F2 ) * r + F1 ) * r + F0 ); - } - - return copysignl( num / den, x ); -} - -long double erfinv_refine( long double x, int nr_iter ) -{ - const long double k = 0.8862269254527580136490837416706L; // 0.5 * sqrt(pi) - long double y = erfinv( x ); - while( nr_iter-- > 0 ) - { - y -= k * ( erfl( y ) - x ) / expl( -y * y ); - } - return y; -} - -#define LSBIT( i ) ( ( i ) & -( i ) ) - -class FenwickTree -{ - std::vector data; - -public: - FenwickTree( size_t n ) : data( n ) - { - } - void Add( size_t i ) - { - for( ; i < data.size(); i += LSBIT( i + 1 ) ) - { - data[i]++; - } - } - int GetCount( size_t i ) - { - int sum = 0; - for( ; i > 0; i -= LSBIT( i ) ) - sum += data[i - 1]; - return sum; - } -}; - -template -size_t ConcordantPairs( const Vector& x ) -{ - size_t count = 0; - FenwickTree tree( x.size() ); - for( auto x_i : x ) - { - count += tree.GetCount( x_i ); - tree.Add( x_i ); - } - return count; -} - -template -double MallowsKernelIdentity( const Vector& x, double lambda ) -{ - auto con = ConcordantPairs( x ); - auto norm = x.size() * ( x.size() - 1 ) / 2; - double y = 1 - ( double( con ) / norm ); - return exp( -lambda * y ); -} - -double MallowsExpectedValue( size_t n, double lambda ) -{ - double norm = n * ( n - 1 ) / 2.0; - double product = 1.0; - for( size_t j = 1; j <= n; j++ ) - { - product *= ( 1.0 - exp( -lambda * j / norm ) ) / ( j * ( 1.0 - exp( -lambda / norm ) ) ); - } - return product; -} - -double HoeffdingAcceptanceThreshold( double alpha, size_t num_samples ) -{ - double w = log( 2 / alpha ) / ( 2 * num_samples ); - return sqrt( w ); -} - -double NormalAcceptanceThreshold( double alpha, size_t num_samples, size_t n, double lambda ) -{ - double var = (MallowsExpectedValue( n, 2 * lambda ) - pow( MallowsExpectedValue( n, lambda ), 2.0 )) / num_samples; - return sqrt( 2 * var ) * erfinv_refine( 1 - alpha, 10 ); -} - -template -void TestShuffleMallows() { - typedef typename Vector::value_type T; - - const uint32_t shuffle_size = std::min((uint32_t)(1u << 13) + 1, (uint32_t)std::numeric_limits::max()); - const uint32_t num_samples = 1000; - const double lambda = 5; - - thrust::default_random_engine g(0xD5); - Vector sequence(shuffle_size); - double mallows_expected = 0; - for( uint32_t i = 0; i < num_samples; i++ ) - { - thrust::sequence(sequence.begin(), sequence.end(), 0); - thrust::shuffle(sequence.begin(), sequence.end(), g); - - thrust::host_vector tmp(sequence.begin(), sequence.end()); - mallows_expected += MallowsKernelIdentity( tmp, lambda ); - } - - mallows_expected /= num_samples; - double mmd = abs( mallows_expected - MallowsExpectedValue( shuffle_size, lambda ) ); - - const double alpha = 0.01; - ASSERT_LESS(mmd, HoeffdingAcceptanceThreshold( alpha, num_samples )); - ASSERT_LESS(mmd, NormalAcceptanceThreshold( alpha, num_samples, shuffle_size, lambda )); - -} -DECLARE_INTEGRAL_VECTOR_UNITTEST(TestShuffleMallows); - - -#endif diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 39556371a..03cd18eec 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -83,7 +83,7 @@ class feistel_bijection { // Find the nearest power of two constexpr static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { - if (m == 0) return 0; + if (m <= 16) return 4; std::uint64_t i = 0; m--; while (m != 0) { From 9e25fe97c67a438bc9137c67b5dba763ca760a84 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Tue, 4 Jan 2022 13:28:22 +1300 Subject: [PATCH 3/4] Touch up c-style casts and test bugs --- testing/shuffle.cu | 16 ++++++++-------- thrust/system/detail/generic/shuffle.inl | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/testing/shuffle.cu b/testing/shuffle.cu index 5d2997319..77e660c00 100644 --- a/testing/shuffle.cu +++ b/testing/shuffle.cu @@ -382,7 +382,7 @@ void TestFunctionIsBijection(size_t m) { thrust::system::detail::generic::feistel_bijection host_f(m, host_g); thrust::system::detail::generic::feistel_bijection device_f(m, device_g); - if (host_f.nearest_power_of_two() >= std::numeric_limits::max() || m == 0) { + if (static_cast(host_f.nearest_power_of_two()) >= static_cast(std::numeric_limits::max()) || m == 0) { return; } @@ -409,17 +409,17 @@ DECLARE_VARIABLE_UNITTEST(TestFunctionIsBijection); void TestBijectionLength() { thrust::default_random_engine g(0xD5); - uint64_t m = 3; + uint64_t m = 31; thrust::system::detail::generic::feistel_bijection f(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(4)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32)); - m = 2; + m = 32; f = thrust::system::detail::generic::feistel_bijection(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(2)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(32)); - m = 0; + m = 1; f = thrust::system::detail::generic::feistel_bijection(m, g); - ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(1)); + ASSERT_EQUAL(f.nearest_power_of_two(), uint64_t(16)); } DECLARE_UNITTEST(TestBijectionLength); @@ -560,7 +560,7 @@ void TestShuffleEvenDistribution() { const uint64_t shuffle_sizes[] = {10, 100, 500}; thrust::default_random_engine g(0xD5); for (auto shuffle_size : shuffle_sizes) { - if(shuffle_size > std::numeric_limits::max()) + if(shuffle_size > (uint64_t)std::numeric_limits::max()) continue; const uint64_t num_samples = shuffle_size == 500 ? 1000 : 200; diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 03cd18eec..603b1faf2 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -58,7 +58,7 @@ class feistel_bijection { } __host__ __device__ std::uint64_t operator()(const std::uint64_t val) const { - std::uint32_t state[2] = { uint32_t( val >> right_side_bits ), uint32_t( val & right_side_mask ) }; + std::uint32_t state[2] = { static_cast( val >> right_side_bits ), static_cast( val & right_side_mask ) }; for( std::uint32_t i = 0; i < num_rounds; i++ ) { std::uint32_t hi, lo; @@ -69,7 +69,7 @@ class feistel_bijection { state[1] = lo & right_side_mask; } // Combine the left and right sides together to get result - return (std::uint64_t)state[0] << right_side_bits | (std::uint64_t)state[1]; + return static_cast(state[0] << right_side_bits) | static_cast(state[1]); } private: @@ -77,8 +77,8 @@ class feistel_bijection { constexpr static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo ) { std::uint64_t product = a * b; - hi = std::uint32_t( product >> 32 ); - lo = std::uint32_t( product ); + hi = static_cast( product >> 32 ); + lo = static_cast( product ); } // Find the nearest power of two From 616f17d3871154e6d46f18e033ee84918b064766 Mon Sep 17 00:00:00 2001 From: djns99 <40156487+djns99@users.noreply.github.com> Date: Sun, 23 Jan 2022 20:21:46 +1300 Subject: [PATCH 4/4] Remove constexpr labels --- thrust/system/detail/generic/shuffle.inl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thrust/system/detail/generic/shuffle.inl b/thrust/system/detail/generic/shuffle.inl index 603b1faf2..baece51be 100644 --- a/thrust/system/detail/generic/shuffle.inl +++ b/thrust/system/detail/generic/shuffle.inl @@ -74,7 +74,7 @@ class feistel_bijection { private: // Perform 64 bit multiplication and save result in two 32 bit int - constexpr static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo ) + static __host__ __device__ void mulhilo( std::uint64_t a, std::uint64_t b, std::uint32_t& hi, std::uint32_t& lo ) { std::uint64_t product = a * b; hi = static_cast( product >> 32 ); @@ -82,7 +82,7 @@ class feistel_bijection { } // Find the nearest power of two - constexpr static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { + static __host__ __device__ std::uint64_t get_cipher_bits(std::uint64_t m) { if (m <= 16) return 4; std::uint64_t i = 0; m--;