diff --git a/cpp/src/join/semi_join.cu b/cpp/src/join/semi_join.cu index 9e1aa27a4e7..007233a0766 100644 --- a/cpp/src/join/semi_join.cu +++ b/cpp/src/join/semi_join.cu @@ -137,19 +137,28 @@ std::unique_ptr> left_semi_anti_join( auto gather_map = std::make_unique>(left_num_rows, stream, mr); - // gather_map_end will be the end of valid data in gather_map - auto gather_map_end = thrust::copy_if( + rmm::device_uvector flagged(left_num_rows, stream, mr); + auto flagged_d = flagged.data(); + + auto counting_iter = thrust::counting_iterator(0); + thrust::for_each( rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(left_num_rows), - gather_map->begin(), - [hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( - size_type const idx) { - // Look up this row. The hash function used here needs to map a (left) row index to the hash - // of the row, so it's a row hash. The equality check needs to verify - return hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; + counting_iter, + counting_iter + left_num_rows, + [flagged_d, hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__( + const size_type idx) { + flagged_d[idx] = + hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean; }); + // gather_map_end will be the end of valid data in gather_map + auto gather_map_end = + thrust::copy_if(rmm::exec_policy(stream), + counting_iter, + counting_iter + left_num_rows, + gather_map->begin(), + [flagged_d] __device__(size_type const idx) { return flagged_d[idx]; }); + auto join_size = thrust::distance(gather_map->begin(), gather_map_end); gather_map->resize(join_size, stream); return gather_map;