Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN handling in drop_list_duplicates #7662

Merged
merged 38 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
81e0d79
Add tests for drop_list_duplicates in case of input containing floati…
ttnghia Mar 17, 2021
e431549
Add negative NaN into the tests
ttnghia Mar 19, 2021
84b06a3
Rewrite tests: split tests into smaller tests with some improvements
ttnghia Mar 19, 2021
34305f2
Some improvement to floating point tests with NaNs
ttnghia Mar 19, 2021
fa46446
Add customized comparators for drop_list_duplicates, still need to up…
ttnghia Mar 20, 2021
452835b
Some cleanup
ttnghia Mar 20, 2021
42535a0
Rewrite doc for element_comparator and element_comparator_fn
ttnghia Mar 20, 2021
2b5b8e4
Using type_dispatcher only for host code
ttnghia Mar 22, 2021
dfa1c8a
Fix memory access violation bug when using reference to column_device…
ttnghia Mar 22, 2021
f2c4d5d
Merge remote-tracking branch 'origin/branch-0.19' into fix_nan_drop_l…
ttnghia Mar 22, 2021
9d07cc7
Add test case when the list contains both -0.0 and 0.0
ttnghia Mar 22, 2021
ef6d7e2
Rename constants
ttnghia Mar 23, 2021
725155b
Change has_null from template parameter to runtime parameter
ttnghia Mar 23, 2021
97815de
Merge branch 'branch-0.19' into fix_nan_drop_list_duplicates
ttnghia Mar 23, 2021
6f76f8e
Remove redundant qualifiers from class constructor
ttnghia Mar 23, 2021
c66b88f
AddAdd `nan_equality` enum to specify whether NaN elements should be …
ttnghia Mar 25, 2021
806b900
Rewrite `drop_list_duplicate`, adding `nans_equal` parameter, allowin…
ttnghia Mar 25, 2021
d96bc79
Rewrite tests for `drop_list_duplicates`
ttnghia Mar 25, 2021
0246c78
Rewrite `collect_set_aggregation`, adding `nans_equal` parameter
ttnghia Mar 25, 2021
2185d8a
Fix typo
ttnghia Mar 25, 2021
919e859
Change `nan_equality` enum names
ttnghia Mar 25, 2021
a002e62
Fix enum in unit tests for `drop_list_duplicates`
ttnghia Mar 25, 2021
26cae09
Add an option to specify NaNs are compared equal only if they have th…
ttnghia Mar 25, 2021
4356bf6
Rework `drop_list_duplicates` for the new `nan_equality` option
ttnghia Mar 25, 2021
e4cfa11
Rewrite unit tests for `drop_list_duplicates` that can test for all c…
ttnghia Mar 25, 2021
98ec9b1
Revert "Rewrite unit tests for `drop_list_duplicates` that can test f…
ttnghia Mar 25, 2021
7a9f850
Revert "Rework `drop_list_duplicates` for the new `nan_equality` option"
ttnghia Mar 25, 2021
04b120d
Revert "Add an option to specify NaNs are compared equal only if they…
ttnghia Mar 25, 2021
0973be9
Fix typo
ttnghia Mar 25, 2021
fa035a6
Avoid initialize-then-assign
ttnghia Mar 30, 2021
8520f0d
Replace `thrust::any_of` by `thrust::count_if`
ttnghia Mar 30, 2021
49bfe13
Copy column by constructor
ttnghia Mar 30, 2021
3d50d8e
Merge remote-tracking branch 'origin/branch-0.19' into fix_nan_drop_l…
ttnghia Mar 30, 2021
27b5beb
Minor cleanup
ttnghia Mar 30, 2021
1406a25
Replace `is_null` by `is_null_nocheck`
ttnghia Mar 30, 2021
32b4393
Replace `make_numeric_column` by `device_uvector`
ttnghia Mar 30, 2021
b5af91e
Rewrite comments, and add a condition check for nans_equal == ALL_EQU…
ttnghia Mar 30, 2021
7812e05
Merge branch 'branch-0.19' into fix_nan_drop_list_duplicates
ttnghia Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,13 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL);
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);
kkraus14 marked this conversation as resolved.
Show resolved Hide resolved

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);
Expand Down
18 changes: 13 additions & 5 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,24 +345,32 @@ struct collect_list_aggregation final : derived_aggregation<nunique_aggregation>
*/
struct collect_set_aggregation final : derived_aggregation<collect_set_aggregation> {
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL)
: derived_aggregation{COLLECT_SET}, _null_handling{null_handling}, _null_equal(null_equal)
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL)
: derived_aggregation{COLLECT_SET},
_null_handling{null_handling},
_nulls_equal(nulls_equal),
_nans_equal(nans_equal)
{
}
null_policy _null_handling; ///< include or exclude nulls
null_equality _null_equal; ///< whether to consider nulls as equal values
null_equality _nulls_equal; ///< whether to consider nulls as equal values
nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to
///< floating point types)

protected:
friend class derived_aggregation<collect_set_aggregation>;

bool operator==(collect_set_aggregation const& other) const
{
return _null_handling == other._null_handling && _null_equal == other._null_equal;
return _null_handling == other._null_handling && _nulls_equal == other._nulls_equal &&
_nans_equal == other._nans_equal;
}

size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_null_equal));
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_nulls_equal) ^
static_cast<int>(_nans_equal));
}
};

Expand Down
1 change: 1 addition & 0 deletions cpp/include/cudf/lists/detail/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace detail {
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
} // namespace detail
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/cudf/lists/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace lists {
*
* @param lists_column The input lists_column_view
* @param nulls_equal Flag to specify whether null entries should be considered equal
* @param nans_equal Flag to specify whether NaN entries should be considered as equal value (only
* applicable for floating point data column)
* @param mr Device resource used to allocate memory
*
* @code{.pseudo}
Expand All @@ -56,6 +58,7 @@ namespace lists {
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/cudf/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ enum class nan_policy : bool {
NAN_IS_VALID ///< treat nans as valid elements (non-null)
};

/**
* @brief Enum to consider different elements (of floating point types) holding NaN value as equal
* or unequal
*/
enum class nan_equality /*unspecified*/ {
isVoid marked this conversation as resolved.
Show resolved Hide resolved
ALL_EQUAL, ///< All NaNs compare equal, regardless of sign
UNEQUAL ///< All NaNs compare unequal (IEE754 behavior)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
};

/**
* @brief
*/
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(null_policy null_hand
}
/// Factory to create a COLLECT_SET aggregation
std::unique_ptr<aggregation> make_collect_set_aggregation(null_policy null_handling,
null_equality null_equal)
null_equality nulls_equal,
nan_equality nans_equal)
{
return std::make_unique<detail::collect_set_aggregation>(null_handling, null_equal);
return std::make_unique<detail::collect_set_aggregation>(null_handling, nulls_equal, nans_equal);
}
/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset)
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,11 +375,14 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(), helper.num_groups(), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, stream, mr));
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._nans_equal;
cache.add_result(
col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};
} // namespace detail

Expand Down
Loading