-
Notifications
You must be signed in to change notification settings - Fork 915
/
contains_table.cu
300 lines (260 loc) · 12 KB
/
contains_table.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <join/join_common_utils.cuh>
#include <cudf/detail/join.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/table/experimental/row_operators.cuh>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <thrust/iterator/counting_iterator.h>
#include <cuco/static_map.cuh>
#include <type_traits>
namespace cudf::detail {
namespace {
using cudf::experimental::row::lhs_index_type;
using cudf::experimental::row::rhs_index_type;
using static_map = cuco::static_map<lhs_index_type,
size_type,
cuda::thread_scope_device,
rmm::mr::stream_allocator_adaptor<default_allocator<char>>>;
/**
* @brief Check if the given type `T` is a strong index type (i.e., `lhs_index_type` or
* `rhs_index_type`).
*
* @return A boolean value indicating if `T` is a strong index type
*/
template <typename T>
constexpr auto is_strong_index_type()
{
return std::is_same_v<T, lhs_index_type> || std::is_same_v<T, rhs_index_type>;
}
/**
* @brief An adapter functor to support strong index types for row hasher that must be operating on
* `cudf::size_type`.
*/
template <typename Hasher>
struct strong_index_hasher_adapter {
strong_index_hasher_adapter(Hasher const& hasher) : _hasher{hasher} {}
template <typename T, CUDF_ENABLE_IF(is_strong_index_type<T>())>
__device__ constexpr auto operator()(T const idx) const noexcept
{
return _hasher(static_cast<size_type>(idx));
}
private:
Hasher const _hasher;
};
/**
* @brief An adapter functor to support strong index type for table row comparator that must be
* operating on `cudf::size_type`.
*/
template <typename Comparator>
struct strong_index_comparator_adapter {
strong_index_comparator_adapter(Comparator const& comparator) : _comparator{comparator} {}
template <typename T,
typename U,
CUDF_ENABLE_IF(is_strong_index_type<T>() && is_strong_index_type<U>())>
__device__ constexpr auto operator()(T const lhs_index, U const rhs_index) const noexcept
{
auto const lhs = static_cast<size_type>(lhs_index);
auto const rhs = static_cast<size_type>(rhs_index);
if constexpr (std::is_same_v<T, U> || std::is_same_v<T, lhs_index_type>) {
return _comparator(lhs, rhs);
} else {
// Here we have T == rhs_index_type.
// This is when the indices are provided in wrong order for two table comparator, so we need
// to switch them back to the right order before calling the underlying comparator.
return _comparator(rhs, lhs);
}
}
private:
Comparator const _comparator;
};
/**
* @brief Build a row bitmask for the input table.
*
* The output bitmask will have invalid bits corresponding to the the input rows having nulls (at
* any nested level) and vice versa.
*
* @param input The input table
* @param stream CUDA stream used for device memory operations and kernel launches
* @return A pair of pointer to the output bitmask and the buffer containing the bitmask
*/
std::pair<rmm::device_buffer, bitmask_type const*> build_row_bitmask(table_view const& input,
rmm::cuda_stream_view stream)
{
auto const nullable_columns = get_nullable_columns(input);
CUDF_EXPECTS(nullable_columns.size() > 0,
"The input table has nulls thus it should have nullable columns.");
// If there are more than one nullable column, we compute `bitmask_and` of their null masks.
// Otherwise, we have only one nullable column and can use its null mask directly.
if (nullable_columns.size() > 1) {
auto row_bitmask =
cudf::detail::bitmask_and(
table_view{nullable_columns}, stream, rmm::mr::get_current_device_resource())
.first;
auto const row_bitmask_ptr = static_cast<bitmask_type const*>(row_bitmask.data());
return std::pair(std::move(row_bitmask), row_bitmask_ptr);
}
return std::pair(rmm::device_buffer{0, stream}, nullable_columns.front().null_mask());
}
/**
* @brief Invoke an `operator()` template with a row equality comparator based on the specified
* `compare_nans` parameter.
*
* @param compare_nans The flag to specify whether NaNs should be compared equal or not
* @param func The input functor to invoke
*/
template <typename Func>
void dispatch_nan_comparator(nan_equality compare_nans, Func&& func)
{
if (compare_nans == nan_equality::ALL_EQUAL) {
using nan_equal_comparator =
cudf::experimental::row::equality::nan_equal_physical_equality_comparator;
func(nan_equal_comparator{});
} else {
using nan_unequal_comparator = cudf::experimental::row::equality::physical_equality_comparator;
func(nan_unequal_comparator{});
}
}
} // namespace
/**
* @brief Check if rows in the given `needles` table exist in the `haystack` table.
*
* @param haystack The table containing the search space
* @param needles A table of rows whose existence to check in the search space
* @param compare_nulls Control whether nulls should be compared as equal or not
* @param compare_nans Control whether floating-point NaNs values should be compared as equal or not
* @param stream CUDA stream used for device memory operations and kernel launches
* @param mr Device memory resource used to allocate the returned vector
* @return A vector of bools indicating if each row in `needles` has matching rows in `haystack`
*/
rmm::device_uvector<bool> contains(table_view const& haystack,
table_view const& needles,
null_equality compare_nulls,
nan_equality compare_nans,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto map = static_map(compute_hash_table_size(haystack.num_rows()),
cuco::empty_key{lhs_index_type{std::numeric_limits<size_type>::max()}},
cuco::empty_value{detail::JoinNoneValue},
detail::hash_table_allocator_type{default_allocator<char>{}, stream},
stream.value());
auto const haystack_has_nulls = has_nested_nulls(haystack);
auto const needles_has_nulls = has_nested_nulls(needles);
auto const has_any_nulls = haystack_has_nulls || needles_has_nulls;
auto const preprocessed_haystack =
cudf::experimental::row::equality::preprocessed_table::create(haystack, stream);
// Insert row indices of the haystack table as map keys.
{
auto const haystack_it = cudf::detail::make_counting_transform_iterator(
size_type{0},
[] __device__(auto const idx) { return cuco::make_pair(lhs_index_type{idx}, 0); });
auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_haystack);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};
auto const comparator =
cudf::experimental::row::equality::self_comparator(preprocessed_haystack);
// If the haystack table has nulls but they are compared unequal, don't insert them.
// Otherwise, it was known to cause performance issue:
// - https://github.com/rapidsai/cudf/pull/6943
// - https://github.com/rapidsai/cudf/pull/8277
if (haystack_has_nulls && compare_nulls == null_equality::UNEQUAL) {
auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream);
auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second;
auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert_if(haystack_it,
haystack_it + haystack.num_rows(),
thrust::counting_iterator<size_type>(0), // stencil
row_is_valid{row_bitmask_ptr},
d_hasher,
d_eqcomp,
stream.value());
}
};
// Insert only rows that do not have any null at any level.
dispatch_nan_comparator(compare_nans, insert_map);
} else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL
auto const insert_map = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack)) {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<true>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
} else {
auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to<false>(
nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)};
map.insert(
haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value());
}
};
dispatch_nan_comparator(compare_nans, insert_map);
}
}
// The output vector.
auto contained = rmm::device_uvector<bool>(needles.num_rows(), stream, mr);
auto const preprocessed_needles =
cudf::experimental::row::equality::preprocessed_table::create(needles, stream);
// Check existence for each row of the needles table in the haystack table.
{
auto const needles_it = cudf::detail::make_counting_transform_iterator(
size_type{0}, [] __device__(auto const idx) { return rhs_index_type{idx}; });
auto const hasher = cudf::experimental::row::hash::row_hasher(preprocessed_needles);
auto const d_hasher =
strong_index_hasher_adapter{hasher.device_hasher(nullate::DYNAMIC{has_any_nulls})};
auto const comparator = cudf::experimental::row::equality::two_table_comparator(
preprocessed_haystack, preprocessed_needles);
auto const check_contains = [&](auto const value_comp) {
if (cudf::detail::has_nested_columns(haystack) or cudf::detail::has_nested_columns(needles)) {
auto const d_eqcomp =
comparator.equal_to<true>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
} else {
auto const d_eqcomp =
comparator.equal_to<false>(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp);
map.contains(needles_it,
needles_it + needles.num_rows(),
contained.begin(),
d_hasher,
d_eqcomp,
stream.value());
}
};
dispatch_nan_comparator(compare_nans, check_contains);
}
return contained;
}
} // namespace cudf::detail