diff --git a/cpp/include/cudf/ast/detail/expression_parser.hpp b/cpp/include/cudf/ast/detail/expression_parser.hpp index db0abe435b0..a36a831a7aa 100644 --- a/cpp/include/cudf/ast/detail/expression_parser.hpp +++ b/cpp/include/cudf/ast/detail/expression_parser.hpp @@ -67,8 +67,8 @@ struct alignas(8) device_data_reference { bool operator==(device_data_reference const& rhs) const { - return std::tie(data_index, reference_type, table_source) == - std::tie(rhs.data_index, rhs.reference_type, rhs.table_source); + return std::tie(data_index, data_type, reference_type, table_source) == + std::tie(rhs.data_index, rhs.data_type, rhs.reference_type, rhs.table_source); } }; diff --git a/cpp/tests/ast/transform_tests.cpp b/cpp/tests/ast/transform_tests.cpp index c0109a40cec..624a781c5b9 100644 --- a/cpp/tests/ast/transform_tests.cpp +++ b/cpp/tests/ast/transform_tests.cpp @@ -316,6 +316,33 @@ TEST_F(TransformTest, ImbalancedTreeArithmetic) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); } +TEST_F(TransformTest, ImbalancedTreeArithmeticDeep) +{ + auto c_0 = column_wrapper{4, 5, 6}; + auto table = cudf::table_view{{c_0}}; + + auto col_ref_0 = cudf::ast::column_reference(0); + + // expression: (c0 < c0) == (c0 < (c0 + c0)) + // {false, false, false} == (c0 < {8, 10, 12}) + // {false, false, false} == {true, true, true} + // {false, false, false} + auto expression_left_subtree = + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, col_ref_0); + auto expression_right_inner_subtree = + cudf::ast::operation(cudf::ast::ast_operator::ADD, col_ref_0, col_ref_0); + auto expression_right_subtree = + cudf::ast::operation(cudf::ast::ast_operator::LESS, col_ref_0, expression_right_inner_subtree); + + auto expression_tree = cudf::ast::operation( + cudf::ast::ast_operator::EQUAL, expression_left_subtree, expression_right_subtree); + + auto result = cudf::compute_column(table, expression_tree); + auto expected = column_wrapper{false, false, false}; + + CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, result->view(), verbosity); +} + TEST_F(TransformTest, MultiLevelTreeComparator) { auto c_0 = column_wrapper{3, 20, 1, 50};