Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Adding floating point specialization to comparators for NaNs #3239

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
- PR #3205 Move transform files to legacy
- PR #3202 Rename and move error.hpp to public headers
- PR #2878 Use upstream merge code in dask_cudf
- PR #3239 Adding floating point specialization to comparators for NaNs

## Bug Fixes

Expand Down
90 changes: 71 additions & 19 deletions cpp/include/cudf/table/row_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,63 @@
namespace cudf {
namespace experimental {

/**---------------------------------------------------------------------------*
* @brief Result type of the `element_relational_comparator` function object.
*
* Indicates how two elements `a` and `b` compare with one and another.
*
* Equivalence is defined as `not (a<b) and not (b<a)`. Elements that are are
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
* EQUIVALENT may not necessarily be *equal*.
*
*---------------------------------------------------------------------------**/
enum class weak_ordering {
LESS, ///< Indicates `a` is less than (ordered before) `b`
EQUIVALENT, ///< Indicates `a` is ordered neither before nor after `b`
GREATER ///< Indicates `a` is greater than (ordered after) `b`
};

/**---------------------------------------------------------------------------*
* @brief Evaluates elements `lhs` and `rhs` for nan and expected ordering, and
* this will be available for only floating point `Element` types.
*
* @param[in] lhs first element
* @param[in] rhs second element
* @param[in] expected_ordering expected relation between elements
* @returns bool true if elements are in order as per `expected_ordering` else false
*---------------------------------------------------------------------------**/
template <typename Element,
std::enable_if_t<std::is_floating_point<Element>::value>* = nullptr>
__device__ bool evaluate_nan_ordering(Element const lhs, Element const rhs, weak_ordering expected_ordering) {
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved

bool result = false;

switch(expected_ordering) {
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
case weak_ordering::EQUIVALENT: result = std::isnan(lhs) and std::isnan(rhs);
break;
case weak_ordering::LESS: result = std::isnan(rhs) and not std::isnan(lhs);
harrism marked this conversation as resolved.
Show resolved Hide resolved
break;
case weak_ordering::GREATER: result = std::isnan(lhs) and not std::isnan(rhs);
break;
}

return result;
}

/**---------------------------------------------------------------------------*
* @brief This funtion is to handle non-floating `Element` types and it will
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
* always return false.
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
*
* @param[in] lhs first element
* @param[in] rhs second element
* @param[in] expected_ordering expected relation between elements
* @returns bool always returns `false`
*---------------------------------------------------------------------------**/
template <typename Element,
std::enable_if_t<not std::is_floating_point<Element>::value>* = nullptr>
__device__ bool evaluate_nan_ordering(Element lhs, Element rhs, weak_ordering expected_ordering) {
return false;
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
}

/**---------------------------------------------------------------------------*
* @brief Performs an equality comparison between two elements in two columns.
*
Expand Down Expand Up @@ -71,8 +128,15 @@ class element_equality_comparator {
return false;
}
}
return lhs.element<Element>(lhs_element_index) ==
rhs.element<Element>(rhs_element_index);
// NaNs are equal
rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
Element const lhs_element = lhs.element<Element>(lhs_element_index);
Element const rhs_element = rhs.element<Element>(rhs_element_index);

if(evaluate_nan_ordering(lhs_element, rhs_element, weak_ordering::EQUIVALENT)) {
return true;
}

return lhs_element == rhs_element;
}

private:
Expand Down Expand Up @@ -109,21 +173,6 @@ class row_equality_comparator {
bool nulls_are_equal;
};

/**---------------------------------------------------------------------------*
* @brief Result type of the `element_relational_comparator` function object.
*
* Indicates how two elements `a` and `b` compare with one and another.
*
* Equivalence is defined as `not (a<b) and not (b<a)`. Elements that are are
* EQUIVALENT may not necessarily be *equal*.
*
*---------------------------------------------------------------------------**/
enum class weak_ordering {
LESS, ///< Indicates `a` is less than (ordered before) `b`
EQUIVALENT, ///< Indicates `a` is ordered neither before nor after `b`
GREATER ///< Indicates `a` is greater than (ordered after) `b`
};

/**---------------------------------------------------------------------------*
* @brief Performs a relational comparison between two elements in two columns.
*
Expand Down Expand Up @@ -181,11 +230,14 @@ class element_relational_comparator {
Element const lhs_element = lhs.element<Element>(lhs_element_index);
Element const rhs_element = rhs.element<Element>(rhs_element_index);

rgsl888prabhu marked this conversation as resolved.
Show resolved Hide resolved
if (lhs_element < rhs_element) {
if ((lhs_element < rhs_element) or
evaluate_nan_ordering(lhs_element, rhs_element, weak_ordering::LESS)) {
return weak_ordering::LESS;
} else if (rhs_element < lhs_element) {
} else if ((rhs_element < lhs_element) or
evaluate_nan_ordering(lhs_element, rhs_element, weak_ordering::GREATER)) {
return weak_ordering::GREATER;
}

return weak_ordering::EQUIVALENT;
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ ConfigureTest(LEGACY_TRANSPOSE_TEST "${LEGACY_TRANSPOSE_TEST_SRC}")

set(TABLE_TEST_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/table/table_tests.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/table/table_view_tests.cu")
"${CMAKE_CURRENT_SOURCE_DIR}/table/table_view_tests.cu"
"${CMAKE_CURRENT_SOURCE_DIR}/table/row_operators_tests.cu")

ConfigureTest(TABLE_TEST "${TABLE_TEST_SRC}")
###################################################################################################
Expand Down
59 changes: 59 additions & 0 deletions cpp/tests/table/row_operators_tests.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column_view.hpp>
#include <cudf/table/row_operators.cuh>
#include <cudf/table/table_view.hpp>
#include <cudf/sorting.hpp>
#include <tests/utilities/base_fixture.hpp>
#include <tests/utilities/type_lists.hpp>
#include <tests/utilities/column_wrapper.hpp>
#include <tests/utilities/column_utilities.hpp>

#include <vector>
#include <gmock/gmock.h>

struct RowOperatorTestForNAN : public cudf::test::BaseFixture {};

TEST_F(RowOperatorTestForNAN, NANEquality)
{
cudf::test::fixed_width_column_wrapper<double> col1 {{1, NAN, 3, 4}, {1, 1, 0, 1}};
cudf::test::fixed_width_column_wrapper<double> col2 {{1, NAN, 3, 4}, {1, 1, 0, 1}};

cudf::test::expect_columns_equal(col1, col2);
}


TEST_F(RowOperatorTestForNAN, NANSorting)
{
// NULL Before
cudf::test::fixed_width_column_wrapper<double> input {{0, NAN, -1, 7, std::numeric_limits<double>::infinity(), 1, -1*std::numeric_limits<double>::infinity()}, {1, 1, 1, 0, 1, 1, 1, 1}};
cudf::test::fixed_width_column_wrapper<int32_t> expected1 {{3, 6, 2, 0, 5, 4, 1}};
std::vector<cudf::order> column_order {cudf::order::ASCENDING};
cudf::table_view input_table {{input}};

auto got1 = cudf::experimental::sorted_order(input_table, column_order, cudf::null_order::BEFORE);

cudf::test::expect_columns_equal(expected1, got1->view());

// NULL After

cudf::test::fixed_width_column_wrapper<int32_t> expected2 {{6, 2, 0, 5, 4, 1, 3}};

auto got2 = cudf::experimental::sorted_order(input_table, column_order, cudf::null_order::AFTER);

cudf::test::expect_columns_equal(expected2, got2->view());
}