Skip to content

Commit

Permalink
Fix NaN handling in drop_list_duplicates (#7662)
Browse files Browse the repository at this point in the history
This PR modifies the behavior of `drop_list_duplicates` to satisfy both Apache Spark and Pandas behavior when dealing with `NaN` value in floating-point columns data:
 * In Apache Spark, `NaNs` are treated as different values, thus no `NaN` entry should be removed after calling `drop_list_duplicates`.
 * In Pandas, `NaNs` are considered as the same value, and even `-NaN` is considered as the same as `NaN`. Thus, only one `NaN` entry per list will be kept.

New tests have also been added to verify such desired behavior.

Authors:
  - Nghia Truong (@ttnghia)

Approvers:
  - Jake Hemstad (@jrhemstad)
  - @nvdbaranec
  - Keith Kraus (@kkraus14)

URL: #7662
  • Loading branch information
ttnghia authored Mar 31, 2021
1 parent 4ee52f3 commit bd11dbe
Show file tree
Hide file tree
Showing 9 changed files with 535 additions and 197 deletions.
5 changes: 4 additions & 1 deletion cpp/include/cudf/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,13 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(
* @param null_handling Indicates whether to include/exclude nulls during collection
* @param nulls_equal Flag to specify whether null entries within each list should be considered
* equal
* @param nans_equal Flag to specify whether NaN values in floating point column should be
* considered equal
*/
std::unique_ptr<aggregation> make_collect_set_aggregation(
null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL);
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL);

/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset);
Expand Down
18 changes: 13 additions & 5 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,24 +345,32 @@ struct collect_list_aggregation final : derived_aggregation<nunique_aggregation>
*/
struct collect_set_aggregation final : derived_aggregation<collect_set_aggregation> {
explicit collect_set_aggregation(null_policy null_handling = null_policy::INCLUDE,
null_equality null_equal = null_equality::EQUAL)
: derived_aggregation{COLLECT_SET}, _null_handling{null_handling}, _null_equal(null_equal)
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL)
: derived_aggregation{COLLECT_SET},
_null_handling{null_handling},
_nulls_equal(nulls_equal),
_nans_equal(nans_equal)
{
}
null_policy _null_handling; ///< include or exclude nulls
null_equality _null_equal; ///< whether to consider nulls as equal values
null_equality _nulls_equal; ///< whether to consider nulls as equal values
nan_equality _nans_equal; ///< whether to consider NaNs as equal value (applicable only to
///< floating point types)

protected:
friend class derived_aggregation<collect_set_aggregation>;

bool operator==(collect_set_aggregation const& other) const
{
return _null_handling == other._null_handling && _null_equal == other._null_equal;
return _null_handling == other._null_handling && _nulls_equal == other._nulls_equal &&
_nans_equal == other._nans_equal;
}

size_t hash_impl() const
{
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_null_equal));
return std::hash<int>{}(static_cast<int>(_null_handling) ^ static_cast<int>(_nulls_equal) ^
static_cast<int>(_nans_equal));
}
};

Expand Down
1 change: 1 addition & 0 deletions cpp/include/cudf/lists/detail/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace detail {
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal,
nan_equality nans_equal,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());
} // namespace detail
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/cudf/lists/drop_list_duplicates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace lists {
*
* @param lists_column The input lists_column_view
* @param nulls_equal Flag to specify whether null entries should be considered equal
* @param nans_equal Flag to specify whether NaN entries should be considered as equal value (only
* applicable for floating point data column)
* @param mr Device resource used to allocate memory
*
* @code{.pseudo}
Expand All @@ -56,6 +58,7 @@ namespace lists {
std::unique_ptr<column> drop_list_duplicates(
lists_column_view const& lists_column,
null_equality nulls_equal = null_equality::EQUAL,
nan_equality nans_equal = nan_equality::UNEQUAL,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of group
Expand Down
9 changes: 9 additions & 0 deletions cpp/include/cudf/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,15 @@ enum class nan_policy : bool {
NAN_IS_VALID ///< treat nans as valid elements (non-null)
};

/**
* @brief Enum to consider different elements (of floating point types) holding NaN value as equal
* or unequal
*/
enum class nan_equality /*unspecified*/ {
ALL_EQUAL, ///< All NaNs compare equal, regardless of sign
UNEQUAL ///< All NaNs compare unequal (IEEE754 behavior)
};

/**
* @brief
*/
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/aggregation/aggregation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ std::unique_ptr<aggregation> make_collect_list_aggregation(null_policy null_hand
}
/// Factory to create a COLLECT_SET aggregation
std::unique_ptr<aggregation> make_collect_set_aggregation(null_policy null_handling,
null_equality null_equal)
null_equality nulls_equal,
nan_equality nans_equal)
{
return std::make_unique<detail::collect_set_aggregation>(null_handling, null_equal);
return std::make_unique<detail::collect_set_aggregation>(null_handling, nulls_equal, nans_equal);
}
/// Factory to create a LAG aggregation
std::unique_ptr<aggregation> make_lag_aggregation(size_type offset)
Expand Down
13 changes: 8 additions & 5 deletions cpp/src/groupby/sort/aggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,11 +379,14 @@ void aggregrate_result_functor::operator()<aggregation::COLLECT_SET>(aggregation
auto const collect_result = detail::group_collect(
get_grouped_values(), helper.group_offsets(stream), helper.num_groups(stream), stream, mr);
auto const nulls_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._null_equal;
cache.add_result(col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, stream, mr));
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._nulls_equal;
auto const nans_equal =
static_cast<cudf::detail::collect_set_aggregation const&>(agg)._nans_equal;
cache.add_result(
col_idx,
agg,
lists::detail::drop_list_duplicates(
lists_column_view(collect_result->view()), nulls_equal, nans_equal, stream, mr));
};
} // namespace detail

Expand Down
Loading

0 comments on commit bd11dbe

Please sign in to comment.