From 52df19b97aaded26748831eac7e26ec3a0e926e1 Mon Sep 17 00:00:00 2001 From: Xavier Simmons Date: Sun, 3 Apr 2022 20:21:55 -0700 Subject: [PATCH] Optimized left_semi_join Up to 20x faster. Separated hash table lookup from copy_if because increased register usage significantly limited occupancy of this kernel. --- cpp/src/join/semi_join.cu | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/cpp/src/join/semi_join.cu b/cpp/src/join/semi_join.cu index 9e1aa27a4e7..d847a88c450 100644 --- a/cpp/src/join/semi_join.cu +++ b/cpp/src/join/semi_join.cu @@ -137,18 +137,29 @@ std::unique_ptr> left_semi_anti_join( auto gather_map = std::make_unique>(left_num_rows, stream, mr); + rmm::device_uvector flagged(left_num_rows, stream, mr); + auto flagged_d = flagged.data(); + + auto counting_iter = thrust::counting_iterator(0); + thrust::for_each( + rmm::exec_policy(stream), + counting_iter, + counting_iter + left_num_rows, + [flagged_d, hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( + const size_type idx) { + flagged_d[idx] = + hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; + }); + // gather_map_end will be the end of valid data in gather_map auto gather_map_end = thrust::copy_if( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(left_num_rows), - gather_map->begin(), - [hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( - size_type const idx) { - // Look up this row. The hash function used here needs to map a (left) row index to the hash - // of the row, so it's a row hash. The equality check needs to verify - return hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; - }); + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(left_num_rows), + gather_map->begin(), + [flagged_d]__device__(size_type const idx) { + return flagged_d[idx]; + }); auto join_size = thrust::distance(gather_map->begin(), gather_map_end); gather_map->resize(join_size, stream);